skatrak created this revision.
skatrak added reviewers: ftynse, kiranchandramohan, jsjodin, domada, agozillon, 
raghavendhra, TIFitis, shraiysh.
Herald added subscribers: bviyer, Moerafaat, zero9178, bzcheeseman, sdasgup3, 
wenzhicui, wrengr, cota, teijeong, rdzhabarov, tatianashp, msifontes, jurahul, 
Kayjukh, grosul1, Joonsoo, liufengdb, aartbik, mgester, arpith-jacob, 
antiagainst, shauheen, rriddle, mehdi_amini.
Herald added a project: All.
skatrak requested review of this revision.
Herald added a reviewer: nicolasvasilache.
Herald added subscribers: cfe-commits, stephenneuendorffer, nicolasvasilache.
Herald added projects: clang, MLIR.

This patch adds the `DialectCSEInterface`, which dialects can implement and 
register to prevent the common sub-expression elimination (CSE) pass from 
modifying regions of certain operations.

The result is that these operations would be treated by CSE as if they were 
`IsolatedFromAbove`, but without the restrictions that come with that trait.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D159212

Files:
  clang/docs/tools/clang-formatted-files.txt
  mlir/include/mlir/Interfaces/CSEInterfaces.h
  mlir/lib/Transforms/CSE.cpp
  mlir/test/Transforms/cse.mlir
  mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
  mlir/test/lib/Dialect/Test/TestOps.td

Index: mlir/test/lib/Dialect/Test/TestOps.td
===================================================================
--- mlir/test/lib/Dialect/Test/TestOps.td
+++ mlir/test/lib/Dialect/Test/TestOps.td
@@ -2703,6 +2703,10 @@
   }];
 }
 
+def NoCSEOneRegionOp : TEST_Op<"no_cse_one_region_op", []> {
+  let regions = (region AnyRegion);
+}
+
 //===----------------------------------------------------------------------===//
 // Test Ops to upgrade base on the dialect versions
 //===----------------------------------------------------------------------===//
Index: mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
===================================================================
--- mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
+++ mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "TestDialect.h"
+#include "mlir/Interfaces/CSEInterfaces.h"
 #include "mlir/Interfaces/FoldInterfaces.h"
 #include "mlir/Reducer/ReductionPatternInterface.h"
 #include "mlir/Transforms/InliningUtils.h"
@@ -273,6 +274,16 @@
   }
 };
 
+struct TestDialectCSEInterface : public DialectCSEInterface {
+  using DialectCSEInterface::DialectCSEInterface;
+
+  bool subexpressionExtractionAllowed(Operation *op) const final {
+    // Don't allow extracting common subexpressions from the region of these
+    // operations.
+    return !isa<NoCSEOneRegionOp>(op);
+  }
+};
+
 /// This class defines the interface for handling inlining with standard
 /// operations.
 struct TestInlinerInterface : public DialectInlinerInterface {
@@ -385,6 +396,7 @@
   auto &blobInterface = addInterface<TestResourceBlobManagerInterface>();
   addInterface<TestOpAsmInterface>(blobInterface);
 
-  addInterfaces<TestDialectFoldInterface, TestInlinerInterface,
-                TestReductionPatternInterface, TestBytecodeDialectInterface>();
+  addInterfaces<TestDialectFoldInterface, TestDialectCSEInterface,
+                TestInlinerInterface, TestReductionPatternInterface,
+                TestBytecodeDialectInterface>();
 }
Index: mlir/test/Transforms/cse.mlir
===================================================================
--- mlir/test/Transforms/cse.mlir
+++ mlir/test/Transforms/cse.mlir
@@ -520,3 +520,23 @@
   %2 = "test.op_with_memread"() : () -> (i32)
   return %0, %2, %1 : i32, i32, i32
 }
+
+// CHECK-LABEL: @no_cse_across_disabled_op
+func.func @no_cse_across_disabled_op() -> (i32) {
+  // CHECK-NEXT: %[[CONST1:.+]] = arith.constant 1 : i32
+  %0 = arith.constant 1 : i32
+
+  // CHECK-NEXT: test.no_cse_one_region_op
+  "test.no_cse_one_region_op"() ({
+    %1 = arith.constant 1 : i32
+    %2 = arith.addi %1, %1 : i32
+    "foo.yield"(%2) : (i32) -> ()
+
+    // CHECK-NEXT: %[[CONST2:.+]] = arith.constant 1 : i32
+    // CHECK-NEXT: %[[SUM:.+]] = arith.addi %[[CONST2]], %[[CONST2]] : i32
+    // CHECK-NEXT: "foo.yield"(%[[SUM]]) : (i32) -> ()
+  }) : () -> ()
+
+  // CHECK: return %[[CONST1]] : i32
+  return %0 : i32
+}
Index: mlir/lib/Transforms/CSE.cpp
===================================================================
--- mlir/lib/Transforms/CSE.cpp
+++ mlir/lib/Transforms/CSE.cpp
@@ -15,6 +15,7 @@
 
 #include "mlir/IR/Dominance.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/CSEInterfaces.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Transforms/Passes.h"
@@ -61,7 +62,8 @@
 class CSEDriver {
 public:
   CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo)
-      : rewriter(rewriter), domInfo(domInfo) {}
+      : rewriter(rewriter), domInfo(domInfo),
+        interfaces(rewriter.getContext()) {}
 
   /// Simplify all operations within the given op.
   void simplify(Operation *op, bool *changed = nullptr);
@@ -122,6 +124,9 @@
   DominanceInfo *domInfo = nullptr;
   MemEffectsCache memEffectsCache;
 
+  /// CSE interfaces in the present context that can modify CSE behavior.
+  DialectInterfaceCollection<DialectCSEInterface> interfaces;
+
   // Various statistics.
   int64_t numCSE = 0;
   int64_t numDCE = 0;
@@ -289,7 +294,12 @@
       // If this operation is isolated above, we can't process nested regions
       // with the given 'knownValues' map. This would cause the insertion of
       // implicit captures in explicit capture only regions.
-      if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) {
+      // Also, avoid capturing known values from parent regions if this behavior
+      // is explicitly disabled for the given operation.
+      const DialectCSEInterface *cseInterface = interfaces.getInterfaceFor(&op);
+      if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
+          LLVM_UNLIKELY(cseInterface &&
+                        !cseInterface->subexpressionExtractionAllowed(&op))) {
         ScopedMapTy nestedKnownValues;
         for (auto &region : op.getRegions())
           simplifyRegion(nestedKnownValues, region);
Index: mlir/include/mlir/Interfaces/CSEInterfaces.h
===================================================================
--- /dev/null
+++ mlir/include/mlir/Interfaces/CSEInterfaces.h
@@ -0,0 +1,32 @@
+//===- CSEInterfaces.h ------------------------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_INTERFACES_CSEINTERFACES_H_
+#define MLIR_INTERFACES_CSEINTERFACES_H_
+
+#include "mlir/IR/DialectInterface.h"
+
+namespace mlir {
+class Operation;
+
+/// Define an interface to allow for dialects to control specific aspects of
+/// common subexpression elimination behavior for operations they define.
+class DialectCSEInterface : public DialectInterface::Base<DialectCSEInterface> {
+public:
+  DialectCSEInterface(Dialect *dialect) : Base(dialect) {}
+
+  /// Registered hook to check if an operation that is *not* isolated from
+  /// above, should allow common subexpressions to be extracted out of its
+  /// regions.
+  virtual bool subexpressionExtractionAllowed(Operation *op) const {
+    return true;
+  }
+};
+
+} // namespace mlir
+
+#endif // MLIR_INTERFACES_CSEINTERFACES_H_
Index: clang/docs/tools/clang-formatted-files.txt
===================================================================
--- clang/docs/tools/clang-formatted-files.txt
+++ clang/docs/tools/clang-formatted-files.txt
@@ -7766,6 +7766,7 @@
 mlir/include/mlir/Interfaces/CastInterfaces.h
 mlir/include/mlir/Interfaces/ControlFlowInterfaces.h
 mlir/include/mlir/Interfaces/CopyOpInterface.h
+mlir/include/mlir/Interfaces/CSEInterfaces.h
 mlir/include/mlir/Interfaces/DataLayoutInterfaces.h
 mlir/include/mlir/Interfaces/DecodeAttributesInterfaces.h
 mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.h
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to