Skip to content

Commit

Permalink
Fix backend functions bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ykim362 committed Jul 4, 2019
1 parent d82bd9d commit cb37661
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 18 deletions.
6 changes: 4 additions & 2 deletions src/rescorer/rescorer.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,10 @@ class Rescore : public ModelTask {
for(auto device : devices) {
auto graph = New<ExpressionGraph>(true);
graph->setDevice(device);
graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
graph->getBackend()->setGemmType(options_->get<std::string>("gemm-type"));
if (device.type == DeviceType::cpu) {
graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
graph->getBackend()->setGemmType(options_->get<std::string>("gemm-type"));
}
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_.push_back(graph);
}
Expand Down
8 changes: 4 additions & 4 deletions src/tensors/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ class Backend {
virtual void setClip(float clipValue) { clipValue_ = clipValue; }
float getClip() { return clipValue_; }

// for CPU & inference only, sets to use optimized code for inference.
// The program aborts if these are called from GPU device.
// for CPU, sets to use optimized code for inference.
// for GPU, this is always false.
virtual void setOptimized(bool optimize) = 0;
virtual bool isOptimized() = 0;
// for CPU only, selects different GEMM types for the inference.
// The program aborts if these are called from GPU device.
// for CPU, selects different GEMM types for the inference.
// for GPU, there's no gemm type. so, it does nothing.
virtual void setGemmType(std::string gemmType) = 0;
virtual GemmType getGemmType() = 0;
};
Expand Down
19 changes: 11 additions & 8 deletions src/tensors/gpu/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,21 @@ class Backend : public marian::Backend {
cublasHandle_t getCublasHandle() { return cublasHandle_; }
cusparseHandle_t getCusparseHandle() { return cusparseHandle_; }

// for CPU & inference only, sets to use optimized code for inference.
// The program aborts if these are called from GPU device.
void setOptimized(bool optimize) override { ABORT("Not supported for GPU_{}", optimize); }
// for CPU, sets to use optimized code for inference.
// for GPU, this is always false.
void setOptimized(bool optimize) override {
LOG(info, "Not supported for GPU_{}", optimize);
}
bool isOptimized() override {
ABORT("Not supported for GPU");
return false;
}
// for CPU only, selects different GEMM types for the inference.
// The program aborts if these are called from GPU device.
void setGemmType(std::string gemmType) override { ABORT("Not supported for GPU_{}", gemmType); }
// for CPU, selects different GEMM types for the inference.
// for GPU, there's no gemm type. so, it does nothing.
void setGemmType(std::string gemmType) override {
LOG(info, "Not supported for GPU_{}", gemmType);
}
GemmType getGemmType() override {
ABORT("Not supported for GPU");
LOG(info, "Not supported for GPU");
return GemmType::Auto;
}

Expand Down
12 changes: 8 additions & 4 deletions src/translator/translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ class Translate : public ModelTask {
auto graph = New<ExpressionGraph>(true);
graph->setDevice(device);
graph->getBackend()->setClip(options_->get<float>("clip-gemm"));
graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
graph->getBackend()->setGemmType(options_->get<std::string>("gemm-type"));
if (device.type == DeviceType::cpu) {
graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
graph->getBackend()->setGemmType(options_->get<std::string>("gemm-type"));
}
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_[id] = graph;

Expand Down Expand Up @@ -172,8 +174,10 @@ class TranslateService : public ModelServiceTask {
auto graph = New<ExpressionGraph>(true);
graph->setDevice(device);
graph->getBackend()->setClip(options_->get<float>("clip-gemm"));
graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
graph->getBackend()->setGemmType(options_->get<std::string>("gemm-type"));
if (device.type == DeviceType::cpu) {
graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
graph->getBackend()->setGemmType(options_->get<std::string>("gemm-type"));
}
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_.push_back(graph);

Expand Down

0 comments on commit cb37661

Please sign in to comment.