https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/86099
This commit changes the API of `ValueBoundsConstraintSet`: the stop condition is now passed to the constructor instead of `processWorklist`. That makes it easier to add items to the worklist multiple times and process them in a consistent manner. The current `ValueBoundsConstraintSet` is passed as a reference to the stop function, so that the stop function can be defined before the the `ValueBoundsConstraintSet` is constructed. This change is in preparation of adding support for branches. >From db3dde1d9c6e3eb1b85083d1a3545691f47acb7c Mon Sep 17 00:00:00 2001 From: Matthias Springer <spring...@google.com> Date: Thu, 21 Mar 2024 08:04:11 +0000 Subject: [PATCH] [mlir][Interfaces][NFC] `ValueBoundsConstraintSet`: Pass stop condition in the constructor This commit changes the API of `ValueBoundsConstraintSet`: the stop condition is now passed to the constructor instead of `processWorklist`. That makes it easier to add items to the worklist multiple times and process them in a consistent manner. The current `ValueBoundsConstraintSet` is passed as a reference to the stop function, so that the stop function can be defined before the the `ValueBoundsConstraintSet` is constructed. This change is in preparation of adding support for branches. --- .../mlir/Interfaces/ValueBoundsOpInterface.h | 16 +++-- .../Affine/Transforms/ReifyValueBounds.cpp | 6 +- .../Arith/Transforms/ReifyValueBounds.cpp | 6 +- .../Linalg/Transforms/HoistPadding.cpp | 2 +- .../SCF/IR/ValueBoundsOpInterfaceImpl.cpp | 2 +- .../lib/Interfaces/ValueBoundsOpInterface.cpp | 60 +++++++++++-------- .../Dialect/Affine/TestReifyValueBounds.cpp | 9 ++- 7 files changed, 62 insertions(+), 39 deletions(-) diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h index 94a8a8b429c801..b79c44162ea8ef 100644 --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -113,8 +113,9 @@ class ValueBoundsConstraintSet { /// /// The first parameter of the function is the shaped value/index-typed /// value. The second parameter is the dimension in case of a shaped value. - using StopConditionFn = - function_ref<bool(Value, std::optional<int64_t> /*dim*/)>; + /// The third parameter is this constraint set. + using StopConditionFn = function_ref<bool( + Value, std::optional<int64_t> /*dim*/, ValueBoundsConstraintSet &cstr)>; /// Compute a bound for the given index-typed value or shape dimension size. /// The computed bound is stored in `resultMap`. The operands of the bound are @@ -263,12 +264,12 @@ class ValueBoundsConstraintSet { /// An index-typed value or the dimension of a shaped-type value. using ValueDim = std::pair<Value, int64_t>; - ValueBoundsConstraintSet(MLIRContext *ctx); + ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition); /// Iteratively process all elements on the worklist until an index-typed - /// value or shaped value meets `stopCondition`. Such values are not processed - /// any further. - void processWorklist(StopConditionFn stopCondition); + /// value or shaped value meets `currentStopCondition`. Such values are not + /// processed any further. + void processWorklist(); /// Bound the given column in the underlying constraint set by the given /// expression. @@ -316,6 +317,9 @@ class ValueBoundsConstraintSet { /// Builder for constructing affine expressions. Builder builder; + + /// The current stop condition function. + StopConditionFn stopCondition = nullptr; }; } // namespace mlir diff --git a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp index 37b36f76d4465d..117ee8e8701ad7 100644 --- a/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Affine/Transforms/ReifyValueBounds.cpp @@ -84,7 +84,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound( OpBuilder &b, Location loc, presburger::BoundType type, Value value, int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) { - auto reifyToOperands = [&](Value v, std::optional<int64_t> d) { + auto reifyToOperands = [&](Value v, std::optional<int64_t> d, + ValueBoundsConstraintSet &cstr) { // We are trying to reify a bound for `value` in terms of the owning op's // operands. Construct a stop condition that evaluates to "true" for any SSA // value except for `value`. I.e., the bound will be computed in terms of @@ -100,7 +101,8 @@ FailureOr<OpFoldResult> mlir::affine::reifyShapedValueDimBound( FailureOr<OpFoldResult> mlir::affine::reifyIndexValueBound( OpBuilder &b, Location loc, presburger::BoundType type, Value value, ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) { - auto reifyToOperands = [&](Value v, std::optional<int64_t> d) { + auto reifyToOperands = [&](Value v, std::optional<int64_t> d, + ValueBoundsConstraintSet &cstr) { return v != value; }; return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt, diff --git a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp index 8d9fd1478aa9e6..fad221288f190e 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ReifyValueBounds.cpp @@ -119,7 +119,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound( OpBuilder &b, Location loc, presburger::BoundType type, Value value, int64_t dim, ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) { - auto reifyToOperands = [&](Value v, std::optional<int64_t> d) { + auto reifyToOperands = [&](Value v, std::optional<int64_t> d, + ValueBoundsConstraintSet &cstr) { // We are trying to reify a bound for `value` in terms of the owning op's // operands. Construct a stop condition that evaluates to "true" for any SSA // value expect for `value`. I.e., the bound will be computed in terms of @@ -135,7 +136,8 @@ FailureOr<OpFoldResult> mlir::arith::reifyShapedValueDimBound( FailureOr<OpFoldResult> mlir::arith::reifyIndexValueBound( OpBuilder &b, Location loc, presburger::BoundType type, Value value, ValueBoundsConstraintSet::StopConditionFn stopCondition, bool closedUB) { - auto reifyToOperands = [&](Value v, std::optional<int64_t> d) { + auto reifyToOperands = [&](Value v, std::optional<int64_t> d, + ValueBoundsConstraintSet &cstr) { return v != value; }; return reifyValueBound(b, loc, type, value, /*dim=*/std::nullopt, diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp index b32ea8eebaecb9..c3a08ce86082a8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp @@ -468,7 +468,7 @@ HoistPaddingAnalysis::getHoistedPackedTensorSizes(RewriterBase &rewriter, FailureOr<OpFoldResult> loopUb = affine::reifyIndexValueBound( rewriter, loc, presburger::BoundType::UB, forOp.getUpperBound(), /*stopCondition=*/ - [&](Value v, std::optional<int64_t> d) { + [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) { if (v == forOp.getUpperBound()) return false; // Compute a bound that is independent of any affine op results. diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp index cb36e0cecf0d24..1e13e60068ee7f 100644 --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -58,7 +58,7 @@ struct ForOpInterface ValueDimList boundOperands; LogicalResult status = ValueBoundsConstraintSet::computeBound( bound, boundOperands, BoundType::EQ, yieldedValue, dim, - [&](Value v, std::optional<int64_t> d) { + [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) { // Stop when reaching a block argument of the loop body. if (auto bbArg = llvm::dyn_cast<BlockArgument>(v)) return bbArg.getOwner()->getParentOp() == forOp; diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp index f2f732f3a21d25..ec710bbacc758f 100644 --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -67,8 +67,9 @@ static std::optional<int64_t> getConstantIntValue(OpFoldResult ofr) { return std::nullopt; } -ValueBoundsConstraintSet::ValueBoundsConstraintSet(MLIRContext *ctx) - : builder(ctx) {} +ValueBoundsConstraintSet::ValueBoundsConstraintSet( + MLIRContext *ctx, StopConditionFn stopCondition) + : builder(ctx), stopCondition(stopCondition) {} #ifndef NDEBUG static void assertValidValueDim(Value value, std::optional<int64_t> dim) { @@ -228,7 +229,8 @@ static Operation *getOwnerOfValue(Value value) { return value.getDefiningOp(); } -void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) { +void ValueBoundsConstraintSet::processWorklist() { + LLVM_DEBUG(llvm::dbgs() << "Processing value bounds worklist...\n"); while (!worklist.empty()) { int64_t pos = worklist.front(); worklist.pop(); @@ -249,13 +251,19 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) { // Do not process any further if the stop condition is met. auto maybeDim = dim == kIndexValue ? std::nullopt : std::make_optional(dim); - if (stopCondition(value, maybeDim)) + if (stopCondition(value, maybeDim, *this)) { + LLVM_DEBUG(llvm::dbgs() << "Stop condition met for: " << value + << " (dim: " << maybeDim << ")\n"); continue; + } // Query `ValueBoundsOpInterface` for constraints. New items may be added to // the worklist. auto valueBoundsOp = dyn_cast<ValueBoundsOpInterface>(getOwnerOfValue(value)); + LLVM_DEBUG(llvm::dbgs() + << "Query value bounds for: " << value + << " (owner: " << getOwnerOfValue(value)->getName() << ")\n"); if (valueBoundsOp) { if (dim == kIndexValue) { valueBoundsOp.populateBoundsForIndexValue(value, *this); @@ -264,6 +272,7 @@ void ValueBoundsConstraintSet::processWorklist(StopConditionFn stopCondition) { } continue; } + LLVM_DEBUG(llvm::dbgs() << "--> ValueBoundsOpInterface not implemented\n"); // If the op does not implement `ValueBoundsOpInterface`, check if it // implements the `DestinationStyleOpInterface`. OpResults of such ops are @@ -313,8 +322,6 @@ LogicalResult ValueBoundsConstraintSet::computeBound( bool closedUB) { #ifndef NDEBUG assertValidValueDim(value, dim); - assert(!stopCondition(value, dim) && - "stop condition should not be satisfied for starting point"); #endif // NDEBUG int64_t ubAdjustment = closedUB ? 0 : 1; @@ -324,9 +331,11 @@ LogicalResult ValueBoundsConstraintSet::computeBound( // Process the backward slice of `value` (i.e., reverse use-def chain) until // `stopCondition` is met. ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue)); - ValueBoundsConstraintSet cstr(value.getContext()); + ValueBoundsConstraintSet cstr(value.getContext(), stopCondition); + assert(!stopCondition(value, dim, cstr) && + "stop condition should not be satisfied for starting point"); int64_t pos = cstr.insert(value, dim, /*isSymbol=*/false); - cstr.processWorklist(stopCondition); + cstr.processWorklist(); // Project out all variables (apart from `valueDim`) that do not match the // stop condition. @@ -336,7 +345,7 @@ LogicalResult ValueBoundsConstraintSet::computeBound( return false; auto maybeDim = p.second == kIndexValue ? std::nullopt : std::make_optional(p.second); - return !stopCondition(p.first, maybeDim); + return !stopCondition(p.first, maybeDim, cstr); }); // Compute lower and upper bounds for `valueDim`. @@ -442,7 +451,7 @@ LogicalResult ValueBoundsConstraintSet::computeDependentBound( bool closedUB) { return computeBound( resultMap, mapOperands, type, value, dim, - [&](Value v, std::optional<int64_t> d) { + [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) { return llvm::is_contained(dependencies, std::make_pair(v, d)); }, closedUB); @@ -478,7 +487,9 @@ LogicalResult ValueBoundsConstraintSet::computeIndependentBound( // Reify bounds in terms of any independent values. return computeBound( resultMap, mapOperands, type, value, dim, - [&](Value v, std::optional<int64_t> d) { return isIndependent(v); }, + [&](Value v, std::optional<int64_t> d, ValueBoundsConstraintSet &cstr) { + return isIndependent(v); + }, closedUB); } @@ -500,8 +511,18 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound( presburger::BoundType type, AffineMap map, ValueDimList operands, StopConditionFn stopCondition, bool closedUB) { assert(map.getNumResults() == 1 && "expected affine map with one result"); - ValueBoundsConstraintSet cstr(map.getContext()); - int64_t pos = cstr.insert(/*isSymbol=*/false); + + // Default stop condition if none was specified: Keep adding constraints until + // a bound could be computed. + int64_t pos; + auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim, + ValueBoundsConstraintSet &cstr) { + return cstr.cstr.getConstantBound64(type, pos).has_value(); + }; + + ValueBoundsConstraintSet cstr( + map.getContext(), stopCondition ? stopCondition : defaultStopCondition); + pos = cstr.insert(/*isSymbol=*/false); // Add map and operands to the constraint set. Dimensions are converted to // symbols. All operands are added to the worklist. @@ -517,17 +538,8 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound( map.getResult(0).replaceDimsAndSymbols(dimReplacements, symReplacements)); // Process the backward slice of `operands` (i.e., reverse use-def chain) - // until `stopCondition` is met. - if (stopCondition) { - cstr.processWorklist(stopCondition); - } else { - // No stop condition specified: Keep adding constraints until a bound could - // be computed. - cstr.processWorklist( - /*stopCondition=*/[&](Value v, std::optional<int64_t> dim) { - return cstr.cstr.getConstantBound64(type, pos).has_value(); - }); - } + // until the stop condition is met. + cstr.processWorklist(); // Compute constant bound for `valueDim`. int64_t ubAdjustment = closedUB ? 0 : 1; diff --git a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp index 39671a930f2e21..e99a13cdca2f3c 100644 --- a/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp +++ b/mlir/test/lib/Dialect/Affine/TestReifyValueBounds.cpp @@ -112,14 +112,17 @@ static LogicalResult testReifyValueBounds(func::FuncOp funcOp, // Prepare stop condition. By default, reify in terms of the op's // operands. No stop condition is used when a constant was requested. - std::function<bool(Value, std::optional<int64_t>)> stopCondition = - [&](Value v, std::optional<int64_t> d) { + std::function<bool(Value, std::optional<int64_t>, + ValueBoundsConstraintSet & cstr)> + stopCondition = [&](Value v, std::optional<int64_t> d, + ValueBoundsConstraintSet &cstr) { // Reify in terms of SSA values that are different from `value`. return v != value; }; if (reifyToFuncArgs) { // Reify in terms of function block arguments. - stopCondition = stopCondition = [](Value v, std::optional<int64_t> d) { + stopCondition = stopCondition = [](Value v, std::optional<int64_t> d, + ValueBoundsConstraintSet &cstr) { auto bbArg = dyn_cast<BlockArgument>(v); if (!bbArg) return false; _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits