This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3d1e402502 [Frontend][TFLite] Add test coverage for SHAPE and RANGE
operators (#19401)
3d1e402502 is described below
commit 3d1e4025020ab4def03a3f2b6058cf5e8fef0153
Author: Bana <[email protected]>
AuthorDate: Tue Apr 14 20:18:09 2026 +0300
[Frontend][TFLite] Add test coverage for SHAPE and RANGE operators (#19401)
Initial goal was to add SHAPE and RANGE tests, solving part of #18971
This PR achieves that and includes the minimum necessary frontend fixes
discovered during implementation so those tests reflect real supported
behavior instead of xfail/workarounds.
so this PR includes both:
**1. New SHAPE/RANGE tests
2. Targeted frontend fixes required to make those tests pass correctly**
## Why These Changes Were Needed
- SHAPE conversion previously produced symbolic shape info instead of a
tensor output aligned with TFLite SHAPE semantics.
- RANGE conversion passed tensor expressions into arange instead of
scalar values for constant scalar bounds.
- Zero-input TFLite subgraphs (valid for constant-only models such as
RANGE without inputs) were blocked by a strict assertion.
- Model output collection was brittle for constant/prefetched outputs
and could fail when output expressions were not already in the expr
table.
- As a result, i could not add meaningful SHAPE/RANGE coverage without
fixing frontend behavior.
## **Modifications**
### **Frontend Changes** (In tflite_frontend.py):
- Updated convert_shape: SHAPE now materializes shape output as a tensor
using shape_to_tensor(shape_of(...))
- Applies output dtype casting based on ShapeOptions OutType
(int32/int64)
- Updated convert_range: Extracts scalar values for start/limit/delta
from scalar constants
- Calls arange with scalar-like values
- Keeps dynamic scalar RANGE explicit as unsupported (raises
OpNotImplemented with clear message)
- Updated _input_type: Removed assumption that every subgraph must have
at least one input
- Supports valid zero-input subgraphs
- Updated from_tflite output assembly: Resolves outputs via tensor
wrappers and get_tensor_expr instead of direct expr-table lookup by name
---
**Main functional changes are localized to SHAPE/RANGE conversion and
model output/input handling.**
---------
Co-authored-by: gemini-code-assist[bot]
<176961590+gemini-code-assist[bot]@users.noreply.github.com>
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 36 +++++++++--
tests/python/relax/test_frontend_tflite.py | 72 ++++++++++++++++++++++
2 files changed, 103 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 584a65e1f4..0f9f168a13 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -925,15 +925,35 @@ class OperatorConverter:
start, limit, delta = input_tensors[0], input_tensors[1],
input_tensors[2]
- expressions = [self.get_tensor_expr(t) for t in [start, limit, delta]]
+ def get_scalar_value(tensor):
+ if self.has_expr(tensor.tensor_idx):
+ expr = self.get_expr(tensor.tensor_idx)
+ if isinstance(expr, relax.Constant):
+ value = expr.data.numpy()
+ else:
+ # relax.op.arange currently expects scalar-like values
here.
+ # Keep dynamic scalar RANGE explicit until frontend
support is added.
+ raise tvm.error.OpNotImplemented(
+ "TFLite RANGE with dynamic scalar inputs is not
supported in Relax frontend yet."
+ )
+ else:
+ value = self.get_tensor_value(tensor)
+ # TFLite RANGE operands are scalar tensors in the flatbuffer.
+ assert value.size == 1, "RANGE scalar input must have exactly one
element"
+ return value.item()
+
+ start_value = get_scalar_value(start)
+ limit_value = get_scalar_value(limit)
+ delta_value = get_scalar_value(delta)
+
# out type inference
if delta.tensor.Type() == TensorType.FLOAT32:
out_type = self.get_tensor_type_str(delta.tensor.Type())
else:
out_type = self.get_tensor_type_str(start.tensor.Type())
- out = relax.op.arange(expressions[0], expressions[1], expressions[2],
out_type)
+ out = relax.op.arange(start_value, limit_value, delta_value, out_type)
return out
@@ -942,6 +962,7 @@ class OperatorConverter:
from tflite.BuiltinOptions import BuiltinOptions
from tflite.ShapeOptions import ShapeOptions
+ from tflite.TensorType import TensorType
input_tensors = self.get_input_tensors(op)
assert len(input_tensors) == 1, "input tensors length should be 1"
@@ -951,7 +972,10 @@ class OperatorConverter:
shape_options = ShapeOptions()
shape_options.Init(op_options.Bytes, op_options.Pos)
- out = relax.op.shape_of(self.get_tensor_expr(input_tensors[0]))
+ # SHAPE must materialize as a tensor output in Relax, not just
symbolic shape info.
+ out =
relax.op.shape_to_tensor(relax.op.shape_of(self.get_tensor_expr(input_tensors[0])))
+ if shape_options.OutType() == TensorType.INT32:
+ out = relax.op.astype(out, "int32")
return out
@@ -4055,7 +4079,7 @@ def _input_type(model):
for subgraph_index in range(subgraph_count):
subgraph = model.Subgraphs(subgraph_index)
inputs_count = subgraph.InputsLength()
- assert inputs_count >= 1
+ # TFLite subgraphs can validly have zero inputs (e.g. constant-only
RANGE models).
for input_index in range(inputs_count):
input_ = subgraph.Inputs(input_index)
assert subgraph.TensorsLength() > input_
@@ -4209,7 +4233,9 @@ def from_tflite(
op_converter.convert_op_to_relax()
# params and outputs
- outputs = [exp_tab.get_expr(get_tensor_name(subgraph, i)) for i in
model_outputs]
+ # Resolve outputs through tensor wrappers so constant/prefetched
outputs are handled.
+ output_tensors = op_converter.get_tensors(model_outputs)
+ outputs = [op_converter.get_tensor_expr(tensor) for tensor in
output_tensors]
outputs = outputs[0] if len(outputs) == 1 else relax.Tuple(outputs)
output_var = bb.emit_output(outputs)
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index bf6ef8e819..8eb2c8e13b 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -279,6 +279,78 @@ def test_reshape():
verify(Reshape, Expected)
[email protected](
+ "input_shape, out_type",
+ [
+ ((2, 3, 4), tf.int32),
+ ((5,), tf.int64),
+ ((1, 1, 1, 1), tf.int32),
+ ((), tf.int32),
+ ((0, 3), tf.int64),
+ ],
+)
+def test_shape(input_shape, out_type):
+ """SHAPE conversion for static-rank non-quantized tensors."""
+
+ class Shape(tf.Module):
+ @tf.function(input_signature=[tf.TensorSpec(shape=input_shape,
dtype=tf.float32)])
+ def func(self, x):
+ return tf.shape(x, out_type=out_type)
+
+ verify(Shape)
+
+
+def test_shape_dynamic_dim():
+ """SHAPE conversion with a dynamic input dimension."""
+
+ class ShapeDynamic(tf.Module):
+ @tf.function(input_signature=[tf.TensorSpec(shape=(None, 3),
dtype=tf.float32)])
+ def func(self, x):
+ return tf.shape(x, out_type=tf.int32)
+
+ verify(ShapeDynamic)
+
+
[email protected](
+ "start, limit, delta, dtype",
+ [
+ (0, 8, 2, tf.int32),
+ (1, 9, 2, tf.int64),
+ (0.0, 1.0, 0.2, tf.float32),
+ (8, 0, -2, tf.int32),
+ (0, 0, 1, tf.int32),
+ (0, 7, 2, tf.int32),
+ (0.0, -1.0, -0.25, tf.float32),
+ ],
+)
+def test_range(start, limit, delta, dtype):
+ """RANGE conversion with non-quantized constant scalar bounds."""
+
+ class Range(tf.Module):
+ @tf.function(input_signature=[])
+ def func(self):
+ return tf.range(start, limit, delta, dtype=dtype)
+
+ verify(Range)
+
+
+def test_range_dynamic_scalar_inputs_not_supported():
+ """RANGE conversion currently rejects dynamic scalar inputs."""
+
+ class RangeDynamic(tf.Module):
+ @tf.function(
+ input_signature=[
+ tf.TensorSpec(shape=(), dtype=tf.int32),
+ tf.TensorSpec(shape=(), dtype=tf.int32),
+ tf.TensorSpec(shape=(), dtype=tf.int32),
+ ]
+ )
+ def func(self, start, limit, delta):
+ return tf.range(start, limit, delta, dtype=tf.int32)
+
+ with pytest.raises(tvm.error.OpNotImplemented, match="dynamic scalar
inputs"):
+ verify(RangeDynamic)
+
def test_tile_ir():
"""TILE conversion with explicit Relax IR structural check."""