Author: zhanghb97 Date: 2021-01-13T09:32:32+08:00 New Revision: c0f3ea8a08ca9a9ec473f6e9072ccf30dad5def8
URL: https://github.com/llvm/llvm-project/commit/c0f3ea8a08ca9a9ec473f6e9072ccf30dad5def8 DIFF: https://github.com/llvm/llvm-project/commit/c0f3ea8a08ca9a9ec473f6e9072ccf30dad5def8.diff LOG: [mlir][Python] Add checking process before create an AffineMap from a permutation. An invalid permutation will trigger a C++ assertion when attempting to create an AffineMap from the permutation. This patch adds an `isPermutation` function to check the given permutation before creating the AffineMap. Reviewed By: mehdi_amini Differential Revision: https://reviews.llvm.org/D94492 Added: Modified: mlir/lib/Bindings/Python/IRModules.cpp mlir/test/Bindings/Python/ir_affine_map.py Removed: ################################################################################ diff --git a/mlir/lib/Bindings/Python/IRModules.cpp b/mlir/lib/Bindings/Python/IRModules.cpp index 218099bedc6f..493ea5c1e47a 100644 --- a/mlir/lib/Bindings/Python/IRModules.cpp +++ b/mlir/lib/Bindings/Python/IRModules.cpp @@ -153,6 +153,21 @@ static MlirStringRef toMlirStringRef(const std::string &s) { return mlirStringRefCreate(s.data(), s.size()); } +template <typename PermutationTy> +static bool isPermutation(std::vector<PermutationTy> permutation) { + llvm::SmallVector<bool, 8> seen(permutation.size(), false); + for (auto val : permutation) { + if (val < permutation.size()) { + if (seen[val]) + return false; + seen[val] = true; + continue; + } + return false; + } + return true; +} + //------------------------------------------------------------------------------ // Collections. //------------------------------------------------------------------------------ @@ -3914,6 +3929,9 @@ void mlir::python::populateIRSubmodule(py::module &m) { "get_permutation", [](std::vector<unsigned> permutation, DefaultingPyMlirContext context) { + if (!isPermutation(permutation)) + throw py::cast_error("Invalid permutation when attempting to " + "create an AffineMap"); MlirAffineMap affineMap = mlirAffineMapPermutationGet( context->get(), permutation.size(), permutation.data()); return PyAffineMap(context->getRef(), affineMap); diff --git a/mlir/test/Bindings/Python/ir_affine_map.py b/mlir/test/Bindings/Python/ir_affine_map.py index fe37eb971555..0c99722dbf04 100644 --- a/mlir/test/Bindings/Python/ir_affine_map.py +++ b/mlir/test/Bindings/Python/ir_affine_map.py @@ -73,6 +73,12 @@ def testAffineMapGet(): # CHECK: Invalid expression (None?) when attempting to create an AffineMap print(e) + try: + AffineMap.get_permutation([1, 0, 1]) + except RuntimeError as e: + # CHECK: Invalid permutation when attempting to create an AffineMap + print(e) + try: map3.get_submap([42]) except ValueError as e: _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits