This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fd9d9db130 [Relax][ONNX] Fix shape/dynamic restrictions for 
`Squeeze`/`Unsqueeze` and `Slice` (#18955)
fd9d9db130 is described below

commit fd9d9db130fb9ede300a062b95f81dd186dc5140
Author: HoYi <[email protected]>
AuthorDate: Tue Mar 31 22:33:59 2026 +0800

    [Relax][ONNX] Fix shape/dynamic restrictions for `Squeeze`/`Unsqueeze` and 
`Slice` (#18955)
    
    ## Summary
    
    Relates to #18945.
    
    This PR improves ONNX frontend handling for dynamic
    `Unsqueeze`/`Squeeze`/`Slice`, tightens validation paths, and adds
    targeted structural/negative regression tests.
    
    - Refactor constant-path `Unsqueeze` lowering to use a single `reshape`
    based on computed target shape.
    - Remove scalar-specific branching and repeated `expand_dims` in the
    constant path.
    - Add/keep structural helper usage in ONNX frontend tests for Relax
    call-op checks.
    - Add regression coverage for scalar-input `Unsqueeze`.
    
    ## Changes
    
    - Add dynamic-axes conversion paths for `Unsqueeze` and `Squeeze`:
      - infer output shape via runtime shape-tensor construction
    - lower to `relax.reshape` with validated shape rank/length assumptions
    - Improve `Slice` conversion robustness:
      - support dynamic parameter forms with stricter rank/length validation
      - reject invalid zero-step inputs when statically known
      - fix docstring wording (`Splice` -> `Slice`)
    - Strengthen ONNX frontend tests:
      - negative test for duplicate `Unsqueeze` axes
    - structural IR check for dynamic `Slice` (`relax.dynamic_strided_slice`
    present, `relax.strided_slice` absent)
      - negative test for zero-step `Slice`
    - Refactor constant-path `Unsqueeze` scalar handling:
    - replace scalar special-casing + repeated `expand_dims` with one
    target-shape `reshape`
      - add scalar-input regression test
    - Restore shared test helper used by structural Relax call-op checks.
    
    ## validation
    
    - `ruff check`: passed
    - `pre-commit --files`: passed
    - `pytest`: 8 passed
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 427 ++++++++++++++++++++----
 tests/python/relax/test_frontend_onnx.py        | 277 ++++++++++++++-
 2 files changed, 630 insertions(+), 74 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index fbbcd68bc5..21b8f22a0c 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -735,38 +735,59 @@ class Unsqueeze(OnnxOpConverter):
     def _impl_v13(cls, bb, inputs, attr, params):
         data = inputs[0]
         axes = get_constant(inputs[1], params)
+        data_ndim = _get_known_tensor_rank(data)
 
-        # Handle ONNX shape inference
         if isinstance(data, relax.PrimValue) and isinstance(axes, 
relax.Constant):
-            axes = axes.data.numpy().tolist()
-            if axes == [0]:
+            constant_axes = _normalize_constant_axes(
+                list(map(int, axes.data.numpy().tolist())), 1, "Unsqueeze"
+            )
+            if constant_axes == [0]:
                 return relax.ShapeExpr([data.value])
-            else:
-                raise NotImplementedError(
-                    "Unsqueeze with symbolic axes and non-zero axes is not 
supported."
-                )
-        # If input is a constant, compute directly
+            raise NotImplementedError("Unsqueeze with symbolic scalar inputs 
only supports axis 0.")
         if isinstance(data, relax.Constant) and isinstance(axes, 
relax.Constant):
-            axes = axes.data.numpy().tolist()
+            constant_axes = _normalize_constant_axes(
+                list(map(int, axes.data.numpy().tolist())),
+                data.data.numpy().ndim + axes.data.numpy().size,
+                "Unsqueeze",
+            )
+            constant_axes = sorted(constant_axes)
             expanded = data.data.numpy()
-            if len(expanded.shape) == 0:
-                # Special case implying input is a scalar, wrap it as a list.
-                if 0 in axes:
-                    axes.remove(0)
-                expanded = [expanded]
-            for axis in axes:
-                expanded = _np.expand_dims(expanded, axis=axis)
+            output_rank = expanded.ndim + len(constant_axes)
+            new_shape = []
+            input_dims_iter = iter(expanded.shape)
+            for i in range(output_rank):
+                if i in constant_axes:
+                    new_shape.append(1)
+                else:
+                    new_shape.append(next(input_dims_iter))
+            expanded = expanded.reshape(new_shape)
             return relax.const(expanded, data.struct_info.dtype)
 
         if isinstance(axes, relax.Constant):
-            constant_axes = list(axes.data.numpy())
-            constant_axes = list(map(int, constant_axes))
+            if data_ndim is None:
+                raise ValueError("Unsqueeze requires a statically known input 
rank.")
+            constant_axes = _normalize_constant_axes(
+                list(map(int, axes.data.numpy().tolist())),
+                data_ndim + axes.data.numpy().size,
+                "Unsqueeze",
+            )
             constant_axes = sorted(constant_axes)
             for axis in constant_axes:
                 data = relax.op.expand_dims(data, axis=axis)
             return data
 
-        raise NotImplementedError("Unsqueeze with dynamic axes is not 
supported.")
+        if data_ndim is None:
+            raise ValueError("Unsqueeze with dynamic axes requires a 
statically known input rank.")
+        axes_len = _get_known_tensor_length(axes)
+        if axes_len is None:
+            raise ValueError("Unsqueeze requires a statically known axes 
length.")
+        data_shape = bb.normalize(relax.op.shape_of(data))
+        data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape))
+        output_shape_tensor = _build_unsqueezed_shape_tensor(bb, 
data_shape_tensor, axes, data_ndim)
+        output_shape = _tensor_to_shape_expr(
+            bb, output_shape_tensor, data_ndim + axes_len, "unsqueeze_dim"
+        )
+        return relax.op.reshape(data, output_shape)
 
 
 class Concat(OnnxOpConverter):
@@ -1492,14 +1513,37 @@ class Squeeze(OnnxOpConverter):
             return relax.const(out_data, data.struct_info.dtype)
 
         if isinstance(data, relax.ShapeExpr):
-            if axis == (0,):
+            shape_tensor_ndim = 1
+            if axis is None:
+                if len(data) == 1:
+                    return relax.PrimValue(data[0])
+                return data
+            normalized_axes = _normalize_constant_axes(list(axis), 
shape_tensor_ndim, "Squeeze")
+            if normalized_axes == [0] and len(data) == 1:
                 return relax.PrimValue(data[0])
-            else:
-                raise NotImplementedError(
-                    "Squeeze with symbolic axes and non-zero axes is not 
supported."
-                )
+            raise NotImplementedError(
+                "Squeeze on symbolic shape tensors only supports removing the 
sole axis."
+            )
 
-        return relax.op.squeeze(data, axis)
+        if axis is None:
+            return relax.op.squeeze(data)
+
+        if isinstance(axis, tuple):
+            return relax.op.squeeze(data, list(axis))
+
+        data_ndim = _get_known_tensor_rank(data)
+        if data_ndim is None:
+            raise ValueError("Squeeze with dynamic axes requires a statically 
known input rank.")
+        axes_len = _get_known_tensor_length(axis)
+        if axes_len is None:
+            raise ValueError("Squeeze requires a statically known axes 
length.")
+        data_shape = bb.normalize(relax.op.shape_of(data))
+        data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape))
+        output_shape_tensor = _build_squeezed_shape_tensor(bb, 
data_shape_tensor, axis, data_ndim)
+        output_shape = _tensor_to_shape_expr(
+            bb, output_shape_tensor, data_ndim - axes_len, "squeeze_dim"
+        )
+        return relax.op.reshape(data, output_shape)
 
 
 class Constant(OnnxOpConverter):
@@ -1896,68 +1940,309 @@ def get_prim_value_list(values):
     return new_values
 
 
+def _get_known_tensor_rank(expr: relax.Expr) -> int | None:
+    """Return the statically known rank of an expression when available."""
+
+    if isinstance(expr, relax.Constant):
+        return len(expr.data.numpy().shape)
+    if isinstance(expr, relax.ShapeExpr):
+        return 1
+    if isinstance(expr, relax.PrimValue):
+        return 0
+    struct_info = expr.struct_info
+    if isinstance(struct_info, relax.TensorStructInfo):
+        return None if struct_info.ndim == -1 else struct_info.ndim
+    return None
+
+
+def _get_known_tensor_length(expr: relax.Expr | None) -> int | None:
+    """Return the statically known length of a 1-D tensor-like expression."""
+
+    if expr is None:
+        return None
+    if isinstance(expr, relax.Constant):
+        np_value = expr.data.numpy()
+        if np_value.ndim != 1:
+            raise ValueError(f"Expected a 1-D tensor, but got 
ndim={np_value.ndim}.")
+        return int(np_value.shape[0])
+    if isinstance(expr, relax.ShapeExpr):
+        return len(expr.values)
+    if isinstance(expr, relax.PrimValue):
+        return 1
+    struct_info = expr.struct_info
+    if not isinstance(struct_info, relax.TensorStructInfo):
+        return None
+    if struct_info.ndim == -1:
+        return None
+    if struct_info.ndim != 1:
+        raise ValueError(f"Expected a 1-D tensor, but got 
ndim={struct_info.ndim}.")
+    if isinstance(struct_info.shape, relax.ShapeExpr):
+        dim = struct_info.shape.values[0]
+        if isinstance(dim, tirx.IntImm):
+            return int(dim.value)
+        if isinstance(dim, int):
+            return dim
+    return None
+
+
+def _normalize_constant_axes(axes: list[int], rank: int, op_name: str) -> 
list[int]:
+    """Normalize a list of constant axes and validate their uniqueness."""
+
+    normalized_axes = []
+    for axis in axes:
+        original_axis = axis
+        if axis < 0:
+            axis += rank
+        if axis < 0 or axis >= rank:
+            raise ValueError(f"{op_name} axis {original_axis} is out of range 
for rank {rank}.")
+        normalized_axes.append(axis)
+    if len(normalized_axes) != len(set(normalized_axes)):
+        raise ValueError(f"{op_name} axes must be unique.")
+    return normalized_axes
+
+
+def _as_int64_tensor(bb: relax.BlockBuilder, expr: relax.Expr) -> relax.Expr:
+    """Convert a tensor-like expression to an int64 tensor expression."""
+
+    if isinstance(expr, relax.ShapeExpr):
+        return bb.normalize(relax.op.shape_to_tensor(expr))
+    if isinstance(expr, relax.PrimValue):
+        return bb.normalize(relax.op.full((1,), expr, dtype="int64"))
+    if isinstance(expr, relax.Constant):
+        if expr.struct_info.dtype == "int64":
+            return expr
+        return bb.normalize(relax.op.astype(expr, "int64"))
+    if isinstance(expr.struct_info, relax.TensorStructInfo) and 
expr.struct_info.dtype != "int64":
+        return bb.normalize(relax.op.astype(expr, "int64"))
+    return expr
+
+
+def _tensor_to_shape_expr(
+    bb: relax.BlockBuilder, shape_tensor: relax.Expr, shape_ndim: int, prefix: 
str
+) -> relax.ShapeExpr:
+    """Convert a statically sized int64 tensor into a ShapeExpr."""
+
+    shape_tensor = bb.match_cast(shape_tensor, 
relax.TensorStructInfo([shape_ndim], "int64"))
+    shape_dataflow_var = bb.emit(relax.op.tensor_to_shape(shape_tensor))
+    shape_vars = [tirx.Var(f"{prefix}_{i}", "int64") for i in 
range(shape_ndim)]
+    bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars))
+    return relax.ShapeExpr(shape_vars)
+
+
+def _build_unsqueezed_shape_tensor(
+    bb: relax.BlockBuilder, data_shape_tensor: relax.Expr, axes: relax.Expr, 
data_ndim: int
+) -> relax.Expr:
+    """Build the output shape tensor for Unsqueeze with runtime axes."""
+
+    axes = _as_int64_tensor(bb, axes)
+    axes_len = _get_known_tensor_length(axes)
+    if axes_len is None:
+        raise ValueError("Unsqueeze requires a statically known axes length.")
+
+    output_ndim = data_ndim + axes_len
+    axes = bb.normalize(
+        relax.op.where(
+            relax.op.less(axes, relax.const(0, "int64")),
+            relax.op.add(axes, relax.const(output_ndim, "int64")),
+            axes,
+        )
+    )
+    positions = relax.op.arange(output_ndim, dtype="int64")
+    positions = bb.normalize(relax.op.expand_dims(positions, axis=1))
+    axes = bb.normalize(relax.op.expand_dims(axes, axis=0))
+    insert_mask = bb.normalize(
+        relax.op.sum(relax.op.astype(relax.op.equal(positions, axes), 
"int64"), axis=1)
+    )
+    keep_mask = bb.normalize(relax.op.subtract(relax.const(1, "int64"), 
insert_mask))
+    input_indices = bb.normalize(
+        relax.op.subtract(relax.op.cumsum(keep_mask, axis=0), relax.const(1, 
"int64"))
+    )
+    safe_indices = bb.normalize(
+        relax.op.where(
+            relax.op.less(input_indices, relax.const(0, "int64")),
+            relax.const(0, "int64"),
+            input_indices,
+        )
+    )
+    kept_dims = bb.normalize(relax.op.take(data_shape_tensor, safe_indices, 
axis=0))
+    return bb.normalize(
+        relax.op.where(
+            relax.op.greater(insert_mask, relax.const(0, "int64")),
+            relax.const(1, "int64"),
+            kept_dims,
+        )
+    )
+
+
+def _build_squeezed_shape_tensor(
+    bb: relax.BlockBuilder, data_shape_tensor: relax.Expr, axes: relax.Expr, 
data_ndim: int
+) -> relax.Expr:
+    """Build the output shape tensor for Squeeze with runtime axes."""
+
+    axes = _as_int64_tensor(bb, axes)
+    axes = bb.normalize(
+        relax.op.where(
+            relax.op.less(axes, relax.const(0, "int64")),
+            relax.op.add(axes, relax.const(data_ndim, "int64")),
+            axes,
+        )
+    )
+    positions = relax.op.arange(data_ndim, dtype="int64")
+    positions = bb.normalize(relax.op.expand_dims(positions, axis=1))
+    axes = bb.normalize(relax.op.expand_dims(axes, axis=0))
+    remove_mask = bb.normalize(
+        relax.op.sum(relax.op.astype(relax.op.equal(positions, axes), 
"int64"), axis=1)
+    )
+    keep_mask = bb.normalize(relax.op.equal(remove_mask, relax.const(0, 
"int64")))
+    keep_indices = bb.normalize(relax.op.nonzero(keep_mask))
+    num_keep_dims = tirx.Var("squeeze_num_keep_dims", "int64")
+    keep_indices = bb.match_cast(keep_indices, relax.TensorStructInfo([1, 
num_keep_dims], "int64"))
+    keep_indices = bb.normalize(relax.op.reshape(keep_indices, [-1]))
+    return bb.normalize(relax.op.take(data_shape_tensor, keep_indices, axis=0))
+
+
 class Slice(OnnxOpConverter):
-    """Converts an onnx Splice node into an equivalent Relax expression."""
+    """Converts an onnx Slice node into an equivalent Relax expression."""
 
     @classmethod
     def _impl_v13(cls, bb, inputs, attr, params):
-        # TODO (jwfromm) currently only supports constant parameters.
         data = inputs[0]
         starts = get_constant(inputs[1], params)
         ends = get_constant(inputs[2], params)
         axes = get_constant(inputs[3], params)
         steps = get_constant(inputs[4], params)
-        if not all(
-            [
-                (
-                    isinstance(param, relax.Constant | relax.ShapeExpr | 
relax.PrimValue)
-                    or param is None
+        all_constant_params = all(
+            isinstance(param, relax.Constant | relax.ShapeExpr | 
relax.PrimValue) or param is None
+            for param in [starts, ends, axes, steps]
+        )
+        if all_constant_params:
+            starts = get_prim_expr_list(starts)
+            ends = get_prim_expr_list(ends)
+            if len(starts) != len(ends):
+                raise ValueError(
+                    f"Slice expects starts and ends to have the same length, 
but got "
+                    f"{len(starts)} and {len(ends)}."
                 )
-                for param in [starts, ends, axes, steps]
-            ]
-        ):
-            raise ValueError("Only constant Slice parameters are currently 
supported.")
-        # Convert parameters to constant lists.
-        starts = get_prim_expr_list(starts)
-        ends = get_prim_expr_list(ends)
-        if axes is not None:
-            axes = get_prim_expr_list(axes)
-        else:
-            axes = list(range(len(starts)))
-        # Convert negative axis to positive if needed.
-        for i, axis in enumerate(axes):
-            if axis < 0:
-                axes[i] = axis + len(data.struct_info.shape)
-        if steps is not None:
-            steps = get_prim_expr_list(steps)
-        else:
-            steps = [1] * len(axes)
-        # If input is a shape tensor, we can directly extract it.
-        if isinstance(data, relax.ShapeExpr):
-            shape_data = list(data)
-            # Starts, ends, and steps must be 1-d for shape operation.
-            assert all(len(i) == 1 for i in [starts, ends, steps])
-            sliced_values = shape_data[starts[0] : ends[0] : steps[0]]
-
-            if all([isinstance(val, tirx.IntImm | int) for val in 
sliced_values]):
-                return relax.const([x.value for x in sliced_values], "int64")
+            if axes is not None:
+                axes = get_prim_expr_list(axes)
+                if len(axes) != len(starts):
+                    raise ValueError(
+                        f"Slice expects axes and starts to have the same 
length, but got "
+                        f"{len(axes)} and {len(starts)}."
+                    )
             else:
+                axes = list(range(len(starts)))
+
+            data_ndim = _get_known_tensor_rank(data)
+            if data_ndim is None:
+                raise ValueError("Slice requires a statically known input 
rank.")
+            axes = _normalize_constant_axes(list(axes), data_ndim, "Slice")
+            if steps is not None:
+                steps = get_prim_expr_list(steps)
+                if len(steps) != len(starts):
+                    raise ValueError(
+                        f"Slice expects steps and starts to have the same 
length, but got "
+                        f"{len(steps)} and {len(starts)}."
+                    )
+            else:
+                steps = [1] * len(axes)
+            if any(
+                (isinstance(step, int) and step == 0)
+                or (isinstance(step, tirx.IntImm) and int(step) == 0)
+                for step in steps
+            ):
+                raise ValueError("Slice step values must be non-zero.")
+            if isinstance(data, relax.ShapeExpr):
+                shape_data = list(data)
+                assert all(len(i) == 1 for i in [starts, ends, steps])
+                sliced_values = shape_data[starts[0] : ends[0] : steps[0]]
+
+                if all([isinstance(val, tirx.IntImm | int) for val in 
sliced_values]):
+                    return relax.const([x.value for x in sliced_values], 
"int64")
                 return relax.ShapeExpr(sliced_values)
 
-        # If all `starts`, `ends`, and `steps` are constant, use strict mode
-        # Otherwise, we assume the slice is inbound.
-        assume_inbound = not all(
-            [isinstance(param, tirx.IntImm | int) for param in [*starts, 
*ends, *steps]]
-        )
+            assume_inbound = not all(
+                [isinstance(param, tirx.IntImm | int) for param in [*starts, 
*ends, *steps]]
+            )
+            starts = get_prim_value_list(starts)
+            ends = get_prim_value_list(ends)
+            steps = get_prim_value_list(steps)
 
-        # Converting PrimExpr to PrimValue since relax.op.strided_slice does 
not accept PrimExpr
-        starts = get_prim_value_list(starts)
-        ends = get_prim_value_list(ends)
-        steps = get_prim_value_list(steps)
+            return relax.op.strided_slice(
+                data, axes, starts, ends, steps, assume_inbound=assume_inbound
+            )
+
+        data_ndim = _get_known_tensor_rank(data)
+        if data_ndim is None:
+            raise ValueError(
+                "Slice with dynamic parameters requires a statically known 
input rank."
+            )
+
+        if isinstance(data, relax.ShapeExpr):
+            raise ValueError("Slice with dynamic parameters does not support 
ShapeExpr input.")
+        data_expr = data
+
+        starts_tensor = _as_int64_tensor(bb, starts)
+        ends_tensor = _as_int64_tensor(bb, ends)
+        axes_len = _get_known_tensor_length(starts_tensor)
+        if axes_len is None:
+            raise ValueError("Slice requires a statically known starts 
length.")
+        ends_len = _get_known_tensor_length(ends_tensor)
+        if ends_len is None:
+            raise ValueError("Slice requires a statically known ends length.")
+        if ends_len != axes_len:
+            raise ValueError(
+                f"Slice expects starts and ends to have the same length, but 
got "
+                f"{axes_len} and {ends_len}."
+            )
 
-        return relax.op.strided_slice(
-            data, axes, starts, ends, steps, assume_inbound=assume_inbound
+        if axes is None:
+            axes_tensor = relax.op.arange(axes_len, dtype="int64")
+        else:
+            axes_tensor = _as_int64_tensor(bb, axes)
+            axes_tensor_len = _get_known_tensor_length(axes_tensor)
+            if axes_tensor_len is None:
+                raise ValueError("Slice requires a statically known axes 
length.")
+            if axes_tensor_len != axes_len:
+                raise ValueError(
+                    f"Slice expects axes and starts to have the same length, 
but got "
+                    f"{axes_tensor_len} and {axes_len}."
+                )
+        if steps is None:
+            steps_tensor = relax.const(_np.ones((axes_len,), dtype="int64"), 
"int64")
+        else:
+            steps_tensor = _as_int64_tensor(bb, steps)
+            steps_len = _get_known_tensor_length(steps_tensor)
+            if steps_len is None:
+                raise ValueError("Slice requires a statically known steps 
length.")
+            if steps_len != axes_len:
+                raise ValueError(
+                    f"Slice expects steps and starts to have the same length, 
but got "
+                    f"{steps_len} and {axes_len}."
+                )
+            if isinstance(steps_tensor, relax.Constant) and 
_np.any(steps_tensor.data.numpy() == 0):
+                raise ValueError("Slice step values must be non-zero.")
+
+        axes_tensor = bb.normalize(
+            relax.op.where(
+                relax.op.less(axes_tensor, relax.const(0, "int64")),
+                relax.op.add(axes_tensor, relax.const(data_ndim, "int64")),
+                axes_tensor,
+            )
+        )
+
+        data_shape = bb.normalize(relax.op.shape_of(data_expr))
+        data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape))
+        full_starts = relax.const(_np.zeros((data_ndim,), dtype="int64"), 
"int64")
+        full_steps = relax.const(_np.ones((data_ndim,), dtype="int64"), 
"int64")
+        full_starts = bb.normalize(
+            relax.op.scatter_elements(full_starts, axes_tensor, starts_tensor)
+        )
+        full_ends = bb.normalize(
+            relax.op.scatter_elements(data_shape_tensor, axes_tensor, 
ends_tensor)
         )
+        full_steps = bb.normalize(relax.op.scatter_elements(full_steps, 
axes_tensor, steps_tensor))
+        return relax.op.dynamic_strided_slice(data_expr, full_starts, 
full_ends, full_steps)
 
 
 class Pad(OnnxOpConverter):
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index b68110425a..621ce43379 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -233,14 +233,14 @@ def run_in_tvm(
 
 
 def collect_relax_call_ops(func: relax.Function) -> list[str]:
-    op_names = []
+    op_names: list[str] = []
 
-    def fvisit(expr):
+    def fvisit(expr: relax.Expr) -> None:
         if isinstance(expr, relax.Call) and isinstance(expr.op, tvm.ir.Op):
             op_names.append(expr.op.name)
 
     relax.analysis.post_order_visit(func.body, fvisit)
-    return op_names
+    return list(op_names)
 
 
 def collect_scalar_constants(func: relax.Function) -> list[bool | int | float]:
@@ -1090,6 +1090,98 @@ def test_unsqueeze():
     check_correctness(model)
 
 
+def test_unsqueeze_scalar_input():
+    unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"])
+
+    graph = helper.make_graph(
+        [unsqueeze_node],
+        "unsqueeze_scalar_input",
+        inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [])],
+        initializer=[helper.make_tensor("axes", TensorProto.INT64, [2], 
vals=[0, 1])],
+        outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 
1])],
+    )
+
+    model = helper.make_model(graph, 
producer_name="unsqueeze_scalar_input_test")
+    inputs = {"a": np.array(3.0, dtype="float32")}
+    check_correctness(model, inputs, opset=13)
+
+
+def test_unsqueeze_dynamic_axes():
+    unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"])
+
+    graph = helper.make_graph(
+        [unsqueeze_node],
+        "unsqueeze_dynamic_axes",
+        inputs=[
+            helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]),
+            helper.make_tensor_value_info("axes", TensorProto.INT64, [2]),
+        ],
+        outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32, 
32, 1])],
+    )
+
+    model = helper.make_model(graph, 
producer_name="unsqueeze_dynamic_axes_test")
+    inputs = {
+        "a": rg.standard_normal(size=[32, 32]).astype("float32"),
+        "axes": np.array([-1, 0], dtype="int64"),
+    }
+    check_correctness(model, inputs, opset=13)
+
+
+def test_unsqueeze_dynamic_axes_ir():
+    unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"])
+
+    graph = helper.make_graph(
+        [unsqueeze_node],
+        "unsqueeze_dynamic_axes_ir",
+        inputs=[
+            helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]),
+            helper.make_tensor_value_info("axes", TensorProto.INT64, [2]),
+        ],
+        outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32, 
32, 1])],
+    )
+
+    model = helper.make_model(graph, 
producer_name="unsqueeze_dynamic_axes_ir_test")
+    tvm_model = from_onnx(model, opset=13, keep_params_in_input=True)
+    call_ops = collect_relax_call_ops(tvm_model["main"])
+
+    assert "relax.tensor_to_shape" in call_ops
+    assert "relax.reshape" in call_ops
+
+
+def test_unsqueeze_dynamic_axes_rank_validation():
+    unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"])
+
+    graph = helper.make_graph(
+        [unsqueeze_node],
+        "unsqueeze_dynamic_axes_rank_validation",
+        inputs=[
+            helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]),
+            helper.make_tensor_value_info("axes", TensorProto.INT64, [1, 2]),
+        ],
+        outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32, 
32, 1])],
+    )
+
+    model = helper.make_model(graph, 
producer_name="unsqueeze_dynamic_axes_rank_validation_test")
+    with pytest.raises(ValueError, match="Expected a 1-D tensor"):
+        from_onnx(model, opset=13, keep_params_in_input=True)
+
+
+def test_unsqueeze_duplicate_axes_validation():
+    unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"])
+
+    graph = helper.make_graph(
+        [unsqueeze_node],
+        "unsqueeze_duplicate_axes_validation",
+        inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 
32])],
+        initializer=[helper.make_tensor("axes", TensorProto.INT64, [2], 
vals=[0, 0])],
+        outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 1, 
32, 32])],
+    )
+
+    model = helper.make_model(graph, 
producer_name="unsqueeze_duplicate_axes_validation_test")
+    with pytest.raises(ValueError, match="axes must be unique"):
+        from_onnx(model, opset=13)
+
+
 def test_unsqueeze_v1():
     # https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-1
     unsqueeze_node = helper.make_node("Unsqueeze", ["a"], ["b"], axes=[0, 2, 
3])
@@ -1577,6 +1669,70 @@ def test_dynamic_squeeze(axis, A, B):
     check_correctness(model, inputs, opset=13)
 
 
+def test_squeeze_dynamic_axes():
+    squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"])
+    shape = [1, 32, 1, 32]
+
+    graph = helper.make_graph(
+        [squeeze_node],
+        "squeeze_dynamic_axes_test",
+        inputs=[
+            helper.make_tensor_value_info("x", TensorProto.FLOAT, shape),
+            helper.make_tensor_value_info("axes", TensorProto.INT64, [2]),
+        ],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 
32])],
+    )
+
+    model = helper.make_model(graph, producer_name="squeeze_dynamic_axes_test")
+    inputs = {
+        "x": rg.standard_normal(size=shape).astype("float32"),
+        "axes": np.array([-4, 2], dtype="int64"),
+    }
+    check_correctness(model, inputs, opset=13)
+
+
+def test_squeeze_dynamic_axes_ir():
+    squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"])
+    shape = [1, 32, 1, 32]
+
+    graph = helper.make_graph(
+        [squeeze_node],
+        "squeeze_dynamic_axes_ir",
+        inputs=[
+            helper.make_tensor_value_info("x", TensorProto.FLOAT, shape),
+            helper.make_tensor_value_info("axes", TensorProto.INT64, [2]),
+        ],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 
32])],
+    )
+
+    model = helper.make_model(graph, 
producer_name="squeeze_dynamic_axes_ir_test")
+    tvm_model = from_onnx(model, opset=13, keep_params_in_input=True)
+    call_ops = collect_relax_call_ops(tvm_model["main"])
+
+    assert "relax.tensor_to_shape" in call_ops
+    assert "relax.reshape" in call_ops
+    assert "relax.squeeze" not in call_ops
+
+
+def test_squeeze_dynamic_axes_rank_validation():
+    squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"])
+    shape = [1, 32, 1, 32]
+
+    graph = helper.make_graph(
+        [squeeze_node],
+        "squeeze_dynamic_axes_rank_validation",
+        inputs=[
+            helper.make_tensor_value_info("x", TensorProto.FLOAT, shape),
+            helper.make_tensor_value_info("axes", TensorProto.INT64, [1, 2]),
+        ],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32, 
32])],
+    )
+
+    model = helper.make_model(graph, 
producer_name="squeeze_dynamic_axes_rank_validation_test")
+    with pytest.raises(ValueError, match="Expected a 1-D tensor"):
+        from_onnx(model, opset=13, keep_params_in_input=True)
+
+
 @pytest.mark.parametrize("axis", [[0]])
 @pytest.mark.parametrize("A", [8, 16, 32])
 def test_dynamic_shape_squeeze(axis, A):
@@ -2480,6 +2636,121 @@ def test_slice():
     # )
 
 
+def test_slice_dynamic_inputs():
+    slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes", 
"steps"], ["y"])
+
+    graph = helper.make_graph(
+        [slice_node],
+        "slice_dynamic_inputs_test",
+        inputs=[
+            helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5]),
+            helper.make_tensor_value_info("starts", TensorProto.INT64, [2]),
+            helper.make_tensor_value_info("ends", TensorProto.INT64, [2]),
+            helper.make_tensor_value_info("axes", TensorProto.INT64, [2]),
+            helper.make_tensor_value_info("steps", TensorProto.INT64, [2]),
+        ],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10, 
5])],
+    )
+
+    model = helper.make_model(graph, producer_name="slice_dynamic_inputs_test")
+    inputs = {
+        "x": rg.standard_normal(size=[20, 10, 5]).astype("float32"),
+        "starts": np.array([0, 0], dtype="int64"),
+        "ends": np.array([3, 10], dtype="int64"),
+        "axes": np.array([0, 1], dtype="int64"),
+        "steps": np.array([1, 1], dtype="int64"),
+    }
+    check_correctness(model, inputs, opset=13)
+
+
+def test_slice_dynamic_inputs_ir():
+    slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes", 
"steps"], ["y"])
+
+    graph = helper.make_graph(
+        [slice_node],
+        "slice_dynamic_inputs_ir",
+        inputs=[
+            helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5]),
+            helper.make_tensor_value_info("starts", TensorProto.INT64, [2]),
+            helper.make_tensor_value_info("ends", TensorProto.INT64, [2]),
+            helper.make_tensor_value_info("axes", TensorProto.INT64, [2]),
+            helper.make_tensor_value_info("steps", TensorProto.INT64, [2]),
+        ],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10, 
5])],
+    )
+
+    model = helper.make_model(graph, 
producer_name="slice_dynamic_inputs_ir_test")
+    tvm_model = from_onnx(model, opset=13, keep_params_in_input=True)
+    call_ops = collect_relax_call_ops(tvm_model["main"])
+
+    assert "relax.dynamic_strided_slice" in call_ops
+    assert "relax.strided_slice" not in call_ops
+
+
+def test_slice_dynamic_inputs_length_validation():
+    slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes", 
"steps"], ["y"])
+
+    graph = helper.make_graph(
+        [slice_node],
+        "slice_dynamic_inputs_length_validation",
+        inputs=[
+            helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5]),
+            helper.make_tensor_value_info("starts", TensorProto.INT64, [2]),
+            helper.make_tensor_value_info("ends", TensorProto.INT64, [1]),
+            helper.make_tensor_value_info("axes", TensorProto.INT64, [2]),
+            helper.make_tensor_value_info("steps", TensorProto.INT64, [2]),
+        ],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10, 
5])],
+    )
+
+    model = helper.make_model(graph, 
producer_name="slice_dynamic_inputs_length_validation_test")
+    with pytest.raises(ValueError, match="starts and ends to have the same 
length"):
+        from_onnx(model, opset=13, keep_params_in_input=True)
+
+
+def test_slice_dynamic_shape_expr_input_validation():
+    shape_node = helper.make_node("Shape", ["x"], ["y"])
+    slice_node = helper.make_node("Slice", ["y", "starts", "ends", "axes", 
"steps"], ["z"])
+
+    graph = helper.make_graph(
+        [shape_node, slice_node],
+        "slice_dynamic_shape_expr_input_validation",
+        inputs=[
+            helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5]),
+            helper.make_tensor_value_info("starts", TensorProto.INT64, [1]),
+            helper.make_tensor_value_info("ends", TensorProto.INT64, [1]),
+            helper.make_tensor_value_info("axes", TensorProto.INT64, [1]),
+            helper.make_tensor_value_info("steps", TensorProto.INT64, [1]),
+        ],
+        outputs=[helper.make_tensor_value_info("z", TensorProto.INT64, [1])],
+    )
+
+    model = helper.make_model(graph, 
producer_name="slice_dynamic_shape_expr_input_validation_test")
+    with pytest.raises(ValueError, match="does not support ShapeExpr input"):
+        from_onnx(model, opset=13, keep_params_in_input=True)
+
+
+def test_slice_zero_step_validation():
+    slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes", 
"steps"], ["y"])
+
+    graph = helper.make_graph(
+        [slice_node],
+        "slice_zero_step_validation",
+        inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 
5])],
+        initializer=[
+            helper.make_tensor("starts", TensorProto.INT64, [2], vals=[0, 0]),
+            helper.make_tensor("ends", TensorProto.INT64, [2], vals=[3, 10]),
+            helper.make_tensor("axes", TensorProto.INT64, [2], vals=[0, 1]),
+            helper.make_tensor("steps", TensorProto.INT64, [2], vals=[1, 0]),
+        ],
+        outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10, 
5])],
+    )
+
+    model = helper.make_model(graph, 
producer_name="slice_zero_step_validation_test")
+    with pytest.raises(ValueError, match="step values must be non-zero"):
+        from_onnx(model, opset=13)
+
+
 def test_slice_dynamic_shape():
     def verify_slice(
         data_shape, data_instance_shape, output_shape, starts, ends, 
axes=None, steps=None


Reply via email to