gemini-code-assist[bot] commented on code in PR #18593:
URL: https://github.com/apache/tvm/pull/18593#discussion_r2637141869
##########
src/relax/op/tensor/manipulate.cc:
##########
@@ -1929,12 +1929,136 @@ StructInfo InferStructInfoTile(const Call& call, const
BlockBuilder& ctx) {
return TensorStructInfo(ShapeExpr(out_shape), data_sinfo->dtype,
data_sinfo->vdevice);
}
-// TODO(relax-team): implement FRelaxInferLayout for tile
+InferLayoutOutput InferLayoutTile(
+ const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ ICHECK(NoDesiredLayout(call, desired_layouts));
+
+ const auto* attrs = call->attrs.as<TileAttrs>();
+ ICHECK(attrs != nullptr) << "Invalid Call";
+ const auto* tensor_sinfo =
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+ ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+ ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now";
+
+ LayoutDecision existing_layout = GetLayoutDecision(var_layout_map,
call->args[0]);
+ int ndim = tensor_sinfo->ndim;
+ int l = attrs->repeats.size();
+ int out_ndim = std::max(l, ndim);
+
+ // Can't handle sub indexed layouts.
+ if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal())
{
+ existing_layout = LayoutDecision(InitialLayout(ndim));
+ }
+
+ // Tile operation repeats data along each axis.
+ // When layout changes, we need to transform the repeats array to match the
new layout.
+ Layout initial_layout = InitialLayout(ndim);
+ Layout existing_layout_obj = existing_layout->layout;
+
+ // Transform repeats array according to layout change
+ // The repeats array corresponds to axes in the initial layout order
(ABCD...).
+ // We need to reorder it to match the existing layout.
+ // The key insight: for each position in existing_layout, find which
position in initial_layout
+ // it corresponds to, and use the repeat value from that position.
+ ffi::Array<Integer> new_repeats;
+
+ if (out_ndim == ndim) {
+ // Same dimension: reorder repeats according to layout transformation
+ // Use TransposeStrLike approach similar to repeat operator:
+ // Build a string representation where each position j has the repeat
value,
+ // then transpose it from initial_layout to existing_layout.
+ // This correctly handles the axis name mapping.
+
+ // Build a string representation of repeats for TransposeStrLike
+ // We encode repeat values as characters (0-9 for values 0-9, and use
direct mapping for larger values)
+ std::string repeats_str;
+ for (int j = 0; j < ndim; ++j) {
+ if (j < l) {
+ int repeat_val = attrs->repeats[j]->value;
+ if (repeat_val >= 0 && repeat_val <= 9) {
+ repeats_str.push_back('0' + repeat_val);
+ } else {
+ // For values > 9, we'll handle them separately after
TransposeStrLike
+ repeats_str.push_back('X');
+ }
+ } else {
+ repeats_str.push_back('1'); // Default repeat of 1
+ }
+ }
+
+ // Transpose the repeats string from initial layout to existing layout
+ // Note: TransposeStrLike(input, src, dst) maps from src to dst
+ // For tile, we need to map repeats from initial_layout to existing_layout
+ // So we use TransposeStrLike(repeats_str, initial_layout,
existing_layout_obj)
+ // This is the same approach as repeat operator uses for axis mapping
+ ffi::String transposed_repeats_str =
+ TransposeStrLike(repeats_str, initial_layout, existing_layout_obj);
+
+ // Convert back to Integer array, handling placeholders for values > 9
+ for (int i = 0; i < ndim; ++i) {
+ char c = transposed_repeats_str.at(i);
+ if (c >= '0' && c <= '9') {
+ new_repeats.push_back(Integer(c - '0'));
+ } else {
+ // For placeholder or out-of-range, find the original value via direct
mapping
+ // This handles values > 9 or when l < ndim
+ const tir::LayoutAxis& axis = existing_layout_obj[i];
+ int pos_in_initial = initial_layout.IndexOf(axis);
+ if (pos_in_initial >= 0 && pos_in_initial < l) {
+ new_repeats.push_back(attrs->repeats[pos_in_initial]);
+ } else {
+ new_repeats.push_back(Integer(1));
+ }
+ }
+ }
+ } else {
+ // Different dimension: handle dimension expansion
+ int l_delta = out_ndim - l;
+ int ndim_delta = out_ndim - ndim;
+
+ // Build new repeats array for output dimensions
+ for (int i = 0; i < out_ndim; ++i) {
+ if (i < l_delta) {
+ // New dimensions from repeats (at front, before input dimensions)
+ new_repeats.push_back(attrs->repeats[i]);
+ } else if (i < ndim_delta) {
+ // New dimensions from input expansion (at front)
+ new_repeats.push_back(Integer(1));
+ } else {
+ // Existing dimensions: map from initial to existing layout
+ int orig_axis = i - ndim_delta;
+ // Get the axis at position orig_axis in existing layout
+ const tir::LayoutAxis& axis = existing_layout_obj[orig_axis];
+ // Find its position in initial layout
+ int axis_in_initial = initial_layout.IndexOf(axis);
+ // The repeat index in original repeats array
+ int repeat_idx = axis_in_initial + l_delta;
+ if (axis_in_initial >= 0 && repeat_idx < l) {
+ new_repeats.push_back(attrs->repeats[repeat_idx]);
+ } else {
+ new_repeats.push_back(Integer(1));
+ }
+ }
+ }
+ }
Review Comment:

The current logic for transforming the `repeats` array is overly complex and
contains a few bugs:
1. **Incorrect Padding**: When `len(repeats) < ndim`, the `repeats` array
should be padded with 1s at the beginning. The current implementation pads at
the end.
2. **Incorrect Dimension Expansion Logic**: When `len(repeats) > ndim`, the
logic for constructing `new_repeats` is incorrect for both new and existing
dimensions.
3. **Complexity**: The use of `TransposeStrLike` with string encoding for
repeat values is hard to follow and has limitations (e.g., special handling for
repeats > 9). A direct mapping approach is much simpler and more robust.
I've provided a simplified and corrected implementation below that addresses
all these points, making the code more readable and correct.
```c
if (out_ndim == ndim) {
// Same dimension: reorder repeats according to layout transformation.
// If len(repeats) < ndim, it's padded with 1s at the beginning.
for (int i = 0; i < ndim; ++i) {
const tir::LayoutAxis& axis = existing_layout_obj[i];
int pos_in_initial = initial_layout.IndexOf(axis);
if (pos_in_initial >= ndim - l) {
new_repeats.push_back(attrs->repeats[pos_in_initial - (ndim - l)]);
} else {
new_repeats.push_back(Integer(1));
}
}
} else {
// Different dimension: handle dimension expansion.
// This case only happens when l > ndim.
ICHECK_GT(l, ndim);
int num_new_dims = l - ndim;
// Repeats for new dimensions are not affected by layout change.
for (int i = 0; i < num_new_dims; ++i) {
new_repeats.push_back(attrs->repeats[i]);
}
// Repeats for existing dimensions need to be permuted.
for (int i = 0; i < ndim; ++i) {
const tir::LayoutAxis& axis = existing_layout_obj[i];
int pos_in_initial = initial_layout.IndexOf(axis);
ICHECK_NE(pos_in_initial, -1) << "Axis not found in initial layout";
new_repeats.push_back(attrs->repeats[pos_in_initial + num_new_dims]);
}
}
```
##########
tests/python/relax/test_transform_convert_layout.py:
##########
@@ -5077,5 +5077,49 @@ def main(
verify(Input, Expected)
+def test_conv2d_tile():
+ @I.ir_module
+ class Input:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), "float32"), w: R.Tensor((4, 3, 3, 3),
"float32")
+ ) -> R.Tensor(None, "float32", ndim=4):
+ with R.dataflow():
+ gv: R.Tensor((2, 4, 26, 26), "float32") = R.nn.conv2d(x, w,
out_dtype="float32")
+ gv2: R.Tensor((2, 8, 26, 26), "float32") = R.tile(gv,
repeats=[1, 2, 1, 1])
+ R.output(gv2)
+ return gv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((2, 3, 28, 28), dtype="float32"), w: R.Tensor((4, 3,
3, 3), dtype="float32")
+ ) -> R.Tensor(None, dtype="float32", ndim=4):
+ with R.dataflow():
+ lv: R.Tensor((2, 28, 28, 3), dtype="float32") =
R.permute_dims(x, axes=[0, 2, 3, 1])
+ lv1: R.Tensor((4, 3, 3, 3), dtype="float32") =
R.permute_dims(w, axes=[0, 2, 3, 1])
+ gv: R.Tensor((2, 26, 26, 4), dtype="float32") = R.nn.conv2d(
+ lv,
+ lv1,
+ strides=[1, 1],
+ padding=[0, 0, 0, 0],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NHWC",
+ kernel_layout="OHWI",
+ out_layout="NHWC",
+ out_dtype="float32",
+ )
+ lv2: R.Tensor((2, 26, 26, 8), dtype="float32") = R.tile(gv,
repeats=[1, 1, 1, 2])
+ gv2: R.Tensor((2, 8, 26, 26), dtype="float32") =
R.permute_dims(
+ lv2, axes=[0, 3, 1, 2]
+ )
+ R.output(gv2)
+ return gv2
+
+ verify(Input, Expected)
Review Comment:

The added test case is good for verifying the same-dimension tiling. To
ensure the implementation is robust, it would be beneficial to add more test
cases to cover other scenarios, especially those where the current
implementation has issues. Specifically:
1. `len(repeats) < ndim`: To test the padding logic.
2. `len(repeats) > ndim`: To test the dimension expansion logic.
3. `repeats` with values > 9: To test the placeholder logic in the original
implementation.
Here's a sketch of what these tests could look like:
```python
def test_conv2d_tile_more_cases():
# Setup similar to test_conv2d_tile
# Case 1: len(repeats) < ndim
# gv has ndim=4. Use repeats with len < 4.
# e.g., R.tile(gv, repeats=[2, 1])
# Expected repeats for NHWC should be [1, 2, 1, 1] if original was NCHW.
# Case 2: len(repeats) > ndim
# e.g., R.tile(gv, repeats=[2, 1, 2, 1, 1])
# Expected repeats for NHWC should be [2, 1, 1, 1, 2]
# Case 3: repeats value > 9
# e.g., R.tile(gv, repeats=[1, 10, 1, 1])
# Expected repeats for NHWC should be [1, 1, 1, 10]
pass
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]