Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"weights" added to solver parameters, "snapshot_prefix" field default initialization #6123

Merged
merged 2 commits into from
Feb 12, 2018
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Weight parameter in solver is used in caffe.exe
Loading weights is moved from caffe.exe to solver class, so new "weights" solver parameter is used not only from command line but when caffe is used as library (including python)

corrected formatting

fixed line length

more formatting corrected
  • Loading branch information
IlyaOvodov committed Feb 10, 2018
commit c32629435a1e5dacc8a90a309d8b255d7b629379
12 changes: 11 additions & 1 deletion src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ message NetParameter {
// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
// SolverParameter next available ID: 42 (last added: layer_wise_reduce)
// SolverParameter next available ID: 43 (last added: weights)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
Expand Down Expand Up @@ -241,6 +241,16 @@ message SolverParameter {

// Overlap compute and communication for data parallel training
optional bool layer_wise_reduce = 41 [default = true];

// Path to caffemodel file(s) with pretrained weights to initialize finetuning.
// Tha same as command line --weights parameter for caffe train command.
// If command line --weights parameter if specified, it has higher priority
// and owerwrites this one(s).
// If --snapshot command line parameter is specified, this one(s) are ignored.
// If several model files are expected, they can be listed in a one
// weights parameter separated by ',' (like in a command string) or
// in repeated weights parameters separately.
repeated string weights = 42;
}

// A message that stores the solver snapshots
Expand Down
21 changes: 21 additions & 0 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <string>
#include <vector>

#include "boost/algorithm/string.hpp"
#include "caffe/solver.hpp"
#include "caffe/util/format.hpp"
#include "caffe/util/hdf5.hpp"
Expand Down Expand Up @@ -59,6 +60,20 @@ void Solver<Dtype>::Init(const SolverParameter& param) {
current_step_ = 0;
}

// Load weights from the caffemodel(s) specified in "weights" solver parameter
// into the train and test nets.
template <typename Dtype>
void LoadNetWeights(shared_ptr<Net<Dtype> > net,
const std::string& model_list) {
std::vector<std::string> model_names;
boost::split(model_names, model_list, boost::is_any_of(","));
for (int i = 0; i < model_names.size(); ++i) {
boost::trim(model_names[i]);
LOG(INFO) << "Finetuning from " << model_names[i];
net->CopyTrainedLayersFrom(model_names[i]);
}
}

template <typename Dtype>
void Solver<Dtype>::InitTrainNet() {
const int num_train_nets = param_.has_net() + param_.has_net_param() +
Expand Down Expand Up @@ -98,6 +113,9 @@ void Solver<Dtype>::InitTrainNet() {
net_state.MergeFrom(param_.train_state());
net_param.mutable_state()->CopyFrom(net_state);
net_.reset(new Net<Dtype>(net_param));
for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) {
LoadNetWeights(net_, param_.weights(w_idx));
}
}

template <typename Dtype>
Expand Down Expand Up @@ -173,6 +191,9 @@ void Solver<Dtype>::InitTestNets() {
<< "Creating test net (#" << i << ") specified by " << sources[i];
test_nets_[i].reset(new Net<Dtype>(net_params[i]));
test_nets_[i]->set_debug_info(param_.debug_info());
for (int w_idx = 0; w_idx < param_.weights_size(); ++w_idx) {
LoadNetWeights(test_nets_[i], param_.weights(w_idx));
}
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/caffe/test/test_upgrade_proto.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2952,6 +2952,8 @@ TEST_F(SolverTypeUpgradeTest, TestSimple) {
for (int i = 0; i < 6; ++i) {
const string& input_proto =
"net: 'examples/mnist/lenet_train_test.prototxt' "
"weights: 'examples/mnist/lenet_train_test1.caffemodel' "
"weights: 'examples/mnist/lenet_train_test2.caffemodel' "
"test_iter: 100 "
"test_interval: 500 "
"base_lr: 0.01 "
Expand All @@ -2968,6 +2970,8 @@ TEST_F(SolverTypeUpgradeTest, TestSimple) {
"solver_type: " + std::string(old_type_vec[i]) + " ";
const string& expected_output_proto =
"net: 'examples/mnist/lenet_train_test.prototxt' "
"weights: 'examples/mnist/lenet_train_test1.caffemodel' "
"weights: 'examples/mnist/lenet_train_test2.caffemodel' "
"test_iter: 100 "
"test_interval: 500 "
"base_lr: 0.01 "
Expand Down
23 changes: 7 additions & 16 deletions tools/caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,20 +146,6 @@ int device_query() {
}
RegisterBrewFunction(device_query);

// Load the weights from the specified caffemodel(s) into the train and
// test nets.
void CopyLayers(caffe::Solver<float>* solver, const std::string& model_list) {
std::vector<std::string> model_names;
boost::split(model_names, model_list, boost::is_any_of(",") );
for (int i = 0; i < model_names.size(); ++i) {
LOG(INFO) << "Finetuning from " << model_names[i];
solver->net()->CopyTrainedLayersFrom(model_names[i]);
for (int j = 0; j < solver->test_nets().size(); ++j) {
solver->test_nets()[j]->CopyTrainedLayersFrom(model_names[i]);
}
}
}

// Translate the signal effect the user specified on the command-line to the
// corresponding enumeration.
caffe::SolverAction::Enum GetRequestedAction(
Expand Down Expand Up @@ -233,6 +219,13 @@ int train() {
GetRequestedAction(FLAGS_sigint_effect),
GetRequestedAction(FLAGS_sighup_effect));

if (FLAGS_snapshot.size()) {
solver_param.clear_weights();
} else if (FLAGS_weights.size()) {
solver_param.clear_weights();
solver_param.add_weights(FLAGS_weights);
}

shared_ptr<caffe::Solver<float> >
solver(caffe::SolverRegistry<float>::CreateSolver(solver_param));

Expand All @@ -241,8 +234,6 @@ int train() {
if (FLAGS_snapshot.size()) {
LOG(INFO) << "Resuming from " << FLAGS_snapshot;
solver->Restore(FLAGS_snapshot.c_str());
} else if (FLAGS_weights.size()) {
CopyLayers(solver.get(), FLAGS_weights);
}

LOG(INFO) << "Starting Optimization";
Expand Down