This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fd9d9db130 [Relax][ONNX] Fix shape/dynamic restrictions for
`Squeeze`/`Unsqueeze` and `Slice` (#18955)
fd9d9db130 is described below
commit fd9d9db130fb9ede300a062b95f81dd186dc5140
Author: HoYi <[email protected]>
AuthorDate: Tue Mar 31 22:33:59 2026 +0800
[Relax][ONNX] Fix shape/dynamic restrictions for `Squeeze`/`Unsqueeze` and
`Slice` (#18955)
## Summary
Relates to #18945.
This PR improves ONNX frontend handling for dynamic
`Unsqueeze`/`Squeeze`/`Slice`, tightens validation paths, and adds
targeted structural/negative regression tests.
- Refactor constant-path `Unsqueeze` lowering to use a single `reshape`
based on computed target shape.
- Remove scalar-specific branching and repeated `expand_dims` in the
constant path.
- Add/keep structural helper usage in ONNX frontend tests for Relax
call-op checks.
- Add regression coverage for scalar-input `Unsqueeze`.
## Changes
- Add dynamic-axes conversion paths for `Unsqueeze` and `Squeeze`:
- infer output shape via runtime shape-tensor construction
- lower to `relax.reshape` with validated shape rank/length assumptions
- Improve `Slice` conversion robustness:
- support dynamic parameter forms with stricter rank/length validation
- reject invalid zero-step inputs when statically known
- fix docstring wording (`Splice` -> `Slice`)
- Strengthen ONNX frontend tests:
- negative test for duplicate `Unsqueeze` axes
- structural IR check for dynamic `Slice` (`relax.dynamic_strided_slice`
present, `relax.strided_slice` absent)
- negative test for zero-step `Slice`
- Refactor constant-path `Unsqueeze` scalar handling:
- replace scalar special-casing + repeated `expand_dims` with one
target-shape `reshape`
- add scalar-input regression test
- Restore shared test helper used by structural Relax call-op checks.
## validation
- `ruff check`: passed
- `pre-commit --files`: passed
- `pytest`: 8 passed
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 427 ++++++++++++++++++++----
tests/python/relax/test_frontend_onnx.py | 277 ++++++++++++++-
2 files changed, 630 insertions(+), 74 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index fbbcd68bc5..21b8f22a0c 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -735,38 +735,59 @@ class Unsqueeze(OnnxOpConverter):
def _impl_v13(cls, bb, inputs, attr, params):
data = inputs[0]
axes = get_constant(inputs[1], params)
+ data_ndim = _get_known_tensor_rank(data)
- # Handle ONNX shape inference
if isinstance(data, relax.PrimValue) and isinstance(axes,
relax.Constant):
- axes = axes.data.numpy().tolist()
- if axes == [0]:
+ constant_axes = _normalize_constant_axes(
+ list(map(int, axes.data.numpy().tolist())), 1, "Unsqueeze"
+ )
+ if constant_axes == [0]:
return relax.ShapeExpr([data.value])
- else:
- raise NotImplementedError(
- "Unsqueeze with symbolic axes and non-zero axes is not
supported."
- )
- # If input is a constant, compute directly
+ raise NotImplementedError("Unsqueeze with symbolic scalar inputs
only supports axis 0.")
if isinstance(data, relax.Constant) and isinstance(axes,
relax.Constant):
- axes = axes.data.numpy().tolist()
+ constant_axes = _normalize_constant_axes(
+ list(map(int, axes.data.numpy().tolist())),
+ data.data.numpy().ndim + axes.data.numpy().size,
+ "Unsqueeze",
+ )
+ constant_axes = sorted(constant_axes)
expanded = data.data.numpy()
- if len(expanded.shape) == 0:
- # Special case implying input is a scalar, wrap it as a list.
- if 0 in axes:
- axes.remove(0)
- expanded = [expanded]
- for axis in axes:
- expanded = _np.expand_dims(expanded, axis=axis)
+ output_rank = expanded.ndim + len(constant_axes)
+ new_shape = []
+ input_dims_iter = iter(expanded.shape)
+ for i in range(output_rank):
+ if i in constant_axes:
+ new_shape.append(1)
+ else:
+ new_shape.append(next(input_dims_iter))
+ expanded = expanded.reshape(new_shape)
return relax.const(expanded, data.struct_info.dtype)
if isinstance(axes, relax.Constant):
- constant_axes = list(axes.data.numpy())
- constant_axes = list(map(int, constant_axes))
+ if data_ndim is None:
+ raise ValueError("Unsqueeze requires a statically known input
rank.")
+ constant_axes = _normalize_constant_axes(
+ list(map(int, axes.data.numpy().tolist())),
+ data_ndim + axes.data.numpy().size,
+ "Unsqueeze",
+ )
constant_axes = sorted(constant_axes)
for axis in constant_axes:
data = relax.op.expand_dims(data, axis=axis)
return data
- raise NotImplementedError("Unsqueeze with dynamic axes is not
supported.")
+ if data_ndim is None:
+ raise ValueError("Unsqueeze with dynamic axes requires a
statically known input rank.")
+ axes_len = _get_known_tensor_length(axes)
+ if axes_len is None:
+ raise ValueError("Unsqueeze requires a statically known axes
length.")
+ data_shape = bb.normalize(relax.op.shape_of(data))
+ data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape))
+ output_shape_tensor = _build_unsqueezed_shape_tensor(bb,
data_shape_tensor, axes, data_ndim)
+ output_shape = _tensor_to_shape_expr(
+ bb, output_shape_tensor, data_ndim + axes_len, "unsqueeze_dim"
+ )
+ return relax.op.reshape(data, output_shape)
class Concat(OnnxOpConverter):
@@ -1492,14 +1513,37 @@ class Squeeze(OnnxOpConverter):
return relax.const(out_data, data.struct_info.dtype)
if isinstance(data, relax.ShapeExpr):
- if axis == (0,):
+ shape_tensor_ndim = 1
+ if axis is None:
+ if len(data) == 1:
+ return relax.PrimValue(data[0])
+ return data
+ normalized_axes = _normalize_constant_axes(list(axis),
shape_tensor_ndim, "Squeeze")
+ if normalized_axes == [0] and len(data) == 1:
return relax.PrimValue(data[0])
- else:
- raise NotImplementedError(
- "Squeeze with symbolic axes and non-zero axes is not
supported."
- )
+ raise NotImplementedError(
+ "Squeeze on symbolic shape tensors only supports removing the
sole axis."
+ )
- return relax.op.squeeze(data, axis)
+ if axis is None:
+ return relax.op.squeeze(data)
+
+ if isinstance(axis, tuple):
+ return relax.op.squeeze(data, list(axis))
+
+ data_ndim = _get_known_tensor_rank(data)
+ if data_ndim is None:
+ raise ValueError("Squeeze with dynamic axes requires a statically
known input rank.")
+ axes_len = _get_known_tensor_length(axis)
+ if axes_len is None:
+ raise ValueError("Squeeze requires a statically known axes
length.")
+ data_shape = bb.normalize(relax.op.shape_of(data))
+ data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape))
+ output_shape_tensor = _build_squeezed_shape_tensor(bb,
data_shape_tensor, axis, data_ndim)
+ output_shape = _tensor_to_shape_expr(
+ bb, output_shape_tensor, data_ndim - axes_len, "squeeze_dim"
+ )
+ return relax.op.reshape(data, output_shape)
class Constant(OnnxOpConverter):
@@ -1896,68 +1940,309 @@ def get_prim_value_list(values):
return new_values
+def _get_known_tensor_rank(expr: relax.Expr) -> int | None:
+ """Return the statically known rank of an expression when available."""
+
+ if isinstance(expr, relax.Constant):
+ return len(expr.data.numpy().shape)
+ if isinstance(expr, relax.ShapeExpr):
+ return 1
+ if isinstance(expr, relax.PrimValue):
+ return 0
+ struct_info = expr.struct_info
+ if isinstance(struct_info, relax.TensorStructInfo):
+ return None if struct_info.ndim == -1 else struct_info.ndim
+ return None
+
+
+def _get_known_tensor_length(expr: relax.Expr | None) -> int | None:
+ """Return the statically known length of a 1-D tensor-like expression."""
+
+ if expr is None:
+ return None
+ if isinstance(expr, relax.Constant):
+ np_value = expr.data.numpy()
+ if np_value.ndim != 1:
+ raise ValueError(f"Expected a 1-D tensor, but got
ndim={np_value.ndim}.")
+ return int(np_value.shape[0])
+ if isinstance(expr, relax.ShapeExpr):
+ return len(expr.values)
+ if isinstance(expr, relax.PrimValue):
+ return 1
+ struct_info = expr.struct_info
+ if not isinstance(struct_info, relax.TensorStructInfo):
+ return None
+ if struct_info.ndim == -1:
+ return None
+ if struct_info.ndim != 1:
+ raise ValueError(f"Expected a 1-D tensor, but got
ndim={struct_info.ndim}.")
+ if isinstance(struct_info.shape, relax.ShapeExpr):
+ dim = struct_info.shape.values[0]
+ if isinstance(dim, tirx.IntImm):
+ return int(dim.value)
+ if isinstance(dim, int):
+ return dim
+ return None
+
+
+def _normalize_constant_axes(axes: list[int], rank: int, op_name: str) ->
list[int]:
+ """Normalize a list of constant axes and validate their uniqueness."""
+
+ normalized_axes = []
+ for axis in axes:
+ original_axis = axis
+ if axis < 0:
+ axis += rank
+ if axis < 0 or axis >= rank:
+ raise ValueError(f"{op_name} axis {original_axis} is out of range
for rank {rank}.")
+ normalized_axes.append(axis)
+ if len(normalized_axes) != len(set(normalized_axes)):
+ raise ValueError(f"{op_name} axes must be unique.")
+ return normalized_axes
+
+
+def _as_int64_tensor(bb: relax.BlockBuilder, expr: relax.Expr) -> relax.Expr:
+ """Convert a tensor-like expression to an int64 tensor expression."""
+
+ if isinstance(expr, relax.ShapeExpr):
+ return bb.normalize(relax.op.shape_to_tensor(expr))
+ if isinstance(expr, relax.PrimValue):
+ return bb.normalize(relax.op.full((1,), expr, dtype="int64"))
+ if isinstance(expr, relax.Constant):
+ if expr.struct_info.dtype == "int64":
+ return expr
+ return bb.normalize(relax.op.astype(expr, "int64"))
+ if isinstance(expr.struct_info, relax.TensorStructInfo) and
expr.struct_info.dtype != "int64":
+ return bb.normalize(relax.op.astype(expr, "int64"))
+ return expr
+
+
+def _tensor_to_shape_expr(
+ bb: relax.BlockBuilder, shape_tensor: relax.Expr, shape_ndim: int, prefix:
str
+) -> relax.ShapeExpr:
+ """Convert a statically sized int64 tensor into a ShapeExpr."""
+
+ shape_tensor = bb.match_cast(shape_tensor,
relax.TensorStructInfo([shape_ndim], "int64"))
+ shape_dataflow_var = bb.emit(relax.op.tensor_to_shape(shape_tensor))
+ shape_vars = [tirx.Var(f"{prefix}_{i}", "int64") for i in
range(shape_ndim)]
+ bb.match_cast(shape_dataflow_var, relax.ShapeStructInfo(shape_vars))
+ return relax.ShapeExpr(shape_vars)
+
+
+def _build_unsqueezed_shape_tensor(
+ bb: relax.BlockBuilder, data_shape_tensor: relax.Expr, axes: relax.Expr,
data_ndim: int
+) -> relax.Expr:
+ """Build the output shape tensor for Unsqueeze with runtime axes."""
+
+ axes = _as_int64_tensor(bb, axes)
+ axes_len = _get_known_tensor_length(axes)
+ if axes_len is None:
+ raise ValueError("Unsqueeze requires a statically known axes length.")
+
+ output_ndim = data_ndim + axes_len
+ axes = bb.normalize(
+ relax.op.where(
+ relax.op.less(axes, relax.const(0, "int64")),
+ relax.op.add(axes, relax.const(output_ndim, "int64")),
+ axes,
+ )
+ )
+ positions = relax.op.arange(output_ndim, dtype="int64")
+ positions = bb.normalize(relax.op.expand_dims(positions, axis=1))
+ axes = bb.normalize(relax.op.expand_dims(axes, axis=0))
+ insert_mask = bb.normalize(
+ relax.op.sum(relax.op.astype(relax.op.equal(positions, axes),
"int64"), axis=1)
+ )
+ keep_mask = bb.normalize(relax.op.subtract(relax.const(1, "int64"),
insert_mask))
+ input_indices = bb.normalize(
+ relax.op.subtract(relax.op.cumsum(keep_mask, axis=0), relax.const(1,
"int64"))
+ )
+ safe_indices = bb.normalize(
+ relax.op.where(
+ relax.op.less(input_indices, relax.const(0, "int64")),
+ relax.const(0, "int64"),
+ input_indices,
+ )
+ )
+ kept_dims = bb.normalize(relax.op.take(data_shape_tensor, safe_indices,
axis=0))
+ return bb.normalize(
+ relax.op.where(
+ relax.op.greater(insert_mask, relax.const(0, "int64")),
+ relax.const(1, "int64"),
+ kept_dims,
+ )
+ )
+
+
+def _build_squeezed_shape_tensor(
+ bb: relax.BlockBuilder, data_shape_tensor: relax.Expr, axes: relax.Expr,
data_ndim: int
+) -> relax.Expr:
+ """Build the output shape tensor for Squeeze with runtime axes."""
+
+ axes = _as_int64_tensor(bb, axes)
+ axes = bb.normalize(
+ relax.op.where(
+ relax.op.less(axes, relax.const(0, "int64")),
+ relax.op.add(axes, relax.const(data_ndim, "int64")),
+ axes,
+ )
+ )
+ positions = relax.op.arange(data_ndim, dtype="int64")
+ positions = bb.normalize(relax.op.expand_dims(positions, axis=1))
+ axes = bb.normalize(relax.op.expand_dims(axes, axis=0))
+ remove_mask = bb.normalize(
+ relax.op.sum(relax.op.astype(relax.op.equal(positions, axes),
"int64"), axis=1)
+ )
+ keep_mask = bb.normalize(relax.op.equal(remove_mask, relax.const(0,
"int64")))
+ keep_indices = bb.normalize(relax.op.nonzero(keep_mask))
+ num_keep_dims = tirx.Var("squeeze_num_keep_dims", "int64")
+ keep_indices = bb.match_cast(keep_indices, relax.TensorStructInfo([1,
num_keep_dims], "int64"))
+ keep_indices = bb.normalize(relax.op.reshape(keep_indices, [-1]))
+ return bb.normalize(relax.op.take(data_shape_tensor, keep_indices, axis=0))
+
+
class Slice(OnnxOpConverter):
- """Converts an onnx Splice node into an equivalent Relax expression."""
+ """Converts an onnx Slice node into an equivalent Relax expression."""
@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
- # TODO (jwfromm) currently only supports constant parameters.
data = inputs[0]
starts = get_constant(inputs[1], params)
ends = get_constant(inputs[2], params)
axes = get_constant(inputs[3], params)
steps = get_constant(inputs[4], params)
- if not all(
- [
- (
- isinstance(param, relax.Constant | relax.ShapeExpr |
relax.PrimValue)
- or param is None
+ all_constant_params = all(
+ isinstance(param, relax.Constant | relax.ShapeExpr |
relax.PrimValue) or param is None
+ for param in [starts, ends, axes, steps]
+ )
+ if all_constant_params:
+ starts = get_prim_expr_list(starts)
+ ends = get_prim_expr_list(ends)
+ if len(starts) != len(ends):
+ raise ValueError(
+ f"Slice expects starts and ends to have the same length,
but got "
+ f"{len(starts)} and {len(ends)}."
)
- for param in [starts, ends, axes, steps]
- ]
- ):
- raise ValueError("Only constant Slice parameters are currently
supported.")
- # Convert parameters to constant lists.
- starts = get_prim_expr_list(starts)
- ends = get_prim_expr_list(ends)
- if axes is not None:
- axes = get_prim_expr_list(axes)
- else:
- axes = list(range(len(starts)))
- # Convert negative axis to positive if needed.
- for i, axis in enumerate(axes):
- if axis < 0:
- axes[i] = axis + len(data.struct_info.shape)
- if steps is not None:
- steps = get_prim_expr_list(steps)
- else:
- steps = [1] * len(axes)
- # If input is a shape tensor, we can directly extract it.
- if isinstance(data, relax.ShapeExpr):
- shape_data = list(data)
- # Starts, ends, and steps must be 1-d for shape operation.
- assert all(len(i) == 1 for i in [starts, ends, steps])
- sliced_values = shape_data[starts[0] : ends[0] : steps[0]]
-
- if all([isinstance(val, tirx.IntImm | int) for val in
sliced_values]):
- return relax.const([x.value for x in sliced_values], "int64")
+ if axes is not None:
+ axes = get_prim_expr_list(axes)
+ if len(axes) != len(starts):
+ raise ValueError(
+ f"Slice expects axes and starts to have the same
length, but got "
+ f"{len(axes)} and {len(starts)}."
+ )
else:
+ axes = list(range(len(starts)))
+
+ data_ndim = _get_known_tensor_rank(data)
+ if data_ndim is None:
+ raise ValueError("Slice requires a statically known input
rank.")
+ axes = _normalize_constant_axes(list(axes), data_ndim, "Slice")
+ if steps is not None:
+ steps = get_prim_expr_list(steps)
+ if len(steps) != len(starts):
+ raise ValueError(
+ f"Slice expects steps and starts to have the same
length, but got "
+ f"{len(steps)} and {len(starts)}."
+ )
+ else:
+ steps = [1] * len(axes)
+ if any(
+ (isinstance(step, int) and step == 0)
+ or (isinstance(step, tirx.IntImm) and int(step) == 0)
+ for step in steps
+ ):
+ raise ValueError("Slice step values must be non-zero.")
+ if isinstance(data, relax.ShapeExpr):
+ shape_data = list(data)
+ assert all(len(i) == 1 for i in [starts, ends, steps])
+ sliced_values = shape_data[starts[0] : ends[0] : steps[0]]
+
+ if all([isinstance(val, tirx.IntImm | int) for val in
sliced_values]):
+ return relax.const([x.value for x in sliced_values],
"int64")
return relax.ShapeExpr(sliced_values)
- # If all `starts`, `ends`, and `steps` are constant, use strict mode
- # Otherwise, we assume the slice is inbound.
- assume_inbound = not all(
- [isinstance(param, tirx.IntImm | int) for param in [*starts,
*ends, *steps]]
- )
+ assume_inbound = not all(
+ [isinstance(param, tirx.IntImm | int) for param in [*starts,
*ends, *steps]]
+ )
+ starts = get_prim_value_list(starts)
+ ends = get_prim_value_list(ends)
+ steps = get_prim_value_list(steps)
- # Converting PrimExpr to PrimValue since relax.op.strided_slice does
not accept PrimExpr
- starts = get_prim_value_list(starts)
- ends = get_prim_value_list(ends)
- steps = get_prim_value_list(steps)
+ return relax.op.strided_slice(
+ data, axes, starts, ends, steps, assume_inbound=assume_inbound
+ )
+
+ data_ndim = _get_known_tensor_rank(data)
+ if data_ndim is None:
+ raise ValueError(
+ "Slice with dynamic parameters requires a statically known
input rank."
+ )
+
+ if isinstance(data, relax.ShapeExpr):
+ raise ValueError("Slice with dynamic parameters does not support
ShapeExpr input.")
+ data_expr = data
+
+ starts_tensor = _as_int64_tensor(bb, starts)
+ ends_tensor = _as_int64_tensor(bb, ends)
+ axes_len = _get_known_tensor_length(starts_tensor)
+ if axes_len is None:
+ raise ValueError("Slice requires a statically known starts
length.")
+ ends_len = _get_known_tensor_length(ends_tensor)
+ if ends_len is None:
+ raise ValueError("Slice requires a statically known ends length.")
+ if ends_len != axes_len:
+ raise ValueError(
+ f"Slice expects starts and ends to have the same length, but
got "
+ f"{axes_len} and {ends_len}."
+ )
- return relax.op.strided_slice(
- data, axes, starts, ends, steps, assume_inbound=assume_inbound
+ if axes is None:
+ axes_tensor = relax.op.arange(axes_len, dtype="int64")
+ else:
+ axes_tensor = _as_int64_tensor(bb, axes)
+ axes_tensor_len = _get_known_tensor_length(axes_tensor)
+ if axes_tensor_len is None:
+ raise ValueError("Slice requires a statically known axes
length.")
+ if axes_tensor_len != axes_len:
+ raise ValueError(
+ f"Slice expects axes and starts to have the same length,
but got "
+ f"{axes_tensor_len} and {axes_len}."
+ )
+ if steps is None:
+ steps_tensor = relax.const(_np.ones((axes_len,), dtype="int64"),
"int64")
+ else:
+ steps_tensor = _as_int64_tensor(bb, steps)
+ steps_len = _get_known_tensor_length(steps_tensor)
+ if steps_len is None:
+ raise ValueError("Slice requires a statically known steps
length.")
+ if steps_len != axes_len:
+ raise ValueError(
+ f"Slice expects steps and starts to have the same length,
but got "
+ f"{steps_len} and {axes_len}."
+ )
+ if isinstance(steps_tensor, relax.Constant) and
_np.any(steps_tensor.data.numpy() == 0):
+ raise ValueError("Slice step values must be non-zero.")
+
+ axes_tensor = bb.normalize(
+ relax.op.where(
+ relax.op.less(axes_tensor, relax.const(0, "int64")),
+ relax.op.add(axes_tensor, relax.const(data_ndim, "int64")),
+ axes_tensor,
+ )
+ )
+
+ data_shape = bb.normalize(relax.op.shape_of(data_expr))
+ data_shape_tensor = bb.normalize(relax.op.shape_to_tensor(data_shape))
+ full_starts = relax.const(_np.zeros((data_ndim,), dtype="int64"),
"int64")
+ full_steps = relax.const(_np.ones((data_ndim,), dtype="int64"),
"int64")
+ full_starts = bb.normalize(
+ relax.op.scatter_elements(full_starts, axes_tensor, starts_tensor)
+ )
+ full_ends = bb.normalize(
+ relax.op.scatter_elements(data_shape_tensor, axes_tensor,
ends_tensor)
)
+ full_steps = bb.normalize(relax.op.scatter_elements(full_steps,
axes_tensor, steps_tensor))
+ return relax.op.dynamic_strided_slice(data_expr, full_starts,
full_ends, full_steps)
class Pad(OnnxOpConverter):
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index b68110425a..621ce43379 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -233,14 +233,14 @@ def run_in_tvm(
def collect_relax_call_ops(func: relax.Function) -> list[str]:
- op_names = []
+ op_names: list[str] = []
- def fvisit(expr):
+ def fvisit(expr: relax.Expr) -> None:
if isinstance(expr, relax.Call) and isinstance(expr.op, tvm.ir.Op):
op_names.append(expr.op.name)
relax.analysis.post_order_visit(func.body, fvisit)
- return op_names
+ return list(op_names)
def collect_scalar_constants(func: relax.Function) -> list[bool | int | float]:
@@ -1090,6 +1090,98 @@ def test_unsqueeze():
check_correctness(model)
+def test_unsqueeze_scalar_input():
+ unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"])
+
+ graph = helper.make_graph(
+ [unsqueeze_node],
+ "unsqueeze_scalar_input",
+ inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [])],
+ initializer=[helper.make_tensor("axes", TensorProto.INT64, [2],
vals=[0, 1])],
+ outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1,
1])],
+ )
+
+ model = helper.make_model(graph,
producer_name="unsqueeze_scalar_input_test")
+ inputs = {"a": np.array(3.0, dtype="float32")}
+ check_correctness(model, inputs, opset=13)
+
+
+def test_unsqueeze_dynamic_axes():
+ unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"])
+
+ graph = helper.make_graph(
+ [unsqueeze_node],
+ "unsqueeze_dynamic_axes",
+ inputs=[
+ helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]),
+ helper.make_tensor_value_info("axes", TensorProto.INT64, [2]),
+ ],
+ outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32,
32, 1])],
+ )
+
+ model = helper.make_model(graph,
producer_name="unsqueeze_dynamic_axes_test")
+ inputs = {
+ "a": rg.standard_normal(size=[32, 32]).astype("float32"),
+ "axes": np.array([-1, 0], dtype="int64"),
+ }
+ check_correctness(model, inputs, opset=13)
+
+
+def test_unsqueeze_dynamic_axes_ir():
+ unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"])
+
+ graph = helper.make_graph(
+ [unsqueeze_node],
+ "unsqueeze_dynamic_axes_ir",
+ inputs=[
+ helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]),
+ helper.make_tensor_value_info("axes", TensorProto.INT64, [2]),
+ ],
+ outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32,
32, 1])],
+ )
+
+ model = helper.make_model(graph,
producer_name="unsqueeze_dynamic_axes_ir_test")
+ tvm_model = from_onnx(model, opset=13, keep_params_in_input=True)
+ call_ops = collect_relax_call_ops(tvm_model["main"])
+
+ assert "relax.tensor_to_shape" in call_ops
+ assert "relax.reshape" in call_ops
+
+
+def test_unsqueeze_dynamic_axes_rank_validation():
+ unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"])
+
+ graph = helper.make_graph(
+ [unsqueeze_node],
+ "unsqueeze_dynamic_axes_rank_validation",
+ inputs=[
+ helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]),
+ helper.make_tensor_value_info("axes", TensorProto.INT64, [1, 2]),
+ ],
+ outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 32,
32, 1])],
+ )
+
+ model = helper.make_model(graph,
producer_name="unsqueeze_dynamic_axes_rank_validation_test")
+ with pytest.raises(ValueError, match="Expected a 1-D tensor"):
+ from_onnx(model, opset=13, keep_params_in_input=True)
+
+
+def test_unsqueeze_duplicate_axes_validation():
+ unsqueeze_node = helper.make_node("Unsqueeze", ["a", "axes"], ["b"])
+
+ graph = helper.make_graph(
+ [unsqueeze_node],
+ "unsqueeze_duplicate_axes_validation",
+ inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [32,
32])],
+ initializer=[helper.make_tensor("axes", TensorProto.INT64, [2],
vals=[0, 0])],
+ outputs=[helper.make_tensor_value_info("b", TensorProto.FLOAT, [1, 1,
32, 32])],
+ )
+
+ model = helper.make_model(graph,
producer_name="unsqueeze_duplicate_axes_validation_test")
+ with pytest.raises(ValueError, match="axes must be unique"):
+ from_onnx(model, opset=13)
+
+
def test_unsqueeze_v1():
# https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Unsqueeze-1
unsqueeze_node = helper.make_node("Unsqueeze", ["a"], ["b"], axes=[0, 2,
3])
@@ -1577,6 +1669,70 @@ def test_dynamic_squeeze(axis, A, B):
check_correctness(model, inputs, opset=13)
+def test_squeeze_dynamic_axes():
+ squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"])
+ shape = [1, 32, 1, 32]
+
+ graph = helper.make_graph(
+ [squeeze_node],
+ "squeeze_dynamic_axes_test",
+ inputs=[
+ helper.make_tensor_value_info("x", TensorProto.FLOAT, shape),
+ helper.make_tensor_value_info("axes", TensorProto.INT64, [2]),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32,
32])],
+ )
+
+ model = helper.make_model(graph, producer_name="squeeze_dynamic_axes_test")
+ inputs = {
+ "x": rg.standard_normal(size=shape).astype("float32"),
+ "axes": np.array([-4, 2], dtype="int64"),
+ }
+ check_correctness(model, inputs, opset=13)
+
+
+def test_squeeze_dynamic_axes_ir():
+ squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"])
+ shape = [1, 32, 1, 32]
+
+ graph = helper.make_graph(
+ [squeeze_node],
+ "squeeze_dynamic_axes_ir",
+ inputs=[
+ helper.make_tensor_value_info("x", TensorProto.FLOAT, shape),
+ helper.make_tensor_value_info("axes", TensorProto.INT64, [2]),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32,
32])],
+ )
+
+ model = helper.make_model(graph,
producer_name="squeeze_dynamic_axes_ir_test")
+ tvm_model = from_onnx(model, opset=13, keep_params_in_input=True)
+ call_ops = collect_relax_call_ops(tvm_model["main"])
+
+ assert "relax.tensor_to_shape" in call_ops
+ assert "relax.reshape" in call_ops
+ assert "relax.squeeze" not in call_ops
+
+
+def test_squeeze_dynamic_axes_rank_validation():
+ squeeze_node = helper.make_node("Squeeze", ["x", "axes"], ["y"])
+ shape = [1, 32, 1, 32]
+
+ graph = helper.make_graph(
+ [squeeze_node],
+ "squeeze_dynamic_axes_rank_validation",
+ inputs=[
+ helper.make_tensor_value_info("x", TensorProto.FLOAT, shape),
+ helper.make_tensor_value_info("axes", TensorProto.INT64, [1, 2]),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [32,
32])],
+ )
+
+ model = helper.make_model(graph,
producer_name="squeeze_dynamic_axes_rank_validation_test")
+ with pytest.raises(ValueError, match="Expected a 1-D tensor"):
+ from_onnx(model, opset=13, keep_params_in_input=True)
+
+
@pytest.mark.parametrize("axis", [[0]])
@pytest.mark.parametrize("A", [8, 16, 32])
def test_dynamic_shape_squeeze(axis, A):
@@ -2480,6 +2636,121 @@ def test_slice():
# )
+def test_slice_dynamic_inputs():
+ slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes",
"steps"], ["y"])
+
+ graph = helper.make_graph(
+ [slice_node],
+ "slice_dynamic_inputs_test",
+ inputs=[
+ helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5]),
+ helper.make_tensor_value_info("starts", TensorProto.INT64, [2]),
+ helper.make_tensor_value_info("ends", TensorProto.INT64, [2]),
+ helper.make_tensor_value_info("axes", TensorProto.INT64, [2]),
+ helper.make_tensor_value_info("steps", TensorProto.INT64, [2]),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10,
5])],
+ )
+
+ model = helper.make_model(graph, producer_name="slice_dynamic_inputs_test")
+ inputs = {
+ "x": rg.standard_normal(size=[20, 10, 5]).astype("float32"),
+ "starts": np.array([0, 0], dtype="int64"),
+ "ends": np.array([3, 10], dtype="int64"),
+ "axes": np.array([0, 1], dtype="int64"),
+ "steps": np.array([1, 1], dtype="int64"),
+ }
+ check_correctness(model, inputs, opset=13)
+
+
+def test_slice_dynamic_inputs_ir():
+ slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes",
"steps"], ["y"])
+
+ graph = helper.make_graph(
+ [slice_node],
+ "slice_dynamic_inputs_ir",
+ inputs=[
+ helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5]),
+ helper.make_tensor_value_info("starts", TensorProto.INT64, [2]),
+ helper.make_tensor_value_info("ends", TensorProto.INT64, [2]),
+ helper.make_tensor_value_info("axes", TensorProto.INT64, [2]),
+ helper.make_tensor_value_info("steps", TensorProto.INT64, [2]),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10,
5])],
+ )
+
+ model = helper.make_model(graph,
producer_name="slice_dynamic_inputs_ir_test")
+ tvm_model = from_onnx(model, opset=13, keep_params_in_input=True)
+ call_ops = collect_relax_call_ops(tvm_model["main"])
+
+ assert "relax.dynamic_strided_slice" in call_ops
+ assert "relax.strided_slice" not in call_ops
+
+
+def test_slice_dynamic_inputs_length_validation():
+ slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes",
"steps"], ["y"])
+
+ graph = helper.make_graph(
+ [slice_node],
+ "slice_dynamic_inputs_length_validation",
+ inputs=[
+ helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5]),
+ helper.make_tensor_value_info("starts", TensorProto.INT64, [2]),
+ helper.make_tensor_value_info("ends", TensorProto.INT64, [1]),
+ helper.make_tensor_value_info("axes", TensorProto.INT64, [2]),
+ helper.make_tensor_value_info("steps", TensorProto.INT64, [2]),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10,
5])],
+ )
+
+ model = helper.make_model(graph,
producer_name="slice_dynamic_inputs_length_validation_test")
+ with pytest.raises(ValueError, match="starts and ends to have the same
length"):
+ from_onnx(model, opset=13, keep_params_in_input=True)
+
+
+def test_slice_dynamic_shape_expr_input_validation():
+ shape_node = helper.make_node("Shape", ["x"], ["y"])
+ slice_node = helper.make_node("Slice", ["y", "starts", "ends", "axes",
"steps"], ["z"])
+
+ graph = helper.make_graph(
+ [shape_node, slice_node],
+ "slice_dynamic_shape_expr_input_validation",
+ inputs=[
+ helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10, 5]),
+ helper.make_tensor_value_info("starts", TensorProto.INT64, [1]),
+ helper.make_tensor_value_info("ends", TensorProto.INT64, [1]),
+ helper.make_tensor_value_info("axes", TensorProto.INT64, [1]),
+ helper.make_tensor_value_info("steps", TensorProto.INT64, [1]),
+ ],
+ outputs=[helper.make_tensor_value_info("z", TensorProto.INT64, [1])],
+ )
+
+ model = helper.make_model(graph,
producer_name="slice_dynamic_shape_expr_input_validation_test")
+ with pytest.raises(ValueError, match="does not support ShapeExpr input"):
+ from_onnx(model, opset=13, keep_params_in_input=True)
+
+
+def test_slice_zero_step_validation():
+ slice_node = helper.make_node("Slice", ["x", "starts", "ends", "axes",
"steps"], ["y"])
+
+ graph = helper.make_graph(
+ [slice_node],
+ "slice_zero_step_validation",
+ inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [20, 10,
5])],
+ initializer=[
+ helper.make_tensor("starts", TensorProto.INT64, [2], vals=[0, 0]),
+ helper.make_tensor("ends", TensorProto.INT64, [2], vals=[3, 10]),
+ helper.make_tensor("axes", TensorProto.INT64, [2], vals=[0, 1]),
+ helper.make_tensor("steps", TensorProto.INT64, [2], vals=[1, 0]),
+ ],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 10,
5])],
+ )
+
+ model = helper.make_model(graph,
producer_name="slice_zero_step_validation_test")
+ with pytest.raises(ValueError, match="step values must be non-zero"):
+ from_onnx(model, opset=13)
+
+
def test_slice_dynamic_shape():
def verify_slice(
data_shape, data_instance_shape, output_shape, starts, ends,
axes=None, steps=None