https://github.com/adk9 updated https://github.com/llvm/llvm-project/pull/75960
>From a43ef7289cd7f5353fc4b365566011b93879e8f6 Mon Sep 17 00:00:00 2001 From: Abhishek Kulkarni <abkulka...@microsoft.com> Date: Tue, 19 Dec 2023 10:50:26 -0800 Subject: [PATCH] Fix generation of python bindings for async dialect --- .../mlir/Dialect/Async/IR/CMakeLists.txt | 1 + mlir/python/CMakeLists.txt | 9 ++-- .../mlir/dialects/async_dialect/__init__.py | 2 +- mlir/test/python/dialects/async_dialect.py | 16 ++++++- .../mlir/python/BUILD.bazel | 47 +++++++++++++++++++ 5 files changed, 68 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Dialect/Async/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Async/IR/CMakeLists.txt index ebbf2df760faa4..2525eee2a34ec9 100644 --- a/mlir/include/mlir/Dialect/Async/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Async/IR/CMakeLists.txt @@ -1,2 +1,3 @@ +set(LLVM_TARGET_DEFINITIONS AsyncOps.td) add_mlir_dialect(AsyncOps async) add_mlir_doc(AsyncOps AsyncDialect Dialects/ -gen-dialect-doc) diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 41d91cf6778338..550465f6e37467 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -72,7 +72,7 @@ declare_mlir_dialect_python_bindings( ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" TD_FILE dialects/AsyncOps.td SOURCES_GLOB dialects/async_dialect/*.py - DIALECT_NAME async_dialect) + DIALECT_NAME async) declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects @@ -522,7 +522,7 @@ declare_mlir_python_extension(MLIRPythonExtension.Dialects.Transform.Pybind declare_mlir_python_extension(MLIRPythonExtension.AsyncDialectPasses MODULE_NAME _mlirAsyncPasses - ADD_TO_PARENT MLIRPythonSources.Dialects.async_dialect + ADD_TO_PARENT MLIRPythonSources.Dialects.async ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES AsyncPasses.cpp @@ -664,11 +664,10 @@ if(NOT LLVM_ENABLE_IDE) COMPONENT mlir-python-sources ) endif() - -################################################################################ +# ############################################################################### # The fully assembled package of modules. # This must come last. -################################################################################ +# ############################################################################### add_mlir_python_modules(MLIRPythonModules ROOT_PREFIX "${MLIR_BINARY_DIR}/python_packages/mlir_core/mlir" diff --git a/mlir/python/mlir/dialects/async_dialect/__init__.py b/mlir/python/mlir/dialects/async_dialect/__init__.py index dcf9d6cb2638f6..6a5ecfc20956cf 100644 --- a/mlir/python/mlir/dialects/async_dialect/__init__.py +++ b/mlir/python/mlir/dialects/async_dialect/__init__.py @@ -2,4 +2,4 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from .._async_dialect_ops_gen import * +from .._async_ops_gen import * diff --git a/mlir/test/python/dialects/async_dialect.py b/mlir/test/python/dialects/async_dialect.py index f6181cc76118ed..13e3c42e57c21e 100644 --- a/mlir/test/python/dialects/async_dialect.py +++ b/mlir/test/python/dialects/async_dialect.py @@ -1,7 +1,8 @@ # RUN: %PYTHON %s | FileCheck %s from mlir.ir import * -import mlir.dialects.async_dialect +from mlir.dialects import arith +import mlir.dialects.async_dialect as async_dialect import mlir.dialects.async_dialect.passes from mlir.passmanager import * @@ -11,6 +12,19 @@ def run(f): f() +# CHECK-LABEL: TEST: testCreateGroupOp +@run +def testCreateGroupOp(): + with Context() as ctx, Location.unknown(): + module = Module.create() + with InsertionPoint(module.body): + i32 = IntegerType.get_signless(32) + group_size = arith.ConstantOp(i32, 4) + async_dialect.create_group(group_size) + # CHECK: %0 = "arith.constant"() <{value = 4 : i32}> : () -> i32 + # CHECK: %1 = "async.create_group"(%0) : (i32) -> !async.group + print(module) + def testAsyncPass(): with Context() as context: PassManager.parse("any(async-to-async-runtime)") diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel index 049098b158f294..18e84ac7b68103 100644 --- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel @@ -331,6 +331,53 @@ filegroup( ], ) +##---------------------------------------------------------------------------## +# Async dialect. +##---------------------------------------------------------------------------## + +gentbl_filegroup( + name = "AsyncOpsPyGen", + tbl_outs = [ + ( + [ + "-gen-python-enum-bindings", + "-bind-dialect=async", + ], + "mlir/dialects/_async_enum_gen.py", + ), + ( + [ + "-gen-python-op-bindings", + "-bind-dialect=async", + ], + "mlir/dialects/_async_ops_gen.py", + ), + ], + tblgen = "//mlir:mlir-tblgen", + td_file = "mlir/dialects/AsyncOps.td", + deps = [ + "//mlir:AsyncOpsTdFiles", + "//mlir:OpBaseTdFiles", + ], +) + +filegroup( + name = "AsyncOpsPyFiles", + srcs = [ + ":AsyncOpsPyGen", + ], +) + +filegroup( + name = "AsyncOpsPackagePyFiles", + srcs = glob(["mlir/dialects/async_dialect/*.py"]), +) + +filegroup( + name = "AsyncOpsPackagePassesPyFiles", + srcs = glob(["mlir/dialects/async_dialect/passes/*.py"]), +) + ##---------------------------------------------------------------------------## # Arith dialect. ##---------------------------------------------------------------------------## _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits