Skip to content

Commit

Permalink
OpenMP improvements, parallel processing in data loader to keep up wi…
Browse files Browse the repository at this point in the history
…th GPU speeds.
  • Loading branch information
naibaf7 committed May 8, 2018
1 parent 243d5a5 commit 52513a8
Show file tree
Hide file tree
Showing 19 changed files with 57 additions and 65 deletions.
1 change: 1 addition & 0 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ if(USE_OPENMP)
list(APPEND Caffe_LINKER_LIBS PRIVATE ${OpenMP_CXX_FLAGS})
list(APPEND Caffe_COMPILE_OPTIONS PRIVATE ${OpenMP_CXX_FLAGS})
endif()
list(APPEND Caffe_DEFINITIONS PUBLIC -DUSE_OPENMP)
endif()

# ---[ BLAS
Expand Down
2 changes: 1 addition & 1 deletion include/caffe/backend/device_program.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class DeviceProgram {
string define_vector_type(string name, int_tp from, int_tp to);

template<typename Dtype>
KernelArg create_kernel_arg(string name, uint64_t flags);
KernelArg create_kernel_arg(string name, uint64_t flags = KERNEL_ARG_NONE);

protected:
DeviceProgram(Device* dev);
Expand Down
3 changes: 2 additions & 1 deletion include/caffe/data_transformer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ class DataTransformer {
* This is destination blob. It can be part of top blob's data if
* set_cpu_data() is used. See data_layer.cpp for an example.
*/
void Transform(const Datum& datum, Blob<Dtype>* transformed_blob);
void Transform(const Datum& datum, Blob<Dtype>* transformed_blob,
int_tp offset);

/**
* @brief Applies the transformation defined in the data layer's
Expand Down
8 changes: 8 additions & 0 deletions include/caffe/definitions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
#include <iostream> // NOLINT(readability/streams)
#include <map>
#include <memory>
#ifdef USE_OPENMP
#include <omp.h>
#endif // USE_OPENMP
#include <set>
#include <sstream>
#include <string>
Expand All @@ -18,6 +21,7 @@
#include <vector>
#include <boost/variant.hpp>


#include "caffe/trait_helper.hpp"
#include "caffe/util/half_fp.hpp"
#include "caffe/util/inline_math.hpp"
Expand Down Expand Up @@ -49,6 +53,10 @@
#define CAFFE_MALLOC_CACHE_ALIGN 64
#endif // CAFFE_MALLOC_CACHE_ALIGN

#ifndef CAFFE_OMP_BYTE_STRIDE
#define CAFFE_OMP_BYTE_STRIDE 8
#endif // CAFFE_OMP_BYTE_STRIDE

namespace caffe {

// Common functions and classes from std and boost that Caffe often uses.
Expand Down
18 changes: 12 additions & 6 deletions src/caffe/data_transformer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ void DataTransformer<Dtype>::Transform(const Datum& datum,

template<typename Dtype>
void DataTransformer<Dtype>::Transform(const Datum& datum,
Blob<Dtype>* transformed_blob) {
Blob<Dtype>* transformed_blob,
int_tp offset) {
// If datum is encoded, decode and transform the cv::image.
if (datum.encoded()) {
#ifdef USE_OPENCV
Expand Down Expand Up @@ -180,7 +181,7 @@ void DataTransformer<Dtype>::Transform(const Datum& datum,
CHECK_EQ(datum_width, width);
}

Dtype* transformed_data = transformed_blob->mutable_cpu_data();
Dtype* transformed_data = transformed_blob->mutable_cpu_data() + offset;
Transform(datum, transformed_data);
}

Expand All @@ -196,11 +197,12 @@ void DataTransformer<Dtype>::Transform(const vector<Datum> & datum_vector,
CHECK_GT(datum_num, 0)<< "There is no datum to add";
CHECK_LE(datum_num, num)<<
"The size of datum_vector must be no greater than transformed_blob->num()";
Blob<Dtype> uni_blob(1, channels, height, width, device_);
#pragma omp parallel for
for (int_tp item_id = 0; item_id < datum_num; ++item_id) {
Blob<Dtype> uni_blob(1, channels, height, width, device_);
int_tp offset = transformed_blob->offset(item_id);
uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset);
Transform(datum_vector[item_id], &uni_blob);
uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data());
Transform(datum_vector[item_id], &uni_blob, offset);
}
}

Expand All @@ -217,8 +219,9 @@ void DataTransformer<Dtype>::Transform(const vector<cv::Mat> & mat_vector,
CHECK_GT(mat_num, 0)<< "There is no MAT to add";
CHECK_EQ(mat_num, num)<<
"The size of mat_vector must be equals to transformed_blob->num()";
Blob<Dtype> uni_blob(1, channels, height, width, device_);
#pragma omp parallel for
for (int_tp item_id = 0; item_id < mat_num; ++item_id) {
Blob<Dtype> uni_blob(1, channels, height, width, device_);
int_tp offset = transformed_blob->offset(item_id);
uni_blob.set_cpu_data(transformed_blob->mutable_cpu_data() + offset);
Transform(mat_vector[item_id], &uni_blob);
Expand Down Expand Up @@ -396,6 +399,7 @@ void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,
CHECK_EQ(input_channels, data_mean_.channels());
CHECK_EQ(input_height, data_mean_.height());
CHECK_EQ(input_width, data_mean_.width());
#pragma omp parallel for
for (int_tp n = 0; n < input_num; ++n) {
int_tp offset = input_blob->offset(n);
caffe_sub<Dtype>(data_mean_.count(), input_data + offset,
Expand All @@ -411,6 +415,7 @@ void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,
caffe_add_scalar<Dtype>(input_blob->count(), -(mean_values_[0]),
input_data);
} else {
#pragma omp parallel for
for (int_tp n = 0; n < input_num; ++n) {
for (int_tp c = 0; c < input_channels; ++c) {
int_tp offset = input_blob->offset(n, c);
Expand All @@ -423,6 +428,7 @@ void DataTransformer<Dtype>::Transform(Blob<Dtype>* input_blob,

Dtype* transformed_data = transformed_blob->mutable_cpu_data();

#pragma omp parallel for
for (int_tp n = 0; n < input_num; ++n) {
int_tp top_index_n = n * channels;
int_tp data_index_n = n * channels;
Expand Down
1 change: 0 additions & 1 deletion src/caffe/layers/affinity_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ void AffinityLayer<Dtype, MItype, MOtype>::Forward_cpu(
int_tp xmin, ymin;

// Construct affinity graph
#pragma omp parallel for
for (int_tp i = 0; i < bottom[bidx]->height() - 1; ++i) {
for (int_tp j = 0; j < bottom[bidx]->width() - 1; ++j) {
// Center
Expand Down
1 change: 0 additions & 1 deletion src/caffe/layers/connected_component_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ void ConnectedComponentLayer<Dtype, MItype, MOtype>::Forward_cpu(
}
}
cv::Mat seg = FindBlobs(maxlabel, img);
#pragma omp parallel for
for (int_tp Y = 0; Y < seg.rows; ++Y) {
for (int_tp X = 0; X < seg.cols; ++X) {
top_data[nc * bottom[0]->width() * bottom[0]->height()
Expand Down
35 changes: 21 additions & 14 deletions src/caffe/layers/data_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,21 @@ template<typename Dtype, typename MItype, typename MOtype>
void DataLayer<Dtype, MItype, MOtype>::load_batch(Batch<Dtype>* batch) {
CPUTimer batch_timer;
batch_timer.Start();
double read_time = 0;
double trans_time = 0;
double read_time = 0.0;
double trans_time = 0.0;
CPUTimer timer;
CHECK(batch->data_.count());
CHECK(this->transformed_data_.count());
const int_tp batch_size = this->layer_param_.data_param().batch_size();

Datum datum;
vector<Datum> datum(batch_size);

for (int_tp item_id = 0; item_id < batch_size; ++item_id) {
timer.Start();
while (Skip()) {
Next();
}
datum.ParseFromString(cursor_->value());
read_time += timer.MicroSeconds();

datum[item_id].ParseFromString(cursor_->value());
if (item_id == 0) {
// Reshape according to the first datum of each batch
// on single input batches allows for inputs of varying dimension.
Expand All @@ -112,21 +111,29 @@ void DataLayer<Dtype, MItype, MOtype>::load_batch(Batch<Dtype>* batch) {
top_shape[0] = batch_size;
batch->data_.Reshape(top_shape);
}
read_time += timer.MicroSeconds();
Next();
}

timer.Start();
Dtype* top_data = batch->data_.mutable_cpu_data();
Dtype* top_label = batch->label_.mutable_cpu_data();

this->transformed_data_.set_cpu_data(top_data);
#pragma omp parallel for
for (int_tp item_id = 0; item_id < batch_size; ++item_id) {
// Apply data transformations (mirror, scale, crop...)
timer.Start();
int_tp offset = batch->data_.offset(item_id);
Dtype* top_data = batch->data_.mutable_cpu_data();
this->transformed_data_.set_cpu_data(top_data + offset);
this->data_transformer_->Transform(datum, &(this->transformed_data_));
this->data_transformer_->Transform(datum[item_id],
&(this->transformed_data_),
offset);
// Copy label.
if (this->output_labels_) {
Dtype* top_label = batch->label_.mutable_cpu_data();
top_label[item_id] = datum.label();
top_label[item_id] = datum[item_id].label();
}
trans_time += timer.MicroSeconds();
Next();
}
trans_time += timer.MicroSeconds();

timer.Stop();
batch_timer.Stop();
DLOG(INFO)<< "Prefetch batch: " << batch_timer.MilliSeconds() << " ms.";
Expand Down
4 changes: 1 addition & 3 deletions src/caffe/layers/malis_loss_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ void MalisLossLayer<Dtype, MItype, MOtype>::Forward_cpu(const vector<Blob<MItype

// Affinity graph must be in the range (0,1)
// square loss (euclidean) is used by MALIS
#pragma omp parallel for
for (int_tp i = 0; i < bottom[0]->count(); ++i) {
affinity_data_pos[i] = std::min(affinity_prob[i], affinity[i]);
affinity_data_neg[i] = std::max(affinity_prob[i], affinity[i]);
Expand All @@ -390,7 +389,7 @@ void MalisLossLayer<Dtype, MItype, MOtype>::Forward_cpu(const vector<Blob<MItype
}
}

Dtype loss = 0;
float loss = 0;

#pragma omp parallel for reduction(+:loss)
for (int_tp batch = 0; batch < bottom[0]->shape()[0]; ++batch) {
Expand Down Expand Up @@ -442,7 +441,6 @@ void MalisLossLayer<Dtype, MItype, MOtype>::Backward_cpu(const vector<Blob<MOtyp
// Clear the diff
caffe_set(bottom[0]->count(), Dtype(0.0), bottom_diff);

#pragma omp parallel for
for (int_tp i = 0; i < bottom[0]->count(); ++i) {
bottom_diff[i] = -(dloss_neg_data[i] + dloss_pos_data[i]) / 2.0;
}
Expand Down
2 changes: 0 additions & 2 deletions src/caffe/layers/moe_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ void MOELayer<Dtype, MItype, MOtype>::Forward_cpu(
const MOtype* gating_data = gating->cpu_data();
MOtype eps = MOtype(0);
size_t j = 0;
#pragma omp parallel for
for (size_t i = 0; i < this->expert_nets_.size(); ++i) {
vector<Blob<MOtype>*> result_vec;
// If the gating network selects this expert, preload blobs and forward
Expand Down Expand Up @@ -101,7 +100,6 @@ void MOELayer<Dtype, MItype, MOtype>::Forward_cpu(
}

// Loop over all top blobs
#pragma omp parallel for
for (size_t i = 0; i < top.size(); ++i) {
MOtype* top_data = top[i]->mutable_cpu_data();
caffe_set(top[i]->count(), MOtype(0), top_data);
Expand Down
2 changes: 0 additions & 2 deletions src/caffe/layers/moe_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ void MOELayer<Dtype, MItype, MOtype>::Forward_gpu(
const MOtype* gating_data = gating->cpu_data();
MOtype eps = MOtype(0);
size_t j = 0;
#pragma omp parallel for
for (size_t i = 0; i < this->expert_nets_.size(); ++i) {
vector<Blob<MOtype>*> result_vec;
// If the gating network selects this expert, preload blobs and forward
Expand Down Expand Up @@ -66,7 +65,6 @@ void MOELayer<Dtype, MItype, MOtype>::Forward_gpu(
}

// Loop over all top blobs
#pragma omp parallel for
for (size_t i = 0; i < top.size(); ++i) {
vptr<MOtype> top_data = top[i]->mutable_gpu_data();
this->device_->template set<MOtype>(top[i]->count(), MOtype(0), top_data);
Expand Down
2 changes: 0 additions & 2 deletions src/caffe/layers/relu_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ inline void forward(int_tp count, const Dtype* bottom_data, Dtype* top_data,
Acctype top_zero = top_qv->get_zero<Acctype>();
Acctype top_min = top_qv->get_min<Acctype>();
Acctype top_max = top_qv->get_max<Acctype>();
#pragma omp parallel for
for (int_tp i = 0; i < count; ++i) {
Difftype relu = std::max(static_cast<Difftype>(
static_cast<Difftype>(bottom_data[i]) - bottom_zero), Difftype(0));
Expand All @@ -59,7 +58,6 @@ template<typename Dtype,
inline void forward(int_tp count, const Dtype* bottom_data, Dtype* top_data,
Dtype negative_slope, const QuantizerValues* const bottom_qv = nullptr,
const QuantizerValues* const top_qv = nullptr) {
#pragma omp parallel for
for (int_tp i = 0; i < count; ++i) {
top_data[i] = std::max(bottom_data[i], Dtype(0))
+ negative_slope * std::min(bottom_data[i], Dtype(0));
Expand Down
2 changes: 0 additions & 2 deletions src/caffe/layers/sigmoid_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ void SigmoidLayer<Dtype, MItype, MOtype>::Forward_cpu(
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = top[0]->mutable_cpu_data();
const int_tp count = bottom[0]->count();
#pragma omp parallel for
for (int_tp i = 0; i < count; ++i) {
top_data[i] = sigmoid<Dtype, MItype, MOtype>(bottom_data[i]);
}
Expand All @@ -44,7 +43,6 @@ void SigmoidLayer<Dtype, MItype, MOtype>::Backward_cpu(
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
const int_tp count = bottom[0]->count();
#pragma omp parallel for
for (int_tp i = 0; i < count; ++i) {
const Dtype sigmoid_x = top_data[i];
bottom_diff[i] = top_diff[i] * sigmoid_x * (1. - sigmoid_x);
Expand Down
2 changes: 0 additions & 2 deletions src/caffe/layers/swish_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ void SwishLayer<Dtype, MItype, MOtype>::Forward_cpu(
MOtype* top_data = top[0]->mutable_cpu_data();
const int_tp count = bottom[0]->count();
Dtype beta = this->layer_param_.swish_param().beta();
#pragma omp parallel for
for (int_tp i = 0; i < count; ++i) {
top_data[i] = bottom_data[i] *
sigmoid<Dtype, MItype, MOtype>(beta * bottom_data[i]);
Expand All @@ -56,7 +55,6 @@ void SwishLayer<Dtype, MItype, MOtype>::Backward_cpu(
MItype* bottom_diff = bottom[0]->mutable_cpu_diff();
const int_tp count = bottom[0]->count();
Dtype beta = this->layer_param_.swish_param().beta();
#pragma omp parallel for
for (int_tp i = 0; i < count; ++i) {
const Dtype swish_x = top_data[i];
bottom_diff[i] = top_diff[i] * (beta * swish_x +
Expand Down
Loading

0 comments on commit 52513a8

Please sign in to comment.