Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"weights" added to solver parameters, "snapshot_prefix" field default initialization #6123

Merged
merged 2 commits into from
Feb 12, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ message NetParameter {
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
// SolverParameter next available ID: 42 (last added: layer_wise_reduce)
// SolverParameter next available ID: 43 (last added: weights)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
Expand Down Expand Up @@ -186,7 +186,11 @@ message SolverParameter {
optional float clip_gradients = 35 [default = -1];

optional int32 snapshot = 14 [default = 0]; // The snapshot interval
optional string snapshot_prefix = 15; // The prefix for the snapshot.
// The prefix for the snapshot.
// If not set then is replaced by prototxt file path without extention.
// If is set to directory then is augmented by prototxt file name
// without extention.
optional string snapshot_prefix = 15;
// whether to snapshot diff in the results or not. Snapshotting diff will help
// debugging but the final protocol buffer size will be much larger.
optional bool snapshot_diff = 16 [default = false];
Expand Down Expand Up @@ -241,6 +245,16 @@ message SolverParameter {

// Overlap compute and communication for data parallel training
optional bool layer_wise_reduce = 41 [default = true];

// Path to caffemodel file(s) with pretrained weights to initialize finetuning.
// Tha same as command line --weights parameter for caffe train command.
// If command line --weights parameter if specified, it has higher priority
// and owerwrites this one(s).
// If --snapshot command line parameter is specified, this one(s) are ignored.
// If several model files are expected, they can be listed in a one
// weights parameter separated by ',' (like in a command string) or
// in repeated weights parameters separately.
repeated string weights = 42;
}

// A message that stores the solver snapshots
Expand Down
21 changes: 21 additions & 0 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <string>
#include <vector>

#include "boost/algorithm/string.hpp"
#include "caffe/solver.hpp"
#include "caffe/util/format.hpp"
#include "caffe/util/hdf5.hpp"
Expand Down Expand Up @@ -59,6 +60,20 @@ void Solver<Dtype>::Init(const SolverParameter& param) {
current_step_ = 0;
}

// Load weights from the caffemodel(s) specified in "weights" solver parameter
// into the train and test nets.
template <typename Dtype>
void LoadNetWeights(shared_ptr<Net<Dtype> > net,
const std::string& model_list) {
std::vector<std::string> model_names;
boost::split(model_names, model_list, boost::is_any_of(","));
for (int i = 0; i < model_names.size(); ++i) {
boost::trim(model_names[i]);
LOG(INFO) << "Finetuning from " << model_names[i];
net->CopyTrainedLayersFrom(model_names[i]);
}
}

template <typename Dtype>
void Solver<Dtype>::InitTrainNet() {
const int num_train_nets = param_.has_net() + param_.has_net_param() +
Expand Down Expand Up @@ -98,6 +113,9 @@ void Solver<Dtype>::InitTrainNet() {
net_state.MergeFrom(param_.train_state());
net_param.mutable_state()->CopyFrom(net_state);
net_.reset(new Net<Dtype>(net_param));
for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) {
LoadNetWeights(net_, param_.weights(w_idx));
}
}

template <typename Dtype>
Expand Down Expand Up @@ -173,6 +191,9 @@ void Solver<Dtype>::InitTestNets() {
<< "Creating test net (#" << i << ") specified by " << sources[i];
test_nets_[i].reset(new Net<Dtype>(net_params[i]));
test_nets_[i]->set_debug_info(param_.debug_info());
for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) {
LoadNetWeights(test_nets_[i], param_.weights(w_idx));
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/caffe/test/test_upgrade_proto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2952,6 +2952,8 @@ TEST_F(SolverTypeUpgradeTest, TestSimple) {
for (int i = 0; i < 6; ++i) {
const string& input_proto =
"net: 'examples/mnist/lenet_train_test.prototxt' "
"weights: 'examples/mnist/lenet_train_test1.caffemodel' "
"weights: 'examples/mnist/lenet_train_test2.caffemodel' "
"test_iter: 100 "
"test_interval: 500 "
"base_lr: 0.01 "
Expand All @@ -2968,6 +2970,8 @@ TEST_F(SolverTypeUpgradeTest, TestSimple) {
"solver_type: " + std::string(old_type_vec[i]) + " ";
const string& expected_output_proto =
"net: 'examples/mnist/lenet_train_test.prototxt' "
"weights: 'examples/mnist/lenet_train_test1.caffemodel' "
"weights: 'examples/mnist/lenet_train_test2.caffemodel' "
"test_iter: 100 "
"test_interval: 500 "
"base_lr: 0.01 "
Expand Down
21 changes: 21 additions & 0 deletions src/caffe/util/upgrade_proto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
#include <google/protobuf/io/zero_copy_stream_impl.h>
#include <google/protobuf/text_format.h>

#include <boost/filesystem.hpp>

#include <map>
#include <string>

Expand Down Expand Up @@ -1095,12 +1097,31 @@ bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param) {
return success;
}

// Replaces snapshot_prefix of SolverParameter if it is not specified
// or is set to directory
void UpgradeSnapshotPrefixProperty(const string& param_file,
SolverParameter* param) {
using boost::filesystem::path;
using boost::filesystem::is_directory;
if (!param->has_snapshot_prefix()) {
param->set_snapshot_prefix(path(param_file).replace_extension().string());
LOG(INFO) << "snapshot_prefix was not specified and is set to "
+ param->snapshot_prefix();
} else if (is_directory(param->snapshot_prefix())) {
param->set_snapshot_prefix((path(param->snapshot_prefix()) /
path(param_file).stem()).string());
LOG(INFO) << "snapshot_prefix was a directory and is replaced to "
+ param->snapshot_prefix();
}
}

// Read parameters from a file into a SolverParameter proto message.
void ReadSolverParamsFromTextFileOrDie(const string& param_file,
SolverParameter* param) {
CHECK(ReadProtoFromTextFile(param_file, param))
<< "Failed to parse SolverParameter file: " << param_file;
UpgradeSolverAsNeeded(param_file, param);
UpgradeSnapshotPrefixProperty(param_file, param);
}

} // namespace caffe
23 changes: 7 additions & 16 deletions tools/caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,20 +146,6 @@ int device_query() {
}
RegisterBrewFunction(device_query);

// Load the weights from the specified caffemodel(s) into the train and
// test nets.
void CopyLayers(caffe::Solver<float>* solver, const std::string& model_list) {
std::vector<std::string> model_names;
boost::split(model_names, model_list, boost::is_any_of(",") );
for (int i = 0; i < model_names.size(); ++i) {
LOG(INFO) << "Finetuning from " << model_names[i];
solver->net()->CopyTrainedLayersFrom(model_names[i]);
for (int j = 0; j < solver->test_nets().size(); ++j) {
solver->test_nets()[j]->CopyTrainedLayersFrom(model_names[i]);
}
}
}

// Translate the signal effect the user specified on the command-line to the
// corresponding enumeration.
caffe::SolverAction::Enum GetRequestedAction(
Expand Down Expand Up @@ -233,6 +219,13 @@ int train() {
GetRequestedAction(FLAGS_sigint_effect),
GetRequestedAction(FLAGS_sighup_effect));

if (FLAGS_snapshot.size()) {
solver_param.clear_weights();
} else if (FLAGS_weights.size()) {
solver_param.clear_weights();
solver_param.add_weights(FLAGS_weights);
}

shared_ptr<caffe::Solver<float> >
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));

Expand All @@ -241,8 +234,6 @@ int train() {
if (FLAGS_snapshot.size()) {
LOG(INFO) << "Resuming from " << FLAGS_snapshot;
solver->Restore(FLAGS_snapshot.c_str());
} else if (FLAGS_weights.size()) {
CopyLayers(solver.get(), FLAGS_weights);
}

LOG(INFO) << "Starting Optimization";
Expand Down