https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/85895
>From 8057ddd7f467891b5fec9c1f7426fd06012453fb Mon Sep 17 00:00:00 2001 From: Matthias Springer <spring...@google.com> Date: Fri, 22 Mar 2024 02:03:32 +0000 Subject: [PATCH] [mlir][SCF] `ValueBoundsConstraintSet`: Support preliminary support for branches This commit adds support for `scf.if` to `ValueBoundsConstraintSet`. Example: ``` %0 = scf.if ... -> index { scf.yield %a : index } else { scf.yield %b : index } ``` The following constraints hold for %0: * %0 >= min(%a, %b) * %0 <= max(%a, %b) Such constraints cannot be added to the constraint set; min/max is not supported by `IntegerRelation`. However, if we know which one of %a and %b is larger, we can add constraints for %0. E.g., if %a <= %b: * %0 >= %a * %0 <= %b This commit required a few minor changes to the `ValueBoundsConstraintSet` infrastructure, so that values can be compared while we are still in the process of traversing the IR/adding constraints. --- .../mlir/Interfaces/ValueBoundsOpInterface.h | 22 ++++ .../SCF/IR/ValueBoundsOpInterfaceImpl.cpp | 63 ++++++++++ .../lib/Interfaces/ValueBoundsOpInterface.cpp | 62 +++++++++ .../SCF/value-bounds-op-interface-impl.mlir | 119 +++++++++++++++++- 4 files changed, 264 insertions(+), 2 deletions(-) diff --git a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h index 77e1af070c3fe9..ef074bcfe0be87 100644 --- a/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h +++ b/mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h @@ -199,6 +199,28 @@ class ValueBoundsConstraintSet { std::optional<int64_t> dim1 = std::nullopt, std::optional<int64_t> dim2 = std::nullopt); + /// Traverse the IR starting from the given value/dim and add populate + /// constraints as long as the currently set stop condition holds. Also + /// processes all values/dims that are already on the worklist. + void populateConstraints(Value value, std::optional<int64_t> dim); + + /// Comparison operator for `ValueBoundsConstraintSet::compare`. + enum ComparisonOperator { LT, LE, EQ, GT, GE }; + + /// Try to prove that, based on the current state of this constraint set + /// (i.e., without analyzing additional IR or adding new constraints), it can + /// be deduced that the first given value/dim is LE/LT/EQ/GT/GE than the + /// second given value/dim. + /// + /// Return "true" if the specified relation between the two values/dims was + /// proven to hold. Return "false" if the specified relation could not be + /// proven. This could be because the specified relation does in fact not hold + /// or because there is not enough information in the constraint set. In other + /// words, if we do not know for sure, this function returns "false". + bool compare(Value value1, std::optional<int64_t> dim1, + ComparisonOperator cmp, Value value2, + std::optional<int64_t> dim2); + /// Compute whether the given values/dimensions are equal. Return "failure" if /// equality could not be determined. /// diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp index 1e13e60068ee7f..72a25d0f0b30b0 100644 --- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp @@ -111,6 +111,68 @@ struct ForOpInterface } }; +struct IfOpInterface + : public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> { + + void populateBoundsForIndexValue(Operation *op, Value value, + ValueBoundsConstraintSet &cstr) const { + auto ifOp = cast<IfOp>(op); + unsigned int resultNum = cast<OpResult>(value).getResultNumber(); + Value thenValue = ifOp.thenYield().getResults()[resultNum]; + Value elseValue = ifOp.elseYield().getResults()[resultNum]; + + // Populate constraints for the yielded value (and all values on the + // backward slice, as long as the current stop condition is not satisfied). + cstr.populateConstraints(thenValue, /*valueDim=*/std::nullopt); + cstr.populateConstraints(elseValue, /*valueDim=*/std::nullopt); + + // Compare yielded values. + // If thenValue <= elseValue: + // * result <= elseValue + // * result >= thenValue + if (cstr.compare(thenValue, /*dim1=*/std::nullopt, + ValueBoundsConstraintSet::ComparisonOperator::LE, + elseValue, /*dim2=*/std::nullopt)) { + cstr.bound(value) >= thenValue; + cstr.bound(value) <= elseValue; + } + // If elseValue <= thenValue: + // * result <= thenValue + // * result >= elseValue + if (cstr.compare(elseValue, /*dim1=*/std::nullopt, + ValueBoundsConstraintSet::ComparisonOperator::LE, + thenValue, /*dim2=*/std::nullopt)) { + cstr.bound(value) >= elseValue; + cstr.bound(value) <= thenValue; + } + } + + void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim, + ValueBoundsConstraintSet &cstr) const { + // See `populateBoundsForIndexValue` for documentation. + auto ifOp = cast<IfOp>(op); + unsigned int resultNum = cast<OpResult>(value).getResultNumber(); + Value thenValue = ifOp.thenYield().getResults()[resultNum]; + Value elseValue = ifOp.elseYield().getResults()[resultNum]; + + cstr.populateConstraints(thenValue, dim); + cstr.populateConstraints(elseValue, dim); + + if (cstr.compare(thenValue, dim, + ValueBoundsConstraintSet::ComparisonOperator::LE, + elseValue, dim)) { + cstr.bound(value)[dim] >= cstr.getExpr(thenValue, dim); + cstr.bound(value)[dim] <= cstr.getExpr(elseValue, dim); + } + if (cstr.compare(elseValue, dim, + ValueBoundsConstraintSet::ComparisonOperator::LE, + thenValue, dim)) { + cstr.bound(value)[dim] >= cstr.getExpr(elseValue, dim); + cstr.bound(value)[dim] <= cstr.getExpr(thenValue, dim); + } + } +}; + } // namespace } // namespace scf } // namespace mlir @@ -119,5 +181,6 @@ void mlir::scf::registerValueBoundsOpInterfaceExternalModels( DialectRegistry ®istry) { registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) { scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx); + scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx); }); } diff --git a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp index ec710bbacc758f..c88532d2325f0c 100644 --- a/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp +++ b/mlir/lib/Interfaces/ValueBoundsOpInterface.cpp @@ -575,6 +575,68 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2, {{value1, dim1}, {value2, dim2}}); } +void ValueBoundsConstraintSet::populateConstraints(Value value, + std::optional<int64_t> dim) { + // `getExpr` pushes the value/dim onto the worklist (unless it was already + // analyzed). + (void)getExpr(value, dim); + // Process all values/dims on the worklist. This may traverse and analyze + // additional IR, depending the current stop function. + processWorklist(); +} + +bool ValueBoundsConstraintSet::compare(Value value1, + std::optional<int64_t> dim1, + ComparisonOperator cmp, Value value2, + std::optional<int64_t> dim2) { + // This function returns "true" if value1/dim1 CMP value2/dim2 is proved to + // hold. + // + // Example for ComparisonOperator::LE and index-typed values: We would like to + // prove that value1 <= value2. Proof by contradiction: add the inverse + // relation (value1 > value2) to the constraint set and check if the resulting + // constraint set is "empty" (i.e. has no solution). In that case, + // value1 > value2 must be incorrect and we can deduce that value1 <= value2 + // holds. + + // We cannot use prove anything if the constraint set is already empty. + if (cstr.isEmpty()) { + LLVM_DEBUG( + llvm::dbgs() + << "cannot compare value/dims: constraint system is already empty"); + return false; + } + + // EQ can be expressed as LE and GE. + if (cmp == EQ) + return compare(value1, dim1, ComparisonOperator::LE, value2, dim2) && + compare(value1, dim1, ComparisonOperator::GE, value2, dim2); + + // Construct inequality. For the above example: value1 > value2. + // `IntegerRelation` inequalities are expressed in the "flattened" form and + // with ">= 0". I.e., value1 - value2 - 1 >= 0. + SmallVector<int64_t> eq(cstr.getNumDimAndSymbolVars() + 1, 0); + if (cmp == LT || cmp == LE) { + eq[getPos(value1, dim1)]++; + eq[getPos(value2, dim2)]--; + } else if (cmp == GT || cmp == GE) { + eq[getPos(value1, dim1)]--; + eq[getPos(value2, dim2)]++; + } else { + llvm_unreachable("unsupported comparison operator"); + } + if (cmp == LE || cmp == GE) + eq[cstr.getNumDimAndSymbolVars()] -= 1; + + // Add inequality to the constraint set and check if it made the constraint + // set empty. + int64_t ineqPos = cstr.getNumInequalities(); + cstr.addInequality(eq); + bool isEmpty = cstr.isEmpty(); + cstr.removeInequality(ineqPos); + return isEmpty; +} + FailureOr<bool> ValueBoundsConstraintSet::areEqual(Value value1, Value value2, std::optional<int64_t> dim1, diff --git a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir index e4d71415924994..0ea06737886d41 100644 --- a/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir +++ b/mlir/test/Dialect/SCF/value-bounds-op-interface-impl.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -test-affine-reify-value-bounds -verify-diagnostics \ -// RUN: -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-affine-reify-value-bounds="reify-to-func-args" \ +// RUN: -verify-diagnostics -split-input-file | FileCheck %s // CHECK-LABEL: func @scf_for( // CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index, %[[c:.*]]: index @@ -104,3 +104,118 @@ func.func @scf_for_swapping_yield(%t1: tensor<?xf32>, %t2: tensor<?xf32>, %a: in "test.some_use"(%reify1) : (index) -> () return } + +// ----- + +// CHECK-LABEL: func @scf_if_constant( +func.func @scf_if_constant(%c : i1) { + // CHECK: arith.constant 4 : index + // CHECK: arith.constant 9 : index + %c4 = arith.constant 4 : index + %c9 = arith.constant 9 : index + %r = scf.if %c -> index { + scf.yield %c4 : index + } else { + scf.yield %c9 : index + } + + // CHECK: %[[c4:.*]] = arith.constant 4 : index + // CHECK: %[[c10:.*]] = arith.constant 10 : index + %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index) + %reify2 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index) + // CHECK: "test.some_use"(%[[c4]], %[[c10]]) + "test.some_use"(%reify1, %reify2) : (index, index) -> () + return +} + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK: #[[$map1:.*]] = affine_map<()[s0, s1] -> (s0 + s1 + 5)> +// CHECK-LABEL: func @scf_if_dynamic( +// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index, %{{.*}}: i1) +func.func @scf_if_dynamic(%a: index, %b: index, %c : i1) { + %c4 = arith.constant 4 : index + %r = scf.if %c -> index { + %add1 = arith.addi %a, %b : index + scf.yield %add1 : index + } else { + %add2 = arith.addi %b, %c4 : index + %add3 = arith.addi %add2, %a : index + scf.yield %add3 : index + } + + // CHECK: %[[lb:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]] + // CHECK: %[[ub:.*]] = affine.apply #[[$map1]]()[%[[a]], %[[b]]] + %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index) + %reify2 = "test.reify_bound"(%r) {type = "UB"} : (index) -> (index) + // CHECK: "test.some_use"(%[[lb]], %[[ub]]) + "test.some_use"(%reify1, %reify2) : (index, index) -> () + return +} + +// ----- + +func.func @scf_if_no_affine_bound(%a: index, %b: index, %c : i1) { + %r = scf.if %c -> index { + scf.yield %a : index + } else { + scf.yield %b : index + } + // The reified bound would be min(%a, %b). min/max expressions are not + // supported in reified bounds. + // expected-error @below{{could not reify bound}} + %reify1 = "test.reify_bound"(%r) {type = "LB"} : (index) -> (index) + "test.some_use"(%reify1) : (index) -> () + return +} + +// ----- + +// CHECK-LABEL: func @scf_if_tensor_dim( +func.func @scf_if_tensor_dim(%c : i1) { + // CHECK: arith.constant 4 : index + // CHECK: arith.constant 9 : index + %c4 = arith.constant 4 : index + %c9 = arith.constant 9 : index + %t1 = tensor.empty(%c4) : tensor<?xf32> + %t2 = tensor.empty(%c9) : tensor<?xf32> + %r = scf.if %c -> tensor<?xf32> { + scf.yield %t1 : tensor<?xf32> + } else { + scf.yield %t2 : tensor<?xf32> + } + + // CHECK: %[[c4:.*]] = arith.constant 4 : index + // CHECK: %[[c10:.*]] = arith.constant 10 : index + %reify1 = "test.reify_bound"(%r) {type = "LB", dim = 0} + : (tensor<?xf32>) -> (index) + %reify2 = "test.reify_bound"(%r) {type = "UB", dim = 0} + : (tensor<?xf32>) -> (index) + // CHECK: "test.some_use"(%[[c4]], %[[c10]]) + "test.some_use"(%reify1, %reify2) : (index, index) -> () + return +} + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-LABEL: func @scf_if_eq( +// CHECK-SAME: %[[a:.*]]: index, %[[b:.*]]: index, %{{.*}}: i1) +func.func @scf_if_eq(%a: index, %b: index, %c : i1) { + %c0 = arith.constant 0 : index + %r = scf.if %c -> index { + %add1 = arith.addi %a, %b : index + scf.yield %add1 : index + } else { + %add2 = arith.addi %b, %c0 : index + %add3 = arith.addi %add2, %a : index + scf.yield %add3 : index + } + + // CHECK: %[[eq:.*]] = affine.apply #[[$map]]()[%[[a]], %[[b]]] + %reify1 = "test.reify_bound"(%r) {type = "EQ"} : (index) -> (index) + // CHECK: "test.some_use"(%[[eq]]) + "test.some_use"(%reify1) : (index) -> () + return +} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits