gemini-code-assist[bot] commented on code in PR #18675:
URL: https://github.com/apache/tvm/pull/18675#discussion_r2706643382


##########
src/relax/transform/convert_layout.cc:
##########
@@ -201,15 +202,21 @@ class LayoutConvertMutator : public ExprMutator {
   ffi::Optional<InferLayoutOutput> GetInferLayoutInfo(
       const CallNode* call_node,
       const ffi::Map<ffi::String, ffi::Array<ffi::String>>& desired_layouts,
-      const VarLayoutMap& var_layout_map) {
+      const LayoutCb& layout_cb, const VarLayoutMap& var_layout_map) {
     const OpNode* op_node = call_node->op.as<OpNode>();
     if (op_node == nullptr) return std::nullopt;
     Op op = Downcast<Op>(ffi::GetRef<Op>(op_node));
     const auto attr_map = 
Op::GetAttrMap<FRelaxInferLayout>("FRelaxInferLayout");
     if (attr_map.count(op) && !HasUnknownDimTensor(call_node->args)) {
       // If the op has FRelaxInferLayout, and all the input tensors have known 
ndim
       FRelaxInferLayout f = attr_map[op];
-      return f(ffi::GetRef<Call>(call_node), desired_layouts, var_layout_map);
+      if (layout_cb != nullptr) {
+        ffi::Map<ffi::String, ffi::Array<ffi::String>> custom_layouts;
+        custom_layouts = layout_cb(ffi::GetRef<Call>(call_node));
+        return f(ffi::GetRef<Call>(call_node), custom_layouts, var_layout_map);
+      } else {
+        return f(ffi::GetRef<Call>(call_node), desired_layouts, 
var_layout_map);
+      }

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   To improve readability and avoid redundant calls to 
`ffi::GetRef<Call>(call_node)`, you could store the result in a local variable. 
This also makes the code slightly more efficient.
   
   ```suggestion
         FRelaxInferLayout f = attr_map[op];
         auto call = ffi::GetRef<Call>(call_node);
         if (layout_cb) {
           auto custom_layouts = layout_cb(call);
           return f(call, custom_layouts, var_layout_map);
         } else {
           return f(call, desired_layouts, var_layout_map);
         }
   ```



##########
python/tvm/relax/transform/transform.py:
##########
@@ -1377,13 +1380,16 @@ def ConvertLayout(desired_layouts: Dict[str, 
List[str]]) -> tvm.ir.transform.Pas
         of the desired feature map, weight and output. For example, if we want 
to convert the
         layout of conv2d from NCHW to NHWC, we can set the desired layout of 
conv2d to be
         ``{"relax.nn.conv2d": ["NHWC", "OHWI"]}``.
+    layout_cb : Callable
+        A user defined call back function that can dynamically handle operator 
layouts
+        based on Call description. desigred_layouts will be ignored if 
layout_cb is defined.

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   There's a small typo in the docstring. `desigred_layouts` should be 
`desired_layouts`.
   
   ```suggestion
           based on Call description. desired_layouts will be ignored if 
layout_cb is defined.
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to