gemini-code-assist[bot] commented on code in PR #18579:
URL: https://github.com/apache/tvm/pull/18579#discussion_r2608863504


##########
src/relax/op/tensor/manipulate.cc:
##########
@@ -1805,12 +1805,70 @@ StructInfo InferStructInfoRepeat(const Call& call, 
const BlockBuilder& ctx) {
   return TensorStructInfo(ShapeExpr(shape_array), data_sinfo->dtype, 
data_sinfo->vdevice);
 }
 
-// TODO(relax-team): implement FRelaxInferLayout for repeat
+InferLayoutOutput InferLayoutRepeat(
+    const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>& 
desired_layouts,
+    const VarLayoutMap& var_layout_map) {
+  ICHECK(NoDesiredLayout(call, desired_layouts));
+
+  const auto* attrs = call->attrs.as<RepeatAttrs>();
+  ICHECK(attrs != nullptr) << "Invalid Call";
+  const auto* tensor_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  ICHECK(tensor_sinfo != nullptr) << "Invalid Call";
+  ICHECK(!tensor_sinfo->IsUnknownNdim()) << "Only support static ndim for now";
+
+  LayoutDecision existing_layout = GetLayoutDecision(var_layout_map, 
call->args[0]);
+  int ndim = tensor_sinfo->ndim;
+
+  // Can't handle sub indexed layouts.
+  if (existing_layout->layout.ndim() != existing_layout->layout.ndim_primal()) 
{
+    existing_layout = LayoutDecision(InitialLayout(ndim));
+  }
+
+  // When axis is not specified, the output is 1D (flattened)
+  if (!attrs->axis.has_value()) {
+    return InferLayoutOutput({existing_layout}, {InitialLayoutDecision(1)}, 
Attrs(call->attrs));
+  }
+
+  // Transform the axis based on the layout
+  int axis = attrs->axis.value();
+  if (axis < 0) {
+    axis += ndim;
+  }
+
+  // Create a mapping from original layout to existing layout
+  std::string axis_str(ndim, '0');
+  axis_str[axis] = '1';
+  for (int i = 0, j = 0; i < ndim; ++i) {
+    if (axis_str[i] != '1') {
+      axis_str[i] = 'A' + j++;
+    }
+  }
+
+  ffi::String new_axis_str =
+      TransposeStrLike(axis_str, InitialLayout(ndim), existing_layout->layout);
+
+  int64_t new_axis = -1;
+  for (size_t i = 0; i < new_axis_str.size(); ++i) {
+    if (new_axis_str.at(i) == '1') {
+      new_axis = i;
+      break;
+    }
+  }
+  ICHECK_GE(new_axis, 0) << "Failed to find transformed axis";

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   This loop to find the index of '1' can be simplified by using 
`ffi::String::find`, which is more idiomatic and concise.
   
   ```c
     size_t pos = new_axis_str.find('1');
     ICHECK(pos != std::string::npos) << "Failed to find transformed axis";
     int64_t new_axis = static_cast<int64_t>(pos);
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to