skatrak created this revision. skatrak added reviewers: ftynse, kiranchandramohan, jsjodin, domada, agozillon, raghavendhra, TIFitis, shraiysh. Herald added subscribers: bviyer, Moerafaat, zero9178, bzcheeseman, sdasgup3, wenzhicui, wrengr, cota, teijeong, rdzhabarov, tatianashp, msifontes, jurahul, Kayjukh, grosul1, Joonsoo, liufengdb, aartbik, mgester, arpith-jacob, antiagainst, shauheen, rriddle, mehdi_amini. Herald added a project: All. skatrak requested review of this revision. Herald added a reviewer: nicolasvasilache. Herald added subscribers: cfe-commits, stephenneuendorffer, nicolasvasilache. Herald added projects: clang, MLIR.
This patch adds the `DialectCSEInterface`, which dialects can implement and register to prevent the common sub-expression elimination (CSE) pass from modifying regions of certain operations. The result is that these operations would be treated by CSE as if they were `IsolatedFromAbove`, but without the restrictions that come with that trait. Repository: rG LLVM Github Monorepo https://reviews.llvm.org/D159212 Files: clang/docs/tools/clang-formatted-files.txt mlir/include/mlir/Interfaces/CSEInterfaces.h mlir/lib/Transforms/CSE.cpp mlir/test/Transforms/cse.mlir mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp mlir/test/lib/Dialect/Test/TestOps.td
Index: mlir/test/lib/Dialect/Test/TestOps.td =================================================================== --- mlir/test/lib/Dialect/Test/TestOps.td +++ mlir/test/lib/Dialect/Test/TestOps.td @@ -2703,6 +2703,10 @@ }]; } +def NoCSEOneRegionOp : TEST_Op<"no_cse_one_region_op", []> { + let regions = (region AnyRegion); +} + //===----------------------------------------------------------------------===// // Test Ops to upgrade base on the dialect versions //===----------------------------------------------------------------------===// Index: mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp =================================================================== --- mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp +++ mlir/test/lib/Dialect/Test/TestDialectInterfaces.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "TestDialect.h" +#include "mlir/Interfaces/CSEInterfaces.h" #include "mlir/Interfaces/FoldInterfaces.h" #include "mlir/Reducer/ReductionPatternInterface.h" #include "mlir/Transforms/InliningUtils.h" @@ -273,6 +274,16 @@ } }; +struct TestDialectCSEInterface : public DialectCSEInterface { + using DialectCSEInterface::DialectCSEInterface; + + bool subexpressionExtractionAllowed(Operation *op) const final { + // Don't allow extracting common subexpressions from the region of these + // operations. + return !isa<NoCSEOneRegionOp>(op); + } +}; + /// This class defines the interface for handling inlining with standard /// operations. struct TestInlinerInterface : public DialectInlinerInterface { @@ -385,6 +396,7 @@ auto &blobInterface = addInterface<TestResourceBlobManagerInterface>(); addInterface<TestOpAsmInterface>(blobInterface); - addInterfaces<TestDialectFoldInterface, TestInlinerInterface, - TestReductionPatternInterface, TestBytecodeDialectInterface>(); + addInterfaces<TestDialectFoldInterface, TestDialectCSEInterface, + TestInlinerInterface, TestReductionPatternInterface, + TestBytecodeDialectInterface>(); } Index: mlir/test/Transforms/cse.mlir =================================================================== --- mlir/test/Transforms/cse.mlir +++ mlir/test/Transforms/cse.mlir @@ -520,3 +520,23 @@ %2 = "test.op_with_memread"() : () -> (i32) return %0, %2, %1 : i32, i32, i32 } + +// CHECK-LABEL: @no_cse_across_disabled_op +func.func @no_cse_across_disabled_op() -> (i32) { + // CHECK-NEXT: %[[CONST1:.+]] = arith.constant 1 : i32 + %0 = arith.constant 1 : i32 + + // CHECK-NEXT: test.no_cse_one_region_op + "test.no_cse_one_region_op"() ({ + %1 = arith.constant 1 : i32 + %2 = arith.addi %1, %1 : i32 + "foo.yield"(%2) : (i32) -> () + + // CHECK-NEXT: %[[CONST2:.+]] = arith.constant 1 : i32 + // CHECK-NEXT: %[[SUM:.+]] = arith.addi %[[CONST2]], %[[CONST2]] : i32 + // CHECK-NEXT: "foo.yield"(%[[SUM]]) : (i32) -> () + }) : () -> () + + // CHECK: return %[[CONST1]] : i32 + return %0 : i32 +} Index: mlir/lib/Transforms/CSE.cpp =================================================================== --- mlir/lib/Transforms/CSE.cpp +++ mlir/lib/Transforms/CSE.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Interfaces/CSEInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/Passes.h" @@ -61,7 +62,8 @@ class CSEDriver { public: CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo) - : rewriter(rewriter), domInfo(domInfo) {} + : rewriter(rewriter), domInfo(domInfo), + interfaces(rewriter.getContext()) {} /// Simplify all operations within the given op. void simplify(Operation *op, bool *changed = nullptr); @@ -122,6 +124,9 @@ DominanceInfo *domInfo = nullptr; MemEffectsCache memEffectsCache; + /// CSE interfaces in the present context that can modify CSE behavior. + DialectInterfaceCollection<DialectCSEInterface> interfaces; + // Various statistics. int64_t numCSE = 0; int64_t numDCE = 0; @@ -289,7 +294,12 @@ // If this operation is isolated above, we can't process nested regions // with the given 'knownValues' map. This would cause the insertion of // implicit captures in explicit capture only regions. - if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) { + // Also, avoid capturing known values from parent regions if this behavior + // is explicitly disabled for the given operation. + const DialectCSEInterface *cseInterface = interfaces.getInterfaceFor(&op); + if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>() || + LLVM_UNLIKELY(cseInterface && + !cseInterface->subexpressionExtractionAllowed(&op))) { ScopedMapTy nestedKnownValues; for (auto ®ion : op.getRegions()) simplifyRegion(nestedKnownValues, region); Index: mlir/include/mlir/Interfaces/CSEInterfaces.h =================================================================== --- /dev/null +++ mlir/include/mlir/Interfaces/CSEInterfaces.h @@ -0,0 +1,32 @@ +//===- CSEInterfaces.h ------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_INTERFACES_CSEINTERFACES_H_ +#define MLIR_INTERFACES_CSEINTERFACES_H_ + +#include "mlir/IR/DialectInterface.h" + +namespace mlir { +class Operation; + +/// Define an interface to allow for dialects to control specific aspects of +/// common subexpression elimination behavior for operations they define. +class DialectCSEInterface : public DialectInterface::Base<DialectCSEInterface> { +public: + DialectCSEInterface(Dialect *dialect) : Base(dialect) {} + + /// Registered hook to check if an operation that is *not* isolated from + /// above, should allow common subexpressions to be extracted out of its + /// regions. + virtual bool subexpressionExtractionAllowed(Operation *op) const { + return true; + } +}; + +} // namespace mlir + +#endif // MLIR_INTERFACES_CSEINTERFACES_H_ Index: clang/docs/tools/clang-formatted-files.txt =================================================================== --- clang/docs/tools/clang-formatted-files.txt +++ clang/docs/tools/clang-formatted-files.txt @@ -7766,6 +7766,7 @@ mlir/include/mlir/Interfaces/CastInterfaces.h mlir/include/mlir/Interfaces/ControlFlowInterfaces.h mlir/include/mlir/Interfaces/CopyOpInterface.h +mlir/include/mlir/Interfaces/CSEInterfaces.h mlir/include/mlir/Interfaces/DataLayoutInterfaces.h mlir/include/mlir/Interfaces/DecodeAttributesInterfaces.h mlir/include/mlir/Interfaces/DerivedAttributeOpInterface.h
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits