Author: River Riddle Date: 2021-01-14T11:35:49-08:00 New Revision: 00a61b327dd8a7071ce0baadd16ea4c7b7e31e73
URL: https://github.com/llvm/llvm-project/commit/00a61b327dd8a7071ce0baadd16ea4c7b7e31e73 DIFF: https://github.com/llvm/llvm-project/commit/00a61b327dd8a7071ce0baadd16ea4c7b7e31e73.diff LOG: [mlir][ODS] Add new RangedTypesMatchWith operation predicate This is a variant of TypesMatchWith that provides support for variadic arguments. This is necessary because ranges generally can't use the default operator== comparators for checking equality. Differential Revision: https://reviews.llvm.org/D94574 Added: Modified: mlir/include/mlir/IR/OpBase.td mlir/test/lib/Dialect/Test/TestOps.td mlir/test/mlir-tblgen/op-format.mlir mlir/tools/mlir-tblgen/OpFormatGen.cpp Removed: ################################################################################ diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 73ddbc1d56eb..3b55e51d8178 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2191,16 +2191,28 @@ class AllTypesMatch<list<string> names> : AllMatchSameOperatorTrait<names, "$_self.getType()", "type">; // A type constraint that denotes `transform(lhs.getType()) == rhs.getType()`. +// An optional comparator function may be provided that changes the above form +// into: `comparator(transform(lhs.getType()), rhs.getType())`. class TypesMatchWith<string summary, string lhsArg, string rhsArg, - string transform> : - PredOpTrait<summary, CPred< - !subst("$_self", "$" # lhsArg # ".getType()", transform) - # " == $" # rhsArg # ".getType()">> { + string transform, string comparator = "std::equal_to<>()"> + : PredOpTrait<summary, CPred< + comparator # "(" # + !subst("$_self", "$" # lhsArg # ".getType()", transform) # + ", $" # rhsArg # ".getType())">> { string lhs = lhsArg; string rhs = rhsArg; string transformer = transform; } +// Special variant of `TypesMatchWith` that provides a comparator suitable for +// ranged arguments. +class RangedTypesMatchWith<string summary, string lhsArg, string rhsArg, + string transform> + : TypesMatchWith<summary, lhsArg, rhsArg, transform, + "[](auto &&lhs, auto &&rhs) { " + "return std::equal(lhs.begin(), lhs.end(), rhs.begin(), rhs.end());" + " }">; + // Type Constraint operand `idx`'s Element type is `type`. class TCopVTEtIs<int idx, Type type> : And<[ CPred<"$_op.getNumOperands() > " # idx>, diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 1fc419cc375f..d1cbe77ac21b 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1733,6 +1733,15 @@ def FormatTypesMatchVarOp : TEST_Op<"format_types_match_var", [ let assemblyFormat = "attr-dict $value `:` type($value)"; } +def FormatTypesMatchVariadicOp : TEST_Op<"format_types_match_variadic", [ + RangedTypesMatchWith<"result type matches operand", "value", "result", + "llvm::make_range($_self.begin(), $_self.end())"> + ]> { + let arguments = (ins Variadic<AnyType>:$value); + let results = (outs Variadic<AnyType>:$result); + let assemblyFormat = "attr-dict $value `:` type($value)"; +} + def FormatTypesMatchAttrOp : TEST_Op<"format_types_match_attr", [ TypesMatchWith<"result type matches constant", "value", "result", "$_self"> ]> { diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir index 334313debda1..4eb64772aee2 100644 --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -308,5 +308,8 @@ test.format_infer_variadic_type_from_non_variadic %i64, %i64 : i64 // CHECK: test.format_types_match_var %[[I64]] : i64 %ignored_res3 = test.format_types_match_var %i64 : i64 +// CHECK: test.format_types_match_variadic %[[I64]], %[[I64]], %[[I64]] : i64, i64, i64 +%ignored_res4:3 = test.format_types_match_variadic %i64, %i64, %i64 : i64, i64, i64 + // CHECK: test.format_types_match_attr 1 : i64 -%ignored_res4 = test.format_types_match_attr 1 : i64 +%ignored_res5 = test.format_types_match_attr 1 : i64 diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index 749ef1613c14..bba796f9b492 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -1287,10 +1287,16 @@ void OperationFormat::genParserTypeResolution(Operator &op, if (Optional<int> val = resolver.getBuilderIdx()) { body << "odsBuildableType" << *val; } else if (const NamedTypeConstraint *var = resolver.getVariable()) { - if (Optional<StringRef> tform = resolver.getVarTransformer()) - body << tgfmt(*tform, &FmtContext().withSelf(var->name + "Types[0]")); - else + if (Optional<StringRef> tform = resolver.getVarTransformer()) { + FmtContext fmtContext; + if (var->isVariadic()) + fmtContext.withSelf(var->name + "Types"); + else + fmtContext.withSelf(var->name + "Types[0]"); + body << tgfmt(*tform, &fmtContext); + } else { body << var->name << "Types"; + } } else if (const NamedAttribute *attr = resolver.getAttribute()) { if (Optional<StringRef> tform = resolver.getVarTransformer()) body << tgfmt(*tform, _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits