https://github.com/makslevental updated 
https://github.com/llvm/llvm-project/pull/173939

>From dc24520fc192e2774509f15093d815910aeec1f4 Mon Sep 17 00:00:00 2001
From: makslevental <[email protected]>
Date: Mon, 29 Dec 2025 16:57:05 -0800
Subject: [PATCH] [mlir][Python] move IRTypes and IRAttributes to public
 headers

---
 mlir/include/mlir/Bindings/Python/IRCore.h    |   17 +-
 mlir/include/mlir/Bindings/Python/IRTypes.h   |  465 ++++-
 mlir/lib/Bindings/Python/IRTypes.cpp          | 1573 +++++++----------
 mlir/python/CMakeLists.txt                    |    4 +-
 .../python/lib/PythonTestModuleNanobind.cpp   |  129 +-
 5 files changed, 1138 insertions(+), 1050 deletions(-)

diff --git a/mlir/include/mlir/Bindings/Python/IRCore.h 
b/mlir/include/mlir/Bindings/Python/IRCore.h
index 0f402b4ce15ff..340b16bcdf558 100644
--- a/mlir/include/mlir/Bindings/Python/IRCore.h
+++ b/mlir/include/mlir/Bindings/Python/IRCore.h
@@ -979,7 +979,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteType : public 
BaseTy {
       PyGlobals::get().registerTypeCaster(
           DerivedTy::getTypeIdFunction(),
           nanobind::cast<nanobind::callable>(nanobind::cpp_function(
-              [](PyType pyType) -> DerivedTy { return pyType; })));
+              [](PyType pyType) -> DerivedTy { return pyType; })),
+          /*replace*/ true);
     }
 
     DerivedTy::bindDerived(cls);
@@ -1123,7 +1124,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteAttribute : 
public BaseTy {
           nanobind::cast<nanobind::callable>(
               nanobind::cpp_function([](PyAttribute pyAttribute) -> DerivedTy {
                 return pyAttribute;
-              })));
+              })),
+          /*replace*/ true);
     }
 
     DerivedTy::bindDerived(cls);
@@ -1511,6 +1513,8 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public 
PyValue {
   // and redefine bindDerived.
   using ClassTy = nanobind::class_<DerivedTy, PyValue>;
   using IsAFunctionTy = bool (*)(MlirValue);
+  using GetTypeIDFunctionTy = MlirTypeID (*)();
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction = nullptr;
 
   PyConcreteValue() = default;
   PyConcreteValue(PyOperationRef operationRef, MlirValue value)
@@ -1553,6 +1557,15 @@ class MLIR_PYTHON_API_EXPORTED PyConcreteValue : public 
PyValue {
         [](DerivedTy &self) -> nanobind::typed<nanobind::object, DerivedTy> {
           return self.maybeDownCast();
         });
+
+    if (DerivedTy::getTypeIdFunction) {
+      PyGlobals::get().registerValueCaster(
+          DerivedTy::getTypeIdFunction(),
+          nanobind::cast<nanobind::callable>(nanobind::cpp_function(
+              [](PyValue pyValue) -> DerivedTy { return pyValue; })),
+          /*replace*/ true);
+    }
+
     DerivedTy::bindDerived(cls);
   }
 
diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h 
b/mlir/include/mlir/Bindings/Python/IRTypes.h
index 87e0e10764bd8..db478e8d33f37 100644
--- a/mlir/include/mlir/Bindings/Python/IRTypes.h
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -9,13 +9,14 @@
 #ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H
 #define MLIR_BINDINGS_PYTHON_IRTYPES_H
 
+#include "mlir-c/BuiltinTypes.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 
 namespace mlir {
 namespace python {
 namespace MLIR_BINDINGS_PYTHON_DOMAIN {
 /// Shaped Type Interface - ShapedType
-class MLIR_PYTHON_API_EXPORTED PyShapedType
+class MLIR_PYTHON_API_EXPORTED MLIR_PYTHON_API_EXPORTED PyShapedType
     : public PyConcreteType<PyShapedType> {
 public:
   static const IsAFunctionTy isaFunction;
@@ -27,6 +28,468 @@ class MLIR_PYTHON_API_EXPORTED PyShapedType
 private:
   void requireHasRank();
 };
+
+/// Checks whether the given type is an integer or float type.
+inline int mlirTypeIsAIntegerOrFloat(MlirType type) {
+  return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
+         mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
+}
+
+class MLIR_PYTHON_API_EXPORTED PyIntegerType
+    : public PyConcreteType<PyIntegerType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirIntegerTypeGetTypeID;
+  static constexpr const char *pyClassName = "IntegerType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Index Type subclass - IndexType.
+class MLIR_PYTHON_API_EXPORTED PyIndexType
+    : public PyConcreteType<PyIndexType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirIndexTypeGetTypeID;
+  static constexpr const char *pyClassName = "IndexType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+class MLIR_PYTHON_API_EXPORTED PyFloatType
+    : public PyConcreteType<PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
+  static constexpr const char *pyClassName = "FloatType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float4E2M1FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat4E2M1FNType
+    : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat4E2M1FNTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float4E2M1FNType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float6E2M3FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat6E2M3FNType
+    : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat6E2M3FNTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float6E2M3FNType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float6E3M2FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat6E3M2FNType
+    : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat6E3M2FNTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float6E3M2FNType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3FNType.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3FNType
+    : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E4M3FNTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E4M3FNType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E5M2Type.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E5M2Type
+    : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E5M2TypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E5M2Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3Type.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3Type
+    : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E4M3TypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E4M3Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3FNUZ.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3FNUZType
+    : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E4M3FNUZTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E4M3FNUZType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E4M3B11FNUZ.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E4M3B11FNUZType
+    : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E4M3B11FNUZTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E5M2FNUZ.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E5M2FNUZType
+    : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E5M2FNUZTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E5M2FNUZType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E3M4Type.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E3M4Type
+    : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E3M4TypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E3M4Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - Float8E8M0FNUType.
+class MLIR_PYTHON_API_EXPORTED PyFloat8E8M0FNUType
+    : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat8E8M0FNUTypeGetTypeID;
+  static constexpr const char *pyClassName = "Float8E8M0FNUType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - BF16Type.
+class MLIR_PYTHON_API_EXPORTED PyBF16Type
+    : public PyConcreteType<PyBF16Type, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirBFloat16TypeGetTypeID;
+  static constexpr const char *pyClassName = "BF16Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F16Type.
+class MLIR_PYTHON_API_EXPORTED PyF16Type
+    : public PyConcreteType<PyF16Type, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat16TypeGetTypeID;
+  static constexpr const char *pyClassName = "F16Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - TF32Type.
+class MLIR_PYTHON_API_EXPORTED PyTF32Type
+    : public PyConcreteType<PyTF32Type, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloatTF32TypeGetTypeID;
+  static constexpr const char *pyClassName = "FloatTF32Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F32Type.
+class MLIR_PYTHON_API_EXPORTED PyF32Type
+    : public PyConcreteType<PyF32Type, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat32TypeGetTypeID;
+  static constexpr const char *pyClassName = "F32Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Floating Point Type subclass - F64Type.
+class MLIR_PYTHON_API_EXPORTED PyF64Type
+    : public PyConcreteType<PyF64Type, PyFloatType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFloat64TypeGetTypeID;
+  static constexpr const char *pyClassName = "F64Type";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// None Type subclass - NoneType.
+class MLIR_PYTHON_API_EXPORTED PyNoneType : public PyConcreteType<PyNoneType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirNoneTypeGetTypeID;
+  static constexpr const char *pyClassName = "NoneType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Complex Type subclass - ComplexType.
+class MLIR_PYTHON_API_EXPORTED PyComplexType
+    : public PyConcreteType<PyComplexType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirComplexTypeGetTypeID;
+  static constexpr const char *pyClassName = "ComplexType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Vector Type subclass - VectorType.
+class MLIR_PYTHON_API_EXPORTED PyVectorType
+    : public PyConcreteType<PyVectorType, PyShapedType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirVectorTypeGetTypeID;
+  static constexpr const char *pyClassName = "VectorType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+
+private:
+  static PyVectorType
+  getChecked(std::vector<int64_t> shape, PyType &elementType,
+             std::optional<nanobind::list> scalable,
+             std::optional<std::vector<int64_t>> scalableDims,
+             DefaultingPyLocation loc) {
+    if (scalable && scalableDims) {
+      throw nanobind::value_error("'scalable' and 'scalable_dims' kwargs "
+                                  "are mutually exclusive.");
+    }
+
+    PyMlirContext::ErrorCapture errors(loc->getContext());
+    MlirType type;
+    if (scalable) {
+      if (scalable->size() != shape.size())
+        throw nanobind::value_error("Expected len(scalable) == len(shape).");
+
+      SmallVector<bool> scalableDimFlags = llvm::to_vector(
+          llvm::map_range(*scalable, [](const nanobind::handle &h) {
+            return nanobind::cast<bool>(h);
+          }));
+      type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
+                                              scalableDimFlags.data(),
+                                              elementType);
+    } else if (scalableDims) {
+      SmallVector<bool> scalableDimFlags(shape.size(), false);
+      for (int64_t dim : *scalableDims) {
+        if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+          throw nanobind::value_error(
+              "Scalable dimension index out of bounds.");
+        scalableDimFlags[dim] = true;
+      }
+      type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
+                                              scalableDimFlags.data(),
+                                              elementType);
+    } else {
+      type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
+                                      elementType);
+    }
+    if (mlirTypeIsNull(type))
+      throw MLIRError("Invalid type", errors.take());
+    return PyVectorType(elementType.getContext(), type);
+  }
+
+  static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
+                          std::optional<nanobind::list> scalable,
+                          std::optional<std::vector<int64_t>> scalableDims,
+                          DefaultingPyMlirContext context) {
+    if (scalable && scalableDims) {
+      throw nanobind::value_error("'scalable' and 'scalable_dims' kwargs "
+                                  "are mutually exclusive.");
+    }
+
+    PyMlirContext::ErrorCapture errors(context->getRef());
+    MlirType type;
+    if (scalable) {
+      if (scalable->size() != shape.size())
+        throw nanobind::value_error("Expected len(scalable) == len(shape).");
+
+      SmallVector<bool> scalableDimFlags = llvm::to_vector(
+          llvm::map_range(*scalable, [](const nanobind::handle &h) {
+            return nanobind::cast<bool>(h);
+          }));
+      type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+                                       scalableDimFlags.data(), elementType);
+    } else if (scalableDims) {
+      SmallVector<bool> scalableDimFlags(shape.size(), false);
+      for (int64_t dim : *scalableDims) {
+        if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
+          throw nanobind::value_error(
+              "Scalable dimension index out of bounds.");
+        scalableDimFlags[dim] = true;
+      }
+      type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
+                                       scalableDimFlags.data(), elementType);
+    } else {
+      type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
+    }
+    if (mlirTypeIsNull(type))
+      throw MLIRError("Invalid type", errors.take());
+    return PyVectorType(elementType.getContext(), type);
+  }
+};
+
+/// Ranked Tensor Type subclass - RankedTensorType.
+class MLIR_PYTHON_API_EXPORTED PyRankedTensorType
+    : public PyConcreteType<PyRankedTensorType, PyShapedType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirRankedTensorTypeGetTypeID;
+  static constexpr const char *pyClassName = "RankedTensorType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Unranked Tensor Type subclass - UnrankedTensorType.
+class MLIR_PYTHON_API_EXPORTED PyUnrankedTensorType
+    : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirUnrankedTensorTypeGetTypeID;
+  static constexpr const char *pyClassName = "UnrankedTensorType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Ranked MemRef Type subclass - MemRefType.
+class MLIR_PYTHON_API_EXPORTED PyMemRefType
+    : public PyConcreteType<PyMemRefType, PyShapedType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirMemRefTypeGetTypeID;
+  static constexpr const char *pyClassName = "MemRefType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Unranked MemRef Type subclass - UnrankedMemRefType.
+class MLIR_PYTHON_API_EXPORTED PyUnrankedMemRefType
+    : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirUnrankedMemRefTypeGetTypeID;
+  static constexpr const char *pyClassName = "UnrankedMemRefType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Tuple Type subclass - TupleType.
+class MLIR_PYTHON_API_EXPORTED PyTupleType
+    : public PyConcreteType<PyTupleType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirTupleTypeGetTypeID;
+  static constexpr const char *pyClassName = "TupleType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Function type.
+class MLIR_PYTHON_API_EXPORTED PyFunctionType
+    : public PyConcreteType<PyFunctionType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirFunctionTypeGetTypeID;
+  static constexpr const char *pyClassName = "FunctionType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
+/// Opaque Type subclass - OpaqueType.
+class MLIR_PYTHON_API_EXPORTED PyOpaqueType
+    : public PyConcreteType<PyOpaqueType> {
+public:
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirOpaqueTypeGetTypeID;
+  static constexpr const char *pyClassName = "OpaqueType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+};
+
 } // namespace MLIR_BINDINGS_PYTHON_DOMAIN
 } // namespace python
 } // namespace mlir
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp 
b/mlir/lib/Bindings/Python/IRTypes.cpp
index 7350046f428c7..951486b818a4e 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -28,492 +28,6 @@ using llvm::Twine;
 namespace mlir {
 namespace python {
 namespace MLIR_BINDINGS_PYTHON_DOMAIN {
-
-/// Checks whether the given type is an integer or float type.
-static int mlirTypeIsAIntegerOrFloat(MlirType type) {
-  return mlirTypeIsAInteger(type) || mlirTypeIsABF16(type) ||
-         mlirTypeIsAF16(type) || mlirTypeIsAF32(type) || mlirTypeIsAF64(type);
-}
-
-class PyIntegerType : public PyConcreteType<PyIntegerType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAInteger;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirIntegerTypeGetTypeID;
-  static constexpr const char *pyClassName = "IntegerType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get_signless",
-        [](unsigned width, DefaultingPyMlirContext context) {
-          MlirType t = mlirIntegerTypeGet(context->get(), width);
-          return PyIntegerType(context->getRef(), t);
-        },
-        nb::arg("width"), nb::arg("context") = nb::none(),
-        "Create a signless integer type");
-    c.def_static(
-        "get_signed",
-        [](unsigned width, DefaultingPyMlirContext context) {
-          MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
-          return PyIntegerType(context->getRef(), t);
-        },
-        nb::arg("width"), nb::arg("context") = nb::none(),
-        "Create a signed integer type");
-    c.def_static(
-        "get_unsigned",
-        [](unsigned width, DefaultingPyMlirContext context) {
-          MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
-          return PyIntegerType(context->getRef(), t);
-        },
-        nb::arg("width"), nb::arg("context") = nb::none(),
-        "Create an unsigned integer type");
-    c.def_prop_ro(
-        "width",
-        [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
-        "Returns the width of the integer type");
-    c.def_prop_ro(
-        "is_signless",
-        [](PyIntegerType &self) -> bool {
-          return mlirIntegerTypeIsSignless(self);
-        },
-        "Returns whether this is a signless integer");
-    c.def_prop_ro(
-        "is_signed",
-        [](PyIntegerType &self) -> bool {
-          return mlirIntegerTypeIsSigned(self);
-        },
-        "Returns whether this is a signed integer");
-    c.def_prop_ro(
-        "is_unsigned",
-        [](PyIntegerType &self) -> bool {
-          return mlirIntegerTypeIsUnsigned(self);
-        },
-        "Returns whether this is an unsigned integer");
-  }
-};
-
-/// Index Type subclass - IndexType.
-class PyIndexType : public PyConcreteType<PyIndexType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAIndex;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirIndexTypeGetTypeID;
-  static constexpr const char *pyClassName = "IndexType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirIndexTypeGet(context->get());
-          return PyIndexType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a index type.");
-  }
-};
-
-class PyFloatType : public PyConcreteType<PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat;
-  static constexpr const char *pyClassName = "FloatType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_prop_ro(
-        "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
-        "Returns the width of the floating-point type");
-  }
-};
-
-/// Floating Point Type subclass - Float4E2M1FNType.
-class PyFloat4E2M1FNType
-    : public PyConcreteType<PyFloat4E2M1FNType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat4E2M1FN;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat4E2M1FNTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float4E2M1FNType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
-          return PyFloat4E2M1FNType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float4_e2m1fn type.");
-  }
-};
-
-/// Floating Point Type subclass - Float6E2M3FNType.
-class PyFloat6E2M3FNType
-    : public PyConcreteType<PyFloat6E2M3FNType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E2M3FN;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat6E2M3FNTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float6E2M3FNType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
-          return PyFloat6E2M3FNType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float6_e2m3fn type.");
-  }
-};
-
-/// Floating Point Type subclass - Float6E3M2FNType.
-class PyFloat6E3M2FNType
-    : public PyConcreteType<PyFloat6E3M2FNType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat6E3M2FN;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat6E3M2FNTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float6E3M2FNType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
-          return PyFloat6E3M2FNType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float6_e3m2fn type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E4M3FNType.
-class PyFloat8E4M3FNType
-    : public PyConcreteType<PyFloat8E4M3FNType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FN;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E4M3FNTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E4M3FNType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
-          return PyFloat8E4M3FNType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e4m3fn type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E5M2Type.
-class PyFloat8E5M2Type : public PyConcreteType<PyFloat8E5M2Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E5M2TypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E5M2Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E5M2TypeGet(context->get());
-          return PyFloat8E5M2Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e5m2 type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E4M3Type.
-class PyFloat8E4M3Type : public PyConcreteType<PyFloat8E4M3Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E4M3TypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E4M3Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E4M3TypeGet(context->get());
-          return PyFloat8E4M3Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e4m3 type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E4M3FNUZ.
-class PyFloat8E4M3FNUZType
-    : public PyConcreteType<PyFloat8E4M3FNUZType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3FNUZ;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E4M3FNUZTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E4M3FNUZType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
-          return PyFloat8E4M3FNUZType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e4m3fnuz type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E4M3B11FNUZ.
-class PyFloat8E4M3B11FNUZType
-    : public PyConcreteType<PyFloat8E4M3B11FNUZType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E4M3B11FNUZ;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E4M3B11FNUZTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E4M3B11FNUZType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
-          return PyFloat8E4M3B11FNUZType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e4m3b11fnuz type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E5M2FNUZ.
-class PyFloat8E5M2FNUZType
-    : public PyConcreteType<PyFloat8E5M2FNUZType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E5M2FNUZ;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E5M2FNUZTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E5M2FNUZType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
-          return PyFloat8E5M2FNUZType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e5m2fnuz type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E3M4Type.
-class PyFloat8E3M4Type : public PyConcreteType<PyFloat8E3M4Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E3M4;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E3M4TypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E3M4Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E3M4TypeGet(context->get());
-          return PyFloat8E3M4Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e3m4 type.");
-  }
-};
-
-/// Floating Point Type subclass - Float8E8M0FNUType.
-class PyFloat8E8M0FNUType
-    : public PyConcreteType<PyFloat8E8M0FNUType, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFloat8E8M0FNU;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat8E8M0FNUTypeGetTypeID;
-  static constexpr const char *pyClassName = "Float8E8M0FNUType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
-          return PyFloat8E8M0FNUType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a float8_e8m0fnu type.");
-  }
-};
-
-/// Floating Point Type subclass - BF16Type.
-class PyBF16Type : public PyConcreteType<PyBF16Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsABF16;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirBFloat16TypeGetTypeID;
-  static constexpr const char *pyClassName = "BF16Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirBF16TypeGet(context->get());
-          return PyBF16Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a bf16 type.");
-  }
-};
-
-/// Floating Point Type subclass - F16Type.
-class PyF16Type : public PyConcreteType<PyF16Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF16;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat16TypeGetTypeID;
-  static constexpr const char *pyClassName = "F16Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirF16TypeGet(context->get());
-          return PyF16Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a f16 type.");
-  }
-};
-
-/// Floating Point Type subclass - TF32Type.
-class PyTF32Type : public PyConcreteType<PyTF32Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATF32;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloatTF32TypeGetTypeID;
-  static constexpr const char *pyClassName = "FloatTF32Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirTF32TypeGet(context->get());
-          return PyTF32Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a tf32 type.");
-  }
-};
-
-/// Floating Point Type subclass - F32Type.
-class PyF32Type : public PyConcreteType<PyF32Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF32;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat32TypeGetTypeID;
-  static constexpr const char *pyClassName = "F32Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirF32TypeGet(context->get());
-          return PyF32Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a f32 type.");
-  }
-};
-
-/// Floating Point Type subclass - F64Type.
-class PyF64Type : public PyConcreteType<PyF64Type, PyFloatType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAF64;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFloat64TypeGetTypeID;
-  static constexpr const char *pyClassName = "F64Type";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirF64TypeGet(context->get());
-          return PyF64Type(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a f64 type.");
-  }
-};
-
-/// None Type subclass - NoneType.
-class PyNoneType : public PyConcreteType<PyNoneType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsANone;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirNoneTypeGetTypeID;
-  static constexpr const char *pyClassName = "NoneType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](DefaultingPyMlirContext context) {
-          MlirType t = mlirNoneTypeGet(context->get());
-          return PyNoneType(context->getRef(), t);
-        },
-        nb::arg("context") = nb::none(), "Create a none type.");
-  }
-};
-
-/// Complex Type subclass - ComplexType.
-class PyComplexType : public PyConcreteType<PyComplexType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAComplex;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirComplexTypeGetTypeID;
-  static constexpr const char *pyClassName = "ComplexType";
-  using PyConcreteType::PyConcreteType;
-
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](PyType &elementType) {
-          // The element must be a floating point or integer scalar type.
-          if (mlirTypeIsAIntegerOrFloat(elementType)) {
-            MlirType t = mlirComplexTypeGet(elementType);
-            return PyComplexType(elementType.getContext(), t);
-          }
-          throw nb::value_error(
-              (Twine("invalid '") +
-               nb::cast<std::string>(nb::repr(nb::cast(elementType))) +
-               "' and expected floating point or integer type.")
-                  .str()
-                  .c_str());
-        },
-        "Create a complex type");
-    c.def_prop_ro(
-        "element_type",
-        [](PyComplexType &self) -> nb::typed<nb::object, PyType> {
-          return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
-              .maybeDownCast();
-        },
-        "Returns element type.");
-  }
-};
-
-} // namespace MLIR_BINDINGS_PYTHON_DOMAIN
-} // namespace python
-} // namespace mlir
-
 // Shaped Type Interface - ShapedType
 void PyShapedType::bindDerived(ClassTy &c) {
   c.def_prop_ro(
@@ -627,521 +141,632 @@ void PyShapedType::requireHasRank() {
   }
 }
 
-const PyShapedType::IsAFunctionTy PyShapedType::isaFunction = 
mlirTypeIsAShaped;
+void PyIntegerType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get_signless",
+      [](unsigned width, DefaultingPyMlirContext context) {
+        MlirType t = mlirIntegerTypeGet(context->get(), width);
+        return PyIntegerType(context->getRef(), t);
+      },
+      nanobind::arg("width"), nanobind::arg("context") = nanobind::none(),
+      "Create a signless integer type");
+  c.def_static(
+      "get_signed",
+      [](unsigned width, DefaultingPyMlirContext context) {
+        MlirType t = mlirIntegerTypeSignedGet(context->get(), width);
+        return PyIntegerType(context->getRef(), t);
+      },
+      nanobind::arg("width"), nanobind::arg("context") = nanobind::none(),
+      "Create a signed integer type");
+  c.def_static(
+      "get_unsigned",
+      [](unsigned width, DefaultingPyMlirContext context) {
+        MlirType t = mlirIntegerTypeUnsignedGet(context->get(), width);
+        return PyIntegerType(context->getRef(), t);
+      },
+      nanobind::arg("width"), nanobind::arg("context") = nanobind::none(),
+      "Create an unsigned integer type");
+  c.def_prop_ro(
+      "width",
+      [](PyIntegerType &self) { return mlirIntegerTypeGetWidth(self); },
+      "Returns the width of the integer type");
+  c.def_prop_ro(
+      "is_signless",
+      [](PyIntegerType &self) -> bool {
+        return mlirIntegerTypeIsSignless(self);
+      },
+      "Returns whether this is a signless integer");
+  c.def_prop_ro(
+      "is_signed",
+      [](PyIntegerType &self) -> bool { return mlirIntegerTypeIsSigned(self); 
},
+      "Returns whether this is a signed integer");
+  c.def_prop_ro(
+      "is_unsigned",
+      [](PyIntegerType &self) -> bool {
+        return mlirIntegerTypeIsUnsigned(self);
+      },
+      "Returns whether this is an unsigned integer");
+}
 
-namespace mlir {
-namespace python {
-namespace MLIR_BINDINGS_PYTHON_DOMAIN {
+void PyIndexType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirIndexTypeGet(context->get());
+        return PyIndexType(context->getRef(), t);
+      },
+      nanobind::arg("context") = nanobind::none(), "Create a index type.");
+}
 
-/// Vector Type subclass - VectorType.
-class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAVector;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirVectorTypeGetTypeID;
-  static constexpr const char *pyClassName = "VectorType";
-  using PyConcreteType::PyConcreteType;
+void PyFloatType::bindDerived(ClassTy &c) {
+  c.def_prop_ro(
+      "width", [](PyFloatType &self) { return mlirFloatTypeGetWidth(self); },
+      "Returns the width of the floating-point type");
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static("get", &PyVectorType::getChecked, nb::arg("shape"),
-                 nb::arg("element_type"), nb::kw_only(),
-                 nb::arg("scalable") = nb::none(),
-                 nb::arg("scalable_dims") = nb::none(),
-                 nb::arg("loc") = nb::none(), "Create a vector type")
-        .def_static("get_unchecked", &PyVectorType::get, nb::arg("shape"),
-                    nb::arg("element_type"), nb::kw_only(),
-                    nb::arg("scalable") = nb::none(),
-                    nb::arg("scalable_dims") = nb::none(),
-                    nb::arg("context") = nb::none(), "Create a vector type")
-        .def_prop_ro(
-            "scalable",
-            [](MlirType self) { return mlirVectorTypeIsScalable(self); })
-        .def_prop_ro("scalable_dims", [](MlirType self) {
-          std::vector<bool> scalableDims;
-          size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
-          scalableDims.reserve(rank);
-          for (size_t i = 0; i < rank; ++i)
-            scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
-          return scalableDims;
-        });
-  }
+void PyFloat4E2M1FNType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat4E2M1FNTypeGet(context->get());
+        return PyFloat4E2M1FNType(context->getRef(), t);
+      },
+      nanobind::arg("context") = nanobind::none(),
+      "Create a float4_e2m1fn type.");
+}
 
-private:
-  static PyVectorType
-  getChecked(std::vector<int64_t> shape, PyType &elementType,
-             std::optional<nb::list> scalable,
-             std::optional<std::vector<int64_t>> scalableDims,
-             DefaultingPyLocation loc) {
-    if (scalable && scalableDims) {
-      throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
-                            "are mutually exclusive.");
-    }
+void PyFloat6E2M3FNType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat6E2M3FNTypeGet(context->get());
+        return PyFloat6E2M3FNType(context->getRef(), t);
+      },
+      nanobind::arg("context") = nanobind::none(),
+      "Create a float6_e2m3fn type.");
+}
 
-    PyMlirContext::ErrorCapture errors(loc->getContext());
-    MlirType type;
-    if (scalable) {
-      if (scalable->size() != shape.size())
-        throw nb::value_error("Expected len(scalable) == len(shape).");
+void PyFloat6E3M2FNType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat6E3M2FNTypeGet(context->get());
+        return PyFloat6E3M2FNType(context->getRef(), t);
+      },
+      nanobind::arg("context") = nanobind::none(),
+      "Create a float6_e3m2fn type.");
+}
 
-      SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
-          *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
-      type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
-                                              scalableDimFlags.data(),
-                                              elementType);
-    } else if (scalableDims) {
-      SmallVector<bool> scalableDimFlags(shape.size(), false);
-      for (int64_t dim : *scalableDims) {
-        if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
-          throw nb::value_error("Scalable dimension index out of bounds.");
-        scalableDimFlags[dim] = true;
-      }
-      type = mlirVectorTypeGetScalableChecked(loc, shape.size(), shape.data(),
-                                              scalableDimFlags.data(),
-                                              elementType);
-    } else {
-      type = mlirVectorTypeGetChecked(loc, shape.size(), shape.data(),
-                                      elementType);
-    }
-    if (mlirTypeIsNull(type))
-      throw MLIRError("Invalid type", errors.take());
-    return PyVectorType(elementType.getContext(), type);
-  }
+void PyFloat8E4M3FNType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat8E4M3FNTypeGet(context->get());
+        return PyFloat8E4M3FNType(context->getRef(), t);
+      },
+      nanobind::arg("context") = nanobind::none(),
+      "Create a float8_e4m3fn type.");
+}
 
-  static PyVectorType get(std::vector<int64_t> shape, PyType &elementType,
-                          std::optional<nb::list> scalable,
-                          std::optional<std::vector<int64_t>> scalableDims,
-                          DefaultingPyMlirContext context) {
-    if (scalable && scalableDims) {
-      throw nb::value_error("'scalable' and 'scalable_dims' kwargs "
-                            "are mutually exclusive.");
-    }
+void PyFloat8E5M2Type::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat8E5M2TypeGet(context->get());
+        return PyFloat8E5M2Type(context->getRef(), t);
+      },
+      nanobind::arg("context") = nanobind::none(),
+      "Create a float8_e5m2 type.");
+}
 
-    PyMlirContext::ErrorCapture errors(context->getRef());
-    MlirType type;
-    if (scalable) {
-      if (scalable->size() != shape.size())
-        throw nb::value_error("Expected len(scalable) == len(shape).");
+void PyFloat8E4M3Type::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat8E4M3TypeGet(context->get());
+        return PyFloat8E4M3Type(context->getRef(), t);
+      },
+      nanobind::arg("context") = nanobind::none(),
+      "Create a float8_e4m3 type.");
+}
 
-      SmallVector<bool> scalableDimFlags = llvm::to_vector(llvm::map_range(
-          *scalable, [](const nb::handle &h) { return nb::cast<bool>(h); }));
-      type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
-                                       scalableDimFlags.data(), elementType);
-    } else if (scalableDims) {
-      SmallVector<bool> scalableDimFlags(shape.size(), false);
-      for (int64_t dim : *scalableDims) {
-        if (static_cast<size_t>(dim) >= scalableDimFlags.size() || dim < 0)
-          throw nb::value_error("Scalable dimension index out of bounds.");
-        scalableDimFlags[dim] = true;
-      }
-      type = mlirVectorTypeGetScalable(shape.size(), shape.data(),
-                                       scalableDimFlags.data(), elementType);
-    } else {
-      type = mlirVectorTypeGet(shape.size(), shape.data(), elementType);
-    }
-    if (mlirTypeIsNull(type))
-      throw MLIRError("Invalid type", errors.take());
-    return PyVectorType(elementType.getContext(), type);
-  }
-};
+void PyFloat8E4M3FNUZType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat8E4M3FNUZTypeGet(context->get());
+        return PyFloat8E4M3FNUZType(context->getRef(), t);
+      },
+      nanobind::arg("context") = nanobind::none(),
+      "Create a float8_e4m3fnuz type.");
+}
 
-/// Ranked Tensor Type subclass - RankedTensorType.
-class PyRankedTensorType
-    : public PyConcreteType<PyRankedTensorType, PyShapedType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedTensor;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirRankedTensorTypeGetTypeID;
-  static constexpr const char *pyClassName = "RankedTensorType";
-  using PyConcreteType::PyConcreteType;
+void PyFloat8E4M3B11FNUZType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat8E4M3B11FNUZTypeGet(context->get());
+        return PyFloat8E4M3B11FNUZType(context->getRef(), t);
+      },
+      nanobind::arg("context") = nanobind::none(),
+      "Create a float8_e4m3b11fnuz type.");
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](std::vector<int64_t> shape, PyType &elementType,
-           std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) 
{
-          PyMlirContext::ErrorCapture errors(loc->getContext());
-          MlirType t = mlirRankedTensorTypeGetChecked(
-              loc, shape.size(), shape.data(), elementType,
-              encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
-          if (mlirTypeIsNull(t))
-            throw MLIRError("Invalid type", errors.take());
-          return PyRankedTensorType(elementType.getContext(), t);
-        },
-        nb::arg("shape"), nb::arg("element_type"),
-        nb::arg("encoding") = nb::none(), nb::arg("loc") = nb::none(),
-        "Create a ranked tensor type");
-    c.def_static(
-        "get_unchecked",
-        [](std::vector<int64_t> shape, PyType &elementType,
-           std::optional<PyAttribute> &encodingAttr,
-           DefaultingPyMlirContext context) {
-          PyMlirContext::ErrorCapture errors(context->getRef());
-          MlirType t = mlirRankedTensorTypeGet(
-              shape.size(), shape.data(), elementType,
-              encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
-          if (mlirTypeIsNull(t))
-            throw MLIRError("Invalid type", errors.take());
-          return PyRankedTensorType(elementType.getContext(), t);
-        },
-        nb::arg("shape"), nb::arg("element_type"),
-        nb::arg("encoding") = nb::none(), nb::arg("context") = nb::none(),
-        "Create a ranked tensor type");
-    c.def_prop_ro(
-        "encoding",
-        [](PyRankedTensorType &self)
-            -> std::optional<nb::typed<nb::object, PyAttribute>> {
-          MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
-          if (mlirAttributeIsNull(encoding))
-            return std::nullopt;
-          return PyAttribute(self.getContext(), encoding).maybeDownCast();
-        });
-  }
-};
+void PyFloat8E5M2FNUZType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat8E5M2FNUZTypeGet(context->get());
+        return PyFloat8E5M2FNUZType(context->getRef(), t);
+      },
+      nanobind::arg("context") = nanobind::none(),
+      "Create a float8_e5m2fnuz type.");
+}
 
-/// Unranked Tensor Type subclass - UnrankedTensorType.
-class PyUnrankedTensorType
-    : public PyConcreteType<PyUnrankedTensorType, PyShapedType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedTensor;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirUnrankedTensorTypeGetTypeID;
-  static constexpr const char *pyClassName = "UnrankedTensorType";
-  using PyConcreteType::PyConcreteType;
+void PyFloat8E3M4Type::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat8E3M4TypeGet(context->get());
+        return PyFloat8E3M4Type(context->getRef(), t);
+      },
+      nanobind::arg("context") = nanobind::none(),
+      "Create a float8_e3m4 type.");
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](PyType &elementType, DefaultingPyLocation loc) {
-          PyMlirContext::ErrorCapture errors(loc->getContext());
-          MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
-          if (mlirTypeIsNull(t))
-            throw MLIRError("Invalid type", errors.take());
-          return PyUnrankedTensorType(elementType.getContext(), t);
-        },
-        nb::arg("element_type"), nb::arg("loc") = nb::none(),
-        "Create a unranked tensor type");
-    c.def_static(
-        "get_unchecked",
-        [](PyType &elementType, DefaultingPyMlirContext context) {
-          PyMlirContext::ErrorCapture errors(context->getRef());
-          MlirType t = mlirUnrankedTensorTypeGet(elementType);
-          if (mlirTypeIsNull(t))
-            throw MLIRError("Invalid type", errors.take());
-          return PyUnrankedTensorType(elementType.getContext(), t);
-        },
-        nb::arg("element_type"), nb::arg("context") = nb::none(),
-        "Create a unranked tensor type");
-  }
-};
+void PyFloat8E8M0FNUType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirFloat8E8M0FNUTypeGet(context->get());
+        return PyFloat8E8M0FNUType(context->getRef(), t);
+      },
+      nanobind::arg("context") = nanobind::none(),
+      "Create a float8_e8m0fnu type.");
+}
 
-/// Ranked MemRef Type subclass - MemRefType.
-class PyMemRefType : public PyConcreteType<PyMemRefType, PyShapedType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAMemRef;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirMemRefTypeGetTypeID;
-  static constexpr const char *pyClassName = "MemRefType";
-  using PyConcreteType::PyConcreteType;
+void PyBF16Type::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirBF16TypeGet(context->get());
+        return PyBF16Type(context->getRef(), t);
+      },
+      nanobind::arg("context") = nanobind::none(), "Create a bf16 type.");
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-         "get",
-         [](std::vector<int64_t> shape, PyType &elementType,
-            PyAttribute *layout, PyAttribute *memorySpace,
-            DefaultingPyLocation loc) {
-           PyMlirContext::ErrorCapture errors(loc->getContext());
-           MlirAttribute layoutAttr = layout ? *layout : 
mlirAttributeGetNull();
-           MlirAttribute memSpaceAttr =
-               memorySpace ? *memorySpace : mlirAttributeGetNull();
-           MlirType t =
-               mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
-                                        shape.data(), layoutAttr, 
memSpaceAttr);
-           if (mlirTypeIsNull(t))
-             throw MLIRError("Invalid type", errors.take());
-           return PyMemRefType(elementType.getContext(), t);
-         },
-         nb::arg("shape"), nb::arg("element_type"),
-         nb::arg("layout") = nb::none(), nb::arg("memory_space") = nb::none(),
-         nb::arg("loc") = nb::none(), "Create a memref type")
-        .def_static(
-            "get_unchecked",
-            [](std::vector<int64_t> shape, PyType &elementType,
-               PyAttribute *layout, PyAttribute *memorySpace,
-               DefaultingPyMlirContext context) {
-              PyMlirContext::ErrorCapture errors(context->getRef());
-              MlirAttribute layoutAttr =
-                  layout ? *layout : mlirAttributeGetNull();
-              MlirAttribute memSpaceAttr =
-                  memorySpace ? *memorySpace : mlirAttributeGetNull();
-              MlirType t =
-                  mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
-                                    layoutAttr, memSpaceAttr);
-              if (mlirTypeIsNull(t))
-                throw MLIRError("Invalid type", errors.take());
-              return PyMemRefType(elementType.getContext(), t);
-            },
-            nb::arg("shape"), nb::arg("element_type"),
-            nb::arg("layout") = nb::none(),
-            nb::arg("memory_space") = nb::none(),
-            nb::arg("context") = nb::none(), "Create a memref type")
-        .def_prop_ro(
-            "layout",
-            [](PyMemRefType &self) -> nb::typed<nb::object, PyAttribute> {
-              return PyAttribute(self.getContext(),
-                                 mlirMemRefTypeGetLayout(self))
-                  .maybeDownCast();
-            },
-            "The layout of the MemRef type.")
-        .def(
-            "get_strides_and_offset",
-            [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> 
{
-              std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
-              int64_t offset;
-              if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
-                      self, strides.data(), &offset)))
-                throw std::runtime_error(
-                    "Failed to extract strides and offset from memref.");
-              return {strides, offset};
-            },
-            "The strides and offset of the MemRef type.")
-        .def_prop_ro(
-            "affine_map",
-            [](PyMemRefType &self) -> PyAffineMap {
-              MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
-              return PyAffineMap(self.getContext(), map);
-            },
-            "The layout of the MemRef type as an affine map.")
-        .def_prop_ro(
-            "memory_space",
-            [](PyMemRefType &self)
-                -> std::optional<nb::typed<nb::object, PyAttribute>> {
-              MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
-              if (mlirAttributeIsNull(a))
-                return std::nullopt;
-              return PyAttribute(self.getContext(), a).maybeDownCast();
-            },
-            "Returns the memory space of the given MemRef type.");
-  }
-};
+void PyF16Type::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirF16TypeGet(context->get());
+        return PyF16Type(context->getRef(), t);
+      },
+      nanobind::arg("context") = nanobind::none(), "Create a f16 type.");
+}
 
-/// Unranked MemRef Type subclass - UnrankedMemRefType.
-class PyUnrankedMemRefType
-    : public PyConcreteType<PyUnrankedMemRefType, PyShapedType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAUnrankedMemRef;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirUnrankedMemRefTypeGetTypeID;
-  static constexpr const char *pyClassName = "UnrankedMemRefType";
-  using PyConcreteType::PyConcreteType;
+void PyTF32Type::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirTF32TypeGet(context->get());
+        return PyTF32Type(context->getRef(), t);
+      },
+      nanobind::arg("context") = nanobind::none(), "Create a tf32 type.");
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-         "get",
-         [](PyType &elementType, PyAttribute *memorySpace,
-            DefaultingPyLocation loc) {
-           PyMlirContext::ErrorCapture errors(loc->getContext());
-           MlirAttribute memSpaceAttr = {};
-           if (memorySpace)
-             memSpaceAttr = *memorySpace;
+void PyF32Type::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirF32TypeGet(context->get());
+        return PyF32Type(context->getRef(), t);
+      },
+      nanobind::arg("context") = nanobind::none(), "Create a f32 type.");
+}
 
-           MlirType t =
-               mlirUnrankedMemRefTypeGetChecked(loc, elementType, 
memSpaceAttr);
-           if (mlirTypeIsNull(t))
-             throw MLIRError("Invalid type", errors.take());
-           return PyUnrankedMemRefType(elementType.getContext(), t);
-         },
-         nb::arg("element_type"), nb::arg("memory_space").none(),
-         nb::arg("loc") = nb::none(), "Create a unranked memref type")
-        .def_static(
-            "get_unchecked",
-            [](PyType &elementType, PyAttribute *memorySpace,
-               DefaultingPyMlirContext context) {
-              PyMlirContext::ErrorCapture errors(context->getRef());
-              MlirAttribute memSpaceAttr = {};
-              if (memorySpace)
-                memSpaceAttr = *memorySpace;
+void PyF64Type::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirF64TypeGet(context->get());
+        return PyF64Type(context->getRef(), t);
+      },
+      nanobind::arg("context") = nanobind::none(), "Create a f64 type.");
+}
 
-              MlirType t = mlirUnrankedMemRefTypeGet(elementType, 
memSpaceAttr);
-              if (mlirTypeIsNull(t))
-                throw MLIRError("Invalid type", errors.take());
-              return PyUnrankedMemRefType(elementType.getContext(), t);
-            },
-            nb::arg("element_type"), nb::arg("memory_space").none(),
-            nb::arg("context") = nb::none(), "Create a unranked memref type")
-        .def_prop_ro(
-            "memory_space",
-            [](PyUnrankedMemRefType &self)
-                -> std::optional<nb::typed<nb::object, PyAttribute>> {
-              MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
-              if (mlirAttributeIsNull(a))
-                return std::nullopt;
-              return PyAttribute(self.getContext(), a).maybeDownCast();
-            },
-            "Returns the memory space of the given Unranked MemRef type.");
-  }
-};
+void PyNoneType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](DefaultingPyMlirContext context) {
+        MlirType t = mlirNoneTypeGet(context->get());
+        return PyNoneType(context->getRef(), t);
+      },
+      nanobind::arg("context") = nanobind::none(), "Create a none type.");
+}
+
+void PyComplexType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](PyType &elementType) {
+        // The element must be a floating point or integer scalar type.
+        if (mlirTypeIsAIntegerOrFloat(elementType)) {
+          MlirType t = mlirComplexTypeGet(elementType);
+          return PyComplexType(elementType.getContext(), t);
+        }
+        throw nanobind::value_error(
+            (Twine("invalid '") +
+             nanobind::cast<std::string>(
+                 nanobind::repr(nanobind::cast(elementType))) +
+             "' and expected floating point or integer type.")
+                .str()
+                .c_str());
+      },
+      "Create a complex type");
+  c.def_prop_ro(
+      "element_type",
+      [](PyComplexType &self) -> nanobind::typed<nanobind::object, PyType> {
+        return PyType(self.getContext(), mlirComplexTypeGetElementType(self))
+            .maybeDownCast();
+      },
+      "Returns element type.");
+}
 
-/// Tuple Type subclass - TupleType.
-class PyTupleType : public PyConcreteType<PyTupleType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsATuple;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirTupleTypeGetTypeID;
-  static constexpr const char *pyClassName = "TupleType";
-  using PyConcreteType::PyConcreteType;
+void PyVectorType::bindDerived(ClassTy &c) {
+  c.def_static("get", &PyVectorType::getChecked, nanobind::arg("shape"),
+               nanobind::arg("element_type"), nanobind::kw_only(),
+               nanobind::arg("scalable") = nanobind::none(),
+               nanobind::arg("scalable_dims") = nanobind::none(),
+               nanobind::arg("loc") = nanobind::none(), "Create a vector type")
+      .def_static("get_unchecked", &PyVectorType::get, nanobind::arg("shape"),
+                  nanobind::arg("element_type"), nanobind::kw_only(),
+                  nanobind::arg("scalable") = nanobind::none(),
+                  nanobind::arg("scalable_dims") = nanobind::none(),
+                  nanobind::arg("context") = nanobind::none(),
+                  "Create a vector type")
+      .def_prop_ro("scalable",
+                   [](MlirType self) { return mlirVectorTypeIsScalable(self); 
})
+      .def_prop_ro("scalable_dims", [](MlirType self) {
+        std::vector<bool> scalableDims;
+        size_t rank = static_cast<size_t>(mlirShapedTypeGetRank(self));
+        scalableDims.reserve(rank);
+        for (size_t i = 0; i < rank; ++i)
+          scalableDims.push_back(mlirVectorTypeIsDimScalable(self, i));
+        return scalableDims;
+      });
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get_tuple",
-        [](const std::vector<PyType> &elements,
-           DefaultingPyMlirContext context) {
-          std::vector<MlirType> mlirElements;
-          mlirElements.reserve(elements.size());
-          for (const auto &element : elements)
-            mlirElements.push_back(element.get());
-          MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
-                                        mlirElements.data());
-          return PyTupleType(context->getRef(), t);
-        },
-        nb::arg("elements"), nb::arg("context") = nb::none(),
-        "Create a tuple type");
-    c.def_static(
-        "get_tuple",
-        [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
-          MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
-                                        elements.data());
-          return PyTupleType(context->getRef(), t);
-        },
-        nb::arg("elements"), nb::arg("context") = nb::none(),
-        // clang-format off
-        nb::sig("def get_tuple(elements: Sequence[Type], context: Context | 
None = None) -> TupleType"),
-        // clang-format on
-        "Create a tuple type");
-    c.def(
-        "get_type",
-        [](PyTupleType &self, intptr_t pos) -> nb::typed<nb::object, PyType> {
-          return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
-              .maybeDownCast();
-        },
-        nb::arg("pos"), "Returns the pos-th type in the tuple type.");
-    c.def_prop_ro(
-        "num_types",
-        [](PyTupleType &self) -> intptr_t {
-          return mlirTupleTypeGetNumTypes(self);
-        },
-        "Returns the number of types contained in a tuple.");
-  }
-};
+void PyRankedTensorType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](std::vector<int64_t> shape, PyType &elementType,
+         std::optional<PyAttribute> &encodingAttr, DefaultingPyLocation loc) {
+        PyMlirContext::ErrorCapture errors(loc->getContext());
+        MlirType t = mlirRankedTensorTypeGetChecked(
+            loc, shape.size(), shape.data(), elementType,
+            encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
+        if (mlirTypeIsNull(t))
+          throw MLIRError("Invalid type", errors.take());
+        return PyRankedTensorType(elementType.getContext(), t);
+      },
+      nanobind::arg("shape"), nanobind::arg("element_type"),
+      nanobind::arg("encoding") = nanobind::none(),
+      nanobind::arg("loc") = nanobind::none(), "Create a ranked tensor type");
+  c.def_static(
+      "get_unchecked",
+      [](std::vector<int64_t> shape, PyType &elementType,
+         std::optional<PyAttribute> &encodingAttr,
+         DefaultingPyMlirContext context) {
+        PyMlirContext::ErrorCapture errors(context->getRef());
+        MlirType t = mlirRankedTensorTypeGet(
+            shape.size(), shape.data(), elementType,
+            encodingAttr ? encodingAttr->get() : mlirAttributeGetNull());
+        if (mlirTypeIsNull(t))
+          throw MLIRError("Invalid type", errors.take());
+        return PyRankedTensorType(elementType.getContext(), t);
+      },
+      nanobind::arg("shape"), nanobind::arg("element_type"),
+      nanobind::arg("encoding") = nanobind::none(),
+      nanobind::arg("context") = nanobind::none(),
+      "Create a ranked tensor type");
+  c.def_prop_ro(
+      "encoding",
+      [](PyRankedTensorType &self)
+          -> std::optional<nanobind::typed<nanobind::object, PyAttribute>> {
+        MlirAttribute encoding = mlirRankedTensorTypeGetEncoding(self.get());
+        if (mlirAttributeIsNull(encoding))
+          return std::nullopt;
+        return PyAttribute(self.getContext(), encoding).maybeDownCast();
+      });
+}
 
-/// Function type.
-class PyFunctionType : public PyConcreteType<PyFunctionType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAFunction;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirFunctionTypeGetTypeID;
-  static constexpr const char *pyClassName = "FunctionType";
-  using PyConcreteType::PyConcreteType;
+void PyUnrankedTensorType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](PyType &elementType, DefaultingPyLocation loc) {
+        PyMlirContext::ErrorCapture errors(loc->getContext());
+        MlirType t = mlirUnrankedTensorTypeGetChecked(loc, elementType);
+        if (mlirTypeIsNull(t))
+          throw MLIRError("Invalid type", errors.take());
+        return PyUnrankedTensorType(elementType.getContext(), t);
+      },
+      nanobind::arg("element_type"), nanobind::arg("loc") = nanobind::none(),
+      "Create a unranked tensor type");
+  c.def_static(
+      "get_unchecked",
+      [](PyType &elementType, DefaultingPyMlirContext context) {
+        PyMlirContext::ErrorCapture errors(context->getRef());
+        MlirType t = mlirUnrankedTensorTypeGet(elementType);
+        if (mlirTypeIsNull(t))
+          throw MLIRError("Invalid type", errors.take());
+        return PyUnrankedTensorType(elementType.getContext(), t);
+      },
+      nanobind::arg("element_type"),
+      nanobind::arg("context") = nanobind::none(),
+      "Create a unranked tensor type");
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](std::vector<PyType> inputs, std::vector<PyType> results,
-           DefaultingPyMlirContext context) {
-          std::vector<MlirType> mlirInputs;
-          mlirInputs.reserve(inputs.size());
-          for (const auto &input : inputs)
-            mlirInputs.push_back(input.get());
-          std::vector<MlirType> mlirResults;
-          mlirResults.reserve(results.size());
-          for (const auto &result : results)
-            mlirResults.push_back(result.get());
+void PyMemRefType::bindDerived(ClassTy &c) {
+  c.def_static(
+       "get",
+       [](std::vector<int64_t> shape, PyType &elementType, PyAttribute *layout,
+          PyAttribute *memorySpace, DefaultingPyLocation loc) {
+         PyMlirContext::ErrorCapture errors(loc->getContext());
+         MlirAttribute layoutAttr = layout ? *layout : mlirAttributeGetNull();
+         MlirAttribute memSpaceAttr =
+             memorySpace ? *memorySpace : mlirAttributeGetNull();
+         MlirType t =
+             mlirMemRefTypeGetChecked(loc, elementType, shape.size(),
+                                      shape.data(), layoutAttr, memSpaceAttr);
+         if (mlirTypeIsNull(t))
+           throw MLIRError("Invalid type", errors.take());
+         return PyMemRefType(elementType.getContext(), t);
+       },
+       nanobind::arg("shape"), nanobind::arg("element_type"),
+       nanobind::arg("layout") = nanobind::none(),
+       nanobind::arg("memory_space") = nanobind::none(),
+       nanobind::arg("loc") = nanobind::none(), "Create a memref type")
+      .def_static(
+          "get_unchecked",
+          [](std::vector<int64_t> shape, PyType &elementType,
+             PyAttribute *layout, PyAttribute *memorySpace,
+             DefaultingPyMlirContext context) {
+            PyMlirContext::ErrorCapture errors(context->getRef());
+            MlirAttribute layoutAttr =
+                layout ? *layout : mlirAttributeGetNull();
+            MlirAttribute memSpaceAttr =
+                memorySpace ? *memorySpace : mlirAttributeGetNull();
+            MlirType t =
+                mlirMemRefTypeGet(elementType, shape.size(), shape.data(),
+                                  layoutAttr, memSpaceAttr);
+            if (mlirTypeIsNull(t))
+              throw MLIRError("Invalid type", errors.take());
+            return PyMemRefType(elementType.getContext(), t);
+          },
+          nanobind::arg("shape"), nanobind::arg("element_type"),
+          nanobind::arg("layout") = nanobind::none(),
+          nanobind::arg("memory_space") = nanobind::none(),
+          nanobind::arg("context") = nanobind::none(), "Create a memref type")
+      .def_prop_ro(
+          "layout",
+          [](PyMemRefType &self)
+              -> nanobind::typed<nanobind::object, PyAttribute> {
+            return PyAttribute(self.getContext(), 
mlirMemRefTypeGetLayout(self))
+                .maybeDownCast();
+          },
+          "The layout of the MemRef type.")
+      .def(
+          "get_strides_and_offset",
+          [](PyMemRefType &self) -> std::pair<std::vector<int64_t>, int64_t> {
+            std::vector<int64_t> strides(mlirShapedTypeGetRank(self));
+            int64_t offset;
+            if (mlirLogicalResultIsFailure(mlirMemRefTypeGetStridesAndOffset(
+                    self, strides.data(), &offset)))
+              throw std::runtime_error(
+                  "Failed to extract strides and offset from memref.");
+            return {strides, offset};
+          },
+          "The strides and offset of the MemRef type.")
+      .def_prop_ro(
+          "affine_map",
+          [](PyMemRefType &self) -> PyAffineMap {
+            MlirAffineMap map = mlirMemRefTypeGetAffineMap(self);
+            return PyAffineMap(self.getContext(), map);
+          },
+          "The layout of the MemRef type as an affine map.")
+      .def_prop_ro(
+          "memory_space",
+          [](PyMemRefType &self)
+              -> std::optional<nanobind::typed<nanobind::object, PyAttribute>> 
{
+            MlirAttribute a = mlirMemRefTypeGetMemorySpace(self);
+            if (mlirAttributeIsNull(a))
+              return std::nullopt;
+            return PyAttribute(self.getContext(), a).maybeDownCast();
+          },
+          "Returns the memory space of the given MemRef type.");
+}
 
-          MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(),
-                                           mlirInputs.data(), results.size(),
-                                           mlirResults.data());
-          return PyFunctionType(context->getRef(), t);
-        },
-        nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
-        "Gets a FunctionType from a list of input and result types");
-    c.def_static(
-        "get",
-        [](std::vector<MlirType> inputs, std::vector<MlirType> results,
-           DefaultingPyMlirContext context) {
-          MlirType t =
-              mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
-                                  results.size(), results.data());
-          return PyFunctionType(context->getRef(), t);
-        },
-        nb::arg("inputs"), nb::arg("results"), nb::arg("context") = nb::none(),
-        // clang-format off
-        nb::sig("def get(inputs: Sequence[Type], results: Sequence[Type], 
context: Context | None = None) -> FunctionType"),
-        // clang-format on
-        "Gets a FunctionType from a list of input and result types");
-    c.def_prop_ro(
-        "inputs",
-        [](PyFunctionType &self) {
-          MlirType t = self;
-          nb::list types;
-          for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
-               ++i) {
-            types.append(mlirFunctionTypeGetInput(t, i));
-          }
-          return types;
-        },
-        "Returns the list of input types in the FunctionType.");
-    c.def_prop_ro(
-        "results",
-        [](PyFunctionType &self) {
-          nb::list types;
-          for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
-               ++i) {
-            types.append(mlirFunctionTypeGetResult(self, i));
-          }
-          return types;
-        },
-        "Returns the list of result types in the FunctionType.");
-  }
-};
+void PyUnrankedMemRefType::bindDerived(ClassTy &c) {
+  c.def_static(
+       "get",
+       [](PyType &elementType, PyAttribute *memorySpace,
+          DefaultingPyLocation loc) {
+         PyMlirContext::ErrorCapture errors(loc->getContext());
+         MlirAttribute memSpaceAttr = {};
+         if (memorySpace)
+           memSpaceAttr = *memorySpace;
+
+         MlirType t =
+             mlirUnrankedMemRefTypeGetChecked(loc, elementType, memSpaceAttr);
+         if (mlirTypeIsNull(t))
+           throw MLIRError("Invalid type", errors.take());
+         return PyUnrankedMemRefType(elementType.getContext(), t);
+       },
+       nanobind::arg("element_type"), nanobind::arg("memory_space").none(),
+       nanobind::arg("loc") = nanobind::none(), "Create a unranked memref 
type")
+      .def_static(
+          "get_unchecked",
+          [](PyType &elementType, PyAttribute *memorySpace,
+             DefaultingPyMlirContext context) {
+            PyMlirContext::ErrorCapture errors(context->getRef());
+            MlirAttribute memSpaceAttr = {};
+            if (memorySpace)
+              memSpaceAttr = *memorySpace;
+
+            MlirType t = mlirUnrankedMemRefTypeGet(elementType, memSpaceAttr);
+            if (mlirTypeIsNull(t))
+              throw MLIRError("Invalid type", errors.take());
+            return PyUnrankedMemRefType(elementType.getContext(), t);
+          },
+          nanobind::arg("element_type"), nanobind::arg("memory_space").none(),
+          nanobind::arg("context") = nanobind::none(),
+          "Create a unranked memref type")
+      .def_prop_ro(
+          "memory_space",
+          [](PyUnrankedMemRefType &self)
+              -> std::optional<nanobind::typed<nanobind::object, PyAttribute>> 
{
+            MlirAttribute a = mlirUnrankedMemrefGetMemorySpace(self);
+            if (mlirAttributeIsNull(a))
+              return std::nullopt;
+            return PyAttribute(self.getContext(), a).maybeDownCast();
+          },
+          "Returns the memory space of the given Unranked MemRef type.");
+}
 
-/// Opaque Type subclass - OpaqueType.
-class PyOpaqueType : public PyConcreteType<PyOpaqueType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAOpaque;
-  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
-      mlirOpaqueTypeGetTypeID;
-  static constexpr const char *pyClassName = "OpaqueType";
-  using PyConcreteType::PyConcreteType;
+void PyTupleType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get_tuple",
+      [](const std::vector<PyType> &elements, DefaultingPyMlirContext context) 
{
+        std::vector<MlirType> mlirElements;
+        mlirElements.reserve(elements.size());
+        for (const auto &element : elements)
+          mlirElements.push_back(element.get());
+        MlirType t = mlirTupleTypeGet(context->get(), elements.size(),
+                                      mlirElements.data());
+        return PyTupleType(context->getRef(), t);
+      },
+      nanobind::arg("elements"), nanobind::arg("context") = nanobind::none(),
+      "Create a tuple type");
+  c.def_static(
+      "get_tuple",
+      [](std::vector<MlirType> elements, DefaultingPyMlirContext context) {
+        MlirType t =
+            mlirTupleTypeGet(context->get(), elements.size(), elements.data());
+        return PyTupleType(context->getRef(), t);
+      },
+      nanobind::arg("elements"), nanobind::arg("context") = nanobind::none(),
+      // clang-format off
+        nanobind::sig("def get_tuple(elements: Sequence[Type], context: 
Context | None = None) -> TupleType"),
+      // clang-format on
+      "Create a tuple type");
+  c.def(
+      "get_type",
+      [](PyTupleType &self,
+         intptr_t pos) -> nanobind::typed<nanobind::object, PyType> {
+        return PyType(self.getContext(), mlirTupleTypeGetType(self, pos))
+            .maybeDownCast();
+      },
+      nanobind::arg("pos"), "Returns the pos-th type in the tuple type.");
+  c.def_prop_ro(
+      "num_types",
+      [](PyTupleType &self) -> intptr_t {
+        return mlirTupleTypeGetNumTypes(self);
+      },
+      "Returns the number of types contained in a tuple.");
+}
 
-  static void bindDerived(ClassTy &c) {
-    c.def_static(
-        "get",
-        [](const std::string &dialectNamespace, const std::string &typeData,
-           DefaultingPyMlirContext context) {
-          MlirType type = mlirOpaqueTypeGet(context->get(),
-                                            toMlirStringRef(dialectNamespace),
-                                            toMlirStringRef(typeData));
-          return PyOpaqueType(context->getRef(), type);
-        },
-        nb::arg("dialect_namespace"), nb::arg("buffer"),
-        nb::arg("context") = nb::none(),
-        "Create an unregistered (opaque) dialect type.");
-    c.def_prop_ro(
-        "dialect_namespace",
-        [](PyOpaqueType &self) {
-          MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
-          return nb::str(stringRef.data, stringRef.length);
-        },
-        "Returns the dialect namespace for the Opaque type as a string.");
-    c.def_prop_ro(
-        "data",
-        [](PyOpaqueType &self) {
-          MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
-          return nb::str(stringRef.data, stringRef.length);
-        },
-        "Returns the data for the Opaque type as a string.");
-  }
-};
+void PyFunctionType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](std::vector<PyType> inputs, std::vector<PyType> results,
+         DefaultingPyMlirContext context) {
+        std::vector<MlirType> mlirInputs;
+        mlirInputs.reserve(inputs.size());
+        for (const auto &input : inputs)
+          mlirInputs.push_back(input.get());
+        std::vector<MlirType> mlirResults;
+        mlirResults.reserve(results.size());
+        for (const auto &result : results)
+          mlirResults.push_back(result.get());
+
+        MlirType t = mlirFunctionTypeGet(context->get(), inputs.size(),
+                                         mlirInputs.data(), results.size(),
+                                         mlirResults.data());
+        return PyFunctionType(context->getRef(), t);
+      },
+      nanobind::arg("inputs"), nanobind::arg("results"),
+      nanobind::arg("context") = nanobind::none(),
+      "Gets a FunctionType from a list of input and result types");
+  c.def_static(
+      "get",
+      [](std::vector<MlirType> inputs, std::vector<MlirType> results,
+         DefaultingPyMlirContext context) {
+        MlirType t =
+            mlirFunctionTypeGet(context->get(), inputs.size(), inputs.data(),
+                                results.size(), results.data());
+        return PyFunctionType(context->getRef(), t);
+      },
+      nanobind::arg("inputs"), nanobind::arg("results"),
+      nanobind::arg("context") = nanobind::none(),
+      // clang-format off
+        nanobind::sig("def get(inputs: Sequence[Type], results: 
Sequence[Type], context: Context | None = None) -> FunctionType"),
+      // clang-format on
+      "Gets a FunctionType from a list of input and result types");
+  c.def_prop_ro(
+      "inputs",
+      [](PyFunctionType &self) {
+        MlirType t = self;
+        nanobind::list types;
+        for (intptr_t i = 0, e = mlirFunctionTypeGetNumInputs(self); i < e;
+             ++i) {
+          types.append(mlirFunctionTypeGetInput(t, i));
+        }
+        return types;
+      },
+      "Returns the list of input types in the FunctionType.");
+  c.def_prop_ro(
+      "results",
+      [](PyFunctionType &self) {
+        nanobind::list types;
+        for (intptr_t i = 0, e = mlirFunctionTypeGetNumResults(self); i < e;
+             ++i) {
+          types.append(mlirFunctionTypeGetResult(self, i));
+        }
+        return types;
+      },
+      "Returns the list of result types in the FunctionType.");
+}
+
+void PyOpaqueType::bindDerived(ClassTy &c) {
+  c.def_static(
+      "get",
+      [](const std::string &dialectNamespace, const std::string &typeData,
+         DefaultingPyMlirContext context) {
+        MlirType type =
+            mlirOpaqueTypeGet(context->get(), 
toMlirStringRef(dialectNamespace),
+                              toMlirStringRef(typeData));
+        return PyOpaqueType(context->getRef(), type);
+      },
+      nanobind::arg("dialect_namespace"), nanobind::arg("buffer"),
+      nanobind::arg("context") = nanobind::none(),
+      "Create an unregistered (opaque) dialect type.");
+  c.def_prop_ro(
+      "dialect_namespace",
+      [](PyOpaqueType &self) {
+        MlirStringRef stringRef = mlirOpaqueTypeGetDialectNamespace(self);
+        return nanobind::str(stringRef.data, stringRef.length);
+      },
+      "Returns the dialect namespace for the Opaque type as a string.");
+  c.def_prop_ro(
+      "data",
+      [](PyOpaqueType &self) {
+        MlirStringRef stringRef = mlirOpaqueTypeGetData(self);
+        return nanobind::str(stringRef.data, stringRef.length);
+      },
+      "Returns the data for the Opaque type as a string.");
+}
 
+const PyShapedType::IsAFunctionTy PyShapedType::isaFunction = 
mlirTypeIsAShaped;
 } // namespace MLIR_BINDINGS_PYTHON_DOMAIN
 } // namespace python
 } // namespace mlir
diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt
index 4a9fb127ee08c..582863ffcbb0d 100644
--- a/mlir/python/CMakeLists.txt
+++ b/mlir/python/CMakeLists.txt
@@ -535,7 +535,6 @@ declare_mlir_python_extension(MLIRPythonExtension.Core
     IRAffine.cpp
     IRAttributes.cpp
     IRInterfaces.cpp
-    IRTypes.cpp
     Pass.cpp
     Rewrite.cpp
 
@@ -846,8 +845,9 @@ 
declare_mlir_python_extension(MLIRPythonExtension.MLIRPythonSupport
   ADD_TO_PARENT MLIRPythonSources.Core
   ROOT_DIR "${PYTHON_SOURCE_DIR}"
   SOURCES
-    IRCore.cpp
     Globals.cpp
+    IRCore.cpp
+    IRTypes.cpp
 )
 
 
################################################################################
diff --git a/mlir/test/python/lib/PythonTestModuleNanobind.cpp 
b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
index 43573cbc305fa..a296b5e814b4b 100644
--- a/mlir/test/python/lib/PythonTestModuleNanobind.cpp
+++ b/mlir/test/python/lib/PythonTestModuleNanobind.cpp
@@ -15,6 +15,7 @@
 #include "mlir-c/IR.h"
 #include "mlir/Bindings/Python/Diagnostics.h"
 #include "mlir/Bindings/Python/IRCore.h"
+#include "mlir/Bindings/Python/IRTypes.h"
 #include "mlir/Bindings/Python/Nanobind.h"
 #include "mlir/Bindings/Python/NanobindAdaptors.h"
 #include "nanobind/nanobind.h"
@@ -47,6 +48,49 @@ struct PyTestType
   }
 };
 
+struct PyTestIntegerRankedTensorType
+    : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteType<
+          PyTestIntegerRankedTensorType,
+          mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyRankedTensorType> {
+  static constexpr IsAFunctionTy isaFunction = mlirTypeIsARankedIntegerTensor;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirRankedTensorTypeGetTypeID;
+  static constexpr const char *pyClassName = "TestIntegerRankedTensorType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c) {
+    c.def_static(
+        "get",
+        [](std::vector<int64_t> shape, unsigned width,
+           mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+               ctx) {
+          MlirAttribute encoding = mlirAttributeGetNull();
+          return PyTestIntegerRankedTensorType(
+              ctx->getRef(),
+              mlirRankedTensorTypeGet(
+                  shape.size(), shape.data(),
+                  mlirIntegerTypeGet(ctx.get()->get(), width), encoding));
+        },
+        nb::arg("shape"), nb::arg("width"),
+        nb::arg("context").none() = nb::none());
+  }
+};
+
+struct PyTestTensorValue
+    : mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteValue<
+          PyTestTensorValue> {
+  static constexpr IsAFunctionTy isaFunction =
+      mlirTypeIsAPythonTestTestTensorValue;
+  static constexpr GetTypeIDFunctionTy getTypeIdFunction =
+      mlirRankedTensorTypeGetTypeID;
+  static constexpr const char *pyClassName = "TestTensorValue";
+  using PyConcreteValue::PyConcreteValue;
+
+  static void bindDerived(ClassTy &c) {
+    c.def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); });
+  }
+};
+
 class PyTestAttr
     : public mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::PyConcreteAttribute<
           PyTestAttr> {
@@ -73,18 +117,18 @@ class PyTestAttr
 NB_MODULE(_mlirPythonTestNanobind, m) {
   m.def(
       "register_python_test_dialect",
-      [](MlirContext context, bool load) {
+      [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+             context,
+         bool load) {
         MlirDialectHandle pythonTestDialect =
             mlirGetDialectHandle__python_test__();
-        mlirDialectHandleRegisterDialect(pythonTestDialect, context);
+        mlirDialectHandleRegisterDialect(pythonTestDialect,
+                                         context.get()->get());
         if (load) {
-          mlirDialectHandleLoadDialect(pythonTestDialect, context);
+          mlirDialectHandleLoadDialect(pythonTestDialect, 
context.get()->get());
         }
       },
-      nb::arg("context"), nb::arg("load") = true,
-      // clang-format off
-      nb::sig("def register_python_test_dialect(context: " 
MAKE_MLIR_PYTHON_QUALNAME("ir.Context") ", load: bool = True) -> None"));
-  // clang-format on
+      nb::arg("context").none() = nb::none(), nb::arg("load") = true);
 
   m.def(
       "register_dialect",
@@ -100,73 +144,16 @@ NB_MODULE(_mlirPythonTestNanobind, m) {
 
   m.def(
       "test_diagnostics_with_errors_and_notes",
-      [](MlirContext ctx) {
-        mlir::python::CollectDiagnosticsToStringScope handler(ctx);
-        mlirPythonTestEmitDiagnosticWithNote(ctx);
+      [](mlir::python::MLIR_BINDINGS_PYTHON_DOMAIN::DefaultingPyMlirContext
+             ctx) {
+        mlir::python::CollectDiagnosticsToStringScope 
handler(ctx.get()->get());
+        mlirPythonTestEmitDiagnosticWithNote(ctx.get()->get());
         throw nb::value_error(handler.takeMessage().c_str());
       },
-      // clang-format off
-      nb::sig("def test_diagnostics_with_errors_and_notes(arg: " 
MAKE_MLIR_PYTHON_QUALNAME("ir.Context") ", /) -> None"));
-  // clang-format on
+      nb::arg("context").none() = nb::none());
 
   PyTestAttr::bind(m);
   PyTestType::bind(m);
-
-  auto typeCls =
-      mlir_type_subclass(m, "TestIntegerRankedTensorType",
-                         mlirTypeIsARankedIntegerTensor,
-                         nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-                             .attr("RankedTensorType"))
-          .def_classmethod(
-              "get",
-              [](const nb::object &cls, std::vector<int64_t> shape,
-                 unsigned width, MlirContext ctx) {
-                MlirAttribute encoding = mlirAttributeGetNull();
-                return cls(mlirRankedTensorTypeGet(
-                    shape.size(), shape.data(), mlirIntegerTypeGet(ctx, width),
-                    encoding));
-              },
-              // clang-format off
-              nb::sig("def get(cls: object, shape: 
collections.abc.Sequence[int], width: int, context: " 
MAKE_MLIR_PYTHON_QUALNAME("ir.Context") " | None = None) -> object"),
-              // clang-format on
-              nb::arg("cls"), nb::arg("shape"), nb::arg("width"),
-              nb::arg("context").none() = nb::none());
-
-  assert(nb::hasattr(typeCls.get_class(), "static_typeid") &&
-         "TestIntegerRankedTensorType has no static_typeid");
-
-  MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID();
-
-  nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-      .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
-          mlirRankedTensorTypeID, nb::arg("replace") = true)(
-          nanobind::cpp_function([typeCls](const nb::object &mlirType) {
-            return typeCls.get_class()(mlirType);
-          }));
-
-  auto valueCls = mlir_value_subclass(m, "TestTensorValue",
-                                      mlirTypeIsAPythonTestTestTensorValue)
-                      .def("is_null", [](MlirValue &self) {
-                        return mlirValueIsNull(self);
-                      });
-
-  nb::module_::import_(MAKE_MLIR_PYTHON_QUALNAME("ir"))
-      .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)(
-          mlirRankedTensorTypeID)(
-          nanobind::cpp_function([valueCls](const nb::object &valueObj) {
-            std::optional<nb::object> capsule =
-                mlirApiObjectToCapsule(valueObj);
-            assert(capsule.has_value() && "capsule is not null");
-            MlirValue v = mlirPythonCapsuleToValue(capsule.value().ptr());
-            MlirType t = mlirValueGetType(v);
-            // This is hyper-specific in order to exercise/test registering a
-            // value caster from cpp (but only for a single test case; see
-            // testTensorValue python_test.py).
-            if (mlirShapedTypeHasStaticShape(t) &&
-                mlirShapedTypeGetDimSize(t, 0) == 1 &&
-                mlirShapedTypeGetDimSize(t, 1) == 2 &&
-                mlirShapedTypeGetDimSize(t, 2) == 3)
-              return valueCls.get_class()(valueObj);
-            return valueObj;
-          }));
+  PyTestIntegerRankedTensorType::bind(m);
+  PyTestTensorValue::bind(m);
 }

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to