Skip to content

Commit

Permalink
Bugfixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
naibaf7 committed Jun 12, 2018
1 parent bc4153a commit 8f37a14
Show file tree
Hide file tree
Showing 22 changed files with 153 additions and 83 deletions.
2 changes: 1 addition & 1 deletion include/caffe/layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ class Layer : public LayerBase {
quant_param.set_input_data_type(layer_param_.compute_data_type());
quant_param.set_output_data_type(layer_param_.top_data_type());
if (!quant_param.has_name()) {
quant_param.set_name(layer_param_.bottom_size() > i ?
quant_param.set_name(layer_param_.top_size() > i ?
layer_param_.top(i) : this->layer_param_.name() + "_top_"
+ std::to_string(i) + "_quant"); }
top_quants_.push_back(
Expand Down
2 changes: 1 addition & 1 deletion include/caffe/layers/loss_layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class LossLayer : public Layer<Dtype, MItype, MOtype> {

/**
* @brief For convenience and backwards compatibility, instruct the Net to
* automatically allocate a single top Blob for LossLayers, int_tpo which
* automatically allocate a single top Blob for LossLayers, into which
* they output their singleton loss, (even if the user didn't specify
* one in the prototxt, etc.).
*/
Expand Down
6 changes: 3 additions & 3 deletions src/caffe/layers/bnll_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ void BNLLLayer<Dtype, MItype, MOtype>::Forward_cpu(
const int_tp count = bottom[0]->count();
for (int_tp i = 0; i < count; ++i) {
top_data[i] = bottom_data[i] > 0 ?
bottom_data[i] + std::log(1. + exp(-bottom_data[i])) :
std::log(1. + exp(bottom_data[i]));
bottom_data[i] + std::log(1. + std::exp(-bottom_data[i])) :
std::log(1. + std::exp(bottom_data[i]));
}
}

Expand All @@ -34,7 +34,7 @@ void BNLLLayer<Dtype, MItype, MOtype>::Backward_cpu(
const int_tp count = bottom[0]->count();
Dtype expval;
for (int_tp i = 0; i < count; ++i) {
expval = exp(std::min(bottom_data[i], Dtype(kBNLL_THRESHOLD)));
expval = std::exp(std::min(bottom_data[i], Dtype(kBNLL_THRESHOLD)));
bottom_diff[i] = top_diff[i] * expval / (expval + 1.);
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/caffe/layers/contrastive_loss_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void ContrastiveLossLayer<Dtype, MItype, MOtype>::Forward_cpu(
if (legacy_version) {
loss += fmax(margin - dist_sq_.cpu_data()[i], Dtype(0.0));
} else {
Dtype dist = std::max<Dtype>(margin - sqrt(dist_sq_.cpu_data()[i]),
Dtype dist = std::max<Dtype>(margin - std::sqrt(dist_sq_.cpu_data()[i]),
Dtype(0.0));
loss += dist*dist;
}
Expand Down Expand Up @@ -105,7 +105,7 @@ void ContrastiveLossLayer<Dtype, MItype, MOtype>::Backward_cpu(const vector<Blob
mdist = margin - dist_sq_.cpu_data()[j];
beta = -alpha;
} else {
Dtype dist = sqrt(dist_sq_.cpu_data()[j]);
Dtype dist = std::sqrt(dist_sq_.cpu_data()[j]);
mdist = margin - dist;
beta = -alpha * mdist / (dist + Dtype(1e-4));
}
Expand Down
5 changes: 3 additions & 2 deletions src/caffe/layers/dropout_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ void DropoutLayer<Dtype, MItype, MOtype>::Forward_gpu(
vector<size_t> local;
this->device_->get_threads(&work_size, &group, &local, kernel.get(), true);
} else {
this->device_->copy(count, bottom_data, top_data);
this->device_->template copy<Dtype>(count, bottom_data, top_data);
}
}

Expand Down Expand Up @@ -115,7 +115,8 @@ void DropoutLayer<Dtype, MItype, MOtype>::Backward_gpu(
this->device_->get_threads(&work_size, &group, &local, kernel.get(),
true);
} else {
this->device_->copy(top[0]->count(), top_diff, bottom_diff);
this->device_->template copy<Dtype>(top[0]->count(), top_diff,
bottom_diff);
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/caffe/layers/elu_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ void ELULayer<Dtype, MItype, MOtype>::Forward_cpu(
Dtype alpha = this->layer_param_.elu_param().alpha();
for (int i = 0; i < count; ++i) {
top_data[i] = std::max(bottom_data[i], Dtype(0))
+ alpha * (exp(std::min(bottom_data[i], Dtype(0))) - Dtype(1));
+ alpha * (std::exp(std::min(bottom_data[i], Dtype(0))) - Dtype(1));
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/caffe/layers/exp_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ void ExpLayer<Dtype, MItype, MOtype>::LayerSetUp(
const Dtype input_shift = this->layer_param_.exp_param().shift();
inner_scale_ = log_base * input_scale;
outer_scale_ = (input_shift == Dtype(0)) ? Dtype(1) :
( (base != Dtype(-1)) ? pow(base, input_shift) : exp(input_shift) );
((base != Dtype(-1)) ? pow(base, input_shift) :
static_cast<Dtype>(std::exp(input_shift)));

this->InitializeQuantizers(bottom, top);
}
Expand Down
2 changes: 1 addition & 1 deletion src/caffe/layers/lstm_unit_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace caffe {

template<typename Dtype, typename MItype, typename MOtype>
inline Dtype sigmoid(Dtype X) {
return 1. / (1. + exp(-X));
return 1. / (1. + std::exp(-X));
}

template<typename Dtype, typename MItype, typename MOtype>
Expand Down
85 changes: 64 additions & 21 deletions src/caffe/layers/moe_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,16 @@ void MOELayer<Dtype, MItype, MOtype>::LayerSetUp(
gating_net_ = make_shared<Net<float> >(gating_net_param,
this->device_);
vector<shared_ptr<BlobBase> > gating_net_params = gating_net_->params();
const vector<float>& params_lr =
gating_net_->params_lr();
const vector<float>& params_weight_decay =
gating_net_->params_weight_decay();
for (size_t i = 0; i < gating_net_params.size(); ++i) {
this->blobs_.push_back(
std::static_pointer_cast<Blob<Dtype> >(gating_net_params[i]));
ParamSpec *param_spec = this->layer_param_.add_param();
param_spec->set_lr_mult(params_lr[i]);
param_spec->set_decay_mult(params_weight_decay[i]);
}
} else {
LOG(FATAL) << "MOE Layer requires a gating network.";
Expand All @@ -75,7 +82,9 @@ void MOELayer<Dtype, MItype, MOtype>::LayerSetUp(
vector<shared_ptr<Net<float> > > expert_nets;
vector<shared_ptr<BlobBase> > expert_net_params_zero;
for (size_t k = 0;
k < (this->phase_ == caffe::TEST ? this->parallel_nets_ : 1); ++k) {
k < ((this->phase_ == caffe::TEST
&& !this->layer_param().moe_param().full_forward()) ?
this->parallel_nets_ : 1); ++k) {
NetParameter expert_net_param = moe_param.expert_net(i);
expert_net_param.mutable_state()->set_phase(this->phase_);
expert_net_param.set_force_backward(this->phase_ == caffe::TRAIN);
Expand All @@ -85,9 +94,16 @@ void MOELayer<Dtype, MItype, MOtype>::LayerSetUp(
if (k == 0) {
// If multiple copies of an expert exists, register the first and
// copy to the others (shared parameters)
const vector<float>& params_lr =
expert_net->params_lr();
const vector<float>& params_weight_decay =
expert_net->params_weight_decay();
for (size_t i = 0; i < expert_net_params.size(); ++i) {
this->blobs_.push_back(
std::static_pointer_cast<Blob<Dtype> >(expert_net_params[i]));
ParamSpec *param_spec = this->layer_param_.add_param();
param_spec->set_lr_mult(params_lr[i]);
param_spec->set_decay_mult(params_weight_decay[i]);
}
expert_net_params_zero = expert_net_params;
} else {
Expand Down Expand Up @@ -141,7 +157,8 @@ void MOELayer<Dtype, MItype, MOtype>::Reshape(
this->layer_param().moe_param().map_bottom(l) ==
MOEParameter_BottomMapping_GATING_AND_EXPERT) {
vector<int_tp> shape = bottom[l]->shape();
shape[0] = this->phase_ == caffe::TEST ? 1 : shape[0];
shape[0] = (this->phase_ == caffe::TEST &&
!this->layer_param().moe_param().full_forward()) ? 1 : shape[0];
expert_input_blobs[k]->Reshape(shape);
++k;
}
Expand All @@ -150,14 +167,19 @@ void MOELayer<Dtype, MItype, MOtype>::Reshape(
if (j == 0 and i == 0) {
const vector<BlobBase*>& expert_output_blobs =
this->expert_nets_[j][i]->output_blobs();
for (size_t l = 0; l < top.size(); ++l) {
for (size_t l = 0; l < top.size() - 2; ++l) {
vector<int_tp> shape = expert_output_blobs[l]->shape();
shape[0] = bottom[0]->shape()[0];
top[l]->Reshape(shape);
}
}
}
}
vector<int_tp> shape(2);
shape[0] = bottom[0]->shape()[0];
shape[1] = this->expert_nets_.size();
top[top.size()-2]->Reshape(shape);
top[top.size()-1]->Reshape(shape);
}

template<typename Dtype, typename MItype, typename MOtype>
Expand Down Expand Up @@ -187,6 +209,7 @@ void MOELayer<Dtype, MItype, MOtype>::Forward_cpu(
gating_ = static_cast<Blob<MOtype>*>(this->gating_net_->Forward(&loss)[0]);
MOtype* gating_data = gating_->mutable_cpu_data();

vector<int_tp> select_count(gating_->shape()[1], 0);
vector<vector<int_tp> > batch_selectors;

// Reset all top blobs
Expand All @@ -211,9 +234,28 @@ void MOELayer<Dtype, MItype, MOtype>::Forward_cpu(
}
}
}
for(size_t k = 0; k < select_experts; ++k) {
select_count[expert_selectors[k]] += 1;
}
batch_selectors.push_back(expert_selectors);
}

// Generate load balancing loss
if (this->phase_ == caffe::TRAIN) {
MOtype* observed_count = top[top.size()-2]->mutable_cpu_data();
MOtype* expected_count = top[top.size()-1]->mutable_cpu_data();
for (size_t j = 0; j < gating_->shape()[1]; ++j) {
MOtype norm_observed = static_cast<MOtype>(select_count[j])
/ static_cast<MOtype>(gating_->shape()[0]);
MOtype norm_expected = static_cast<MOtype>(select_experts)
/ static_cast<MOtype>(gating_->shape()[1]);
for (size_t i = 0; i < gating_->shape()[0]; ++i) {
observed_count[i * select_count.size() + j] = norm_observed;
expected_count[i * select_count.size() + j] = norm_expected;
}
}
}

// Make gating data sparse and renormalize
for (size_t i = 0; i < gating_->shape()[0]; ++i) {
MOtype norm = MOtype(0);
Expand All @@ -236,7 +278,8 @@ void MOELayer<Dtype, MItype, MOtype>::Forward_cpu(
}

// Forward experts
if (this->phase_ == caffe::TEST) {
if (this->phase_ == caffe::TEST &&
!this->layer_param().moe_param().full_forward()) {
// Forward expert networks (partial, only forward selected experts
// per batch item)
#pragma omp parallel for num_threads(this->parallel_nets_)
Expand Down Expand Up @@ -325,11 +368,11 @@ void MOELayer<Dtype, MItype, MOtype>::Backward_cpu(
caffe_set(bottom[i]->count(), MItype(0), bottom_diff);
}

// Reset gating diff
// Set gating diff to load balancing diff
const MOtype* gating_data = gating_->cpu_data();
MOtype* gating_diff = gating_->mutable_cpu_diff();
caffe_set(gating_->count(), MOtype(0), gating_diff);

const MOtype* observed_diff = top[top.size()-2]->cpu_diff();
caffe_copy(gating_->count(), observed_diff, gating_diff);

// Backward all experts
for (size_t j = 0; j < this->expert_nets_.size(); ++j) {
Expand Down Expand Up @@ -400,21 +443,21 @@ void MOELayer<Dtype, MItype, MOtype>::Backward_cpu(
}


INSTANTIATE_CLASS_3T_GUARDED(MOELayer, (half_fp), (half_fp), PROTO_TYPES);
INSTANTIATE_CLASS_3T_GUARDED(MOELayer, (float), (float), PROTO_TYPES);
INSTANTIATE_CLASS_3T_GUARDED(MOELayer, (double), (double), PROTO_TYPES);
INSTANTIATE_CLASS_3T_GUARDED(MOELayer, (uint8_t), (uint8_t), PROTO_TYPES);
INSTANTIATE_CLASS_3T_GUARDED(MOELayer, (uint16_t), (uint16_t), PROTO_TYPES);
INSTANTIATE_CLASS_3T_GUARDED(MOELayer, (uint32_t), (uint32_t), PROTO_TYPES);
INSTANTIATE_CLASS_3T_GUARDED(MOELayer, (uint64_t), (uint64_t), PROTO_TYPES);
INSTANTIATE_CLASS_3T_GUARDED(MOELayer, (half_fp), (half_fp), (float));
INSTANTIATE_CLASS_3T_GUARDED(MOELayer, (float), (float), (float));
INSTANTIATE_CLASS_3T_GUARDED(MOELayer, (double), (double), (float));
INSTANTIATE_CLASS_3T_GUARDED(MOELayer, (uint8_t), (uint8_t), (float));
INSTANTIATE_CLASS_3T_GUARDED(MOELayer, (uint16_t), (uint16_t), (float));
INSTANTIATE_CLASS_3T_GUARDED(MOELayer, (uint32_t), (uint32_t), (float));
INSTANTIATE_CLASS_3T_GUARDED(MOELayer, (uint64_t), (uint64_t), (float));

REGISTER_LAYER_CLASS(MOE);
REGISTER_LAYER_CLASS_INST(MOE, (half_fp), (half_fp), PROTO_TYPES);
REGISTER_LAYER_CLASS_INST(MOE, (float), (float), PROTO_TYPES);
REGISTER_LAYER_CLASS_INST(MOE, (double), (double), PROTO_TYPES);
REGISTER_LAYER_CLASS_INST(MOE, (uint8_t), (uint8_t), PROTO_TYPES);
REGISTER_LAYER_CLASS_INST(MOE, (uint16_t), (uint16_t), PROTO_TYPES);
REGISTER_LAYER_CLASS_INST(MOE, (uint32_t), (uint32_t), PROTO_TYPES);
REGISTER_LAYER_CLASS_INST(MOE, (uint64_t), (uint64_t), PROTO_TYPES);
REGISTER_LAYER_CLASS_INST(MOE, (half_fp), (half_fp), (float));
REGISTER_LAYER_CLASS_INST(MOE, (float), (float), (float));
REGISTER_LAYER_CLASS_INST(MOE, (double), (double), (float));
REGISTER_LAYER_CLASS_INST(MOE, (uint8_t), (uint8_t), (float));
REGISTER_LAYER_CLASS_INST(MOE, (uint16_t), (uint16_t), (float));
REGISTER_LAYER_CLASS_INST(MOE, (uint32_t), (uint32_t), (float));
REGISTER_LAYER_CLASS_INST(MOE, (uint64_t), (uint64_t), (float));

} // namespace caffe
29 changes: 25 additions & 4 deletions src/caffe/layers/moe_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ void MOELayer<Dtype, MItype, MOtype>::Forward_gpu(
gating_ = static_cast<Blob<MOtype>*>(this->gating_net_->Forward(&loss)[0]);
MOtype* gating_data = gating_->mutable_cpu_data();

vector<int_tp> select_count(gating_->shape()[1], 0);
vector<vector<int_tp> > batch_selectors;

// Reset all top blobs
Expand All @@ -56,9 +57,28 @@ void MOELayer<Dtype, MItype, MOtype>::Forward_gpu(
}
}
}
for(size_t k = 0; k < select_experts; ++k) {
select_count[expert_selectors[k]] += 1;
}
batch_selectors.push_back(expert_selectors);
}

// Generate load balancing loss
if (this->phase_ == caffe::TRAIN) {
MOtype* observed_count = top[top.size()-2]->mutable_cpu_data();
MOtype* expected_count = top[top.size()-1]->mutable_cpu_data();
for (size_t j = 0; j < gating_->shape()[1]; ++j) {
MOtype norm_observed = static_cast<MOtype>(select_count[j])
/ static_cast<MOtype>(gating_->shape()[0]);
MOtype norm_expected = static_cast<MOtype>(select_experts)
/ static_cast<MOtype>(gating_->shape()[1]);
for (size_t i = 0; i < gating_->shape()[0]; ++i) {
observed_count[i * select_count.size() + j] = norm_observed;
expected_count[i * select_count.size() + j] = norm_expected;
}
}
}

// Make gating data sparse and renormalize
for (size_t i = 0; i < gating_->shape()[0]; ++i) {
MOtype norm = MOtype(0);
Expand All @@ -81,7 +101,8 @@ void MOELayer<Dtype, MItype, MOtype>::Forward_gpu(
}

// Forward experts
if (this->phase_ == caffe::TEST) {
if (this->phase_ == caffe::TEST &&
!this->layer_param().moe_param().full_forward()) {
// Forward expert networks (partial, only forward selected experts
// per batch item)
#pragma omp parallel for num_threads(this->parallel_nets_)
Expand Down Expand Up @@ -176,11 +197,11 @@ void MOELayer<Dtype, MItype, MOtype>::Backward_gpu(
caffe_set(bottom[i]->count(), MItype(0), bottom_diff);
}

// Reset gating diff
// Set gating diff to load balancing diff
const MOtype* gating_data = gating_->cpu_data();
MOtype* gating_diff = gating_->mutable_cpu_diff();
caffe_set(gating_->count(), MOtype(0), gating_diff);

const MOtype* observed_diff = top[top.size()-2]->cpu_diff();
caffe_copy(gating_->count(), observed_diff, gating_diff);

// Backward all experts
for (size_t j = 0; j < this->expert_nets_.size(); ++j) {
Expand Down
12 changes: 6 additions & 6 deletions src/caffe/libdnn/libdnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ std::map<string, int64_t> LibDNN<MItype, MOtype>
params["workgroup_size_0"] = 8;
params["workgroup_size_1"] = 8;
params["workgroup_size_2"] = 1;
params["TSK"] = 16 / safe_sizeof<MItype>();;
params["TSK_UNROLL"] = 16 / safe_sizeof<MItype>();
params["WPTM"] = 16 / safe_sizeof<MItype>();
params["WPTN"] = 16 / safe_sizeof<MItype>();
params["VWM"] = 16 / safe_sizeof<MItype>();
params["VWN"] = 16 / safe_sizeof<MItype>();
params["TSK"] = 8 / safe_sizeof<MItype>();;
params["TSK_UNROLL"] = 8 / safe_sizeof<MItype>();
params["WPTM"] = 8 / safe_sizeof<MItype>();
params["WPTN"] = 8 / safe_sizeof<MItype>();
params["VWM"] = 8 / safe_sizeof<MItype>();
params["VWN"] = 8 / safe_sizeof<MItype>();
}

if (this->dev_ptr_->name().find("VideoCore IV") != string::npos) {
Expand Down
1 change: 1 addition & 0 deletions src/caffe/proto/caffe.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,7 @@ message MOEParameter {
repeated NetParameter expert_net = 2;
repeated int64 expert_instances = 3;
optional int64 select_experts = 4;
optional bool full_forward = 6 [default = false];

enum BottomMapping {
GATING = 1;
Expand Down
2 changes: 1 addition & 1 deletion src/caffe/solvers/adam_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ void AdamSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
Blob<Dtype>* val_t = this->temp_[param_id].get();

const uint_tp t = this->iter_ + 1;
const Dtype correction = sqrt(Dtype(1) - pow(beta2, Dtype(t))) /
const Dtype correction = std::sqrt(Dtype(1) - pow(beta2, Dtype(t))) /
(Dtype(1.) - pow(beta1, Dtype(t)));
const uint_tp n = net_params[param_id]->count();
const Dtype eps_hat = this->param_.delta();
Expand Down
2 changes: 1 addition & 1 deletion src/caffe/solvers/sgd_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ Dtype SGDSolver<Dtype>::GetLearningRate() {
CHECK_GE(this->param_.gamma(), 0);
CHECK_GT(this->param_.stepsize(), 0);
rate = this->param_.base_lr() * (Dtype(1.) /
(Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) -
(Dtype(1.) + std::exp(-this->param_.gamma() * (Dtype(this->iter_) -
Dtype(this->param_.stepsize())))));
} else {
LOG(FATAL) << "Unknown learning rate policy: " << lr_policy;
Expand Down
2 changes: 1 addition & 1 deletion src/caffe/test/test_contrastive_loss_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ TYPED_TEST(ContrastiveLossLayerTest, TestForward) {
if (this->blob_bottom_y_->cpu_data()[i]) { // similar pairs
loss += dist_sq;
} else {
Dtype dist = fmax(Dtype(margin - sqrt(dist_sq)), Dtype(0.0));
Dtype dist = fmax(Dtype(margin - std::sqrt(dist_sq)), Dtype(0.0));
loss += dist*dist;
}
}
Expand Down
Loading

0 comments on commit 8f37a14

Please sign in to comment.