fhahn created this revision. fhahn added reviewers: rjmccall, anemet, Bigcheese, rsmith, martong. Herald added subscribers: tschuett, dexonsmith, rnkovacs. Herald added a project: clang.
This patch implements the * binary operator for values of MatrixType. It adds support for matrix * matrix, scalar * matrix and matrix * scalar. For the matrix, matrix case, the number of columns of the first operand must match the number of rows of the second. For the scalar,matrix variants, the element type of the matrix must match the scalar type. Repository: rG LLVM Github Monorepo https://reviews.llvm.org/D76794 Files: clang/include/clang/Sema/Sema.h clang/lib/CodeGen/CGExprScalar.cpp clang/lib/Sema/SemaExpr.cpp clang/test/CodeGen/matrix-type-operators.c clang/test/CodeGenCXX/matrix-type-operators.cpp clang/test/Sema/matrix-type-operators.c clang/test/SemaCXX/matrix-type-operators.cpp llvm/include/llvm/IR/MatrixBuilder.h
Index: llvm/include/llvm/IR/MatrixBuilder.h =================================================================== --- llvm/include/llvm/IR/MatrixBuilder.h +++ llvm/include/llvm/IR/MatrixBuilder.h @@ -144,15 +144,24 @@ : B.CreateSub(LHS, RHS); } - /// Multiply matrix \p LHS with scalar \p RHS. + /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p + /// RHS. Value *CreateScalarMultiply(Value *LHS, Value *RHS) { - Value *ScalarVector = - B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getNumElements(), - RHS, "scalar.splat"); - if (RHS->getType()->isFloatingPointTy()) - return B.CreateFMul(LHS, ScalarVector); - - return B.CreateMul(LHS, ScalarVector); + assert(LHS->getType()->isVectorTy() || + RHS->getType()->isVectorTy() && + "One of the operands must be a matrix (embedded in a vector)"); + Value *ScalarVector = B.CreateVectorSplat( + cast<VectorType>(LHS->getType())->getNumElements(), + LHS->getType()->isVectorTy() ? RHS : LHS, "scalar.splat"); + if (RHS->getType()->isFloatingPointTy()) { + if (LHS->getType()->isVectorTy()) + return B.CreateFMul(LHS, ScalarVector); + return B.CreateFMul(ScalarVector, RHS); + } + + if (LHS->getType()->isVectorTy()) + return B.CreateMul(LHS, ScalarVector); + return B.CreateMul(ScalarVector, RHS); } /// Extracts the element at (\p Row, \p Column) from \p Matrix. Index: clang/test/SemaCXX/matrix-type-operators.cpp =================================================================== --- clang/test/SemaCXX/matrix-type-operators.cpp +++ clang/test/SemaCXX/matrix-type-operators.cpp @@ -122,3 +122,26 @@ Mat1.value = subtract<unsigned, 3, 3, float, 2, 2, unsigned, 2, 2>(Mat2, Mat3); // expected-note@-1 {{in instantiation of function template specialization 'subtract<unsigned int, 3, 3, float, 2, 2, unsigned int, 2, 2>' requested here}} } + +template <typename EltTy0, unsigned R0, unsigned C0, typename EltTy1, unsigned R1, unsigned C1, typename EltTy2, unsigned R2, unsigned C2> +typename MyMatrix<EltTy2, R2, C2>::matrix_t multiply(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy1, R1, C1> &B) { + char *v1 = A.value * B.value; + // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'unsigned int __attribute__((matrix_type(2, 2))) '}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}} + + return A.value * B.value; + // expected-error@-1 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}} +} + +void test_multiply_template(unsigned *Ptr1, float *Ptr2) { + MyMatrix<unsigned, 2, 2> Mat1; + MyMatrix<unsigned, 3, 3> Mat2; + MyMatrix<float, 2, 2> Mat3; + Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1); + unsigned v1 = multiply<unsigned, 2, 2, unsigned, 2, 2, unsigned, 2, 2>(Mat1, Mat1); + // expected-note@-1 {{in instantiation of function template specialization 'multiply<unsigned int, 2, 2, unsigned int, 2, 2, unsigned int, 2, 2>' requested here}} + // expected-error@-2 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix<unsigned int, 2U, 2U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}} + + Mat1.value = multiply<unsigned, 2, 2, unsigned, 3, 3, unsigned, 2, 2>(Mat1, Mat2); + // expected-note@-1 {{in instantiation of function template specialization 'multiply<unsigned int, 2, 2, unsigned int, 3, 3, unsigned int, 2, 2>' requested here}} +} Index: clang/test/Sema/matrix-type-operators.c =================================================================== --- clang/test/Sema/matrix-type-operators.c +++ clang/test/Sema/matrix-type-operators.c @@ -96,3 +96,22 @@ a = b - &c; // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*'))}} } + +void mat_mat_multiply(sx10x10_t a, sx5x10_t b, sx10x5_t c) { + // Invalid dimensions for operands. + a = c * c; + // expected-error@-1 {{invalid operands to binary expression ('sx10x5_t' (aka 'float __attribute__((matrix_type(10, 5)))') and 'sx10x5_t')}} + + // Shape of multiplication result does not match the type of b. + b = a * a; + // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'float __attribute__((matrix_type(10, 10)))'}} + + b = a * &c; + // expected-error@-1 {{invalid operands to binary expression ('sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') and 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*'))}} +} + +void mat_scalar_multiply(sx10x10_t a, sx5x10_t b, float scalar) { + // Shape of multiplication result does not match the type of b. + b = a * scalar; + // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}} +} Index: clang/test/CodeGenCXX/matrix-type-operators.cpp =================================================================== --- clang/test/CodeGenCXX/matrix-type-operators.cpp +++ clang/test/CodeGenCXX/matrix-type-operators.cpp @@ -285,3 +285,53 @@ MyMatrix<float, 2, 5> Mat2; Mat1.value = subtract(Mat1, Mat2); } + +void multiply1_matrix(dx5x5_t *a, dx5x5_t *b, dx5x5_t *c) { + *a = *b * *c; + + // CHECK-LABEL: @_Z16multiply1_matrixPDm5_5_dS0_S0_([25 x double]* %a, [25 x double]* %b, [25 x double]* %c) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %c.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: store [25 x double]* %c, [25 x double]** %c.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load [25 x double]*, [25 x double]** %c.addr, align 8 + // CHECK-NEXT: %4 = bitcast [25 x double]* %3 to <25 x double>* + // CHECK-NEXT: %5 = load <25 x double>, <25 x double>* %4, align 8 + // CHECK-NEXT: %6 = call <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double> %2, <25 x double> %5, i32 5, i32 5, i32 5) + // CHECK-NEXT: %7 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %8 = bitcast [25 x double]* %7 to <25 x double>* + // CHECK-NEXT: store <25 x double> %6, <25 x double>* %8, align 8 + // CHECK-NEXT: ret void +} + +// CHECK: declare <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double>, <25 x double>, i32 immarg, i32 immarg, i32 immarg) + +void multiply1_scalar(dx5x5_t *a, dx5x5_t *b, double c) { + *a = *b * c; + + // CHECK-LABEL:@_Z16multiply1_scalarPDm5_5_dS0_d([25 x double]* %a, [25 x double]* %b, double %c + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %c.addr = alloca double, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: store double %c, double* %c.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load double, double* %c.addr, align 8 + // CHECK-NEXT: %scalar.splat.splatinsert = insertelement <25 x double> undef, double %3, i32 0 + // CHECK-NEXT: %scalar.splat.splat = shufflevector <25 x double> %scalar.splat.splatinsert, <25 x double> undef, <25 x i32> zeroinitializer + // CHECK-NEXT: %4 = fmul <25 x double> %2, %scalar.splat.splat + // CHECK-NEXT: %5 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %6 = bitcast [25 x double]* %5 to <25 x double>* + // CHECK-NEXT: store <25 x double> %4, <25 x double>* %6, align 8 + // CHECK-NEXT: ret void +} Index: clang/test/CodeGen/matrix-type-operators.c =================================================================== --- clang/test/CodeGen/matrix-type-operators.c +++ clang/test/CodeGen/matrix-type-operators.c @@ -225,3 +225,55 @@ // CHECK-NEXT: store <27 x i32> %11, <27 x i32>* %3, align 4 // CHECK-NEXT: ret void } + +void multiply_matrix_matrix(dx5x5_t *a, dx5x5_t *b, dx5x5_t *c) { + *a = *b * *c; + + // CHECK-LABEL: @multiply_matrix_matrix( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %c.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: store [25 x double]* %c, [25 x double]** %c.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load [25 x double]*, [25 x double]** %c.addr, align 8 + // CHECK-NEXT: %4 = bitcast [25 x double]* %3 to <25 x double>* + // CHECK-NEXT: %5 = load <25 x double>, <25 x double>* %4, align 8 + // CHECK-NEXT: %6 = call <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double> %2, <25 x double> %5, i32 5, i32 5, i32 5) + // CHECK-NEXT: %7 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %8 = bitcast [25 x double]* %7 to <25 x double>* + // CHECK-NEXT: store <25 x double> %6, <25 x double>* %8, align 8 + // CHECK-NEXT: ret void +} + +// CHECK: declare <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double>, <25 x double>, i32 immarg, i32 immarg, i32 immarg) [[READNONE:#[0-9]]] +// +void multiply_matrix_scalar(dx5x5_t *a, dx5x5_t *b, double c) { + *a = *b * c; + + // CHECK-LABEL: @multiply_matrix_scalar( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %c.addr = alloca double, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: store double %c, double* %c.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load double, double* %c.addr, align 8 + // CHECK-NEXT: %scalar.splat.splatinsert = insertelement <25 x double> undef, double %3, i32 0 + // CHECK-NEXT: %scalar.splat.splat = shufflevector <25 x double> %scalar.splat.splatinsert, <25 x double> undef, <25 x i32> zeroinitializer + // CHECK-NEXT: %4 = fmul <25 x double> %2, %scalar.splat.splat + // CHECK-NEXT: %5 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %6 = bitcast [25 x double]* %5 to <25 x double>* + // CHECK-NEXT: store <25 x double> %4, <25 x double>* %6, align 8 + // CHECK-NEXT: ret void +} + +// CHECK: attributes [[READNONE]] = { nounwind readnone speculatable willreturn } Index: clang/lib/Sema/SemaExpr.cpp =================================================================== --- clang/lib/Sema/SemaExpr.cpp +++ clang/lib/Sema/SemaExpr.cpp @@ -9544,6 +9544,9 @@ return CheckVectorOperands(LHS, RHS, Loc, IsCompAssign, /*AllowBothBool*/getLangOpts().AltiVec, /*AllowBoolConversions*/false); + if (!IsDiv && (LHS.get()->getType()->isMatrixType() || + RHS.get()->getType()->isMatrixType())) + return CheckMatrixMultiplyOperands(LHS, RHS, Loc); QualType compType = UsualArithmeticConversions( LHS, RHS, Loc, IsCompAssign ? ACK_CompAssign : ACK_Arithmetic); @@ -11578,6 +11581,41 @@ // assert(LHSMatType || RHSMatType); } +QualType Sema::CheckMatrixMultiplyOperands(ExprResult &LHS, ExprResult &RHS, + SourceLocation Loc) { + // For conversion purposes, we ignore any qualifiers. + // For example, "const float" and "float" are equivalent. + QualType LHSType = LHS.get()->getType().getUnqualifiedType(); + QualType RHSType = RHS.get()->getType().getUnqualifiedType(); + + const MatrixType *LHSMatType = LHSType->getAs<MatrixType>(); + const MatrixType *RHSMatType = RHSType->getAs<MatrixType>(); + assert((LHSMatType || RHSMatType) && "At least one operand must be a matrix"); + + if (LHSMatType && !RHSMatType) { + if (!Context.hasSameType(LHSMatType->getElementType(), RHSType)) + return InvalidOperands(Loc, LHS, RHS); + return LHSType; + } + + if (!LHSMatType && RHSMatType) { + if (!Context.hasSameType(LHSType, RHSMatType->getElementType())) + return InvalidOperands(Loc, LHS, RHS); + return RHSType; + } + + if (LHSMatType->getNumColumns() != RHSMatType->getNumRows()) + return InvalidOperands(Loc, LHS, RHS); + + if (!Context.hasSameType(LHSMatType->getElementType(), + RHSMatType->getElementType())) + return InvalidOperands(Loc, LHS, RHS); + + return Context.getMatrixType(LHSMatType->getElementType(), + LHSMatType->getNumRows(), + RHSMatType->getNumColumns()); +} + inline QualType Sema::CheckBitwiseOperands(ExprResult &LHS, ExprResult &RHS, SourceLocation Loc, BinaryOperatorKind Opc) { Index: clang/lib/CodeGen/CGExprScalar.cpp =================================================================== --- clang/lib/CodeGen/CGExprScalar.cpp +++ clang/lib/CodeGen/CGExprScalar.cpp @@ -738,6 +738,22 @@ } } + if (Ops.Ty->isMatrixType()) { + llvm::MatrixBuilder<CGBuilderTy> MB(Builder); + // We need to check the types of the operands of the operator to get the + // correct matrix dimensions. + auto *BO = cast<BinaryOperator>(Ops.E); + auto *LHSMatTy = + dyn_cast<MatrixType>(BO->getLHS()->getType().getCanonicalType()); + auto *RHSMatTy = + dyn_cast<MatrixType>(BO->getRHS()->getType().getCanonicalType()); + if (LHSMatTy && RHSMatTy) + return MB.CreateMatrixMultiply(Ops.LHS, Ops.RHS, LHSMatTy->getNumRows(), + LHSMatTy->getNumColumns(), + RHSMatTy->getNumColumns()); + return MB.CreateScalarMultiply(Ops.LHS, Ops.RHS); + } + if (Ops.Ty->isUnsignedIntegerType() && CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) && !CanElideOverflowCheck(CGF.getContext(), Ops)) Index: clang/include/clang/Sema/Sema.h =================================================================== --- clang/include/clang/Sema/Sema.h +++ clang/include/clang/Sema/Sema.h @@ -11077,6 +11077,8 @@ /// Type checking for matrix binary operators. QualType CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS, SourceLocation Loc); + QualType CheckMatrixMultiplyOperands(ExprResult &LHS, ExprResult &RHS, + SourceLocation Loc); bool areLaxCompatibleVectorTypes(QualType srcType, QualType destType); bool isLaxVectorConversion(QualType srcType, QualType destType);
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits