================
@@ -124,63 +124,74 @@ xegpu::DistributeLayoutAttr 
xegpu::getDistributeLayoutAttr(const Value value) {
     Operation *defOp = result.getDefiningOp();
     assert(defOp && "result must have a defining op");
 
-    // For ConvertLayoutOp, the layout is stored in the targetLayoutAttr
-    if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(defOp))
-      return convertOp.getTargetLayoutAttr();
-
-    // for LoadNdOp, the layout is stored in the tensor descriptor
-    if (auto loadNd = dyn_cast<xegpu::LoadNdOp>(defOp))
-      return getDistributeLayoutAttr(loadNd.getTensorDesc());
-
-    // for LoadMatrixOp, the layout is attached to the property of the op
-    if (auto loadOp = dyn_cast<xegpu::LoadMatrixOp>(defOp))
-      return loadOp.getLayoutAttr();
-
-    // for StoreMatrixOp, the layout is attached to the property of the op
-    if (auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(defOp))
-      return storeOp.getLayoutAttr();
-    std::string layoutName = getLayoutName(result);
-    if (defOp->hasAttr(layoutName))
-      return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
-
-    // check for "permament" layout only after "temporary" layout name lookup
-    // for backward compatibility
-    if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(defOp))
-      return loadGatherOp.getLayoutAttr();
+    if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(defOp)) {
+      auto layout = anchorOp.getAnchorLayout();
+      return layout;
+    }
+
+    std::string layoutName = getTempLayoutName(result);
+    if (defOp->hasAttr(layoutName)) {
+      auto layout =
+          defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
+      return layout;
+    }
   }
 
   if (auto arg = dyn_cast<BlockArgument>(value)) {
     auto *parentOp = arg.getOwner()->getParentOp();
     if (auto loop = dyn_cast<LoopLikeOpInterface>(parentOp)) {
       OpOperand *tiedInit = loop.getTiedLoopInit(arg);
-      if (tiedInit)
-        return getDistributeLayoutAttr(tiedInit->get());
+      if (tiedInit) {
+        auto layout = getDistributeLayoutAttr(tiedInit->get());
+        return layout;
+      }
     }
   }
 
   return nullptr;
 }
-
 xegpu::DistributeLayoutAttr
 xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
   Operation *op = opr.getOwner();
+  unsigned idx = const_cast<OpOperand &>(opr).getOperandNumber();
+
+  if (auto anchorOp = dyn_cast<xegpu::AnchorLayoutInterface>(op)) {
+    if (auto dpasOp = dyn_cast<xegpu::DpasOp>(op)) {
+      if (idx == 0) {
+        return dpasOp.getLayoutAAttr();
+      } else if (idx == 1) {
+        return dpasOp.getLayoutBAttr();
+      } else if (idx == 2) {
+        return dpasOp.getLayoutCdAttr();
+      }
+    }
+    if (auto convertOp = dyn_cast<xegpu::ConvertLayoutOp>(op)) {
+      return convertOp.getInputLayoutAttr();
+    }
+    auto layout = anchorOp.getAnchorLayout();
+    // For store operations (StoreScatterOp, StoreNdOp, StoreMatrixOp),
+    // the layout is valid for the first two operands: value and memref/tdesc.
+    // For other operations, the layout applies to the first operand only.
+    if (isa<xegpu::StoreScatterOp, xegpu::StoreNdOp, xegpu::StoreMatrixOp>(
+            op)) {
+      if (idx < 2) {
+        return layout;
+      }
+    } else {
+      if (idx == 0) {
+        return layout;
+      }
+    }
+  }
 
-  if (auto loadOp = dyn_cast<xegpu::LoadMatrixOp>(op))
-    return loadOp.getLayoutAttr();
-
-  if (auto storeOp = dyn_cast<xegpu::StoreMatrixOp>(op))
-    return storeOp.getLayoutAttr();
-
-  std::string layoutName = xegpu::getLayoutName(opr);
-  if (op->hasAttr(layoutName))
-    return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
-
-  // check for "permament" layout only after "temporary" layout name lookup
-  if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
-    if (auto layout = storeScatterOp.getLayoutAttr())
-      return layout;
+  std::string layoutName = xegpu::getTempLayoutName(opr);
----------------
Jianhui-Li wrote:

Remember that refactoring is a long process. 
Comparing to the previous version, the code gives way better structure: first 
check anchor layout, and local layout, and then def op. The future work is to 
remove the look up from def op and further clean up the special case for store 
operands.  
The previous version mixed them and is position-sensitive due to the test case 
build in certain assumption. 

https://github.com/llvm/llvm-project/pull/172125
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to