Skip to content

Commit

Permalink
[Inference] add config.enable_low_precision_io api and remove rely on…
Browse files Browse the repository at this point in the history
… AnalysisConfig::Precison in trt (PaddlePaddle#52485)
  • Loading branch information
yuanlehome committed May 22, 2023
1 parent 5ac8c04 commit d1bbd90
Show file tree
Hide file tree
Showing 41 changed files with 468 additions and 240 deletions.
144 changes: 92 additions & 52 deletions paddle/fluid/framework/ir/auto_mixed_precision_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,16 @@ void AutoMixedPrecisionPass::Init(Graph* graph) const {
"Cannot enable custom_device_mixed."));
#endif
}
skip_pass_ = backend_ == phi::Backend::UNDEFINED;

low_precision_ = static_cast<phi::DataType>(Get<int>("mixed_precision_mode"));
if (Has("mixed_precision_mode")) {
low_precision_ =
static_cast<phi::DataType>(Get<int>("mixed_precision_mode"));
}

skip_pass_ = (backend_ == phi::Backend::UNDEFINED) ||
(low_precision_ == phi::DataType::UNDEFINED);

if (skip_pass_) return;

black_list_ = Get<std::unordered_set<std::string>>("mixed_black_list");
SetDefaultBlacklist();
Expand All @@ -226,8 +233,8 @@ void AutoMixedPrecisionPass::Init(Graph* graph) const {
VLOG(4) << " - " << name;
}

if (Has("keep_io_types")) {
keep_io_types_ = Get<bool>("keep_io_types");
if (Has("enable_low_precision_io")) {
enable_low_precision_io_ = Get<bool>("enable_low_precision_io");
}

auto graph_size = graph->SubGraphsSize();
Expand Down Expand Up @@ -290,8 +297,8 @@ void AutoMixedPrecisionPass::ApplyImpl(Graph* graph) const {
RestoreOpOriginType();
VLOG(4) << "RestoreOpOriginType done";
LOG(INFO) << "The number of ops run at low precision ["
<< op_run_low_precision_.size() << "/" << op_original_type_.size()
<< "]";
<< op_run_low_precision_.size() << "/"
<< op_original_type_.size() + 2 << "]";
}

void AutoMixedPrecisionPass::SetOpUniqueType() const {
Expand Down Expand Up @@ -385,61 +392,68 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {
bool support_low_precision = true;
if (GetOpOriginalType(op_type) == "feed" ||
GetOpOriginalType(op_type) == "fetch") {
support_low_precision = !keep_io_types_;
support_low_precision = enable_low_precision_io_;
} else if (GetOpOriginalType(op_type) == "tensorrt_engine") {
auto enable_fp16 = op_node->Op()->GetAttrIfExists<bool>("enable_fp16");
auto enable_int8 = op_node->Op()->GetAttrIfExists<bool>("enable_int8");
auto low_precision_io =
op_node->Op()->GetAttrIfExists<bool>("enable_low_precision_io");
support_low_precision = enable_fp16 && !enable_int8 && low_precision_io;
} else {
support_low_precision = OpSupportPrecision(
GetOpOriginalType(op_type), backend_, low_precision_, black_list_);
}

if (op_node->Op()->HasAttr("dtype")) {
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
support_low_precision =
support_low_precision &&
IsFP32AndFP64(static_cast<VarType::Type>(dtype));
} else if (op_node->Op()->HasAttr("out_dtype")) {
auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
support_low_precision =
support_low_precision &&
IsFP32AndFP64(static_cast<VarType::Type>(out_dtype));
}

// If scale op's "scale" and "bias" attr value exceed the range of fp16
// and bf16, it cannot run at low precision.
if (GetOpOriginalType(op_node->Op()->Type()) == "scale") {
auto scale = op_node->Op()->GetAttrIfExists<float>("scale");
auto bias = op_node->Op()->GetAttrIfExists<float>("bias");
if (low_precision_ == phi::DataType::FLOAT16) {
if (op_node->Op()->HasAttr("dtype")) {
auto dtype = op_node->Op()->GetAttrIfExists<int>("dtype");
support_low_precision =
support_low_precision &&
phi::dtype::isfinite(static_cast<phi::dtype::float16>(scale)) &&
phi::dtype::isfinite(static_cast<phi::dtype::float16>(bias));
} else if (low_precision_ == phi::DataType::BFLOAT16) {
IsFP32AndFP64(static_cast<VarType::Type>(dtype));
} else if (op_node->Op()->HasAttr("out_dtype")) {
auto out_dtype = op_node->Op()->GetAttrIfExists<int>("out_dtype");
support_low_precision =
support_low_precision &&
phi::dtype::isfinite(static_cast<phi::dtype::bfloat16>(scale)) &&
phi::dtype::isfinite(static_cast<phi::dtype::bfloat16>(bias));
IsFP32AndFP64(static_cast<VarType::Type>(out_dtype));
}
}

// if op's input var and output var is not dense tensor, the op should
// not run at low precision.
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
if (real_in_var_node->Var()->Persistable()) continue;
// If scale op's "scale" and "bias" attr value exceed the range of fp16
// and bf16, it cannot run at low precision.
if (GetOpOriginalType(op_node->Op()->Type()) == "scale") {
auto scale = op_node->Op()->GetAttrIfExists<float>("scale");
auto bias = op_node->Op()->GetAttrIfExists<float>("bias");
if (low_precision_ == phi::DataType::FLOAT16) {
support_low_precision =
support_low_precision &&
phi::dtype::isfinite(static_cast<phi::dtype::float16>(scale)) &&
phi::dtype::isfinite(static_cast<phi::dtype::float16>(bias));
} else if (low_precision_ == phi::DataType::BFLOAT16) {
support_low_precision =
support_low_precision &&
phi::dtype::isfinite(
static_cast<phi::dtype::bfloat16>(scale)) &&
phi::dtype::isfinite(static_cast<phi::dtype::bfloat16>(bias));
}
}

support_low_precision =
support_low_precision &&
(real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR);
}
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
if (real_out_var_node->Var()->Persistable()) continue;

support_low_precision =
support_low_precision &&
(real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR);
// if op's input var and output var is not dense tensor, the op should
// not run at low precision.
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
auto* real_in_var_node = real_vars_[in_var_node->Var()->Name()];
if (real_in_var_node->Var()->Persistable()) continue;

support_low_precision =
support_low_precision &&
(real_in_var_node->Var()->GetType() == VarType::LOD_TENSOR);
}
for (auto* out_var_node : op_node->outputs) {
CHECK_EQ(out_var_node->IsVar(), true);
auto* real_out_var_node = real_vars_[out_var_node->Var()->Name()];
if (real_out_var_node->Var()->Persistable()) continue;

support_low_precision =
support_low_precision &&
(real_out_var_node->Var()->GetType() == VarType::LOD_TENSOR);
}
}

if (support_low_precision) {
Expand Down Expand Up @@ -572,7 +586,12 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
bool AutoMixedPrecisionPass::InputVarsNotConvert(
Node* op_node, const std::string& var_name) const {
auto* op_desc = op_node->Op();
if (GetOpOriginalType(op_desc->Type()) == "batch_norm") {
if (GetOpOriginalType(op_desc->Type()) == "tensorrt_engine") {
auto vecs = op_desc->Input("Xs");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
} else if (GetOpOriginalType(op_desc->Type()) == "batch_norm") {
auto vecs = op_desc->Input("Bias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
Expand All @@ -589,6 +608,15 @@ bool AutoMixedPrecisionPass::InputVarsNotConvert(
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
} else if (GetOpOriginalType(op_desc->Type()) == "instance_norm") {
auto vecs = op_desc->Input("Bias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("Scale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
} else if (GetOpOriginalType(op_desc->Type()) == "fused_multi_transformer") {
auto vecs = op_desc->Input("LnScale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
Expand All @@ -606,6 +634,16 @@ bool AutoMixedPrecisionPass::InputVarsNotConvert(
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
} else if (GetOpOriginalType(op_desc->Type()) ==
"fused_bias_dropout_residual_layer_norm") {
auto vecs = op_desc->Input("LnScale");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
vecs = op_desc->Input("LnBias");
if (std::find(vecs.begin(), vecs.end(), var_name) != vecs.end()) {
return true;
}
}

if (backend_ == phi::Backend::XPU) {
Expand Down Expand Up @@ -805,7 +843,9 @@ void AutoMixedPrecisionPass::InsertCastOp() const {
auto op_type = op_node->Op()->Type();

if (GetOpOriginalType(op_type) == "feed") continue;
if (op_node->Op()->HasAttr("sub_block")) continue;
if (op_node->Op()->HasAttr("sub_block") &&
GetOpOriginalType(op_type) != "tensorrt_engine")
continue;

VLOG(4) << "process op: " << op_type
<< " run low precision: " << op_run_low_precision_.count(op_type);
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/ir/auto_mixed_precision_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ class AutoMixedPrecisionPass : public FusePassBase {
private:
mutable bool skip_pass_{false};

mutable bool keep_io_types_{true};
mutable bool enable_low_precision_io_{false};
// float16 or bfloat16 now
mutable phi::DataType low_precision_{phi::DataType::FLOAT16};
mutable phi::DataType low_precision_{phi::DataType::UNDEFINED};

mutable phi::Backend backend_{phi::Backend::UNDEFINED};

Expand Down
1 change: 0 additions & 1 deletion paddle/fluid/framework/ir/dense_fc_to_sparse_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"

namespace paddle {
namespace framework {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"

namespace paddle {
namespace framework {
Expand Down
14 changes: 4 additions & 10 deletions paddle/fluid/inference/analysis/argument.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"

#include "paddle/phi/common/data_type.h"

Expand Down Expand Up @@ -225,9 +224,7 @@ struct Argument {
DECL_ARGUMENT_FIELD(tensorrt_disabled_ops,
TensorRtDisabledOPs,
std::vector<std::string>);
DECL_ARGUMENT_FIELD(tensorrt_precision_mode,
TensorRtPrecisionMode,
AnalysisConfig::Precision);
DECL_ARGUMENT_FIELD(tensorrt_precision_mode, TensorRtPrecisionMode, int);
DECL_ARGUMENT_FIELD(tensorrt_use_static_engine,
TensorRtUseStaticEngine,
bool);
Expand Down Expand Up @@ -263,9 +260,7 @@ struct Argument {
DlnneDisableNodesByOutputs,
std::unordered_set<std::string>);
DECL_ARGUMENT_FIELD(dlnne_use_calib_mode, DlnneUseCalibMode, bool);
DECL_ARGUMENT_FIELD(dlnne_precision_mode,
DlnnePrecisionMode,
AnalysisConfig::Precision);
DECL_ARGUMENT_FIELD(dlnne_precision_mode, DlnnePrecisionMode, int);

using dlnne_input_shape_type = std::map<std::string, std::vector<int64_t>>;
DECL_ARGUMENT_FIELD(dlnne_input_shape_dict,
Expand All @@ -277,9 +272,7 @@ struct Argument {
LitePassesFilter,
std::vector<std::string>);
DECL_ARGUMENT_FIELD(lite_ops_filter, LiteOpsFilter, std::vector<std::string>);
DECL_ARGUMENT_FIELD(lite_precision_mode,
LitePrecisionMode,
AnalysisConfig::Precision);
DECL_ARGUMENT_FIELD(lite_precision_mode, LitePrecisionMode, int);
DECL_ARGUMENT_FIELD(lite_zero_copy, LiteZeroCopy, bool);

DECL_ARGUMENT_FIELD(use_xpu, UseXpu, bool);
Expand Down Expand Up @@ -372,6 +365,7 @@ struct Argument {
std::unordered_set<std::string>);
DECL_ARGUMENT_FIELD(enable_gpu_mixed, EnableGPUMixed, bool);
DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int);
DECL_ARGUMENT_FIELD(enable_low_precision_io, EnableLowPrecisionIO, bool);

// cinn compiler related
DECL_ARGUMENT_FIELD(use_cinn_compiler, UseCinnCompiler, bool);
Expand Down
18 changes: 10 additions & 8 deletions paddle/fluid/inference/analysis/ir_pass_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/inference/analysis/argument.h"
#include "paddle/fluid/string/pretty_log.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/errors.h"

namespace paddle {
Expand Down Expand Up @@ -60,8 +61,9 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("tensorrt_transformer_maskid",
new std::string(argument->tensorrt_transformer_maskid()));
pass->Set("disable_logs", new bool(argument->disable_logs()));
auto precision_mode = argument->tensorrt_precision_mode();
bool enable_int8 = precision_mode == AnalysisConfig::Precision::kInt8;
auto trt_precision_mode = argument->tensorrt_precision_mode();
bool enable_int8 =
trt_precision_mode == static_cast<int>(phi::DataType::INT8);
pass->Set("enable_int8", new bool(enable_int8));
pass->Set("max_input_shape",
new std::map<std::string, std::vector<int>>(
Expand Down Expand Up @@ -104,6 +106,8 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("mixed_precision_mode",
new int(argument->mixed_precision_mode()));
pass->Set("model_precision", new int(argument->model_precision()));
pass->Set("enable_low_precision_io",
new bool(argument->enable_low_precision_io()));

// "use_xpu" is used for passes in subgraphs.
pass->Set("use_xpu", new bool(argument->use_xpu()));
Expand Down Expand Up @@ -161,8 +165,7 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("predictor_id", new int(argument->predictor_id()));
bool use_calib_mode = argument->tensorrt_use_calib_mode();
pass->Set("use_calib_mode", new bool(use_calib_mode));
pass->Set("precision_mode",
new AnalysisConfig::Precision(precision_mode));
pass->Set("trt_precision_mode", new int(trt_precision_mode));
pass->Set("context_memory_sharing",
new bool(argument->trt_engine_memory_sharing()));
pass->Set("use_cuda_graph",
Expand Down Expand Up @@ -242,8 +245,7 @@ void IRPassManager::CreatePasses(Argument *argument,
new std::unordered_set<std::string>(
argument->dlnne_disable_nodes_by_outputs()));
pass->Set("use_calib_mode", new bool(argument->dlnne_use_calib_mode()));
pass->Set("precision_mode",
new AnalysisConfig::Precision(precision_mode));
pass->Set("dlnne_precision_mode", new int(precision_mode));
pass->Set("input_shape_dict",
new std::map<std::string, std::vector<int64_t>>(
argument->dlnne_input_shape_dict()));
Expand All @@ -254,8 +256,8 @@ void IRPassManager::CreatePasses(Argument *argument,
} else if (pass_name == "build_cinn_pass") {
pass->Set("is_inference_stage", new bool(argument->use_cinn_compiler()));
} else if (pass_name == "lite_subgraph_pass") {
bool lite_enable_int8 =
argument->lite_precision_mode() == AnalysisConfig::Precision::kInt8;
bool lite_enable_int8 = argument->lite_precision_mode() ==
static_cast<int>(phi::DataType::INT8);
pass->Set("program",
new framework::ProgramDesc *(&argument->main_program()));
pass->Set("lite_ops_filter",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -572,9 +572,9 @@ void DlnneSubgraphPass::CreateDlnneOp(
// is unique.
auto engine_key = GenerateEngineKey(
input_names_with_id, output_names_with_id, std::to_string(0));
auto precision_mode = Get<AnalysisConfig::Precision>("precision_mode");
auto precision_mode = Get<int>("dlnne_precision_mode");
bool enable_int8 = false;
if (precision_mode == AnalysisConfig::Precision::kInt8) {
if (precision_mode == static_cast<int>(phi::DataType::INT8)) {
enable_int8 = true;
}
auto use_calib_mode = Get<bool>("use_calib_mode");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/inference/analysis/ir_passes/subgraph_util.h"
#include "paddle/fluid/inference/api/paddle_analysis_config.h"

namespace paddle {
namespace framework {
Expand Down
Loading

0 comments on commit d1bbd90

Please sign in to comment.