https://github.com/makslevental updated https://github.com/llvm/llvm-project/pull/173939
>From dc24520fc192e2774509f15093d815910aeec1f4 Mon Sep 17 00:00:00 2001 From: makslevental <[email protected]> Date: Mon, 29 Dec 2025 16:57:05 -0800 Subject: [PATCH] [mlir][Python] move IRTypes and IRAttributes to public headers --- mlir/include/mlir/Bindings/Python/IRCore.h | 17 +- mlir/include/mlir/Bindings/Python/IRTypes.h | 465 ++++- mlir/lib/Bindings/Python/IRTypes.cpp | 1573 +++++++---------- mlir/python/CMakeLists.txt | 4 +- .../python/lib/PythonTestModuleNanobind.cpp | 129 +- 5 files changed, 1138 insertions(+), 1050 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h b/mlir/include/mlir/Bindings/Python/IRCore.h index 0f402b4ce15ff..340b16bcdf558 100644 --- a/mlir/include/mlir/Bindings/Python/IRCore.h +++ b/mlir/include/mlir/Bindings/Python/IRCore.h @@ -979,7 +979,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteType : public BaseTy { PyGlobals::get().registerTypeCaster( DerivedTy::getTypeIdFunction(), nanobind::cast<nanobind::callable>(nanobind::cpp_function( - [](PyType pyType) -> DerivedTy { return pyType; }))); + [](PyType pyType) -> DerivedTy { return pyType; })), + /*replace*/ true); } DerivedTy::bindDerived(cls); @@ -1123,7 +1124,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteAttribute : public BaseTy { nanobind::cast<nanobind::callable>( nanobind::cpp_function([](PyAttribute pyAttribute) -> DerivedTy { return pyAttribute; - }))); + })), + /*replace*/ true); } DerivedTy::bindDerived(cls); @@ -1511,6 +1513,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public PyValue { // and redefine bindDerived. using ClassTy = nanobind::class_<DerivedTy, PyValue>; using IsAFunctionTy = bool (*)(MlirValue); + using GetTypeIDFunctionTy = MlirTypeID (*)(); + static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr; PyConcreteValue() = default; PyConcreteValue(PyOperationRef operationRef, MlirValue value) @@ -1553,6 +1557,15 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public PyValue { [](DerivedTy &self) -> nanobind::typed<nanobind::object, DerivedTy> { return self.maybeDownCast(); }); + + if (DerivedTy::getTypeIdFunction) { + PyGlobals::get().registerValueCaster( + DerivedTy::getTypeIdFunction(), + nanobind::cast<nanobind::callable>(nanobind::cpp_function( + [](PyValue pyValue) -> DerivedTy { return pyValue; })), + /*replace*/ true); + } + DerivedTy::bindDerived(cls); } diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h index 87e0e10764bd8..db478e8d33f37 100644 --- a/mlir/include/mlir/Bindings/Python/IRTypes.h +++ b/mlir/include/mlir/Bindings/Python/IRTypes.h @@ -9,13 +9,14 @@ #ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H #define MLIR_BINDINGS_PYTHON_IRTYPES_H +#include "mlir-c/BuiltinTypes.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" namespace mlir { namespace python { namespace MLIR_BINDINGS_PYTHON_DOMAIN { /// Shaped Type Interface - ShapedType -class MLIR_PYTHON_API_EXPORTED PyShapedType +class MLIR_PYTHON_API_EXPORTED MLIR_PYTHON_API_EXPORTED PyShapedType : public PyConcreteType<PyShapedType> { public: static const IsAFunctionTy isaFunction; @@ -27,6 +28,468 @@ class MLIR_PYTHON_API_EXPORTED PyShapedType private: void requireHasRank(); }; + +/// Checks whether the given type is an integer or float type. +inline int mlirTypeIsAIntegerOrFloat(MlirType type) { + return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) || + mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); +} + +class MLIR_PYTHON_API_EXPORTED PyIntegerType + : public PyConcreteType<PyIntegerType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirIntegerTypeGetTypeID; + static constexpr const char *pyClassName = "IntegerType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Index Type subclass - IndexType. +class MLIR_PYTHON_API_EXPORTED PyIndexType + : public PyConcreteType<PyIndexType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirIndexTypeGetTypeID; + static constexpr const char *pyClassName = "IndexType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +class MLIR_PYTHON_API_EXPORTED PyFloatType + : public PyConcreteType<PyFloatType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat; + static constexpr const char *pyClassName = "FloatType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float4E2M1FNType. +class MLIR_PYTHON_API_EXPORTED PyFloat4E2M1FNType + : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat4E2M1FNTypeGetTypeID; + static constexpr const char *pyClassName = "Float4E2M1FNType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float6E2M3FNType. +class MLIR_PYTHON_API_EXPORTED PyFloat6E2M3FNType + : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat6E2M3FNTypeGetTypeID; + static constexpr const char *pyClassName = "Float6E2M3FNType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float6E3M2FNType. +class MLIR_PYTHON_API_EXPORTED PyFloat6E3M2FNType + : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat6E3M2FNTypeGetTypeID; + static constexpr const char *pyClassName = "Float6E3M2FNType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8E4M3FNType. +class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3FNType + : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3FNTypeGetTypeID; + static constexpr const char *pyClassName = "Float8E4M3FNType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8E5M2Type. +class MLIR_PYTHON_API_EXPORTED PyFloat8E5M2Type + : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E5M2TypeGetTypeID; + static constexpr const char *pyClassName = "Float8E5M2Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8E4M3Type. +class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3Type + : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3TypeGetTypeID; + static constexpr const char *pyClassName = "Float8E4M3Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8E4M3FNUZ. +class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3FNUZType + : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3FNUZTypeGetTypeID; + static constexpr const char *pyClassName = "Float8E4M3FNUZType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8E4M3B11FNUZ. +class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3B11FNUZType + : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E4M3B11FNUZTypeGetTypeID; + static constexpr const char *pyClassName = "Float8E4M3B11FNUZType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8E5M2FNUZ. +class MLIR_PYTHON_API_EXPORTED PyFloat8E5M2FNUZType + : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E5M2FNUZTypeGetTypeID; + static constexpr const char *pyClassName = "Float8E5M2FNUZType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8E3M4Type. +class MLIR_PYTHON_API_EXPORTED PyFloat8E3M4Type + : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E3M4TypeGetTypeID; + static constexpr const char *pyClassName = "Float8E3M4Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - Float8E8M0FNUType. +class MLIR_PYTHON_API_EXPORTED PyFloat8E8M0FNUType + : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat8E8M0FNUTypeGetTypeID; + static constexpr const char *pyClassName = "Float8E8M0FNUType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - BF16Type. +class MLIR_PYTHON_API_EXPORTED PyBF16Type + : public PyConcreteType<PyBF16Type, PyFloatType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirBFloat16TypeGetTypeID; + static constexpr const char *pyClassName = "BF16Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - F16Type. +class MLIR_PYTHON_API_EXPORTED PyF16Type + : public PyConcreteType<PyF16Type, PyFloatType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat16TypeGetTypeID; + static constexpr const char *pyClassName = "F16Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - TF32Type. +class MLIR_PYTHON_API_EXPORTED PyTF32Type + : public PyConcreteType<PyTF32Type, PyFloatType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloatTF32TypeGetTypeID; + static constexpr const char *pyClassName = "FloatTF32Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - F32Type. +class MLIR_PYTHON_API_EXPORTED PyF32Type + : public PyConcreteType<PyF32Type, PyFloatType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat32TypeGetTypeID; + static constexpr const char *pyClassName = "F32Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Floating Point Type subclass - F64Type. +class MLIR_PYTHON_API_EXPORTED PyF64Type + : public PyConcreteType<PyF64Type, PyFloatType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFloat64TypeGetTypeID; + static constexpr const char *pyClassName = "F64Type"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// None Type subclass - NoneType. +class MLIR_PYTHON_API_EXPORTED PyNoneType : public PyConcreteType<PyNoneType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirNoneTypeGetTypeID; + static constexpr const char *pyClassName = "NoneType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Complex Type subclass - ComplexType. +class MLIR_PYTHON_API_EXPORTED PyComplexType + : public PyConcreteType<PyComplexType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirComplexTypeGetTypeID; + static constexpr const char *pyClassName = "ComplexType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Vector Type subclass - VectorType. +class MLIR_PYTHON_API_EXPORTED PyVectorType + : public PyConcreteType<PyVectorType, PyShapedType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirVectorTypeGetTypeID; + static constexpr const char *pyClassName = "VectorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); + +private: + static PyVectorType + getChecked(std::vector<int64_t> shape, PyType &elementType, + std::optional<nanobind::list> scalable, + std::optional<std::vector<int64_t>> scalableDims, + DefaultingPyLocation loc) { + if (scalable && scalableDims) { + throw nanobind::value_error("'scalable' and 'scalable_dims' kwargs " + "are mutually exclusive."); + } + + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirType type; + if (scalable) { + if (scalable->size() != shape.size()) + throw nanobind::value_error("Expected len(scalable) == len(shape)."); + + SmallVector<bool> scalableDimFlags = llvm::to_vector( + llvm::map_range(*scalable, [](const nanobind::handle &h) { + return nanobind::cast<bool>(h); + })); + type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), + scalableDimFlags.data(), + elementType); + } else if (scalableDims) { + SmallVector<bool> scalableDimFlags(shape.size(), false); + for (int64_t dim : *scalableDims) { + if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0) + throw nanobind::value_error( + "Scalable dimension index out of bounds."); + scalableDimFlags[dim] = true; + } + type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), + scalableDimFlags.data(), + elementType); + } else { + type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), + elementType); + } + if (mlirTypeIsNull(type)) + throw MLIRError("Invalid type", errors.take()); + return PyVectorType(elementType.getContext(), type); + } + + static PyVectorType get(std::vector<int64_t> shape, PyType &elementType, + std::optional<nanobind::list> scalable, + std::optional<std::vector<int64_t>> scalableDims, + DefaultingPyMlirContext context) { + if (scalable && scalableDims) { + throw nanobind::value_error("'scalable' and 'scalable_dims' kwargs " + "are mutually exclusive."); + } + + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirType type; + if (scalable) { + if (scalable->size() != shape.size()) + throw nanobind::value_error("Expected len(scalable) == len(shape)."); + + SmallVector<bool> scalableDimFlags = llvm::to_vector( + llvm::map_range(*scalable, [](const nanobind::handle &h) { + return nanobind::cast<bool>(h); + })); + type = mlirVectorTypeGetScalable(shape.size(), shape.data(), + scalableDimFlags.data(), elementType); + } else if (scalableDims) { + SmallVector<bool> scalableDimFlags(shape.size(), false); + for (int64_t dim : *scalableDims) { + if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0) + throw nanobind::value_error( + "Scalable dimension index out of bounds."); + scalableDimFlags[dim] = true; + } + type = mlirVectorTypeGetScalable(shape.size(), shape.data(), + scalableDimFlags.data(), elementType); + } else { + type = mlirVectorTypeGet(shape.size(), shape.data(), elementType); + } + if (mlirTypeIsNull(type)) + throw MLIRError("Invalid type", errors.take()); + return PyVectorType(elementType.getContext(), type); + } +}; + +/// Ranked Tensor Type subclass - RankedTensorType. +class MLIR_PYTHON_API_EXPORTED PyRankedTensorType + : public PyConcreteType<PyRankedTensorType, PyShapedType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirRankedTensorTypeGetTypeID; + static constexpr const char *pyClassName = "RankedTensorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Unranked Tensor Type subclass - UnrankedTensorType. +class MLIR_PYTHON_API_EXPORTED PyUnrankedTensorType + : public PyConcreteType<PyUnrankedTensorType, PyShapedType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUnrankedTensorTypeGetTypeID; + static constexpr const char *pyClassName = "UnrankedTensorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Ranked MemRef Type subclass - MemRefType. +class MLIR_PYTHON_API_EXPORTED PyMemRefType + : public PyConcreteType<PyMemRefType, PyShapedType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirMemRefTypeGetTypeID; + static constexpr const char *pyClassName = "MemRefType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Unranked MemRef Type subclass - UnrankedMemRefType. +class MLIR_PYTHON_API_EXPORTED PyUnrankedMemRefType + : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirUnrankedMemRefTypeGetTypeID; + static constexpr const char *pyClassName = "UnrankedMemRefType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Tuple Type subclass - TupleType. +class MLIR_PYTHON_API_EXPORTED PyTupleType + : public PyConcreteType<PyTupleType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirTupleTypeGetTypeID; + static constexpr const char *pyClassName = "TupleType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Function type. +class MLIR_PYTHON_API_EXPORTED PyFunctionType + : public PyConcreteType<PyFunctionType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirFunctionTypeGetTypeID; + static constexpr const char *pyClassName = "FunctionType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + +/// Opaque Type subclass - OpaqueType. +class MLIR_PYTHON_API_EXPORTED PyOpaqueType + : public PyConcreteType<PyOpaqueType> { +public: + static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirOpaqueTypeGetTypeID; + static constexpr const char *pyClassName = "OpaqueType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c); +}; + } // namespace MLIR_BINDINGS_PYTHON_DOMAIN } // namespace python } // namespace mlir diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp index 7350046f428c7..951486b818a4e 100644 --- a/mlir/lib/Bindings/Python/IRTypes.cpp +++ b/mlir/lib/Bindings/Python/IRTypes.cpp @@ -28,492 +28,6 @@ using llvm::Twine; namespace mlir { namespace python { namespace MLIR_BINDINGS_PYTHON_DOMAIN { - -/// Checks whether the given type is an integer or float type. -static int mlirTypeIsAIntegerOrFloat(MlirType type) { - return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) || - mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type); -} - -class PyIntegerType : public PyConcreteType<PyIntegerType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirIntegerTypeGetTypeID; - static constexpr const char *pyClassName = "IntegerType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get_signless", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - nb::arg("width"), nb::arg("context") = nb::none(), - "Create a signless integer type"); - c.def_static( - "get_signed", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeSignedGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - nb::arg("width"), nb::arg("context") = nb::none(), - "Create a signed integer type"); - c.def_static( - "get_unsigned", - [](unsigned width, DefaultingPyMlirContext context) { - MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); - return PyIntegerType(context->getRef(), t); - }, - nb::arg("width"), nb::arg("context") = nb::none(), - "Create an unsigned integer type"); - c.def_prop_ro( - "width", - [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, - "Returns the width of the integer type"); - c.def_prop_ro( - "is_signless", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsSignless(self); - }, - "Returns whether this is a signless integer"); - c.def_prop_ro( - "is_signed", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsSigned(self); - }, - "Returns whether this is a signed integer"); - c.def_prop_ro( - "is_unsigned", - [](PyIntegerType &self) -> bool { - return mlirIntegerTypeIsUnsigned(self); - }, - "Returns whether this is an unsigned integer"); - } -}; - -/// Index Type subclass - IndexType. -class PyIndexType : public PyConcreteType<PyIndexType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirIndexTypeGetTypeID; - static constexpr const char *pyClassName = "IndexType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirIndexTypeGet(context->get()); - return PyIndexType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a index type."); - } -}; - -class PyFloatType : public PyConcreteType<PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat; - static constexpr const char *pyClassName = "FloatType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_prop_ro( - "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); }, - "Returns the width of the floating-point type"); - } -}; - -/// Floating Point Type subclass - Float4E2M1FNType. -class PyFloat4E2M1FNType - : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat4E2M1FNTypeGetTypeID; - static constexpr const char *pyClassName = "Float4E2M1FNType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat4E2M1FNTypeGet(context->get()); - return PyFloat4E2M1FNType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float4_e2m1fn type."); - } -}; - -/// Floating Point Type subclass - Float6E2M3FNType. -class PyFloat6E2M3FNType - : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat6E2M3FNTypeGetTypeID; - static constexpr const char *pyClassName = "Float6E2M3FNType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat6E2M3FNTypeGet(context->get()); - return PyFloat6E2M3FNType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float6_e2m3fn type."); - } -}; - -/// Floating Point Type subclass - Float6E3M2FNType. -class PyFloat6E3M2FNType - : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat6E3M2FNTypeGetTypeID; - static constexpr const char *pyClassName = "Float6E3M2FNType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat6E3M2FNTypeGet(context->get()); - return PyFloat6E3M2FNType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float6_e3m2fn type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3FNType. -class PyFloat8E4M3FNType - : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E4M3FNTypeGetTypeID; - static constexpr const char *pyClassName = "Float8E4M3FNType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); - return PyFloat8E4M3FNType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e4m3fn type."); - } -}; - -/// Floating Point Type subclass - Float8E5M2Type. -class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E5M2TypeGetTypeID; - static constexpr const char *pyClassName = "Float8E5M2Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E5M2TypeGet(context->get()); - return PyFloat8E5M2Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e5m2 type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3Type. -class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E4M3TypeGetTypeID; - static constexpr const char *pyClassName = "Float8E4M3Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3TypeGet(context->get()); - return PyFloat8E4M3Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e4m3 type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3FNUZ. -class PyFloat8E4M3FNUZType - : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E4M3FNUZTypeGetTypeID; - static constexpr const char *pyClassName = "Float8E4M3FNUZType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); - return PyFloat8E4M3FNUZType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e4m3fnuz type."); - } -}; - -/// Floating Point Type subclass - Float8E4M3B11FNUZ. -class PyFloat8E4M3B11FNUZType - : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E4M3B11FNUZTypeGetTypeID; - static constexpr const char *pyClassName = "Float8E4M3B11FNUZType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); - return PyFloat8E4M3B11FNUZType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e4m3b11fnuz type."); - } -}; - -/// Floating Point Type subclass - Float8E5M2FNUZ. -class PyFloat8E5M2FNUZType - : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E5M2FNUZTypeGetTypeID; - static constexpr const char *pyClassName = "Float8E5M2FNUZType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); - return PyFloat8E5M2FNUZType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e5m2fnuz type."); - } -}; - -/// Floating Point Type subclass - Float8E3M4Type. -class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E3M4TypeGetTypeID; - static constexpr const char *pyClassName = "Float8E3M4Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E3M4TypeGet(context->get()); - return PyFloat8E3M4Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e3m4 type."); - } -}; - -/// Floating Point Type subclass - Float8E8M0FNUType. -class PyFloat8E8M0FNUType - : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat8E8M0FNUTypeGetTypeID; - static constexpr const char *pyClassName = "Float8E8M0FNUType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirFloat8E8M0FNUTypeGet(context->get()); - return PyFloat8E8M0FNUType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a float8_e8m0fnu type."); - } -}; - -/// Floating Point Type subclass - BF16Type. -class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirBFloat16TypeGetTypeID; - static constexpr const char *pyClassName = "BF16Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirBF16TypeGet(context->get()); - return PyBF16Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a bf16 type."); - } -}; - -/// Floating Point Type subclass - F16Type. -class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat16TypeGetTypeID; - static constexpr const char *pyClassName = "F16Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF16TypeGet(context->get()); - return PyF16Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a f16 type."); - } -}; - -/// Floating Point Type subclass - TF32Type. -class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloatTF32TypeGetTypeID; - static constexpr const char *pyClassName = "FloatTF32Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirTF32TypeGet(context->get()); - return PyTF32Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a tf32 type."); - } -}; - -/// Floating Point Type subclass - F32Type. -class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat32TypeGetTypeID; - static constexpr const char *pyClassName = "F32Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF32TypeGet(context->get()); - return PyF32Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a f32 type."); - } -}; - -/// Floating Point Type subclass - F64Type. -class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFloat64TypeGetTypeID; - static constexpr const char *pyClassName = "F64Type"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirF64TypeGet(context->get()); - return PyF64Type(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a f64 type."); - } -}; - -/// None Type subclass - NoneType. -class PyNoneType : public PyConcreteType<PyNoneType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirNoneTypeGetTypeID; - static constexpr const char *pyClassName = "NoneType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](DefaultingPyMlirContext context) { - MlirType t = mlirNoneTypeGet(context->get()); - return PyNoneType(context->getRef(), t); - }, - nb::arg("context") = nb::none(), "Create a none type."); - } -}; - -/// Complex Type subclass - ComplexType. -class PyComplexType : public PyConcreteType<PyComplexType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirComplexTypeGetTypeID; - static constexpr const char *pyClassName = "ComplexType"; - using PyConcreteType::PyConcreteType; - - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType) { - // The element must be a floating point or integer scalar type. - if (mlirTypeIsAIntegerOrFloat(elementType)) { - MlirType t = mlirComplexTypeGet(elementType); - return PyComplexType(elementType.getContext(), t); - } - throw nb::value_error( - (Twine("invalid '") + - nb::cast<std::string>(nb::repr(nb::cast(elementType))) + - "' and expected floating point or integer type.") - .str() - .c_str()); - }, - "Create a complex type"); - c.def_prop_ro( - "element_type", - [](PyComplexType &self) -> nb::typed<nb::object, PyType> { - return PyType(self.getContext(), mlirComplexTypeGetElementType(self)) - .maybeDownCast(); - }, - "Returns element type."); - } -}; - -} // namespace MLIR_BINDINGS_PYTHON_DOMAIN -} // namespace python -} // namespace mlir - // Shaped Type Interface - ShapedType void PyShapedType::bindDerived(ClassTy &c) { c.def_prop_ro( @@ -627,521 +141,632 @@ void PyShapedType::requireHasRank() { } } -const PyShapedType::IsAFunctionTy PyShapedType::isaFunction = mlirTypeIsAShaped; +void PyIntegerType::bindDerived(ClassTy &c) { + c.def_static( + "get_signless", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + nanobind::arg("width"), nanobind::arg("context") = nanobind::none(), + "Create a signless integer type"); + c.def_static( + "get_signed", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeSignedGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + nanobind::arg("width"), nanobind::arg("context") = nanobind::none(), + "Create a signed integer type"); + c.def_static( + "get_unsigned", + [](unsigned width, DefaultingPyMlirContext context) { + MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width); + return PyIntegerType(context->getRef(), t); + }, + nanobind::arg("width"), nanobind::arg("context") = nanobind::none(), + "Create an unsigned integer type"); + c.def_prop_ro( + "width", + [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); }, + "Returns the width of the integer type"); + c.def_prop_ro( + "is_signless", + [](PyIntegerType &self) -> bool { + return mlirIntegerTypeIsSignless(self); + }, + "Returns whether this is a signless integer"); + c.def_prop_ro( + "is_signed", + [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsSigned(self); }, + "Returns whether this is a signed integer"); + c.def_prop_ro( + "is_unsigned", + [](PyIntegerType &self) -> bool { + return mlirIntegerTypeIsUnsigned(self); + }, + "Returns whether this is an unsigned integer"); +} -namespace mlir { -namespace python { -namespace MLIR_BINDINGS_PYTHON_DOMAIN { +void PyIndexType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirIndexTypeGet(context->get()); + return PyIndexType(context->getRef(), t); + }, + nanobind::arg("context") = nanobind::none(), "Create a index type."); +} -/// Vector Type subclass - VectorType. -class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirVectorTypeGetTypeID; - static constexpr const char *pyClassName = "VectorType"; - using PyConcreteType::PyConcreteType; +void PyFloatType::bindDerived(ClassTy &c) { + c.def_prop_ro( + "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); }, + "Returns the width of the floating-point type"); +} - static void bindDerived(ClassTy &c) { - c.def_static("get", &PyVectorType::getChecked, nb::arg("shape"), - nb::arg("element_type"), nb::kw_only(), - nb::arg("scalable") = nb::none(), - nb::arg("scalable_dims") = nb::none(), - nb::arg("loc") = nb::none(), "Create a vector type") - .def_static("get_unchecked", &PyVectorType::get, nb::arg("shape"), - nb::arg("element_type"), nb::kw_only(), - nb::arg("scalable") = nb::none(), - nb::arg("scalable_dims") = nb::none(), - nb::arg("context") = nb::none(), "Create a vector type") - .def_prop_ro( - "scalable", - [](MlirType self) { return mlirVectorTypeIsScalable(self); }) - .def_prop_ro("scalable_dims", [](MlirType self) { - std::vector<bool> scalableDims; - size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self)); - scalableDims.reserve(rank); - for (size_t i = 0; i < rank; ++i) - scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i)); - return scalableDims; - }); - } +void PyFloat4E2M1FNType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat4E2M1FNTypeGet(context->get()); + return PyFloat4E2M1FNType(context->getRef(), t); + }, + nanobind::arg("context") = nanobind::none(), + "Create a float4_e2m1fn type."); +} -private: - static PyVectorType - getChecked(std::vector<int64_t> shape, PyType &elementType, - std::optional<nb::list> scalable, - std::optional<std::vector<int64_t>> scalableDims, - DefaultingPyLocation loc) { - if (scalable && scalableDims) { - throw nb::value_error("'scalable' and 'scalable_dims' kwargs " - "are mutually exclusive."); - } +void PyFloat6E2M3FNType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat6E2M3FNTypeGet(context->get()); + return PyFloat6E2M3FNType(context->getRef(), t); + }, + nanobind::arg("context") = nanobind::none(), + "Create a float6_e2m3fn type."); +} - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirType type; - if (scalable) { - if (scalable->size() != shape.size()) - throw nb::value_error("Expected len(scalable) == len(shape)."); +void PyFloat6E3M2FNType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat6E3M2FNTypeGet(context->get()); + return PyFloat6E3M2FNType(context->getRef(), t); + }, + nanobind::arg("context") = nanobind::none(), + "Create a float6_e3m2fn type."); +} - SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range( - *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); })); - type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), - scalableDimFlags.data(), - elementType); - } else if (scalableDims) { - SmallVector<bool> scalableDimFlags(shape.size(), false); - for (int64_t dim : *scalableDims) { - if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0) - throw nb::value_error("Scalable dimension index out of bounds."); - scalableDimFlags[dim] = true; - } - type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(), - scalableDimFlags.data(), - elementType); - } else { - type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(), - elementType); - } - if (mlirTypeIsNull(type)) - throw MLIRError("Invalid type", errors.take()); - return PyVectorType(elementType.getContext(), type); - } +void PyFloat8E4M3FNType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3FNTypeGet(context->get()); + return PyFloat8E4M3FNType(context->getRef(), t); + }, + nanobind::arg("context") = nanobind::none(), + "Create a float8_e4m3fn type."); +} - static PyVectorType get(std::vector<int64_t> shape, PyType &elementType, - std::optional<nb::list> scalable, - std::optional<std::vector<int64_t>> scalableDims, - DefaultingPyMlirContext context) { - if (scalable && scalableDims) { - throw nb::value_error("'scalable' and 'scalable_dims' kwargs " - "are mutually exclusive."); - } +void PyFloat8E5M2Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E5M2TypeGet(context->get()); + return PyFloat8E5M2Type(context->getRef(), t); + }, + nanobind::arg("context") = nanobind::none(), + "Create a float8_e5m2 type."); +} - PyMlirContext::ErrorCapture errors(context->getRef()); - MlirType type; - if (scalable) { - if (scalable->size() != shape.size()) - throw nb::value_error("Expected len(scalable) == len(shape)."); +void PyFloat8E4M3Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3TypeGet(context->get()); + return PyFloat8E4M3Type(context->getRef(), t); + }, + nanobind::arg("context") = nanobind::none(), + "Create a float8_e4m3 type."); +} - SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range( - *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); })); - type = mlirVectorTypeGetScalable(shape.size(), shape.data(), - scalableDimFlags.data(), elementType); - } else if (scalableDims) { - SmallVector<bool> scalableDimFlags(shape.size(), false); - for (int64_t dim : *scalableDims) { - if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0) - throw nb::value_error("Scalable dimension index out of bounds."); - scalableDimFlags[dim] = true; - } - type = mlirVectorTypeGetScalable(shape.size(), shape.data(), - scalableDimFlags.data(), elementType); - } else { - type = mlirVectorTypeGet(shape.size(), shape.data(), elementType); - } - if (mlirTypeIsNull(type)) - throw MLIRError("Invalid type", errors.take()); - return PyVectorType(elementType.getContext(), type); - } -}; +void PyFloat8E4M3FNUZType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get()); + return PyFloat8E4M3FNUZType(context->getRef(), t); + }, + nanobind::arg("context") = nanobind::none(), + "Create a float8_e4m3fnuz type."); +} -/// Ranked Tensor Type subclass - RankedTensorType. -class PyRankedTensorType - : public PyConcreteType<PyRankedTensorType, PyShapedType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirRankedTensorTypeGetTypeID; - static constexpr const char *pyClassName = "RankedTensorType"; - using PyConcreteType::PyConcreteType; +void PyFloat8E4M3B11FNUZType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get()); + return PyFloat8E4M3B11FNUZType(context->getRef(), t); + }, + nanobind::arg("context") = nanobind::none(), + "Create a float8_e4m3b11fnuz type."); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector<int64_t> shape, PyType &elementType, - std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) { - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirType t = mlirRankedTensorTypeGetChecked( - loc, shape.size(), shape.data(), elementType, - encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyRankedTensorType(elementType.getContext(), t); - }, - nb::arg("shape"), nb::arg("element_type"), - nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(), - "Create a ranked tensor type"); - c.def_static( - "get_unchecked", - [](std::vector<int64_t> shape, PyType &elementType, - std::optional<PyAttribute> &encodingAttr, - DefaultingPyMlirContext context) { - PyMlirContext::ErrorCapture errors(context->getRef()); - MlirType t = mlirRankedTensorTypeGet( - shape.size(), shape.data(), elementType, - encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyRankedTensorType(elementType.getContext(), t); - }, - nb::arg("shape"), nb::arg("element_type"), - nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(), - "Create a ranked tensor type"); - c.def_prop_ro( - "encoding", - [](PyRankedTensorType &self) - -> std::optional<nb::typed<nb::object, PyAttribute>> { - MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); - if (mlirAttributeIsNull(encoding)) - return std::nullopt; - return PyAttribute(self.getContext(), encoding).maybeDownCast(); - }); - } -}; +void PyFloat8E5M2FNUZType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get()); + return PyFloat8E5M2FNUZType(context->getRef(), t); + }, + nanobind::arg("context") = nanobind::none(), + "Create a float8_e5m2fnuz type."); +} -/// Unranked Tensor Type subclass - UnrankedTensorType. -class PyUnrankedTensorType - : public PyConcreteType<PyUnrankedTensorType, PyShapedType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirUnrankedTensorTypeGetTypeID; - static constexpr const char *pyClassName = "UnrankedTensorType"; - using PyConcreteType::PyConcreteType; +void PyFloat8E3M4Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E3M4TypeGet(context->get()); + return PyFloat8E3M4Type(context->getRef(), t); + }, + nanobind::arg("context") = nanobind::none(), + "Create a float8_e3m4 type."); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType, DefaultingPyLocation loc) { - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyUnrankedTensorType(elementType.getContext(), t); - }, - nb::arg("element_type"), nb::arg("loc") = nb::none(), - "Create a unranked tensor type"); - c.def_static( - "get_unchecked", - [](PyType &elementType, DefaultingPyMlirContext context) { - PyMlirContext::ErrorCapture errors(context->getRef()); - MlirType t = mlirUnrankedTensorTypeGet(elementType); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyUnrankedTensorType(elementType.getContext(), t); - }, - nb::arg("element_type"), nb::arg("context") = nb::none(), - "Create a unranked tensor type"); - } -}; +void PyFloat8E8M0FNUType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirFloat8E8M0FNUTypeGet(context->get()); + return PyFloat8E8M0FNUType(context->getRef(), t); + }, + nanobind::arg("context") = nanobind::none(), + "Create a float8_e8m0fnu type."); +} -/// Ranked MemRef Type subclass - MemRefType. -class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirMemRefTypeGetTypeID; - static constexpr const char *pyClassName = "MemRefType"; - using PyConcreteType::PyConcreteType; +void PyBF16Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirBF16TypeGet(context->get()); + return PyBF16Type(context->getRef(), t); + }, + nanobind::arg("context") = nanobind::none(), "Create a bf16 type."); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector<int64_t> shape, PyType &elementType, - PyAttribute *layout, PyAttribute *memorySpace, - DefaultingPyLocation loc) { - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull(); - MlirAttribute memSpaceAttr = - memorySpace ? *memorySpace : mlirAttributeGetNull(); - MlirType t = - mlirMemRefTypeGetChecked(loc, elementType, shape.size(), - shape.data(), layoutAttr, memSpaceAttr); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyMemRefType(elementType.getContext(), t); - }, - nb::arg("shape"), nb::arg("element_type"), - nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(), - nb::arg("loc") = nb::none(), "Create a memref type") - .def_static( - "get_unchecked", - [](std::vector<int64_t> shape, PyType &elementType, - PyAttribute *layout, PyAttribute *memorySpace, - DefaultingPyMlirContext context) { - PyMlirContext::ErrorCapture errors(context->getRef()); - MlirAttribute layoutAttr = - layout ? *layout : mlirAttributeGetNull(); - MlirAttribute memSpaceAttr = - memorySpace ? *memorySpace : mlirAttributeGetNull(); - MlirType t = - mlirMemRefTypeGet(elementType, shape.size(), shape.data(), - layoutAttr, memSpaceAttr); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyMemRefType(elementType.getContext(), t); - }, - nb::arg("shape"), nb::arg("element_type"), - nb::arg("layout") = nb::none(), - nb::arg("memory_space") = nb::none(), - nb::arg("context") = nb::none(), "Create a memref type") - .def_prop_ro( - "layout", - [](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> { - return PyAttribute(self.getContext(), - mlirMemRefTypeGetLayout(self)) - .maybeDownCast(); - }, - "The layout of the MemRef type.") - .def( - "get_strides_and_offset", - [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> { - std::vector<int64_t> strides(mlirShapedTypeGetRank(self)); - int64_t offset; - if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset( - self, strides.data(), &offset))) - throw std::runtime_error( - "Failed to extract strides and offset from memref."); - return {strides, offset}; - }, - "The strides and offset of the MemRef type.") - .def_prop_ro( - "affine_map", - [](PyMemRefType &self) -> PyAffineMap { - MlirAffineMap map = mlirMemRefTypeGetAffineMap(self); - return PyAffineMap(self.getContext(), map); - }, - "The layout of the MemRef type as an affine map.") - .def_prop_ro( - "memory_space", - [](PyMemRefType &self) - -> std::optional<nb::typed<nb::object, PyAttribute>> { - MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); - if (mlirAttributeIsNull(a)) - return std::nullopt; - return PyAttribute(self.getContext(), a).maybeDownCast(); - }, - "Returns the memory space of the given MemRef type."); - } -}; +void PyF16Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF16TypeGet(context->get()); + return PyF16Type(context->getRef(), t); + }, + nanobind::arg("context") = nanobind::none(), "Create a f16 type."); +} -/// Unranked MemRef Type subclass - UnrankedMemRefType. -class PyUnrankedMemRefType - : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirUnrankedMemRefTypeGetTypeID; - static constexpr const char *pyClassName = "UnrankedMemRefType"; - using PyConcreteType::PyConcreteType; +void PyTF32Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirTF32TypeGet(context->get()); + return PyTF32Type(context->getRef(), t); + }, + nanobind::arg("context") = nanobind::none(), "Create a tf32 type."); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](PyType &elementType, PyAttribute *memorySpace, - DefaultingPyLocation loc) { - PyMlirContext::ErrorCapture errors(loc->getContext()); - MlirAttribute memSpaceAttr = {}; - if (memorySpace) - memSpaceAttr = *memorySpace; +void PyF32Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF32TypeGet(context->get()); + return PyF32Type(context->getRef(), t); + }, + nanobind::arg("context") = nanobind::none(), "Create a f32 type."); +} - MlirType t = - mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyUnrankedMemRefType(elementType.getContext(), t); - }, - nb::arg("element_type"), nb::arg("memory_space").none(), - nb::arg("loc") = nb::none(), "Create a unranked memref type") - .def_static( - "get_unchecked", - [](PyType &elementType, PyAttribute *memorySpace, - DefaultingPyMlirContext context) { - PyMlirContext::ErrorCapture errors(context->getRef()); - MlirAttribute memSpaceAttr = {}; - if (memorySpace) - memSpaceAttr = *memorySpace; +void PyF64Type::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirF64TypeGet(context->get()); + return PyF64Type(context->getRef(), t); + }, + nanobind::arg("context") = nanobind::none(), "Create a f64 type."); +} - MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr); - if (mlirTypeIsNull(t)) - throw MLIRError("Invalid type", errors.take()); - return PyUnrankedMemRefType(elementType.getContext(), t); - }, - nb::arg("element_type"), nb::arg("memory_space").none(), - nb::arg("context") = nb::none(), "Create a unranked memref type") - .def_prop_ro( - "memory_space", - [](PyUnrankedMemRefType &self) - -> std::optional<nb::typed<nb::object, PyAttribute>> { - MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self); - if (mlirAttributeIsNull(a)) - return std::nullopt; - return PyAttribute(self.getContext(), a).maybeDownCast(); - }, - "Returns the memory space of the given Unranked MemRef type."); - } -}; +void PyNoneType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](DefaultingPyMlirContext context) { + MlirType t = mlirNoneTypeGet(context->get()); + return PyNoneType(context->getRef(), t); + }, + nanobind::arg("context") = nanobind::none(), "Create a none type."); +} + +void PyComplexType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType) { + // The element must be a floating point or integer scalar type. + if (mlirTypeIsAIntegerOrFloat(elementType)) { + MlirType t = mlirComplexTypeGet(elementType); + return PyComplexType(elementType.getContext(), t); + } + throw nanobind::value_error( + (Twine("invalid '") + + nanobind::cast<std::string>( + nanobind::repr(nanobind::cast(elementType))) + + "' and expected floating point or integer type.") + .str() + .c_str()); + }, + "Create a complex type"); + c.def_prop_ro( + "element_type", + [](PyComplexType &self) -> nanobind::typed<nanobind::object, PyType> { + return PyType(self.getContext(), mlirComplexTypeGetElementType(self)) + .maybeDownCast(); + }, + "Returns element type."); +} -/// Tuple Type subclass - TupleType. -class PyTupleType : public PyConcreteType<PyTupleType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirTupleTypeGetTypeID; - static constexpr const char *pyClassName = "TupleType"; - using PyConcreteType::PyConcreteType; +void PyVectorType::bindDerived(ClassTy &c) { + c.def_static("get", &PyVectorType::getChecked, nanobind::arg("shape"), + nanobind::arg("element_type"), nanobind::kw_only(), + nanobind::arg("scalable") = nanobind::none(), + nanobind::arg("scalable_dims") = nanobind::none(), + nanobind::arg("loc") = nanobind::none(), "Create a vector type") + .def_static("get_unchecked", &PyVectorType::get, nanobind::arg("shape"), + nanobind::arg("element_type"), nanobind::kw_only(), + nanobind::arg("scalable") = nanobind::none(), + nanobind::arg("scalable_dims") = nanobind::none(), + nanobind::arg("context") = nanobind::none(), + "Create a vector type") + .def_prop_ro("scalable", + [](MlirType self) { return mlirVectorTypeIsScalable(self); }) + .def_prop_ro("scalable_dims", [](MlirType self) { + std::vector<bool> scalableDims; + size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self)); + scalableDims.reserve(rank); + for (size_t i = 0; i < rank; ++i) + scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i)); + return scalableDims; + }); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get_tuple", - [](const std::vector<PyType> &elements, - DefaultingPyMlirContext context) { - std::vector<MlirType> mlirElements; - mlirElements.reserve(elements.size()); - for (const auto &element : elements) - mlirElements.push_back(element.get()); - MlirType t = mlirTupleTypeGet(context->get(), elements.size(), - mlirElements.data()); - return PyTupleType(context->getRef(), t); - }, - nb::arg("elements"), nb::arg("context") = nb::none(), - "Create a tuple type"); - c.def_static( - "get_tuple", - [](std::vector<MlirType> elements, DefaultingPyMlirContext context) { - MlirType t = mlirTupleTypeGet(context->get(), elements.size(), - elements.data()); - return PyTupleType(context->getRef(), t); - }, - nb::arg("elements"), nb::arg("context") = nb::none(), - // clang-format off - nb::sig("def get_tuple(elements: Sequence[Type], context: Context | None = None) -> TupleType"), - // clang-format on - "Create a tuple type"); - c.def( - "get_type", - [](PyTupleType &self, intptr_t pos) -> nb::typed<nb::object, PyType> { - return PyType(self.getContext(), mlirTupleTypeGetType(self, pos)) - .maybeDownCast(); - }, - nb::arg("pos"), "Returns the pos-th type in the tuple type."); - c.def_prop_ro( - "num_types", - [](PyTupleType &self) -> intptr_t { - return mlirTupleTypeGetNumTypes(self); - }, - "Returns the number of types contained in a tuple."); - } -}; +void PyRankedTensorType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector<int64_t> shape, PyType &elementType, + std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirType t = mlirRankedTensorTypeGetChecked( + loc, shape.size(), shape.data(), elementType, + encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyRankedTensorType(elementType.getContext(), t); + }, + nanobind::arg("shape"), nanobind::arg("element_type"), + nanobind::arg("encoding") = nanobind::none(), + nanobind::arg("loc") = nanobind::none(), "Create a ranked tensor type"); + c.def_static( + "get_unchecked", + [](std::vector<int64_t> shape, PyType &elementType, + std::optional<PyAttribute> &encodingAttr, + DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirType t = mlirRankedTensorTypeGet( + shape.size(), shape.data(), elementType, + encodingAttr ? encodingAttr->get() : mlirAttributeGetNull()); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyRankedTensorType(elementType.getContext(), t); + }, + nanobind::arg("shape"), nanobind::arg("element_type"), + nanobind::arg("encoding") = nanobind::none(), + nanobind::arg("context") = nanobind::none(), + "Create a ranked tensor type"); + c.def_prop_ro( + "encoding", + [](PyRankedTensorType &self) + -> std::optional<nanobind::typed<nanobind::object, PyAttribute>> { + MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get()); + if (mlirAttributeIsNull(encoding)) + return std::nullopt; + return PyAttribute(self.getContext(), encoding).maybeDownCast(); + }); +} -/// Function type. -class PyFunctionType : public PyConcreteType<PyFunctionType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirFunctionTypeGetTypeID; - static constexpr const char *pyClassName = "FunctionType"; - using PyConcreteType::PyConcreteType; +void PyUnrankedTensorType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyUnrankedTensorType(elementType.getContext(), t); + }, + nanobind::arg("element_type"), nanobind::arg("loc") = nanobind::none(), + "Create a unranked tensor type"); + c.def_static( + "get_unchecked", + [](PyType &elementType, DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirType t = mlirUnrankedTensorTypeGet(elementType); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyUnrankedTensorType(elementType.getContext(), t); + }, + nanobind::arg("element_type"), + nanobind::arg("context") = nanobind::none(), + "Create a unranked tensor type"); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](std::vector<PyType> inputs, std::vector<PyType> results, - DefaultingPyMlirContext context) { - std::vector<MlirType> mlirInputs; - mlirInputs.reserve(inputs.size()); - for (const auto &input : inputs) - mlirInputs.push_back(input.get()); - std::vector<MlirType> mlirResults; - mlirResults.reserve(results.size()); - for (const auto &result : results) - mlirResults.push_back(result.get()); +void PyMemRefType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector<int64_t> shape, PyType &elementType, PyAttribute *layout, + PyAttribute *memorySpace, DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull(); + MlirAttribute memSpaceAttr = + memorySpace ? *memorySpace : mlirAttributeGetNull(); + MlirType t = + mlirMemRefTypeGetChecked(loc, elementType, shape.size(), + shape.data(), layoutAttr, memSpaceAttr); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyMemRefType(elementType.getContext(), t); + }, + nanobind::arg("shape"), nanobind::arg("element_type"), + nanobind::arg("layout") = nanobind::none(), + nanobind::arg("memory_space") = nanobind::none(), + nanobind::arg("loc") = nanobind::none(), "Create a memref type") + .def_static( + "get_unchecked", + [](std::vector<int64_t> shape, PyType &elementType, + PyAttribute *layout, PyAttribute *memorySpace, + DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirAttribute layoutAttr = + layout ? *layout : mlirAttributeGetNull(); + MlirAttribute memSpaceAttr = + memorySpace ? *memorySpace : mlirAttributeGetNull(); + MlirType t = + mlirMemRefTypeGet(elementType, shape.size(), shape.data(), + layoutAttr, memSpaceAttr); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyMemRefType(elementType.getContext(), t); + }, + nanobind::arg("shape"), nanobind::arg("element_type"), + nanobind::arg("layout") = nanobind::none(), + nanobind::arg("memory_space") = nanobind::none(), + nanobind::arg("context") = nanobind::none(), "Create a memref type") + .def_prop_ro( + "layout", + [](PyMemRefType &self) + -> nanobind::typed<nanobind::object, PyAttribute> { + return PyAttribute(self.getContext(), mlirMemRefTypeGetLayout(self)) + .maybeDownCast(); + }, + "The layout of the MemRef type.") + .def( + "get_strides_and_offset", + [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> { + std::vector<int64_t> strides(mlirShapedTypeGetRank(self)); + int64_t offset; + if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset( + self, strides.data(), &offset))) + throw std::runtime_error( + "Failed to extract strides and offset from memref."); + return {strides, offset}; + }, + "The strides and offset of the MemRef type.") + .def_prop_ro( + "affine_map", + [](PyMemRefType &self) -> PyAffineMap { + MlirAffineMap map = mlirMemRefTypeGetAffineMap(self); + return PyAffineMap(self.getContext(), map); + }, + "The layout of the MemRef type as an affine map.") + .def_prop_ro( + "memory_space", + [](PyMemRefType &self) + -> std::optional<nanobind::typed<nanobind::object, PyAttribute>> { + MlirAttribute a = mlirMemRefTypeGetMemorySpace(self); + if (mlirAttributeIsNull(a)) + return std::nullopt; + return PyAttribute(self.getContext(), a).maybeDownCast(); + }, + "Returns the memory space of the given MemRef type."); +} - MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(), - mlirInputs.data(), results.size(), - mlirResults.data()); - return PyFunctionType(context->getRef(), t); - }, - nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(), - "Gets a FunctionType from a list of input and result types"); - c.def_static( - "get", - [](std::vector<MlirType> inputs, std::vector<MlirType> results, - DefaultingPyMlirContext context) { - MlirType t = - mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(), - results.size(), results.data()); - return PyFunctionType(context->getRef(), t); - }, - nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(), - // clang-format off - nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: Context | None = None) -> FunctionType"), - // clang-format on - "Gets a FunctionType from a list of input and result types"); - c.def_prop_ro( - "inputs", - [](PyFunctionType &self) { - MlirType t = self; - nb::list types; - for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; - ++i) { - types.append(mlirFunctionTypeGetInput(t, i)); - } - return types; - }, - "Returns the list of input types in the FunctionType."); - c.def_prop_ro( - "results", - [](PyFunctionType &self) { - nb::list types; - for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; - ++i) { - types.append(mlirFunctionTypeGetResult(self, i)); - } - return types; - }, - "Returns the list of result types in the FunctionType."); - } -}; +void PyUnrankedMemRefType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](PyType &elementType, PyAttribute *memorySpace, + DefaultingPyLocation loc) { + PyMlirContext::ErrorCapture errors(loc->getContext()); + MlirAttribute memSpaceAttr = {}; + if (memorySpace) + memSpaceAttr = *memorySpace; + + MlirType t = + mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyUnrankedMemRefType(elementType.getContext(), t); + }, + nanobind::arg("element_type"), nanobind::arg("memory_space").none(), + nanobind::arg("loc") = nanobind::none(), "Create a unranked memref type") + .def_static( + "get_unchecked", + [](PyType &elementType, PyAttribute *memorySpace, + DefaultingPyMlirContext context) { + PyMlirContext::ErrorCapture errors(context->getRef()); + MlirAttribute memSpaceAttr = {}; + if (memorySpace) + memSpaceAttr = *memorySpace; + + MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr); + if (mlirTypeIsNull(t)) + throw MLIRError("Invalid type", errors.take()); + return PyUnrankedMemRefType(elementType.getContext(), t); + }, + nanobind::arg("element_type"), nanobind::arg("memory_space").none(), + nanobind::arg("context") = nanobind::none(), + "Create a unranked memref type") + .def_prop_ro( + "memory_space", + [](PyUnrankedMemRefType &self) + -> std::optional<nanobind::typed<nanobind::object, PyAttribute>> { + MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self); + if (mlirAttributeIsNull(a)) + return std::nullopt; + return PyAttribute(self.getContext(), a).maybeDownCast(); + }, + "Returns the memory space of the given Unranked MemRef type."); +} -/// Opaque Type subclass - OpaqueType. -class PyOpaqueType : public PyConcreteType<PyOpaqueType> { -public: - static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque; - static constexpr GetTypeIDFunctionTy getTypeIdFunction = - mlirOpaqueTypeGetTypeID; - static constexpr const char *pyClassName = "OpaqueType"; - using PyConcreteType::PyConcreteType; +void PyTupleType::bindDerived(ClassTy &c) { + c.def_static( + "get_tuple", + [](const std::vector<PyType> &elements, DefaultingPyMlirContext context) { + std::vector<MlirType> mlirElements; + mlirElements.reserve(elements.size()); + for (const auto &element : elements) + mlirElements.push_back(element.get()); + MlirType t = mlirTupleTypeGet(context->get(), elements.size(), + mlirElements.data()); + return PyTupleType(context->getRef(), t); + }, + nanobind::arg("elements"), nanobind::arg("context") = nanobind::none(), + "Create a tuple type"); + c.def_static( + "get_tuple", + [](std::vector<MlirType> elements, DefaultingPyMlirContext context) { + MlirType t = + mlirTupleTypeGet(context->get(), elements.size(), elements.data()); + return PyTupleType(context->getRef(), t); + }, + nanobind::arg("elements"), nanobind::arg("context") = nanobind::none(), + // clang-format off + nanobind::sig("def get_tuple(elements: Sequence[Type], context: Context | None = None) -> TupleType"), + // clang-format on + "Create a tuple type"); + c.def( + "get_type", + [](PyTupleType &self, + intptr_t pos) -> nanobind::typed<nanobind::object, PyType> { + return PyType(self.getContext(), mlirTupleTypeGetType(self, pos)) + .maybeDownCast(); + }, + nanobind::arg("pos"), "Returns the pos-th type in the tuple type."); + c.def_prop_ro( + "num_types", + [](PyTupleType &self) -> intptr_t { + return mlirTupleTypeGetNumTypes(self); + }, + "Returns the number of types contained in a tuple."); +} - static void bindDerived(ClassTy &c) { - c.def_static( - "get", - [](const std::string &dialectNamespace, const std::string &typeData, - DefaultingPyMlirContext context) { - MlirType type = mlirOpaqueTypeGet(context->get(), - toMlirStringRef(dialectNamespace), - toMlirStringRef(typeData)); - return PyOpaqueType(context->getRef(), type); - }, - nb::arg("dialect_namespace"), nb::arg("buffer"), - nb::arg("context") = nb::none(), - "Create an unregistered (opaque) dialect type."); - c.def_prop_ro( - "dialect_namespace", - [](PyOpaqueType &self) { - MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); - return nb::str(stringRef.data, stringRef.length); - }, - "Returns the dialect namespace for the Opaque type as a string."); - c.def_prop_ro( - "data", - [](PyOpaqueType &self) { - MlirStringRef stringRef = mlirOpaqueTypeGetData(self); - return nb::str(stringRef.data, stringRef.length); - }, - "Returns the data for the Opaque type as a string."); - } -}; +void PyFunctionType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector<PyType> inputs, std::vector<PyType> results, + DefaultingPyMlirContext context) { + std::vector<MlirType> mlirInputs; + mlirInputs.reserve(inputs.size()); + for (const auto &input : inputs) + mlirInputs.push_back(input.get()); + std::vector<MlirType> mlirResults; + mlirResults.reserve(results.size()); + for (const auto &result : results) + mlirResults.push_back(result.get()); + + MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(), + mlirInputs.data(), results.size(), + mlirResults.data()); + return PyFunctionType(context->getRef(), t); + }, + nanobind::arg("inputs"), nanobind::arg("results"), + nanobind::arg("context") = nanobind::none(), + "Gets a FunctionType from a list of input and result types"); + c.def_static( + "get", + [](std::vector<MlirType> inputs, std::vector<MlirType> results, + DefaultingPyMlirContext context) { + MlirType t = + mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(), + results.size(), results.data()); + return PyFunctionType(context->getRef(), t); + }, + nanobind::arg("inputs"), nanobind::arg("results"), + nanobind::arg("context") = nanobind::none(), + // clang-format off + nanobind::sig("def get(inputs: Sequence[Type], results: Sequence[Type], context: Context | None = None) -> FunctionType"), + // clang-format on + "Gets a FunctionType from a list of input and result types"); + c.def_prop_ro( + "inputs", + [](PyFunctionType &self) { + MlirType t = self; + nanobind::list types; + for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e; + ++i) { + types.append(mlirFunctionTypeGetInput(t, i)); + } + return types; + }, + "Returns the list of input types in the FunctionType."); + c.def_prop_ro( + "results", + [](PyFunctionType &self) { + nanobind::list types; + for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e; + ++i) { + types.append(mlirFunctionTypeGetResult(self, i)); + } + return types; + }, + "Returns the list of result types in the FunctionType."); +} + +void PyOpaqueType::bindDerived(ClassTy &c) { + c.def_static( + "get", + [](const std::string &dialectNamespace, const std::string &typeData, + DefaultingPyMlirContext context) { + MlirType type = + mlirOpaqueTypeGet(context->get(), toMlirStringRef(dialectNamespace), + toMlirStringRef(typeData)); + return PyOpaqueType(context->getRef(), type); + }, + nanobind::arg("dialect_namespace"), nanobind::arg("buffer"), + nanobind::arg("context") = nanobind::none(), + "Create an unregistered (opaque) dialect type."); + c.def_prop_ro( + "dialect_namespace", + [](PyOpaqueType &self) { + MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self); + return nanobind::str(stringRef.data, stringRef.length); + }, + "Returns the dialect namespace for the Opaque type as a string."); + c.def_prop_ro( + "data", + [](PyOpaqueType &self) { + MlirStringRef stringRef = mlirOpaqueTypeGetData(self); + return nanobind::str(stringRef.data, stringRef.length); + }, + "Returns the data for the Opaque type as a string."); +} +const PyShapedType::IsAFunctionTy PyShapedType::isaFunction = mlirTypeIsAShaped; } // namespace MLIR_BINDINGS_PYTHON_DOMAIN } // namespace python } // namespace mlir diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index 4a9fb127ee08c..582863ffcbb0d 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -535,7 +535,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Core IRAffine.cpp IRAttributes.cpp IRInterfaces.cpp - IRTypes.cpp Pass.cpp Rewrite.cpp @@ -846,8 +845,9 @@ declare_mlir_python_extension(MLIRPythonExtension.MLIRPythonSupport ADD_TO_PARENT MLIRPythonSources.Core ROOT_DIR "${PYTHON_SOURCE_DIR}" SOURCES - IRCore.cpp Globals.cpp + IRCore.cpp + IRTypes.cpp ) ################################################################################ diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp b/mlir/test/python/lib/PythonTestModuleNanobind.cpp index 43573cbc305fa..a296b5e814b4b 100644 --- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp +++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp @@ -15,6 +15,7 @@ #include "mlir-c/IR.h" #include "mlir/Bindings/Python/Diagnostics.h" #include "mlir/Bindings/Python/IRCore.h" +#include "mlir/Bindings/Python/IRTypes.h" #include "mlir/Bindings/Python/Nanobind.h" #include "mlir/Bindings/Python/NanobindAdaptors.h" #include "nanobind/nanobind.h" @@ -47,6 +48,49 @@ struct PyTestType } }; +struct PyTestIntegerRankedTensorType + : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType< + PyTestIntegerRankedTensorType, + mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType> { + static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedIntegerTensor; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirRankedTensorTypeGetTypeID; + static constexpr const char *pyClassName = "TestIntegerRankedTensorType"; + using PyConcreteType::PyConcreteType; + + static void bindDerived(ClassTy &c) { + c.def_static( + "get", + [](std::vector<int64_t> shape, unsigned width, + mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext + ctx) { + MlirAttribute encoding = mlirAttributeGetNull(); + return PyTestIntegerRankedTensorType( + ctx->getRef(), + mlirRankedTensorTypeGet( + shape.size(), shape.data(), + mlirIntegerTypeGet(ctx.get()->get(), width), encoding)); + }, + nb::arg("shape"), nb::arg("width"), + nb::arg("context").none() = nb::none()); + } +}; + +struct PyTestTensorValue + : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue< + PyTestTensorValue> { + static constexpr IsAFunctionTy isaFunction = + mlirTypeIsAPythonTestTestTensorValue; + static constexpr GetTypeIDFunctionTy getTypeIdFunction = + mlirRankedTensorTypeGetTypeID; + static constexpr const char *pyClassName = "TestTensorValue"; + using PyConcreteValue::PyConcreteValue; + + static void bindDerived(ClassTy &c) { + c.def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); }); + } +}; + class PyTestAttr : public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute< PyTestAttr> { @@ -73,18 +117,18 @@ class PyTestAttr NB_MODULE(_mlirPythonTestNanobind, m) { m.def( "register_python_test_dialect", - [](MlirContext context, bool load) { + [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext + context, + bool load) { MlirDialectHandle pythonTestDialect = mlirGetDialectHandle__python_test__(); - mlirDialectHandleRegisterDialect(pythonTestDialect, context); + mlirDialectHandleRegisterDialect(pythonTestDialect, + context.get()->get()); if (load) { - mlirDialectHandleLoadDialect(pythonTestDialect, context); + mlirDialectHandleLoadDialect(pythonTestDialect, context.get()->get()); } }, - nb::arg("context"), nb::arg("load") = true, - // clang-format off - nb::sig("def register_python_test_dialect(context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") ", load: bool = True) -> None")); - // clang-format on + nb::arg("context").none() = nb::none(), nb::arg("load") = true); m.def( "register_dialect", @@ -100,73 +144,16 @@ NB_MODULE(_mlirPythonTestNanobind, m) { m.def( "test_diagnostics_with_errors_and_notes", - [](MlirContext ctx) { - mlir::python::CollectDiagnosticsToStringScope handler(ctx); - mlirPythonTestEmitDiagnosticWithNote(ctx); + [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext + ctx) { + mlir::python::CollectDiagnosticsToStringScope handler(ctx.get()->get()); + mlirPythonTestEmitDiagnosticWithNote(ctx.get()->get()); throw nb::value_error(handler.takeMessage().c_str()); }, - // clang-format off - nb::sig("def test_diagnostics_with_errors_and_notes(arg: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") ", /) -> None")); - // clang-format on + nb::arg("context").none() = nb::none()); PyTestAttr::bind(m); PyTestType::bind(m); - - auto typeCls = - mlir_type_subclass(m, "TestIntegerRankedTensorType", - mlirTypeIsARankedIntegerTensor, - nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr("RankedTensorType")) - .def_classmethod( - "get", - [](const nb::object &cls, std::vector<int64_t> shape, - unsigned width, MlirContext ctx) { - MlirAttribute encoding = mlirAttributeGetNull(); - return cls(mlirRankedTensorTypeGet( - shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width), - encoding)); - }, - // clang-format off - nb::sig("def get(cls: object, shape: collections.abc.Sequence[int], width: int, context: " MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> object"), - // clang-format on - nb::arg("cls"), nb::arg("shape"), nb::arg("width"), - nb::arg("context").none() = nb::none()); - - assert(nb::hasattr(typeCls.get_class(), "static_typeid") && - "TestIntegerRankedTensorType has no static_typeid"); - - MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID(); - - nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( - mlirRankedTensorTypeID, nb::arg("replace") = true)( - nanobind::cpp_function([typeCls](const nb::object &mlirType) { - return typeCls.get_class()(mlirType); - })); - - auto valueCls = mlir_value_subclass(m, "TestTensorValue", - mlirTypeIsAPythonTestTestTensorValue) - .def("is_null", [](MlirValue &self) { - return mlirValueIsNull(self); - }); - - nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)( - mlirRankedTensorTypeID)( - nanobind::cpp_function([valueCls](const nb::object &valueObj) { - std::optional<nb::object> capsule = - mlirApiObjectToCapsule(valueObj); - assert(capsule.has_value() && "capsule is not null"); - MlirValue v = mlirPythonCapsuleToValue(capsule.value().ptr()); - MlirType t = mlirValueGetType(v); - // This is hyper-specific in order to exercise/test registering a - // value caster from cpp (but only for a single test case; see - // testTensorValue python_test.py). - if (mlirShapedTypeHasStaticShape(t) && - mlirShapedTypeGetDimSize(t, 0) == 1 && - mlirShapedTypeGetDimSize(t, 1) == 2 && - mlirShapedTypeGetDimSize(t, 2) == 3) - return valueCls.get_class()(valueObj); - return valueObj; - })); + PyTestIntegerRankedTensorType::bind(m); + PyTestTensorValue::bind(m); } _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
