everton.constantino created this revision.
everton.constantino added reviewers: anemet, rjmccall, rsmith, Bigcheese, fhahn.
Herald added subscribers: dexonsmith, tschuett, hiraditya.
everton.constantino requested review of this revision.
Herald added subscribers: llvm-commits, cfe-commits, jdoerfert.
Herald added projects: clang, LLVM.

This patch creates a new builtin to support matrix multiply add. Currently when 
you do C = A*B + C you have the overhead of additional fadds. With this
builtin the accumulatores are loaded with the C matrix during the 
multiplication considerably reducing the ammount of operations.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D99433

Files:
  clang/docs/MatrixTypes.rst
  clang/include/clang/Basic/Builtins.def
  clang/include/clang/Basic/DiagnosticSemaKinds.td
  clang/include/clang/Sema/Sema.h
  clang/lib/CodeGen/CGBuiltin.cpp
  clang/lib/Sema/SemaChecking.cpp
  clang/test/CodeGen/matrix-type-builtins.c
  clang/test/Sema/matrix-type-builtins.c
  llvm/include/llvm/IR/Intrinsics.td
  llvm/include/llvm/IR/MatrixBuilder.h
  llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp

Index: llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
===================================================================
--- llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
+++ llvm/lib/Transforms/Scalar/LowerMatrixIntrinsics.cpp
@@ -511,6 +511,7 @@
     if (II)
       switch (II->getIntrinsicID()) {
       case Intrinsic::matrix_multiply:
+      case Intrinsic::matrix_multiply_add:
       case Intrinsic::matrix_transpose:
       case Intrinsic::matrix_column_major_load:
       case Intrinsic::matrix_column_major_store:
@@ -540,6 +541,7 @@
 
       Value *MatrixA;
       Value *MatrixB;
+      Value *MatrixC;
       Value *M;
       Value *N;
       Value *K;
@@ -547,6 +549,11 @@
                           m_Value(MatrixA), m_Value(MatrixB), m_Value(M),
                           m_Value(N), m_Value(K)))) {
         Propagate = setShapeInfo(Inst, {M, K});
+      } else if (match(Inst,
+                       m_Intrinsic<Intrinsic::matrix_multiply_add>(
+                           m_Value(MatrixA), m_Value(MatrixB), m_Value(MatrixC),
+                           m_Value(M), m_Value(N), m_Value(K)))) {
+        Propagate = setShapeInfo(Inst, {M, K});
       } else if (match(Inst, m_Intrinsic<Intrinsic::matrix_transpose>(
                                  m_Value(MatrixA), m_Value(M), m_Value(N)))) {
         // Flip dimensions.
@@ -611,6 +618,7 @@
 
       Value *MatrixA;
       Value *MatrixB;
+      Value *MatrixC;
       Value *M;
       Value *N;
       Value *K;
@@ -622,7 +630,18 @@
 
         if (setShapeInfo(MatrixB, {N, K}))
           pushInstruction(MatrixB, WorkList);
+      } else if (match(V,
+                       m_Intrinsic<Intrinsic::matrix_multiply_add>(
+                           m_Value(MatrixA), m_Value(MatrixB), m_Value(MatrixC),
+                           m_Value(M), m_Value(N), m_Value(K)))) {
+        if (setShapeInfo(MatrixA, {M, N}))
+          pushInstruction(MatrixA, WorkList);
 
+        if (setShapeInfo(MatrixB, {N, K}))
+          pushInstruction(MatrixB, WorkList);
+
+        if (setShapeInfo(MatrixC, {M, K}))
+          pushInstruction(MatrixC, WorkList);
       } else if (match(V, m_Intrinsic<Intrinsic::matrix_transpose>(
                               m_Value(MatrixA), m_Value(M), m_Value(N)))) {
         // Flip dimensions.
@@ -673,6 +692,7 @@
 
           switch (II->getIntrinsicID()) {
           case Intrinsic::matrix_multiply:
+          case Intrinsic::matrix_multiply_add:
           case Intrinsic::matrix_transpose:
           case Intrinsic::matrix_column_major_load:
           case Intrinsic::matrix_column_major_store:
@@ -769,6 +789,9 @@
     case Intrinsic::matrix_column_major_store:
       LowerColumnMajorStore(Inst);
       break;
+    case Intrinsic::matrix_multiply_add:
+      LowerMultiplyAdd(Inst);
+      break;
     default:
       return false;
     }
@@ -1009,11 +1032,13 @@
     }
   }
 
-  /// Compute \p Result += \p A * \p B for input matrices with left-associating
-  /// addition.
+  /// Compute \p Result += \p A * \p B + \p ACC for input matrices with
+  /// left-associating addition.
+  template <bool isAccumulating = false>
   void emitMatrixMultiply(MatrixTy &Result, const MatrixTy &A,
                           const MatrixTy &B, bool AllowContraction,
-                          IRBuilder<> &Builder, bool isTiled) {
+                          IRBuilder<> &Builder, bool isTiled,
+                          const MatrixTy *ACC = nullptr) {
     const unsigned VF = std::max<unsigned>(
         TTI.getRegisterBitWidth(TargetTransformInfo::RGK_FixedWidthVector)
                 .getFixedSize() /
@@ -1030,20 +1055,25 @@
     unsigned NumComputeOps = 0;
     if (A.isColumnMajor()) {
       // Multiply columns from the first operand with scalars from the second
-      // operand. Then move along the K axes and accumulate the columns.  With
+      // operand. Then move along the K axes and accumulate the columns. With
       // this the adds can be vectorized without reassociation.
       for (unsigned J = 0; J < C; ++J) {
         unsigned BlockSize = VF;
         // If Result is zero, we don't need to accumulate in the K==0 iteration.
-        bool isSumZero = isa<ConstantAggregateZero>(Result.getColumn(J));
+        bool isSumZero = isAccumulating
+                             ? false
+                             : isa<ConstantAggregateZero>(Result.getColumn(J));
 
         for (unsigned I = 0; I < R; I += BlockSize) {
           // Gradually lower the vectorization factor to cover the remainder.
           while (I + BlockSize > R)
             BlockSize /= 2;
 
-          Value *Sum = isTiled ? Result.extractVector(I, J, BlockSize, Builder)
-                               : nullptr;
+          Value *Sum =
+              isAccumulating ? ACC->extractVector(I, J, BlockSize, Builder)
+              : isTiled      ? Result.extractVector(I, J, BlockSize, Builder)
+                             : nullptr;
+          ;
           for (unsigned K = 0; K < M; ++K) {
             Value *L = A.extractVector(I, K, BlockSize, Builder);
             Value *RH = Builder.CreateExtractElement(B.getColumn(J), K);
@@ -1062,13 +1092,17 @@
       // the adds can be vectorized without reassociation.
       for (unsigned I = 0; I < R; ++I) {
         unsigned BlockSize = VF;
-        bool isSumZero = isa<ConstantAggregateZero>(Result.getRow(I));
+        bool isSumZero = isAccumulating
+                             ? false
+                             : isa<ConstantAggregateZero>(Result.getRow(I));
         for (unsigned J = 0; J < C; J += BlockSize) {
           // Gradually lower the vectorization factor to cover the remainder.
           while (J + BlockSize > C)
             BlockSize /= 2;
 
-          Value *Sum = nullptr;
+          Value *Sum = isAccumulating
+                           ? ACC->extractVector(I, J, BlockSize, Builder)
+                           : nullptr;
           for (unsigned K = 0; K < M; ++K) {
             Value *R = B.extractVector(K, J, BlockSize, Builder);
             Value *LH = Builder.CreateExtractElement(A.getVector(I), K);
@@ -1367,6 +1401,40 @@
     }
   }
 
+  /// Lowers llvm.matrix.multiply.add
+  void LowerMultiplyAdd(CallInst *MatMulAdd) {
+    IRBuilder<> Builder(MatMulAdd);
+    auto *EltType = cast<VectorType>(MatMulAdd->getType())->getElementType();
+    ShapeInfo LShape(MatMulAdd->getArgOperand(3), MatMulAdd->getArgOperand(4));
+    ShapeInfo RShape(MatMulAdd->getArgOperand(4), MatMulAdd->getArgOperand(5));
+    ShapeInfo AShape(MatMulAdd->getArgOperand(3), MatMulAdd->getArgOperand(5));
+
+    const MatrixTy &Lhs =
+        getMatrix(MatMulAdd->getArgOperand(0), LShape, Builder);
+    const MatrixTy &Rhs =
+        getMatrix(MatMulAdd->getArgOperand(1), RShape, Builder);
+    const MatrixTy &Acc =
+        getMatrix(MatMulAdd->getArgOperand(2), AShape, Builder);
+    assert(Lhs.getElementType() == Rhs.getElementType() &&
+           "Matrix multiply argument element types do not match.");
+
+    const unsigned R = LShape.NumRows;
+    const unsigned C = RShape.NumColumns;
+    assert(LShape.NumColumns == RShape.NumRows);
+
+    // Initialize the output
+    MatrixTy Result(R, C, EltType);
+    assert(Lhs.getElementType() == Result.getElementType() &&
+           "Matrix multiply result element type does not match arguments.");
+
+    bool AllowContract =
+        AllowContractEnabled ||
+        (isa<FPMathOperator>(MatMulAdd) && MatMulAdd->hasAllowContract());
+    emitMatrixMultiply<true>(Result, Lhs, Rhs, AllowContract, Builder, false,
+                             &Acc);
+    finalizeLowering(MatMulAdd, Result, Builder);
+  }
+
   /// Lowers llvm.matrix.multiply.
   void LowerMultiply(CallInst *MatMul) {
     IRBuilder<> Builder(MatMul);
@@ -1648,6 +1716,14 @@
           prettyPrintMatrixType(II->getOperand(1), SS);
           SS << "." << *II->getType()->getScalarType();
           break;
+        case Intrinsic::matrix_multiply_add:
+          prettyPrintMatrixType(II->getOperand(0), SS);
+          SS << ".";
+          prettyPrintMatrixType(II->getOperand(1), SS);
+          SS << "." << *II->getType()->getScalarType();
+          prettyPrintMatrixType(II->getOperand(2), SS);
+          SS << "." << *II->getType()->getScalarType();
+          break;
         case Intrinsic::matrix_transpose:
           prettyPrintMatrixType(II->getOperand(0), SS);
           SS << "." << *II->getType()->getScalarType();
@@ -1672,6 +1748,7 @@
       if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) {
         switch (II->getIntrinsicID()) {
         case Intrinsic::matrix_multiply:
+        case Intrinsic::matrix_multiply_add:
           return 3;
         case Intrinsic::matrix_transpose:
           return 2;
Index: llvm/include/llvm/IR/MatrixBuilder.h
===================================================================
--- llvm/include/llvm/IR/MatrixBuilder.h
+++ llvm/include/llvm/IR/MatrixBuilder.h
@@ -125,6 +125,31 @@
     return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
   }
 
+  /// Create a llvm.matrix.multiply.add call, multiplying matrixes \p LHS and \p
+  /// RHS and adding the result to \p ACC.
+  CallInst *CreateMatrixMultiplyAdd(Value *LHS, Value *RHS, Value *ACC,
+                                    unsigned LHSRows, unsigned LHSColumns,
+                                    unsigned RHSColumns,
+                                    const Twine &Name = "") {
+    auto *LHSType = cast<VectorType>(LHS->getType());
+    auto *RHSType = cast<VectorType>(RHS->getType());
+    auto *AccType = cast<VectorType>(ACC->getType());
+
+    auto *ReturnType =
+        FixedVectorType::get(LHSType->getElementType(), LHSRows * RHSColumns);
+    Value *Ops[] = {LHS,
+                    RHS,
+                    ACC,
+                    B.getInt32(LHSRows),
+                    B.getInt32(LHSColumns),
+                    B.getInt32(RHSColumns)};
+    Type *OverloadedTypes[] = {ReturnType, LHSType, RHSType, AccType};
+
+    Function *TheFn = Intrinsic::getDeclaration(
+        getModule(), Intrinsic::matrix_multiply_add, OverloadedTypes);
+    return B.CreateCall(TheFn->getFunctionType(), TheFn, Ops, Name);
+  }
+
   /// Create a llvm.matrix.multiply call, multiplying matrixes \p LHS and \p
   /// RHS.
   CallInst *CreateMatrixMultiply(Value *LHS, Value *RHS, unsigned LHSRows,
Index: llvm/include/llvm/IR/Intrinsics.td
===================================================================
--- llvm/include/llvm/IR/Intrinsics.td
+++ llvm/include/llvm/IR/Intrinsics.td
@@ -1571,6 +1571,13 @@
               [IntrNoSync, IntrWillReturn, IntrNoMem, IntrSpeculatable, ImmArg<ArgIndex<2>>,
                ImmArg<ArgIndex<3>>, ImmArg<ArgIndex<4>>]>;
 
+def int_matrix_multiply_add
+  : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
+              [llvm_anyvector_ty, llvm_anyvector_ty, llvm_anyvector_ty, llvm_i32_ty, llvm_i32_ty,
+               llvm_i32_ty],
+              [IntrNoSync, IntrWillReturn, IntrNoMem, IntrSpeculatable, ImmArg<ArgIndex<3>>,
+               ImmArg<ArgIndex<4>>, ImmArg<ArgIndex<5>>]>;
+
 def int_matrix_column_major_load
   : DefaultAttrsIntrinsic<[llvm_anyvector_ty],
               [LLVMPointerToElt<0>, llvm_i64_ty, llvm_i1_ty,
Index: clang/test/Sema/matrix-type-builtins.c
===================================================================
--- clang/test/Sema/matrix-type-builtins.c
+++ clang/test/Sema/matrix-type-builtins.c
@@ -96,3 +96,14 @@
   __builtin_matrix_column_major_store(*m1, p4, 20);
   // expected-error@-1 {{cannot store matrix to read-only pointer}}
 }
+
+void multiply_add(sx5x10_t a, sx5x10_t b, sx5x10_t c, dx3x3 d, dx3x3 e, ix3x3 f) {
+  c = __builtin_matrix_multiply_add(a, b, c);
+  // expected-error@-1 {{The number of columns of the 1st argument must be the same as the number of rows of the 2nd argument and the number of rows of the 1st argument and columns of the 2nd argument must match 3rd argument}}
+
+  f = __builtin_matrix_multiply_add(d, e, f);
+  // expected-error@-1 {{All arguments elements type must match}}
+
+  f = __builtin_matrix_multiply_add(d, e, e);
+  // expected-error@-1 {{assigning to 'ix3x3' (aka 'unsigned int __attribute__((matrix_type(3, 3)))') from incompatible type 'double __attribute__((matrix_type(3, 3)))'}}
+}
Index: clang/test/CodeGen/matrix-type-builtins.c
===================================================================
--- clang/test/CodeGen/matrix-type-builtins.c
+++ clang/test/CodeGen/matrix-type-builtins.c
@@ -9,11 +9,36 @@
 typedef double dx5x5_t __attribute__((matrix_type(5, 5)));
 typedef float fx2x3_t __attribute__((matrix_type(2, 3)));
 typedef float fx3x2_t __attribute__((matrix_type(3, 2)));
+typedef float fx2x2_t __attribute__((matrix_type(5, 5)));
 typedef int ix20x4_t __attribute__((matrix_type(20, 4)));
 typedef int ix4x20_t __attribute__((matrix_type(4, 20)));
 typedef unsigned ux1x6_t __attribute__((matrix_type(1, 6)));
 typedef unsigned ux6x1_t __attribute__((matrix_type(6, 1)));
 
+void multiply_add_2x2(const fx2x2_t *a, const fx2x2_t *b, fx2x2_t *c) {
+  // CHECK-LABEL: define{{.*.*.*}} void @multiply_add_2x2(
+  // CHECK:       [[A_ADDR:%.*]] = alloca [25 x float]*, align 8
+  // CHECK-NEXT:  [[B_ADDR:%.*]] = alloca [25 x float]*, align 8
+  // CHECK-NEXT:  [[C_ADDR:%.*]] = alloca [25 x float]*, align 8
+  // CHECK-NEXT:  store [25 x float]* %a, [25 x float]** [[A_ADDR]], align 8
+  // CHECK-NEXT:  store [25 x float]* %b, [25 x float]** [[B_ADDR]], align 8
+  // CHECK-NEXT:  store [25 x float]* %c, [25 x float]** [[C_ADDR]], align 8
+  // CHECK-NEXT:  [[A_L:%.*]] = load [25 x float]*, [25 x float]** [[A_ADDR]], align 8
+  // CHECK-NEXT:  [[A_B:%.*]] = bitcast [25 x float]* [[A_L]] to <25 x float>*
+  // CHECK-NEXT:  [[A:%.*]] = load <25 x float>, <25 x float>* [[A_B]], align 4
+  // CHECK-NEXT:  [[B_L:%.*]] = load [25 x float]*, [25 x float]** [[B_ADDR]], align 8
+  // CHECK-NEXT:  [[B_B:%.*]] = bitcast [25 x float]* [[B_L]] to <25 x float>*
+  // CHECK-NEXT:  [[B:%.*]] = load <25 x float>, <25 x float>* [[B_B]], align 4
+  // CHECK-NEXT:  [[C_L:%.*]] = load [25 x float]*, [25 x float]** [[C_ADDR]], align 8
+  // CHECK-NEXT:  [[C_B:%.*]] = bitcast [25 x float]* [[C_L]] to <25 x float>*
+  // CHECK-NEXT:  [[C:%.*]] = load <25 x float>, <25 x float>* [[C_B]], align 4
+  // CHECK-NEXT:  [[MADD:%.*]] = call <25 x float> @llvm.matrix.multiply.add.v25f32.v25f32.v25f32.v25f32(<25 x float> [[A]], <25 x float> [[B]], <25 x float> [[C]], i32 5, i32 5, i32 5)
+  // CHECK-NEXT:  [[CR_L:%.*]] = load [25 x float]*, [25 x float]** [[C_ADDR]], align 8
+  // CHECK-NEXT:  [[CR_B:%.*]] = bitcast [25 x float]* [[CR_L]] to <25 x float>*
+  // CHECK-NEXT: store <25 x float> [[MADD]], <25 x float>* [[CR_B]], align 4
+  *c = __builtin_matrix_multiply_add(*a, *b, *c);
+}
+
 void transpose_double_5x5(dx5x5_t *a) {
   // CHECK-LABEL: define{{.*}} void @transpose_double_5x5(
   // CHECK:        [[A:%.*]] = load <25 x double>, <25 x double>* {{.*}}, align 8
Index: clang/lib/Sema/SemaChecking.cpp
===================================================================
--- clang/lib/Sema/SemaChecking.cpp
+++ clang/lib/Sema/SemaChecking.cpp
@@ -1967,6 +1967,9 @@
   case Builtin::BI__builtin_matrix_column_major_store:
     return SemaBuiltinMatrixColumnMajorStore(TheCall, TheCallResult);
 
+  case Builtin::BI__builtin_matrix_multiply_add:
+    return SemaBuiltinMatrixMultiplyAdd(TheCall, TheCallResult);
+
   case Builtin::BI__builtin_get_device_side_mangled_name: {
     auto Check = [](CallExpr *TheCall) {
       if (TheCall->getNumArgs() != 1)
@@ -16152,6 +16155,78 @@
   return CallResult;
 }
 
+ExprResult Sema::SemaBuiltinMatrixMultiplyAdd(CallExpr *TheCall,
+                                              ExprResult CallResult) {
+  if (!getLangOpts().MatrixTypes) {
+    Diag(TheCall->getBeginLoc(), diag::err_builtin_matrix_disabled);
+    return ExprError();
+  }
+
+  if (checkArgCount(*this, TheCall, 3))
+    return ExprError();
+
+  ExprResult MatrixAArg = DefaultLvalueConversion(TheCall->getArg(0));
+  if (MatrixAArg.isInvalid())
+    return MatrixAArg;
+  Expr *MatrixA = MatrixAArg.get();
+
+  auto *MTypeA = MatrixA->getType()->getAs<ConstantMatrixType>();
+  if (!MTypeA) {
+    Diag(MatrixA->getBeginLoc(), diag::err_builtin_matrix_arg);
+    return ExprError();
+  }
+
+  ExprResult MatrixBArg = DefaultLvalueConversion(TheCall->getArg(1));
+  if (MatrixBArg.isInvalid())
+    return MatrixBArg;
+  Expr *MatrixB = MatrixBArg.get();
+
+  auto *MTypeB = MatrixB->getType()->getAs<ConstantMatrixType>();
+  if (!MTypeB) {
+    Diag(MatrixB->getBeginLoc(), diag::err_builtin_matrix_arg);
+    return ExprError();
+  }
+
+  ExprResult MatrixCArg = DefaultLvalueConversion(TheCall->getArg(2));
+  if (MatrixCArg.isInvalid())
+    return MatrixCArg;
+  Expr *MatrixC = MatrixCArg.get();
+
+  auto *MTypeC = MatrixC->getType()->getAs<ConstantMatrixType>();
+  if (!MTypeC) {
+    Diag(MatrixC->getBeginLoc(), diag::err_builtin_matrix_arg);
+    return ExprError();
+  }
+
+  // Check wether all matrices have the same element type. We don't support
+  // mixed precision as of yet.
+  if (!(Context.hasSameType(MTypeC->getElementType(),
+                            MTypeA->getElementType()) &&
+        Context.hasSameType(MTypeC->getElementType(),
+                            MTypeB->getElementType()))) {
+    Diag(MatrixC->getBeginLoc(), diag::err_builtin_matrix_scalar_type);
+    return ExprError();
+  }
+
+  // Check if dimensions are appropriate.
+  if (MTypeA->getNumColumns() != MTypeB->getNumRows() ||
+      !(MTypeC->getNumColumns() == MTypeB->getNumColumns() &&
+        MTypeC->getNumRows() == MTypeA->getNumRows())) {
+    Diag(MatrixC->getBeginLoc(), diag::err_builtin_matrix_dimension_mismatch);
+    return ExprError();
+  }
+
+  // Prepare Result matrix.
+  QualType ResultType = Context.getConstantMatrixType(
+      MTypeC->getElementType(), MTypeC->getNumRows(), MTypeC->getNumColumns());
+
+  TheCall->setType(ResultType);
+  TheCall->setArg(0, MatrixA);
+  TheCall->setArg(1, MatrixB);
+  TheCall->setArg(2, MatrixC);
+  return CallResult;
+}
+
 ExprResult Sema::SemaBuiltinMatrixColumnMajorStore(CallExpr *TheCall,
                                                    ExprResult CallResult) {
   if (checkArgCount(*this, TheCall, 3))
Index: clang/lib/CodeGen/CGBuiltin.cpp
===================================================================
--- clang/lib/CodeGen/CGBuiltin.cpp
+++ clang/lib/CodeGen/CGBuiltin.cpp
@@ -3067,6 +3067,23 @@
     return RValue::get(Result);
   }
 
+  case Builtin::BI__builtin_matrix_multiply_add: {
+    MatrixBuilder<CGBuilderTy> MB(Builder);
+    Value *MatrixA = EmitScalarExpr(E->getArg(0));
+    Value *MatrixB = EmitScalarExpr(E->getArg(1));
+    Value *MatrixC = EmitScalarExpr(E->getArg(2));
+
+    const auto *MatrixTy1 =
+        E->getArg(0)->getType()->getAs<ConstantMatrixType>();
+    const auto *MatrixTy2 =
+        E->getArg(1)->getType()->getAs<ConstantMatrixType>();
+
+    Value *Result = MB.CreateMatrixMultiplyAdd(
+        MatrixA, MatrixB, MatrixC, MatrixTy1->getNumRows(),
+        MatrixTy1->getNumColumns(), MatrixTy2->getNumColumns());
+    return RValue::get(Result);
+  }
+
   case Builtin::BIfinite:
   case Builtin::BI__finite:
   case Builtin::BIfinitef:
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -12514,6 +12514,8 @@
                                               ExprResult CallResult);
   ExprResult SemaBuiltinMatrixColumnMajorStore(CallExpr *TheCall,
                                                ExprResult CallResult);
+  ExprResult SemaBuiltinMatrixMultiplyAdd(CallExpr *TheCall,
+                                          ExprResult CallResult);
 
 public:
   enum FormatStringType {
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -11107,6 +11107,10 @@
 def err_matrix_subscript_comma: Error<
   "comma expressions are not allowed as indices in matrix subscript expressions">;
 def err_builtin_matrix_arg: Error<"1st argument must be a matrix">;
+def err_builtin_matrix_dimension_mismatch: Error<
+  "The number of columns of the 1st argument must be the same as the number of rows of the 2nd argument and the number of rows of the 1st argument and columns of the 2nd argument must match 3rd argument">;
+def err_builtin_matrix_scalar_type: Error<
+  "All arguments elements type must match">;
 def err_builtin_matrix_scalar_unsigned_arg: Error<
   "%0 argument must be a constant unsigned integer expression">;
 def err_builtin_matrix_pointer_arg: Error<
Index: clang/include/clang/Basic/Builtins.def
===================================================================
--- clang/include/clang/Basic/Builtins.def
+++ clang/include/clang/Basic/Builtins.def
@@ -642,6 +642,7 @@
 BUILTIN(__builtin_matrix_transpose, "v.", "nFt")
 BUILTIN(__builtin_matrix_column_major_load, "v.", "nFt")
 BUILTIN(__builtin_matrix_column_major_store, "v.", "nFt")
+BUILTIN(__builtin_matrix_multiply_add, "v.", "nFt")
 
 // "Overloaded" Atomic operator builtins.  These are overloaded to support data
 // types of i8, i16, i32, i64, and i128.  The front-end sees calls to the
Index: clang/docs/MatrixTypes.rst
===================================================================
--- clang/docs/MatrixTypes.rst
+++ clang/docs/MatrixTypes.rst
@@ -204,6 +204,23 @@
 * *T* - Element type
 * *row*, *col* - Row and column arguments respectively.
 
+``M3 __builtin_matrix_multiply_add(M1 matrixA, M2 matrixB, M3 matrixC)``
+
+**Returns**: A matrix ``Res`` equivalent to the code below, where ``row`` refers to the
+number of rows of ``M1``, ``depth`` to the number of either columns of ``M1`` or rows of ``M2`` and
+``col`` to the number of columns of ``M2``.
+
+**Effects**: Equivalent to:
+
+.. code-block:: c++
+
+  M Res;
+  for (int C = 0; C < col; ++C)
+    for (int R = 0; R < row; ++R)
+      Acc = matrixC[R][C];
+      for (int K = 0; K < depth; ++K)
+         Acc += matrix[R][C];
+      Res[R][C] = Acc
 
 ``M2 __builtin_matrix_transpose(M1 matrix)``
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to