fhahn updated this revision to Diff 263879.
fhahn edited the summary of this revision.
fhahn added a comment.

ping. Simplify code, extend tests. This should now be ready for review.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D72781/new/

https://reviews.llvm.org/D72781

Files:
  clang/include/clang/AST/Type.h
  clang/include/clang/Basic/Builtins.def
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Sema/Sema.h
  clang/lib/CodeGen/CGBuiltin.cpp
  clang/lib/Sema/SemaChecking.cpp
  clang/test/CodeGen/matrix-type-builtins.c
  clang/test/CodeGenCXX/matrix-type-builtins.cpp
  clang/test/Sema/matrix-type-builtins.c
  clang/test/SemaCXX/matrix-type-builtins.cpp

Index: clang/test/SemaCXX/matrix-type-builtins.cpp
===================================================================
--- clang/test/SemaCXX/matrix-type-builtins.cpp
+++ clang/test/SemaCXX/matrix-type-builtins.cpp
@@ -32,3 +32,27 @@
   Mat1.value = transpose<unsigned, 3, 3, unsigned, 2, 3>(Mat2);
   // expected-note@-1 {{in instantiation of function template specialization 'transpose<unsigned int, 3, 3, unsigned int, 2, 3>' requested here}}
 }
+
+template <typename EltTy0, unsigned R0, unsigned C0, typename EltTy1, unsigned R1, unsigned C1>
+typename MyMatrix<EltTy1, R1, C1>::matrix_t column_major_load(MyMatrix<EltTy0, R0, C0> &A, EltTy0 *Ptr) {
+  char *v1 = __builtin_matrix_column_major_load(Ptr, 9, 4, 10);
+  // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'unsigned int __attribute__((matrix_type(9, 4)))'}}
+  // expected-error@-2 {{cannot initialize a variable of type 'char *' with an rvalue of type 'unsigned int __attribute__((matrix_type(9, 4)))'}}
+  // expected-error@-3 {{cannot initialize a variable of type 'char *' with an rvalue of type 'float __attribute__((matrix_type(9, 4)))'}}
+
+  return __builtin_matrix_column_major_load(Ptr, R0, C0, R0);
+  // expected-error@-1 {{cannot initialize return object of type 'typename MyMatrix<unsigned int, 5U, 5U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(5, 5)))') with an rvalue of type 'unsigned int __attribute__((matrix_type(2, 3)))'}}
+  // expected-error@-2 {{cannot initialize return object of type 'typename MyMatrix<unsigned int, 2U, 3U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 3)))') with an rvalue of type 'float __attribute__((matrix_type(2, 3)))'}}
+}
+
+void test_column_major_loads_template(unsigned *Ptr1, float *Ptr2) {
+  MyMatrix<unsigned, 2, 3> Mat1;
+  Mat1.value = column_major_load<unsigned, 2, 3, unsigned, 2, 3>(Mat1, Ptr1);
+  // expected-note@-1 {{in instantiation of function template specialization 'column_major_load<unsigned int, 2, 3, unsigned int, 2, 3>' requested here}}
+  column_major_load<unsigned, 2, 3, unsigned, 5, 5>(Mat1, Ptr1);
+  // expected-note@-1 {{in instantiation of function template specialization 'column_major_load<unsigned int, 2, 3, unsigned int, 5, 5>' requested here}}
+
+  MyMatrix<float, 2, 3> Mat2;
+  Mat1.value = column_major_load<float, 2, 3, unsigned, 2, 3>(Mat2, Ptr2);
+  // expected-note@-1 {{in instantiation of function template specialization 'column_major_load<float, 2, 3, unsigned int, 2, 3>' requested here}}
+}
Index: clang/test/Sema/matrix-type-builtins.c
===================================================================
--- clang/test/Sema/matrix-type-builtins.c
+++ clang/test/Sema/matrix-type-builtins.c
@@ -15,3 +15,48 @@
   __builtin_matrix_transpose("test");
   // expected-error@-1 {{first argument must be a matrix}}
 }
+
+struct Foo {
+  unsigned x;
+};
+
+void column_major_load(float *p1, int *p2, _Bool *p3, struct Foo *p4) {
+  sx5x10_t a1 = __builtin_matrix_column_major_load(p1, 5, 11, 5);
+  // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'float __attribute__((matrix_type(5, 11)))'}}
+  sx5x10_t a2 = __builtin_matrix_column_major_load(p1, 5, 9, 5);
+  // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'float __attribute__((matrix_type(5, 9)))'}}
+  sx5x10_t a3 = __builtin_matrix_column_major_load(p1, 6, 10, 6);
+  // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'float __attribute__((matrix_type(6, 10)))'}}
+  sx5x10_t a4 = __builtin_matrix_column_major_load(p1, 4, 10, 4);
+  // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'float __attribute__((matrix_type(4, 10)))'}}
+  sx5x10_t a5 = __builtin_matrix_column_major_load(p1, 6, 9, 6);
+  // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'float __attribute__((matrix_type(6, 9)))'}}
+  sx5x10_t a6 = __builtin_matrix_column_major_load(p2, 5, 10, 6);
+  // expected-error@-1 {{initializing 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') with an expression of incompatible type 'int __attribute__((matrix_type(5, 10)))'}}
+
+  sx5x10_t a7 = __builtin_matrix_column_major_load(p1, 5, 10, 3);
+  // expected-error@-1 {{stride must be greater or equal to the number of rows}}
+
+  sx5x10_t a8 = __builtin_matrix_column_major_load(p3, 5, 10, 6);
+  // expected-error@-1 {{first argument must be a pointer to a valid matrix element type}}
+
+  sx5x10_t a9 = __builtin_matrix_column_major_load(p4, 5, 10, 6);
+  // expected-error@-1 {{first argument must be a pointer to a valid matrix element type}}
+
+  sx5x10_t a10 = __builtin_matrix_column_major_load(p1, 1ull << 21, 10, 6);
+  // expected-error@-1 {{row dimension is outside the allowed range [1, 1048575}}
+  sx5x10_t a11 = __builtin_matrix_column_major_load(p1, 10, 1ull << 21, 10);
+  // expected-error@-1 {{column dimension is outside the allowed range [1, 1048575}}
+
+  sx5x10_t a12 = __builtin_matrix_column_major_load(
+      10,         // expected-error {{first argument must be a pointer to a valid matrix element type}}
+      1ull << 21, // expected-error {{row dimension is outside the allowed range [1, 1048575]}}
+      1ull << 21, // expected-error {{column dimension is outside the allowed range [1, 1048575]}}
+      "");        // expected-error {{stride argument must be a constant unsigned integer expression}}
+
+  sx5x10_t a13 = __builtin_matrix_column_major_load(
+      10, // expected-error {{first argument must be a pointer to a valid matrix element type}}
+      "", // expected-error {{row argument must be a constant unsigned integer expression}}
+      "", // expected-error {{column argument must be a constant unsigned integer expression}}
+      10);
+}
Index: clang/test/CodeGenCXX/matrix-type-builtins.cpp
===================================================================
--- clang/test/CodeGenCXX/matrix-type-builtins.cpp
+++ clang/test/CodeGenCXX/matrix-type-builtins.cpp
@@ -122,3 +122,92 @@
 // CHECK-NEXT:    %1 = load <42 x double>, <42 x double>* %0, align 8
 // CHECK-NEXT:    %2 = call <42 x double> @llvm.matrix.transpose.v42f64(<42 x double> %1, i32 6, i32 7)
 // CHECK-NEXT:    ret <42 x double> %2
+
+void column_major_load(float *Ptr1, unsigned *Ptr2) {
+  // CHECK-LABEL: define void @_Z17column_major_loadPfPj(float* %Ptr1, i32* %Ptr2) #0 {
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %Ptr1.addr = alloca float*, align 8
+  // CHECK-NEXT:    %Ptr2.addr = alloca i32*, align 8
+  // CHECK-NEXT:    %M1 = alloca [28 x float], align 4
+  // CHECK-NEXT:    %M2 = alloca [105 x float], align 4
+  // CHECK-NEXT:    %M3 = alloca [80 x i32], align 4
+  // CHECK-NEXT:    store float* %Ptr1, float** %Ptr1.addr, align 8
+  // CHECK-NEXT:    store i32* %Ptr2, i32** %Ptr2.addr, align 8
+  // CHECK-NEXT:    %0 = load float*, float** %Ptr1.addr, align 8
+  // CHECK-NEXT:    %matrix = call <28 x float> @llvm.matrix.columnwise.load.v28f32.p0f32(float* %0, i32 5, i32 4, i32 7)
+  // CHECK-NEXT:    %1 = bitcast [28 x float]* %M1 to <28 x float>*
+  // CHECK-NEXT:    store <28 x float> %matrix, <28 x float>* %1, align 4
+  matrix_t<float, 4, 7> M1 = __builtin_matrix_column_major_load(Ptr1, 4, 7, 5);
+
+  // CHECK-NEXT:    %2 = load float*, float** %Ptr1.addr, align 8
+  // CHECK-NEXT:    %matrix1 = call <105 x float> @llvm.matrix.columnwise.load.v105f32.p0f32(float* %2, i32 15, i32 15, i32 7)
+  // CHECK-NEXT:    %3 = bitcast [105 x float]* %M2 to <105 x float>*
+  // CHECK-NEXT:    store <105 x float> %matrix1, <105 x float>* %3, align 4
+  matrix_t<float, 15, 7> M2 = __builtin_matrix_column_major_load(Ptr1, 15, 7, 15);
+
+  // CHECK-NEXT:    %4 = load i32*, i32** %Ptr2.addr, align 8
+  // CHECK-NEXT:    %matrix2 = call <80 x i32> @llvm.matrix.columnwise.load.v80i32.p0i32(i32* %4, i32 32, i32 16, i32 5)
+  // CHECK-NEXT:    %5 = bitcast [80 x i32]* %M3 to <80 x i32>*
+  // CHECK-NEXT:    store <80 x i32> %matrix2, <80 x i32>* %5, align 4
+  // CHECK-NEXT:    ret void
+  auto M3 = __builtin_matrix_column_major_load(Ptr2, 16, 5, 32);
+}
+
+template <typename T, unsigned R, unsigned C>
+matrix_t<T, R, C> column_major_load(T *Ptr) {
+  return __builtin_matrix_column_major_load(Ptr, R, C, R);
+}
+
+void test_column_major_load_template(int *Ptr1, double *Ptr2) {
+  // CHECK-LABEL: define void @_Z31test_column_major_load_templatePiPd(i32* %Ptr1, double* %Ptr2) #5 {
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %Ptr1.addr = alloca i32*, align 8
+  // CHECK-NEXT:    %Ptr2.addr = alloca double*, align 8
+  // CHECK-NEXT:    %M1 = alloca [40 x i32], align 4
+  // CHECK-NEXT:    %M2 = alloca [63 x i32], align 4
+  // CHECK-NEXT:    %M3 = alloca [63 x double], align 8
+  // CHECK-NEXT:    store i32* %Ptr1, i32** %Ptr1.addr, align 8
+  // CHECK-NEXT:    store double* %Ptr2, double** %Ptr2.addr, align 8
+  // CHECK-NEXT:    %0 = load i32*, i32** %Ptr1.addr, align 8
+  // CHECK-NEXT:    %call = call <40 x i32> @_Z17column_major_loadIiLj10ELj4EEU11matrix_typeXT0_EXT1_ET_PS0_(i32* %0)
+  // CHECK-NEXT:    %1 = bitcast [40 x i32]* %M1 to <40 x i32>*
+  // CHECK-NEXT:    store <40 x i32> %call, <40 x i32>* %1, align 4
+  matrix_t<int, 10, 4> M1 = column_major_load<int, 10, 4>(Ptr1);
+
+  // CHECK-NEXT:    %2 = load i32*, i32** %Ptr1.addr, align 8
+  // CHECK-NEXT:    %call1 = call <63 x i32> @_Z17column_major_loadIiLj7ELj9EEU11matrix_typeXT0_EXT1_ET_PS0_(i32* %2)
+  // CHECK-NEXT:    %3 = bitcast [63 x i32]* %M2 to <63 x i32>*
+  // CHECK-NEXT:    store <63 x i32> %call1, <63 x i32>* %3, align 4
+  matrix_t<int, 7, 9> M2 = column_major_load<int, 7, 9>(Ptr1);
+
+  // CHECK-NEXT:    %4 = load double*, double** %Ptr2.addr, align 8
+  // CHECK-NEXT:    %call2 = call <63 x double> @_Z17column_major_loadIdLj7ELj9EEU11matrix_typeXT0_EXT1_ET_PS0_(double* %4)
+  // CHECK-NEXT:    %5 = bitcast [63 x double]* %M3 to <63 x double>*
+  // CHECK-NEXT:    store <63 x double> %call2, <63 x double>* %5, align 8
+  // CHECK-NEXT:    ret void
+  matrix_t<double, 7, 9> M3 = column_major_load<double, 7, 9>(Ptr2);
+}
+
+// CHECK-LABEL:  define linkonce_odr <40 x i32> @_Z17column_major_loadIiLj10ELj4EEU11matrix_typeXT0_EXT1_ET_PS0_(i32* %Ptr)
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    %Ptr.addr = alloca i32*, align 8
+// CHECK-NEXT:    store i32* %Ptr, i32** %Ptr.addr, align 8
+// CHECK-NEXT:    %0 = load i32*, i32** %Ptr.addr, align 8
+// CHECK-NEXT:    %matrix = call <40 x i32> @llvm.matrix.columnwise.load.v40i32.p0i32(i32* %0, i32 10, i32 10, i32 4)
+// CHECK-NEXT:    ret <40 x i32> %matrix
+
+// CHECK-LABEL: define linkonce_odr <63 x i32> @_Z17column_major_loadIiLj7ELj9EEU11matrix_typeXT0_EXT1_ET_PS0_(i32* %Ptr)
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    %Ptr.addr = alloca i32*, align 8
+// CHECK-NEXT:    store i32* %Ptr, i32** %Ptr.addr, align 8
+// CHECK-NEXT:    %0 = load i32*, i32** %Ptr.addr, align 8
+// CHECK-NEXT:    %matrix = call <63 x i32> @llvm.matrix.columnwise.load.v63i32.p0i32(i32* %0, i32 7, i32 7, i32 9)
+// CHECK-NEXT:    ret <63 x i32> %matrix
+
+// CHECK-LABEL: define linkonce_odr <63 x double> @_Z17column_major_loadIdLj7ELj9EEU11matrix_typeXT0_EXT1_ET_PS0_(double* %Ptr)
+// CHECK-NEXT:  entry:
+// CHECK-NEXT:    %Ptr.addr = alloca double*, align 8
+// CHECK-NEXT:    store double* %Ptr, double** %Ptr.addr, align 8
+// CHECK-NEXT:    %0 = load double*, double** %Ptr.addr, align 8
+// CHECK-NEXT:    %matrix = call <63 x double> @llvm.matrix.columnwise.load.v63f64.p0f64(double* %0, i32 7, i32 7, i32 9)
+// CHECK-NEXT:    ret <63 x double> %matrix
Index: clang/test/CodeGen/matrix-type-builtins.c
===================================================================
--- clang/test/CodeGen/matrix-type-builtins.c
+++ clang/test/CodeGen/matrix-type-builtins.c
@@ -72,3 +72,42 @@
 }
 
 // CHECK: declare <6 x i32> @llvm.matrix.transpose.v6i32(<6 x i32>, i32 immarg, i32 immarg)
+
+void column_major_load1(double *a, float *b, int *c) {
+  // CHECK-LABEL: define void @column_major_load1(double* %a, float* %b, i32* %c) #0 {
+  // CHECK-NEXT:  entry:
+  // CHECK-NEXT:    %a.addr = alloca double*, align 8
+  // CHECK-NEXT:    %b.addr = alloca float*, align 8
+  // CHECK-NEXT:    %c.addr = alloca i32*, align 8
+  // CHECK-NEXT:    %m_a1 = alloca [25 x double], align 8
+  // CHECK-NEXT:    %m_a2 = alloca [25 x double], align 8
+  // CHECK-NEXT:    %m_b = alloca [6 x float], align 4
+  // CHECK-NEXT:    %m_c = alloca [80 x i32], align 4
+  // CHECK-NEXT:    store double* %a, double** %a.addr, align 8
+  // CHECK-NEXT:    store float* %b, float** %b.addr, align 8
+  // CHECK-NEXT:    store i32* %c, i32** %c.addr, align 8
+  // CHECK-NEXT:    %0 = load double*, double** %a.addr, align 8
+  // CHECK-NEXT:    %matrix = call <25 x double> @llvm.matrix.columnwise.load.v25f64.p0f64(double* %0, i32 5, i32 5, i32 5)
+  // CHECK-NEXT:    %1 = bitcast [25 x double]* %m_a1 to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %matrix, <25 x double>* %1, align 8
+  dx5x5_t m_a1 = __builtin_matrix_column_major_load(a, 5, 5, 5);
+
+  // CHECK-NEXT:    %2 = load double*, double** %a.addr, align 8
+  // CHECK-NEXT:    %matrix1 = call <25 x double> @llvm.matrix.columnwise.load.v25f64.p0f64(double* %2, i32 9, i32 5, i32 5)
+  // CHECK-NEXT:    %3 = bitcast [25 x double]* %m_a2 to <25 x double>*
+  // CHECK-NEXT:    store <25 x double> %matrix1, <25 x double>* %3, align 8
+  dx5x5_t m_a2 = __builtin_matrix_column_major_load(a, 5, 5, 9);
+
+  // CHECK-NEXT:    %4 = load float*, float** %b.addr, align 8
+  // CHECK-NEXT:    %matrix2 = call <6 x float> @llvm.matrix.columnwise.load.v6f32.p0f32(float* %4, i32 3, i32 2, i32 3)
+  // CHECK-NEXT:    %5 = bitcast [6 x float]* %m_b to <6 x float>*
+  // CHECK-NEXT:    store <6 x float> %matrix2, <6 x float>* %5, align 4
+  fx2x3_t m_b = __builtin_matrix_column_major_load(b, 2, 3, 3);
+
+  // CHECK-NEXT:    %6 = load i32*, i32** %c.addr, align 8
+  // CHECK-NEXT:    %matrix3 = call <80 x i32> @llvm.matrix.columnwise.load.v80i32.p0i32(i32* %6, i32 4, i32 4, i32 20)
+  // CHECK-NEXT:    %7 = bitcast [80 x i32]* %m_c to <80 x i32>*
+  // CHECK-NEXT:    store <80 x i32> %matrix3, <80 x i32>* %7, align 4
+  // CHECK-NEXT:    ret void
+  ix4x20_t m_c = __builtin_matrix_column_major_load(c, 4, 20, 4);
+}
Index: clang/lib/Sema/SemaChecking.cpp
===================================================================
--- clang/lib/Sema/SemaChecking.cpp
+++ clang/lib/Sema/SemaChecking.cpp
@@ -1913,6 +1913,7 @@
   }
 
   case Builtin::BI__builtin_matrix_transpose:
+  case Builtin::BI__builtin_matrix_column_major_load:
     if (!getLangOpts().MatrixTypes) {
       Diag(TheCall->getBeginLoc(), diag::err_builtin_matrix_disabled);
       return ExprError();
@@ -1921,6 +1922,8 @@
     switch (BuiltinID) {
     case Builtin::BI__builtin_matrix_transpose:
       return SemaBuiltinMatrixTransposeOverload(TheCall, TheCallResult);
+    case Builtin::BI__builtin_matrix_column_major_load:
+      return SemaBuiltinMatrixColumnMajorLoadOverload(TheCall, TheCallResult);
     default:
       llvm_unreachable("All matrix builtins should be handled here!");
     }
@@ -14814,3 +14817,82 @@
   TheCall->setType(ResultType);
   return CallResult;
 }
+
+// Get and verify the matrix dimensions.
+static llvm::Optional<unsigned>
+getAndVerifyMatrixDimension(Expr *Expr, unsigned ErrIdx, Sema &S) {
+  llvm::APSInt Value(64);
+  SourceLocation ErrorPos;
+  if (!Expr->isIntegerConstantExpr(Value, S.Context, &ErrorPos)) {
+    S.Diag(Expr->getBeginLoc(), diag::err_builtin_matrix_scalar_int_arg)
+        << ErrIdx << 1;
+    return {};
+  }
+  uint64_t Dim = Value.getZExtValue();
+  if (!ConstantMatrixType::isDimensionValid(Dim)) {
+    S.Diag(Expr->getBeginLoc(), diag::err_builtin_matrix_invalid_dimension)
+        << ErrIdx << ConstantMatrixType::getMaxElementsPerDimension();
+    return {};
+  }
+  return Dim;
+}
+
+ExprResult
+Sema::SemaBuiltinMatrixColumnMajorLoadOverload(CallExpr *TheCall,
+                                               ExprResult CallResult) {
+  if (checkArgCount(*this, TheCall, 4))
+    return ExprError();
+
+  Expr *DataExpr = TheCall->getArg(0);
+  Expr *StrideExpr = TheCall->getArg(3);
+  bool ArgError = false;
+
+  // Check data pointer.
+  QualType ElementType;
+  if (!(DataExpr->getType()->isPointerType() ||
+        DataExpr->getType()->isArrayType())) {
+    Diag(DataExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg) << 0;
+    ArgError = true;
+  } else {
+    if (const PointerType *PTy = dyn_cast<PointerType>(DataExpr->getType()))
+      ElementType = PTy->getPointeeType();
+    else if (const ArrayType *ATy = dyn_cast<ArrayType>(DataExpr->getType()))
+      ElementType = ATy->getElementType();
+    else
+      llvm_unreachable("Pointer Expression must be a pointer or an array");
+    ElementType.removeLocalConst();
+    if (!ConstantMatrixType::isValidElementType(ElementType)) {
+      Diag(DataExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg) << 0;
+      ArgError = true;
+    }
+  }
+
+  // Check rows and columns arguments.
+  auto MaybeRows = getAndVerifyMatrixDimension(TheCall->getArg(1), 0, *this);
+  auto MaybeCols = getAndVerifyMatrixDimension(TheCall->getArg(2), 1, *this);
+
+  // Check stride.
+  if (!StrideExpr->getType()->isIntegralType(Context)) {
+    Diag(StrideExpr->getBeginLoc(), diag::err_builtin_matrix_scalar_int_arg)
+        << 2 << 1;
+    ArgError = true;
+  } else {
+    llvm::APSInt Value(64);
+    if (StrideExpr->isIntegerConstantExpr(Value, Context)) {
+      uint64_t Stride = Value.getZExtValue();
+      if (MaybeRows && Stride < *MaybeRows) {
+        Diag(StrideExpr->getBeginLoc(),
+             diag::err_builtin_matrix_stride_too_small);
+        ArgError = true;
+      }
+    }
+  }
+
+  if (ArgError || !MaybeRows || !MaybeCols)
+    return ExprError();
+
+  QualType ReturnType =
+      Context.getConstantMatrixType(ElementType, *MaybeRows, *MaybeCols);
+  TheCall->setType(ReturnType);
+  return CallResult;
+}
Index: clang/lib/CodeGen/CGBuiltin.cpp
===================================================================
--- clang/lib/CodeGen/CGBuiltin.cpp
+++ clang/lib/CodeGen/CGBuiltin.cpp
@@ -2376,6 +2376,39 @@
     return RValue::get(Builder.CreateZExt(V, ConvertType(E->getType())));
   }
 
+  case Builtin::BI__builtin_matrix_column_major_load: {
+    MatrixBuilder<CGBuilderTy> MB(Builder);
+    // Emit everything that isn't dependent on the first parameter type
+    Value *Stride = EmitScalarExpr(E->getArg(3));
+    const ConstantMatrixType *ResultTy = getMatrixTy(E->getType());
+
+    // If it's an address we need to emit the pointer
+    // otherwise, emit the array
+    Value *Result = nullptr;
+    if (const PointerType *PTy =
+            dyn_cast<PointerType>(E->getArg(0)->getType())) {
+      Address Src = EmitPointerWithAlignment(E->getArg(0));
+      EmitNonNullArgCheck(RValue::get(Src.getPointer()),
+                          E->getArg(0)->getType(), E->getArg(0)->getExprLoc(),
+                          FD, 0);
+      Result = MB.CreateMatrixColumnwiseLoad(
+          Src.getPointer(), ResultTy->getNumRows(), ResultTy->getNumColumns(),
+          Stride, "matrix");
+    } else if (const ArrayType *ATy =
+                   dyn_cast<ArrayType>(E->getArg(0)->getType())) {
+      Address Src = EmitArrayToPointerDecay(E->getArg(0));
+      EmitNonNullArgCheck(RValue::get(Src.getPointer()),
+                          E->getArg(0)->getType(), E->getArg(0)->getExprLoc(),
+                          FD, 0);
+      Result = MB.CreateMatrixColumnwiseLoad(
+          Src.getPointer(), ResultTy->getNumRows(), ResultTy->getNumColumns(),
+          Stride, "matrix");
+    } else {
+      llvm_unreachable(
+          "CGBuiltin.cpp: First argument must either be a pointer or an array");
+    }
+    return RValue::get(Result);
+  }
   case Builtin::BI__builtin_matrix_transpose: {
     const ConstantMatrixType *MatrixTy = getMatrixTy(E->getArg(0)->getType());
     Value *MatValue = EmitScalarExpr(E->getArg(0));
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -12086,6 +12086,8 @@
   // Matrix builtin handling.
   ExprResult SemaBuiltinMatrixTransposeOverload(CallExpr *TheCall,
                                                 ExprResult CallResult);
+  ExprResult SemaBuiltinMatrixColumnMajorLoadOverload(CallExpr *TheCall,
+                                                      ExprResult CallResult);
 
 public:
   enum FormatStringType {
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -10757,6 +10757,18 @@
 def err_builtin_matrix_arg: Error<
   "%select{first|second}0 argument must be a matrix">;
 
+def err_builtin_matrix_scalar_int_arg: Error<
+  "%select{row|column|stride}0 argument must be %select{an unsigned integer|a constant unsigned integer expression}1">;
+
+def err_builtin_matrix_pointer_arg: Error<
+  "%select{first|second}0 argument must be a pointer to a valid matrix element type">;
+
+def err_builtin_matrix_stride_too_small: Error<
+  "stride must be greater or equal to the number of rows">;
+
+def err_builtin_matrix_invalid_dimension: Error<
+  "%select{row|column}0 dimension is outside the allowed range [1, %1]">;
+
 def err_preserve_field_info_not_field : Error<
   "__builtin_preserve_field_info argument %0 not a field access">;
 def err_preserve_field_info_not_const: Error<
Index: clang/include/clang/Basic/Builtins.def
===================================================================
--- clang/include/clang/Basic/Builtins.def
+++ clang/include/clang/Basic/Builtins.def
@@ -576,6 +576,7 @@
 BUILTIN(__builtin_call_with_static_chain, "v.", "nt")
 
 BUILTIN(__builtin_matrix_transpose, "v.", "nFt")
+BUILTIN(__builtin_matrix_column_major_load, "v.", "nFt")
 
 // "Overloaded" Atomic operator builtins.  These are overloaded to support data
 // types of i8, i16, i32, i64, and i128.  The front-end sees calls to the
Index: clang/include/clang/AST/Type.h
===================================================================
--- clang/include/clang/AST/Type.h
+++ clang/include/clang/AST/Type.h
@@ -3475,6 +3475,11 @@
            NumElements <= ConstantMatrixTypeBitfields::MaxElementsPerDimension;
   }
 
+  /// Returns the maximum number of elements per dimension.
+  static unsigned getMaxElementsPerDimension() {
+    return ConstantMatrixTypeBitfields::MaxElementsPerDimension;
+  }
+
   void Profile(llvm::FoldingSetNodeID &ID) {
     Profile(ID, getElementType(), getNumRows(), getNumColumns(),
             getTypeClass());
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to