Skip to content

Commit

Permalink
Merge pull request onnx#1081 from peri044/qdq_per_channel
Browse files Browse the repository at this point in the history
Add functionality for QDQ per channel
  • Loading branch information
guschmue committed Aug 31, 2020
2 parents 6ec695b + b18e633 commit 0b15fe1
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 17 deletions.
13 changes: 13 additions & 0 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2044,6 +2044,19 @@ def func(x):
return tf.identity(x_, name=_TFOUTPUT)
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@check_tf_min_version("2.0")
@check_opset_min_version(13, "quantize_and_dequantize")
def test_qdq_per_channel_signed_input(self):
x_shape = [3, 3, 2]
x_val = np.arange(-np.prod(x_shape)/2, np.prod(x_shape)/2).astype("float32").reshape(x_shape)
def func(x):
x_ = quantize_and_dequantize(x, np.array([-1.72, -3.89]).astype(np.float32), \
np.array([5.12, 2.36]).astype(np.float32), \
signed_input=True, narrow_range=False, \
range_given=True, axis=-1)
return tf.identity(x_, name=_TFOUTPUT)
_ = self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})

@skip_caffe2_backend()
@check_opset_min_version(7, "resize_nearest_neighbor")
def test_resize_nearest_neighbor(self):
Expand Down
52 changes: 35 additions & 17 deletions tf2onnx/rewriter/quantization_ops_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the MIT license.

"""
tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV3 op
tf2onnx.rewriter - rewrite tensorflow QuantizeAndDequantizeV2|QuantizeAndDequantizeV3 op
"""

import numpy as np
Expand Down Expand Up @@ -32,47 +32,65 @@ def create_qdq_nodes(g, match_results):
if not signed_input:
min_quantized, max_quantized = [0, 255]

# Get axis attribute for per channel implementation.
if 'axis' in qdq_node.attr:
axis = qdq_node.attr['axis'].i

# Get the min and max value of the inputs to QDQ op
min_value = extract_numpy_array(qdq_node.inputs[1])
max_value = extract_numpy_array(qdq_node.inputs[2])

# Calculate scales from the min and max values
scale_from_min_side = min_quantized/min_value if min_quantized*min_value > 0 else max_quantized
scale_from_max_side = max_quantized/max_value if max_quantized*max_value > 0 else max_quantized

if scale_from_min_side < scale_from_max_side:
scale = scale_from_min_side
else:
scale = scale_from_max_side

utils.make_sure(scale > 0, "Quantize/Dequantize scale must be greater than zero")

if signed_input:
zero_point = np.int8(0)
num_channels = min_value.shape[0]
scales = np.zeros(num_channels, dtype=np.float32)
zero_point_dtype = np.int8 if signed_input else np.uint8
zero_point = np.zeros(num_channels, dtype=zero_point_dtype)

for i in range(num_channels):
# Calculate scales from the min and max values
scale_from_min_side = min_quantized/min_value[i] if min_quantized*min_value[i] > 0 else max_quantized
scale_from_max_side = max_quantized/max_value[i] if max_quantized*max_value[i] > 0 else max_quantized

if scale_from_min_side < scale_from_max_side:
scale = scale_from_min_side
else:
scale = scale_from_max_side

utils.make_sure(scale > 0, "Quantize/Dequantize scale must be greater than zero")
scales[i] = np.float32(scale)

# Set scalars for scale and zero point for per layer quantization
if num_channels == 1:
scales = scales[0]
zero_point = zero_point[0]
attrs = {}
else:
zero_point = np.uint8(0)
utils.make_sure(axis, "Axis must be specified for per channel quantization")
attrs = {'axis': axis}

# Split it into QuantizeLinear and DequantizeLinear and remove the QDQ node reference
y_quant_scale = g.make_const(name=utils.make_name("y_quant_scale"), np_val=1/scale)
inverse_scale = (1/scales).astype(np.float32)
y_quant_scale = g.make_const(name=utils.make_name("y_quant_scale"), np_val=inverse_scale)
y_zero_point = g.make_const(name=utils.make_name("y_zero_point"), np_val=zero_point)
quant_node = g.make_node(op_type="QuantizeLinear",
inputs=[qdq_node.input[0], y_quant_scale.output[0],
y_zero_point.output[0]],
shapes=[qdq_node_output_shape],
attr=attrs,
dtypes=[qdq_node_output_dtype],
name=utils.make_name("QuantLinearNode"))

g.set_shape(quant_node.output[0], qdq_node_output_shape)

g.remove_node(qdq_node.name)

y_dequant_scale = g.make_const(name=utils.make_name("y_dequant_scale"), np_val=1/scale)
y_dequant_scale = g.make_const(name=utils.make_name("y_dequant_scale"), np_val=inverse_scale)
y_inv_zero_point = g.make_const(name=utils.make_name("y_inv_zero_point"), np_val=zero_point)
dequant_node = g.make_node(op_type="DequantizeLinear",
inputs=[quant_node.output[0], y_dequant_scale.output[0],
y_inv_zero_point.output[0]],
outputs=[qdq_node.output[0]],
shapes=[qdq_node_output_shape],
attr=attrs,
dtypes=[qdq_node_output_dtype],
name=utils.make_name("DequantLinearNode"))
g.set_shape(dequant_node.output[0], qdq_node_output_shape)
Expand Down

0 comments on commit 0b15fe1

Please sign in to comment.