This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 645fcf9f0a [Relax][ONNX] Add frontend support for QuantizeLinear, 
DequantizeLinear, and DynamicQuantizeLinear (#19391)
645fcf9f0a is described below

commit 645fcf9f0aac0c38fc9f46bcc159486cc19fb635
Author: WANG HUNG-HSIANG <[email protected]>
AuthorDate: Sun Apr 12 09:52:09 2026 +0800

    [Relax][ONNX] Add frontend support for QuantizeLinear, DequantizeLinear, 
and DynamicQuantizeLinear (#19391)
    
    ## Summary
    
    This PR adds Relax ONNX frontend support for:
    - `QuantizeLinear`
    - `DequantizeLinear`
    - `DynamicQuantizeLinear`
    
    The implementation follows existing TVM ONNX frontend patterns and keeps
    QDQ handling consistent for singleton quantization parameters and
    optional zero-point inputs.
    
    ## Changes
    
    - add ONNX frontend converters for `QuantizeLinear`,`DequantizeLinear`,
    and `DynamicQuantizeLinear`
    - register Q/DQ-related ops in the ONNX converter map
    - handle optional zero-point inputs consistently during import
    - preserve singleton quantization parameter semantics in the QDQ
    legalization path
    - improve QDQ legalization behavior for imported ONNX models
    - add and update frontend tests for Q/DQ and `DynamicQuantizeLinear`
    
    ## Tests
    
    Added or updated tests in `tests/python/relax/test_frontend_onnx.py` to
    cover:
    - singleton-qparam `QuantizeLinear` in opset 10
    - singleton-qparam `DequantizeLinear` in opset 10
    - optional-zero-point `QuantizeLinear` in opset 13
    - `DynamicQuantizeLinear` in opset 11
    
    ## Validation
    
    Validated with:
    - `python -m pytest -n 1 tests/python/relax/test_frontend_onnx.py -k
    "quantizelinear or dequantizelinear or dynamicquantizelinear" -v`
    
    Result:
    - `4 passed`
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py |  71 ++++++++++++++++
 python/tvm/relax/transform/legalize_ops/qdq.py  |  57 +++++++++++--
 src/relax/op/tensor/qdq.cc                      |  54 +++++++++---
 tests/python/relax/test_frontend_onnx.py        | 108 ++++++++++++++++++++++++
 4 files changed, 272 insertions(+), 18 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 2707f6ff1c..5397f2c309 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -311,6 +311,73 @@ class OnnxOpConverter:
             return getattr(cls, f"_impl_v{version}")
         raise NotImplementedError(f"opset version {version} of {cls.__name__} 
not implemented")
 
+class QuantizeLinear(OnnxOpConverter):
+    @classmethod
+    def _impl_v10(cls, bb, inputs, attr, params):
+        x, scale = inputs[0], inputs[1]
+        zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None
+        axis = attr.get("axis", 1)
+        if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis 
== 1:
+            axis = 0
+        out_dtype = "uint8" if zp is None else zp.struct_info.dtype
+        if zp is None:
+            zp = relax.const(0, out_dtype)
+        return relax.op.quantize(x, scale, zp, axis=axis, out_dtype=out_dtype)
+
+    @classmethod
+    def _impl_v13(cls, bb, inputs, attr, params):
+        x, scale = inputs[0], inputs[1]
+        zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None
+        axis = attr.get("axis", 1)
+        if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis 
== 1:
+            axis = 0
+        out_dtype = "uint8" if zp is None else zp.struct_info.dtype
+        if zp is None:
+            zp = relax.const(0, out_dtype)
+        return relax.op.quantize(x, scale, zp, axis=axis, out_dtype=out_dtype)
+
+
+class DequantizeLinear(OnnxOpConverter):
+    @classmethod
+    def _impl_v10(cls, bb, inputs, attr, params):
+        x, scale = inputs[0], inputs[1]
+        zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None
+        axis = attr.get("axis", 1)
+        if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis 
== 1:
+            axis = 0
+        if zp is None:
+            zp = relax.const(0, x.struct_info.dtype)
+        return relax.op.dequantize(x, scale, zp, axis=axis, 
out_dtype="float32")
+
+    @classmethod
+    def _impl_v13(cls, bb, inputs, attr, params):
+        x, scale = inputs[0], inputs[1]
+        zp = inputs[2] if len(inputs) > 2 and inputs[2] is not None else None
+        axis = attr.get("axis", 1)
+        if hasattr(x.struct_info, "ndim") and x.struct_info.ndim <= 1 and axis 
== 1:
+            axis = 0
+        if zp is None:
+            zp = relax.const(0, x.struct_info.dtype)
+        return relax.op.dequantize(x, scale, zp, axis=axis, 
out_dtype="float32")
+
+
+class DynamicQuantizeLinear(OnnxOpConverter):
+    @classmethod
+    def _impl_v11(cls, bb, inputs, attr, params):
+        x = inputs[0]
+        x_dtype = x.struct_info.dtype
+        qmin = relax.const(0, x_dtype)
+        qmax = relax.const(255, x_dtype)
+
+        x_max = relax.op.maximum(qmin, relax.op.max(x))
+        x_min = relax.op.minimum(qmin, relax.op.min(x))
+        y_scale = relax.op.divide(relax.op.subtract(x_max, x_min), qmax)
+
+        zp_fp = relax.op.subtract(qmin, relax.op.divide(x_min, y_scale))
+        y_zero_point = relax.op.astype(relax.op.round(relax.op.clip(zp_fp, 0, 
255)), "uint8")
+
+        y = relax.op.quantize(x, y_scale, y_zero_point, axis=0, 
out_dtype="uint8")
+        return relax.Tuple([y, y_scale, y_zero_point])
 
 class MatMul(OnnxOpConverter):
     """Converts an onnx MatMul node into an equivalent Relax expression."""
@@ -4812,6 +4879,10 @@ def _get_convert_map():
         "ConcatFromSequence": ConcatFromSequence,
         "SplitToSequence": SplitToSequence,
         "SequenceAt": SequenceAt,
+        # Quantization
+        "QuantizeLinear": QuantizeLinear,
+        "DequantizeLinear": DequantizeLinear,
+        "DynamicQuantizeLinear": DynamicQuantizeLinear,
     }
 
 
diff --git a/python/tvm/relax/transform/legalize_ops/qdq.py 
b/python/tvm/relax/transform/legalize_ops/qdq.py
index caec63ffa8..5e28d1b291 100644
--- a/python/tvm/relax/transform/legalize_ops/qdq.py
+++ b/python/tvm/relax/transform/legalize_ops/qdq.py
@@ -17,6 +17,7 @@
 # pylint: disable=invalid-name
 """Default legalization function for quantize/dequantize operators."""
 
+from typing import Union
 import tvm
 from tvm import te, tirx
 
@@ -35,6 +36,18 @@ def is_const_scalar(x):
     return isinstance(x, tvm.tirx.IntImm | tvm.tirx.FloatImm)
 
 
+def _is_singleton_qparam(qparam: te.Tensor) -> bool:
+    """Return True if qparam is a tensor with all dimensions equal to 1."""
+    if not isinstance(qparam, te.Tensor):
+        return False
+    if len(qparam.shape) == 0:
+        return True
+    for dim in qparam.shape:
+        if not isinstance(dim, tirx.IntImm) or dim.value != 1:
+            return False
+    return True
+
+
 @register_legalize("relax.quantize")
 def _quantize(bb: BlockBuilder, call: Call) -> Expr:
     """
@@ -46,12 +59,26 @@ def _quantize(bb: BlockBuilder, call: Call) -> Expr:
 
     def te_quantize(
         data: te.Tensor,
-        scale: te.Tensor | tirx.IntImm | tirx.FloatImm,
-        zp: te.Tensor | tirx.IntImm | tirx.FloatImm,
+        scale: Union[te.Tensor, tirx.IntImm, tirx.FloatImm],
+        zp: Union[te.Tensor, tirx.IntImm, tirx.FloatImm],
     ):
+        scale_singleton = _is_singleton_qparam(scale) if isinstance(scale, 
te.Tensor) else False
+        zp_singleton = _is_singleton_qparam(zp) if isinstance(zp, te.Tensor) 
else False
+
         def quantize_compute(*indices):
-            scale_value = scale if is_const_scalar(scale) else 
scale[indices[axis]]
-            zp_value = zp if is_const_scalar(zp) else zp[indices[axis]]
+            if is_const_scalar(scale):
+                scale_value = scale
+            elif scale_singleton:
+                scale_value = scale[(0,) * len(scale.shape)]
+            else:
+                scale_value = scale[indices[axis]]
+
+            if is_const_scalar(zp):
+                zp_value = zp
+            elif zp_singleton:
+                zp_value = zp[(0,) * len(zp.shape)]
+            else:
+                zp_value = zp[indices[axis]]
             scaled = data[indices] / scale_value
             round_val = (te.round(scaled) if "int" in out_dtype else scaled) + 
zp_value
             return clip_cast(round_val, out_dtype)
@@ -94,12 +121,26 @@ def _dequantize(bb: BlockBuilder, call: Call) -> Expr:
 
     def te_dequantize(
         data: te.Tensor,
-        scale: te.Tensor | tirx.IntImm | tirx.FloatImm,
-        zp: te.Tensor | tirx.IntImm | tirx.FloatImm,
+        scale: Union[te.Tensor, tirx.IntImm, tirx.FloatImm],
+        zp: Union[te.Tensor, tirx.IntImm, tirx.FloatImm],
     ):
+        scale_singleton = _is_singleton_qparam(scale) if isinstance(scale, 
te.Tensor) else False
+        zp_singleton = _is_singleton_qparam(zp) if isinstance(zp, te.Tensor) 
else False
+
         def dequantize_compute(*indices):
-            scale_value = scale if is_const_scalar(scale) else 
scale[indices[axis]]
-            zp_value = zp if is_const_scalar(zp) else zp[indices[axis]]
+            if is_const_scalar(scale):
+                scale_value = scale
+            elif scale_singleton:
+                scale_value = scale[(0,) * len(scale.shape)]
+            else:
+                scale_value = scale[indices[axis]]
+
+            if is_const_scalar(zp):
+                zp_value = zp
+            elif zp_singleton:
+                zp_value = zp[(0,) * len(zp.shape)]
+            else:
+                zp_value = zp[indices[axis]]
             dtype = "float32" if "float" in data.dtype else "int32"
             sub = te.subtract(data[indices].astype(dtype), zp_value)
             out = te.multiply(sub, scale_value.astype("float32"))
diff --git a/src/relax/op/tensor/qdq.cc b/src/relax/op/tensor/qdq.cc
index 406868ab4b..3a7a9f164a 100644
--- a/src/relax/op/tensor/qdq.cc
+++ b/src/relax/op/tensor/qdq.cc
@@ -79,10 +79,14 @@ StructInfo InferStructInfoQuantize(const Call& call, const 
BlockBuilder& ctx) {
   }
 
   // Check datatype of zero_point param:
-  if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != 
DataType::Float(16)) {
+  if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != 
DataType::UInt(8) &&
+      zp_sinfo->dtype != DataType::Int(16) && zp_sinfo->dtype != 
DataType::UInt(16) &&
+      zp_sinfo->dtype != DataType::Int(32) && zp_sinfo->dtype != 
DataType::UInt(32) &&
+      zp_sinfo->dtype != DataType::Float(16)) {
     ctx->ReportFatal(Diagnostic::Error(call)
-                     << "zero_point param datatype should be 'int8' or 
'float16', but got "
-                     << zp_sinfo->dtype);
+                     << "zero_point param datatype should be one of "
+                     << "['int8', 'uint8', 'int16', 'uint16', 'int32', 
'uint32', 'float16'], "
+                     << "but got " << zp_sinfo->dtype);
   }
 
   // Check that "axis" attribute is not out of range:
@@ -104,9 +108,22 @@ StructInfo InferStructInfoQuantize(const Call& call, const 
BlockBuilder& ctx) {
     }
   };
 
+  auto is_scalar_or_singleton_vector = [&](const TensorStructInfo& 
param_sinfo) {
+    if (IsScalarTensor(param_sinfo)) return true;
+    if (param_sinfo->shape.defined() && 
param_sinfo->shape->IsInstance<ShapeExprNode>()) {
+      const auto& values = param_sinfo->shape.as<ShapeExprNode>()->values;
+      if (!values.empty()) {
+        return std::all_of(values.begin(), values.end(), [&](const PrimExpr& 
dim) {
+          return ctx->GetAnalyzer()->CanProveEqual(dim, 1);
+        });
+      }
+    }
+    return false;
+  };
+
   // Check size matching of scale/zp params with input shape at dim = 
attrs->axis.
-  if (!IsScalarTensor(scale_sinfo)) check_param_size(scale_sinfo, input_sinfo, 
"scale");
-  if (!IsScalarTensor(zp_sinfo)) check_param_size(zp_sinfo, input_sinfo, 
"zero_point");
+  if (!is_scalar_or_singleton_vector(scale_sinfo)) 
check_param_size(scale_sinfo, input_sinfo, "scale");
+  if (!is_scalar_or_singleton_vector(zp_sinfo)) check_param_size(zp_sinfo, 
input_sinfo, "zero_point");
 
   auto output_sinfo = 
ffi::make_object<TensorStructInfoNode>(*input_sinfo.get());
   output_sinfo->dtype = attrs->out_dtype;
@@ -167,10 +184,14 @@ StructInfo InferStructInfoDequantize(const Call& call, 
const BlockBuilder& ctx)
   }
 
   // Check datatype of zero_point param:
-  if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != 
DataType::Float(16)) {
+  if (zp_sinfo->dtype != DataType::Int(8) && zp_sinfo->dtype != 
DataType::UInt(8) &&
+      zp_sinfo->dtype != DataType::Int(16) && zp_sinfo->dtype != 
DataType::UInt(16) &&
+      zp_sinfo->dtype != DataType::Int(32) && zp_sinfo->dtype != 
DataType::UInt(32) &&
+      zp_sinfo->dtype != DataType::Float(16)) {
     ctx->ReportFatal(Diagnostic::Error(call)
-                     << "zero_point param datatype should be 'int8' or 
'float16', but got "
-                     << zp_sinfo->dtype);
+                     << "zero_point param datatype should be one of "
+                     << "['int8', 'uint8', 'int16', 'uint16', 'int32', 
'uint32', 'float16'], "
+                     << "but got " << zp_sinfo->dtype);
   }
 
   // Check that "axis" attribute is not out of range:
@@ -192,9 +213,22 @@ StructInfo InferStructInfoDequantize(const Call& call, 
const BlockBuilder& ctx)
     }
   };
 
+  auto is_scalar_or_singleton_vector = [&](const TensorStructInfo& 
param_sinfo) {
+    if (IsScalarTensor(param_sinfo)) return true;
+    if (param_sinfo->shape.defined() && 
param_sinfo->shape->IsInstance<ShapeExprNode>()) {
+      const auto& values = param_sinfo->shape.as<ShapeExprNode>()->values;
+      if (!values.empty()) {
+        return std::all_of(values.begin(), values.end(), [&](const PrimExpr& 
dim) {
+          return ctx->GetAnalyzer()->CanProveEqual(dim, 1);
+        });
+      }
+    }
+    return false;
+  };
+
   // Check size matching of scale/zp params with input shape at dim = 
attrs->axis.
-  if (!IsScalarTensor(scale_sinfo)) check_param_size(scale_sinfo, input_sinfo, 
"scale");
-  if (!IsScalarTensor(zp_sinfo)) check_param_size(zp_sinfo, input_sinfo, 
"zero_point");
+  if (!is_scalar_or_singleton_vector(scale_sinfo)) 
check_param_size(scale_sinfo, input_sinfo, "scale");
+  if (!is_scalar_or_singleton_vector(zp_sinfo)) check_param_size(zp_sinfo, 
input_sinfo, "zero_point");
 
   auto output_sinfo = 
ffi::make_object<TensorStructInfoNode>(*input_sinfo.get());
   output_sinfo->dtype = attrs->out_dtype;
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index db7c3da25a..7e434d2659 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -5599,6 +5599,114 @@ def test_split_to_sequence_uneven_last_chunk(axis: int):
     model = helper.make_model(graph, 
producer_name="test_split_to_sequence_uneven")
     check_correctness(model)
 
+def test_quantizelinear_singleton_qparams_opset10():
+    """QuantizeLinear must treat shape-[1] scale/zp as scalar in opset10."""
+    node = helper.make_node("QuantizeLinear", ["x", "scale", "zero_point"], 
["y"])
+    graph = helper.make_graph(
+        [node],
+        "quantizelinear_singleton_qparams_opset10",
+        [helper.make_tensor_value_info("x", TensorProto.FLOAT, [4, 3, 2, 2])],
+        [helper.make_tensor_value_info("y", TensorProto.UINT8, [4, 3, 2, 2])],
+        initializer=[
+            helper.make_tensor("scale", TensorProto.FLOAT, [1], [0.03125]),
+            helper.make_tensor("zero_point", TensorProto.UINT8, [1], [127]),
+        ],
+    )
+    model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 
10)])
+
+    x = rg.standard_normal((4, 3, 2, 2)).astype("float32")
+    check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True)
+
+
+def test_dequantizelinear_singleton_qparams_opset10():
+    """DequantizeLinear must treat shape-[1] scale/zp as scalar in opset10."""
+    node = helper.make_node("DequantizeLinear", ["x", "scale", "zero_point"], 
["y"])
+    graph = helper.make_graph(
+        [node],
+        "dequantizelinear_singleton_qparams_opset10",
+        [helper.make_tensor_value_info("x", TensorProto.UINT8, [64])],
+        [helper.make_tensor_value_info("y", TensorProto.FLOAT, [64])],
+        initializer=[
+            helper.make_tensor("scale", TensorProto.FLOAT, [1], [0.125]),
+            helper.make_tensor("zero_point", TensorProto.UINT8, [1], [1]),
+        ],
+    )
+    model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 
10)])
+
+    x = rg.integers(low=0, high=255, size=(64,), dtype=np.uint8)
+    check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True)
+
+
+def test_quantizelinear_optional_zero_point_opset13():
+    """ONNX allows missing zero_point input; importer should default it to 0 
(uint8)."""
+    node = helper.make_node("QuantizeLinear", ["x", "scale"], ["y"])
+    graph = helper.make_graph(
+        [node],
+        "quantizelinear_optional_zero_point_opset13",
+        [helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 5])],
+        [helper.make_tensor_value_info("y", TensorProto.UINT8, [2, 5])],
+        initializer=[helper.make_tensor("scale", TensorProto.FLOAT, [], 
[0.2])],
+    )
+    model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 
13)])
+
+    x = rg.standard_normal((2, 5)).astype("float32")
+    check_correctness(model, inputs={"x": x}, opset=13, check_dtypes=True)
+
+
+def test_dynamicquantizelinear_opset11():
+    """DynamicQuantizeLinear returns (y, y_scale, y_zero_point) with ORT 
parity."""
+    node = helper.make_node("DynamicQuantizeLinear", ["x"], ["y", "y_scale", 
"y_zero_point"])
+    graph = helper.make_graph(
+        [node],
+        "dynamicquantizelinear_opset11",
+        [helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 3, 4])],
+        [
+            helper.make_tensor_value_info("y", TensorProto.UINT8, [2, 3, 4]),
+            helper.make_tensor_value_info("y_scale", TensorProto.FLOAT, []),
+            helper.make_tensor_value_info("y_zero_point", TensorProto.UINT8, 
[]),
+        ],
+    )
+    model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 
11)])
+
+    x = rg.standard_normal((2, 3, 4)).astype("float32")
+    check_correctness(model, inputs={"x": x}, opset=11, atol=1e-5, rtol=1e-5, 
check_dtypes=True)
+
+def test_quantizelinear_default_axis_opset10():
+    """opset10 QuantizeLinear should honor default axis=1 (not hardcode 
axis=0)."""
+    node = helper.make_node("QuantizeLinear", ["x", "scale", "zero_point"], 
["y"])
+    graph = helper.make_graph(
+        [node],
+        "quantizelinear_axis_opset10",
+        [helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 3, 4])],
+        [helper.make_tensor_value_info("y", TensorProto.UINT8, [2, 3, 4])],
+        initializer=[
+            helper.make_tensor("scale", TensorProto.FLOAT, [3], [0.05, 0.1, 
0.2]),
+            helper.make_tensor("zero_point", TensorProto.UINT8, [3], [1, 127, 
250]),
+        ],
+    )
+    model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 
10)])
+
+    x = rg.standard_normal((2, 3, 4)).astype("float32")
+    check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True)
+
+
+def test_dequantizelinear_default_axis_opset10():
+    """opset10 DequantizeLinear should honor default axis=1 (not hardcode 
axis=0)."""
+    node = helper.make_node("DequantizeLinear", ["x", "scale", "zero_point"], 
["y"])
+    graph = helper.make_graph(
+        [node],
+        "dequantizelinear_axis_opset10",
+        [helper.make_tensor_value_info("x", TensorProto.UINT8, [2, 3, 4])],
+        [helper.make_tensor_value_info("y", TensorProto.FLOAT, [2, 3, 4])],
+        initializer=[
+            helper.make_tensor("scale", TensorProto.FLOAT, [3], [0.05, 0.1, 
0.2]),
+            helper.make_tensor("zero_point", TensorProto.UINT8, [3], [1, 127, 
250]),
+        ],
+    )
+    model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 
10)])
+
+    x = rg.integers(low=0, high=255, size=(2, 3, 4), dtype=np.uint8)
+    check_correctness(model, inputs={"x": x}, opset=10, check_dtypes=True)
 
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to