fhahn updated this revision to Diff 254601. fhahn added a comment. Implement conversion for matrix/scalar variants.
Repository: rG LLVM Github Monorepo CHANGES SINCE LAST ACTION https://reviews.llvm.org/D76794/new/ https://reviews.llvm.org/D76794 Files: clang/include/clang/Sema/Sema.h clang/lib/CodeGen/CGExprScalar.cpp clang/lib/Sema/SemaExpr.cpp clang/test/CodeGen/matrix-type-operators.c clang/test/CodeGenCXX/matrix-type-operators.cpp clang/test/Sema/matrix-type-operators.c clang/test/SemaCXX/matrix-type-operators.cpp llvm/include/llvm/IR/MatrixBuilder.h
Index: llvm/include/llvm/IR/MatrixBuilder.h =================================================================== --- llvm/include/llvm/IR/MatrixBuilder.h +++ llvm/include/llvm/IR/MatrixBuilder.h @@ -33,6 +33,19 @@ IRBuilderTy &B; Module *getModule() { return B.GetInsertBlock()->getParent()->getParent(); } + std::pair<Value *, Value *> splatScalarOperandIfNeeded(Value *LHS, Value *RHS) { + assert((LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()) && "One of the operands must be a matrix (embedded in a vector)"); + if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) + RHS = B.CreateVectorSplat( + cast<VectorType>(LHS->getType())->getNumElements(), RHS, + "scalar.splat"); + else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) + LHS = B.CreateVectorSplat( + cast<VectorType>(RHS->getType())->getNumElements(), LHS, + "scalar.splat"); + return {LHS, RHS}; +} + public: MatrixBuilder(IRBuilderTy &Builder) : B(Builder) {} @@ -127,16 +140,7 @@ /// Add matrixes \p LHS and \p RHS. Support both integer and floating point /// matrixes. Value *CreateAdd(Value *LHS, Value *RHS) { - assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); - if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) - RHS = B.CreateVectorSplat( - cast<VectorType>(LHS->getType())->getNumElements(), RHS, - "scalar.splat"); - else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) - LHS = B.CreateVectorSplat( - cast<VectorType>(RHS->getType())->getNumElements(), LHS, - "scalar.splat"); - + std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS); return cast<VectorType>(LHS->getType()) ->getElementType() ->isFloatingPointTy() @@ -147,16 +151,7 @@ /// Subtract matrixes \p LHS and \p RHS. Support both integer and floating /// point matrixes. Value *CreateSub(Value *LHS, Value *RHS) { - assert(LHS->getType()->isVectorTy() || RHS->getType()->isVectorTy()); - if (LHS->getType()->isVectorTy() && !RHS->getType()->isVectorTy()) - RHS = B.CreateVectorSplat( - cast<VectorType>(LHS->getType())->getNumElements(), RHS, - "scalar.splat"); - else if (!LHS->getType()->isVectorTy() && RHS->getType()->isVectorTy()) - LHS = B.CreateVectorSplat( - cast<VectorType>(RHS->getType())->getNumElements(), LHS, - "scalar.splat"); - + std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS); return cast<VectorType>(LHS->getType()) ->getElementType() ->isFloatingPointTy() @@ -164,15 +159,13 @@ : B.CreateSub(LHS, RHS); } - /// Multiply matrix \p LHS with scalar \p RHS. + /// Multiply matrix \p LHS with scalar \p RHS or scalar \p LHS with matrix \p + /// RHS. Value *CreateScalarMultiply(Value *LHS, Value *RHS) { - Value *ScalarVector = - B.CreateVectorSplat(cast<VectorType>(LHS->getType())->getNumElements(), - RHS, "scalar.splat"); - if (RHS->getType()->isFloatingPointTy()) - return B.CreateFMul(LHS, ScalarVector); - - return B.CreateMul(LHS, ScalarVector); + std::tie(LHS, RHS) = splatScalarOperandIfNeeded(LHS, RHS); + if (LHS->getType()->getScalarType()->isFloatingPointTy()) + return B.CreateFMul(LHS, RHS); + return B.CreateMul(LHS, RHS); } /// Extracts the element at (\p Row, \p Column) from \p Matrix. Index: clang/test/SemaCXX/matrix-type-operators.cpp =================================================================== --- clang/test/SemaCXX/matrix-type-operators.cpp +++ clang/test/SemaCXX/matrix-type-operators.cpp @@ -129,3 +129,26 @@ Mat1.value = subtract<unsigned, 3, 3, float, 2, 2, unsigned, 2, 2>(Mat2, Mat3); // expected-note@-1 {{in instantiation of function template specialization 'subtract<unsigned int, 3, 3, float, 2, 2, unsigned int, 2, 2>' requested here}} } + +template <typename EltTy0, unsigned R0, unsigned C0, typename EltTy1, unsigned R1, unsigned C1, typename EltTy2, unsigned R2, unsigned C2> +typename MyMatrix<EltTy2, R2, C2>::matrix_t multiply(MyMatrix<EltTy0, R0, C0> &A, MyMatrix<EltTy1, R1, C1> &B) { + char *v1 = A.value * B.value; + // expected-error@-1 {{cannot initialize a variable of type 'char *' with an rvalue of type 'unsigned int __attribute__((matrix_type(2, 2)))'}} + // expected-error@-2 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}} + + return A.value * B.value; + // expected-error@-1 {{invalid operands to binary expression ('MyMatrix<unsigned int, 2, 2>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))') and 'MyMatrix<unsigned int, 3, 3>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(3, 3)))'))}} +} + +void test_multiply_template(unsigned *Ptr1, float *Ptr2) { + MyMatrix<unsigned, 2, 2> Mat1; + MyMatrix<unsigned, 3, 3> Mat2; + MyMatrix<float, 2, 2> Mat3; + Mat1.value = *((decltype(Mat1)::matrix_t *)Ptr1); + unsigned v1 = multiply<unsigned, 2, 2, unsigned, 2, 2, unsigned, 2, 2>(Mat1, Mat1); + // expected-note@-1 {{in instantiation of function template specialization 'multiply<unsigned int, 2, 2, unsigned int, 2, 2, unsigned int, 2, 2>' requested here}} + // expected-error@-2 {{cannot initialize a variable of type 'unsigned int' with an rvalue of type 'typename MyMatrix<unsigned int, 2U, 2U>::matrix_t' (aka 'unsigned int __attribute__((matrix_type(2, 2)))')}} + + Mat1.value = multiply<unsigned, 2, 2, unsigned, 3, 3, unsigned, 2, 2>(Mat1, Mat2); + // expected-note@-1 {{in instantiation of function template specialization 'multiply<unsigned int, 2, 2, unsigned int, 3, 3, unsigned int, 2, 2>' requested here}} +} Index: clang/test/Sema/matrix-type-operators.c =================================================================== --- clang/test/Sema/matrix-type-operators.c +++ clang/test/Sema/matrix-type-operators.c @@ -94,3 +94,22 @@ a = b - &c; // expected-error@-1 {{invalid operands to binary expression ('sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') and 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*'))}} } + +void mat_mat_multiply(sx10x10_t a, sx5x10_t b, sx10x5_t c) { + // Invalid dimensions for operands. + a = c * c; + // expected-error@-1 {{invalid operands to binary expression ('sx10x5_t' (aka 'float __attribute__((matrix_type(10, 5)))') and 'sx10x5_t')}} + + // Shape of multiplication result does not match the type of b. + b = a * a; + // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'float __attribute__((matrix_type(10, 10)))'}} + + b = a * &c; + // expected-error@-1 {{invalid operands to binary expression ('sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))') and 'sx10x5_t *' (aka 'float __attribute__((matrix_type(10, 5)))*'))}} +} + +void mat_scalar_multiply(sx10x10_t a, sx5x10_t b, float scalar) { + // Shape of multiplication result does not match the type of b. + b = a * scalar; + // expected-error@-1 {{assigning to 'sx5x10_t' (aka 'float __attribute__((matrix_type(5, 10)))') from incompatible type 'sx10x10_t' (aka 'float __attribute__((matrix_type(10, 10)))')}} +} Index: clang/test/CodeGenCXX/matrix-type-operators.cpp =================================================================== --- clang/test/CodeGenCXX/matrix-type-operators.cpp +++ clang/test/CodeGenCXX/matrix-type-operators.cpp @@ -285,3 +285,53 @@ MyMatrix<float, 2, 5> Mat2; Mat1.value = subtract(Mat1, Mat2); } + +void multiply1_matrix(dx5x5_t *a, dx5x5_t *b, dx5x5_t *c) { + *a = *b * *c; + + // CHECK-LABEL: @_Z16multiply1_matrixPDm5_5_dS0_S0_([25 x double]* %a, [25 x double]* %b, [25 x double]* %c) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %c.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: store [25 x double]* %c, [25 x double]** %c.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load [25 x double]*, [25 x double]** %c.addr, align 8 + // CHECK-NEXT: %4 = bitcast [25 x double]* %3 to <25 x double>* + // CHECK-NEXT: %5 = load <25 x double>, <25 x double>* %4, align 8 + // CHECK-NEXT: %6 = call <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double> %2, <25 x double> %5, i32 5, i32 5, i32 5) + // CHECK-NEXT: %7 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %8 = bitcast [25 x double]* %7 to <25 x double>* + // CHECK-NEXT: store <25 x double> %6, <25 x double>* %8, align 8 + // CHECK-NEXT: ret void +} + +// CHECK: declare <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double>, <25 x double>, i32 immarg, i32 immarg, i32 immarg) + +void multiply1_scalar(dx5x5_t *a, dx5x5_t *b, double c) { + *a = *b * c; + + // CHECK-LABEL:@_Z16multiply1_scalarPDm5_5_dS0_d([25 x double]* %a, [25 x double]* %b, double %c + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %c.addr = alloca double, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: store double %c, double* %c.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load double, double* %c.addr, align 8 + // CHECK-NEXT: %scalar.splat.splatinsert = insertelement <25 x double> undef, double %3, i32 0 + // CHECK-NEXT: %scalar.splat.splat = shufflevector <25 x double> %scalar.splat.splatinsert, <25 x double> undef, <25 x i32> zeroinitializer + // CHECK-NEXT: %4 = fmul <25 x double> %2, %scalar.splat.splat + // CHECK-NEXT: %5 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %6 = bitcast [25 x double]* %5 to <25 x double>* + // CHECK-NEXT: store <25 x double> %4, <25 x double>* %6, align 8 + // CHECK-NEXT: ret void +} Index: clang/test/CodeGen/matrix-type-operators.c =================================================================== --- clang/test/CodeGen/matrix-type-operators.c +++ clang/test/CodeGen/matrix-type-operators.c @@ -463,3 +463,211 @@ // CHECK-NEXT: store <27 x i32> %19, <27 x i32>* %1, align 4 // CHECK-NEXT: ret void } + +void multiply_matrix_matrix(dx5x5_t *a, dx5x5_t *b, dx5x5_t *c) { + *a = *b * *c; + + // CHECK-LABEL: @multiply_matrix_matrix( + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %c.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: store [25 x double]* %c, [25 x double]** %c.addr, align 8 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load [25 x double]*, [25 x double]** %c.addr, align 8 + // CHECK-NEXT: %4 = bitcast [25 x double]* %3 to <25 x double>* + // CHECK-NEXT: %5 = load <25 x double>, <25 x double>* %4, align 8 + // CHECK-NEXT: %6 = call <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double> %2, <25 x double> %5, i32 5, i32 5, i32 5) + // CHECK-NEXT: %7 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %8 = bitcast [25 x double]* %7 to <25 x double>* + // CHECK-NEXT: store <25 x double> %6, <25 x double>* %8, align 8 + // CHECK-NEXT: ret void +} + +// CHECK: declare <25 x double> @llvm.matrix.multiply.v25f64.v25f64.v25f64(<25 x double>, <25 x double>, i32 immarg, i32 immarg, i32 immarg) [[READNONE:#[0-9]]] +// +void multiply_matrix_scalar_floats(dx5x5_t *a, fx2x3_t *b, double vf, float vd) { + *a = *a * vf; + *a = *a * vd; + + // CHECK-LABEL: define void @multiply_matrix_scalar_floats([25 x double]* %a, [6 x float]* %b, double %vf, float %vd) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8 + // CHECK-NEXT: %b.addr = alloca [6 x float]*, align 8 + // CHECK-NEXT: %vf.addr = alloca double, align 8 + // CHECK-NEXT: %vd.addr = alloca float, align 4 + // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: store [6 x float]* %b, [6 x float]** %b.addr, align 8 + // CHECK-NEXT: store double %vf, double* %vf.addr, align 8 + // CHECK-NEXT: store float %vd, float* %vd.addr, align 4 + // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>* + // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8 + // CHECK-NEXT: %3 = load double, double* %vf.addr, align 8 + // CHECK-NEXT: %scalar.splat.splatinsert = insertelement <25 x double> undef, double %3, i32 0 + // CHECK-NEXT: %scalar.splat.splat = shufflevector <25 x double> %scalar.splat.splatinsert, <25 x double> undef, <25 x i32> zeroinitializer + // CHECK-NEXT: %4 = fmul <25 x double> %2, %scalar.splat.splat + // CHECK-NEXT: %5 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %6 = bitcast [25 x double]* %5 to <25 x double>* + // CHECK-NEXT: store <25 x double> %4, <25 x double>* %6, align 8 + // CHECK-NEXT: %7 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %8 = bitcast [25 x double]* %7 to <25 x double>* + // CHECK-NEXT: %9 = load <25 x double>, <25 x double>* %8, align 8 + // CHECK-NEXT: %10 = load float, float* %vd.addr, align 4 + // CHECK-NEXT: %conv = fpext float %10 to double + // CHECK-NEXT: %scalar.splat.splatinsert1 = insertelement <25 x double> undef, double %conv, i32 0 + // CHECK-NEXT: %scalar.splat.splat2 = shufflevector <25 x double> %scalar.splat.splatinsert1, <25 x double> undef, <25 x i32> zeroinitializer + // CHECK-NEXT: %11 = fmul <25 x double> %9, %scalar.splat.splat2 + // CHECK-NEXT: %12 = load [25 x double]*, [25 x double]** %a.addr, align 8 + // CHECK-NEXT: %13 = bitcast [25 x double]* %12 to <25 x double>* + // CHECK-NEXT: store <25 x double> %11, <25 x double>* %13, align 8 + + *b = vf * *b; + *b = vd * *b; + + // CHECK-NEXT: %14 = load double, double* %vf.addr, align 8 + // CHECK-NEXT: %conv3 = fptrunc double %14 to float + // CHECK-NEXT: %15 = load [6 x float]*, [6 x float]** %b.addr, align 8 + // CHECK-NEXT: %16 = bitcast [6 x float]* %15 to <6 x float>* + // CHECK-NEXT: %17 = load <6 x float>, <6 x float>* %16, align 4 + // CHECK-NEXT: %scalar.splat.splatinsert4 = insertelement <6 x float> undef, float %conv3, i32 0 + // CHECK-NEXT: %scalar.splat.splat5 = shufflevector <6 x float> %scalar.splat.splatinsert4, <6 x float> undef, <6 x i32> zeroinitializer + // CHECK-NEXT: %18 = fmul <6 x float> %scalar.splat.splat5, %17 + // CHECK-NEXT: %19 = load [6 x float]*, [6 x float]** %b.addr, align 8 + // CHECK-NEXT: %20 = bitcast [6 x float]* %19 to <6 x float>* + // CHECK-NEXT: store <6 x float> %18, <6 x float>* %20, align 4 + // CHECK-NEXT: %21 = load float, float* %vd.addr, align 4 + // CHECK-NEXT: %22 = load [6 x float]*, [6 x float]** %b.addr, align 8 + // CHECK-NEXT: %23 = bitcast [6 x float]* %22 to <6 x float>* + // CHECK-NEXT: %24 = load <6 x float>, <6 x float>* %23, align 4 + // CHECK-NEXT: %scalar.splat.splatinsert6 = insertelement <6 x float> undef, float %21, i32 0 + // CHECK-NEXT: %scalar.splat.splat7 = shufflevector <6 x float> %scalar.splat.splatinsert6, <6 x float> undef, <6 x i32> zeroinitializer + // CHECK-NEXT: %25 = fmul <6 x float> %scalar.splat.splat7, %24 + // CHECK-NEXT: %26 = load [6 x float]*, [6 x float]** %b.addr, align 8 + // CHECK-NEXT: %27 = bitcast [6 x float]* %26 to <6 x float>* + // CHECK-NEXT: store <6 x float> %25, <6 x float>* %27, align 4 + // CHECK-NEXT: ret void +} + +void multiply_matrix_scalar_ints(ix9x3_t a, llix9x3_t b, short vs, long int vli, unsigned long long int vulli) { + a = a * vs; + a = a * vli; + a = a * vulli; + + // CHECK-LABEL: define void @multiply_matrix_scalar_ints(<27 x i32> %a, <27 x i32> %b, i16 signext %vs, i64 %vli, i64 %vulli) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [27 x i32], align 4 + // CHECK-NEXT: %b.addr = alloca [27 x i32], align 4 + // CHECK-NEXT: %vs.addr = alloca i16, align 2 + // CHECK-NEXT: %vli.addr = alloca i64, align 8 + // CHECK-NEXT: %vulli.addr = alloca i64, align 8 + // CHECK-NEXT: %0 = bitcast [27 x i32]* %a.addr to <27 x i32>* + // CHECK-NEXT: store <27 x i32> %a, <27 x i32>* %0, align 4 + // CHECK-NEXT: %1 = bitcast [27 x i32]* %b.addr to <27 x i32>* + // CHECK-NEXT: store <27 x i32> %b, <27 x i32>* %1, align 4 + // CHECK-NEXT: store i16 %vs, i16* %vs.addr, align 2 + // CHECK-NEXT: store i64 %vli, i64* %vli.addr, align 8 + // CHECK-NEXT: store i64 %vulli, i64* %vulli.addr, align 8 + // CHECK-NEXT: %2 = load <27 x i32>, <27 x i32>* %0, align 4 + // CHECK-NEXT: %3 = load i16, i16* %vs.addr, align 2 + // CHECK-NEXT: %conv = sext i16 %3 to i32 + // CHECK-NEXT: %scalar.splat.splatinsert = insertelement <27 x i32> undef, i32 %conv, i32 0 + // CHECK-NEXT: %scalar.splat.splat = shufflevector <27 x i32> %scalar.splat.splatinsert, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %4 = mul <27 x i32> %2, %scalar.splat.splat + // CHECK-NEXT: store <27 x i32> %4, <27 x i32>* %0, align 4 + // CHECK-NEXT: %5 = load <27 x i32>, <27 x i32>* %0, align 4 + // CHECK-NEXT: %6 = load i64, i64* %vli.addr, align 8 + // CHECK-NEXT: %conv1 = trunc i64 %6 to i32 + // CHECK-NEXT: %scalar.splat.splatinsert2 = insertelement <27 x i32> undef, i32 %conv1, i32 0 + // CHECK-NEXT: %scalar.splat.splat3 = shufflevector <27 x i32> %scalar.splat.splatinsert2, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %7 = mul <27 x i32> %5, %scalar.splat.splat3 + // CHECK-NEXT: store <27 x i32> %7, <27 x i32>* %0, align 4 + // CHECK-NEXT: %8 = load <27 x i32>, <27 x i32>* %0, align 4 + // CHECK-NEXT: %9 = load i64, i64* %vulli.addr, align 8 + // CHECK-NEXT: %conv4 = trunc i64 %9 to i32 + // CHECK-NEXT: %scalar.splat.splatinsert5 = insertelement <27 x i32> undef, i32 %conv4, i32 0 + // CHECK-NEXT: %scalar.splat.splat6 = shufflevector <27 x i32> %scalar.splat.splatinsert5, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %10 = mul <27 x i32> %8, %scalar.splat.splat6 + // CHECK-NEXT: store <27 x i32> %10, <27 x i32>* %0, align 4 + + b = vs * b; + b = vli * b; + b = vulli * b; + // CHECK-NEXT: %11 = load i16, i16* %vs.addr, align 2 + // CHECK-NEXT: %conv7 = sext i16 %11 to i32 + // CHECK-NEXT: %12 = load <27 x i32>, <27 x i32>* %1, align 4 + // CHECK-NEXT: %scalar.splat.splatinsert8 = insertelement <27 x i32> undef, i32 %conv7, i32 0 + // CHECK-NEXT: %scalar.splat.splat9 = shufflevector <27 x i32> %scalar.splat.splatinsert8, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %13 = mul <27 x i32> %scalar.splat.splat9, %12 + // CHECK-NEXT: store <27 x i32> %13, <27 x i32>* %1, align 4 + // CHECK-NEXT: %14 = load i64, i64* %vli.addr, align 8 + // CHECK-NEXT: %conv10 = trunc i64 %14 to i32 + // CHECK-NEXT: %15 = load <27 x i32>, <27 x i32>* %1, align 4 + // CHECK-NEXT: %scalar.splat.splatinsert11 = insertelement <27 x i32> undef, i32 %conv10, i32 0 + // CHECK-NEXT: %scalar.splat.splat12 = shufflevector <27 x i32> %scalar.splat.splatinsert11, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %16 = mul <27 x i32> %scalar.splat.splat12, %15 + // CHECK-NEXT: store <27 x i32> %16, <27 x i32>* %1, align 4 + // CHECK-NEXT: %17 = load i64, i64* %vulli.addr, align 8 + // CHECK-NEXT: %conv13 = trunc i64 %17 to i32 + // CHECK-NEXT: %18 = load <27 x i32>, <27 x i32>* %1, align 4 + // CHECK-NEXT: %scalar.splat.splatinsert14 = insertelement <27 x i32> undef, i32 %conv13, i32 0 + // CHECK-NEXT: %scalar.splat.splat15 = shufflevector <27 x i32> %scalar.splat.splatinsert14, <27 x i32> undef, <27 x i32> zeroinitializer + // CHECK-NEXT: %19 = mul <27 x i32> %scalar.splat.splat15, %18 + // CHECK-NEXT: store <27 x i32> %19, <27 x i32>* %1, align 4 + // CHECK-NEXT: ret void +} + +void multiply_matrix_scalar_constants(ix9x3_t a, fx2x3_t b, dx5x5_t c) { + a = a * 10; + a = a * 20ull; + a = a * 30ll; + + // CHECK-LABEL: define void @multiply_matrix_scalar_constants(<27 x i32> %a, <6 x float> %b, <25 x double> %c) + // CHECK-NEXT: entry: + // CHECK-NEXT: %a.addr = alloca [27 x i32], align 4 + // CHECK-NEXT: %b.addr = alloca [6 x float], align 4 + // CHECK-NEXT: %c.addr = alloca [25 x double], align 8 + // CHECK-NEXT: %0 = bitcast [27 x i32]* %a.addr to <27 x i32>* + // CHECK-NEXT: store <27 x i32> %a, <27 x i32>* %0, align 4 + // CHECK-NEXT: %1 = bitcast [6 x float]* %b.addr to <6 x float>* + // CHECK-NEXT: store <6 x float> %b, <6 x float>* %1, align 4 + // CHECK-NEXT: %2 = bitcast [25 x double]* %c.addr to <25 x double>* + // CHECK-NEXT: store <25 x double> %c, <25 x double>* %2, align 8 + // CHECK-NEXT: %3 = load <27 x i32>, <27 x i32>* %0, align 4 + // CHECK-NEXT: %4 = mul <27 x i32> %3, <i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10, i32 10> + // CHECK-NEXT: store <27 x i32> %4, <27 x i32>* %0, align 4 + // CHECK-NEXT: %5 = load <27 x i32>, <27 x i32>* %0, align 4 + // CHECK-NEXT: %6 = mul <27 x i32> %5, <i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20, i32 20> + // CHECK-NEXT: store <27 x i32> %6, <27 x i32>* %0, align 4 + // CHECK-NEXT: %7 = load <27 x i32>, <27 x i32>* %0, align 4 + // CHECK-NEXT: %8 = mul <27 x i32> %7, <i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30, i32 30> + // CHECK-NEXT: store <27 x i32> %8, <27 x i32>* %0, align 4 + + b = 10.0 * b; + b = ((float)20.0) * b; + + // CHECK-NEXT: %9 = load <6 x float>, <6 x float>* %1, align 4 + // CHECK-NEXT: %10 = fmul <6 x float> <float 1.000000e+01, float 1.000000e+01, float 1.000000e+01, float 1.000000e+01, float 1.000000e+01, float 1.000000e+01>, %9 + // CHECK-NEXT: store <6 x float> %10, <6 x float>* %1, align 4 + // CHECK-NEXT: %11 = load <6 x float>, <6 x float>* %1, align 4 + // CHECK-NEXT: %12 = fmul <6 x float> <float 2.000000e+01, float 2.000000e+01, float 2.000000e+01, float 2.000000e+01, float 2.000000e+01, float 2.000000e+01>, %11 + // CHECK-NEXT: store <6 x float> %12, <6 x float>* %1, align 4 + + c = 10.0 * c; + c = ((float)20.0) * c; + + // CHECK-NEXT: %13 = load <25 x double>, <25 x double>* %2, align 8 + // CHECK-NEXT: %14 = fmul <25 x double> <double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01, double 1.000000e+01>, %13 + // CHECK-NEXT: store <25 x double> %14, <25 x double>* %2, align 8 + // CHECK-NEXT: %15 = load <25 x double>, <25 x double>* %2, align 8 + // CHECK-NEXT: %16 = fmul <25 x double> <double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01, double 2.000000e+01>, %15 + // CHECK-NEXT: store <25 x double> %16, <25 x double>* %2, align 8 + // CHECK-NEXT: ret void +} + + +// CHECK: attributes [[READNONE]] = { nounwind readnone speculatable willreturn } Index: clang/lib/Sema/SemaExpr.cpp =================================================================== --- clang/lib/Sema/SemaExpr.cpp +++ clang/lib/Sema/SemaExpr.cpp @@ -9639,6 +9639,9 @@ return CheckVectorOperands(LHS, RHS, Loc, IsCompAssign, /*AllowBothBool*/getLangOpts().AltiVec, /*AllowBoolConversions*/false); + if (!IsDiv && (LHS.get()->getType()->isMatrixType() || + RHS.get()->getType()->isMatrixType())) + return CheckMatrixMultiplyOperands(LHS, RHS, Loc, IsCompAssign); QualType compType = UsualArithmeticConversions( LHS, RHS, Loc, IsCompAssign ? ACK_CompAssign : ACK_Arithmetic); @@ -11720,6 +11723,33 @@ return InvalidOperands(Loc, LHS, RHS); } +QualType Sema::CheckMatrixMultiplyOperands(ExprResult &LHS, ExprResult &RHS, + SourceLocation Loc, bool IsCompAssign) { + + // For conversion purposes, we ignore any qualifiers. + // For example, "const float" and "float" are equivalent. + QualType LHSType = LHS.get()->getType().getUnqualifiedType(); + QualType RHSType = RHS.get()->getType().getUnqualifiedType(); + + const MatrixType *LHSMatType = LHSType->getAs<MatrixType>(); + const MatrixType *RHSMatType = RHSType->getAs<MatrixType>(); + assert((LHSMatType || RHSMatType) && "At least one operand must be a matrix"); + + if (LHSMatType && RHSMatType) { + if (LHSMatType->getNumColumns() != RHSMatType->getNumRows()) + return InvalidOperands(Loc, LHS, RHS); + + if (!Context.hasSameType(LHSMatType->getElementType(), + RHSMatType->getElementType())) + return InvalidOperands(Loc, LHS, RHS); + + return Context.getMatrixType(LHSMatType->getElementType(), + LHSMatType->getNumRows(), + RHSMatType->getNumColumns()); + } + return CheckMatrixElementwiseOperands(LHS, RHS, Loc, IsCompAssign); +} + inline QualType Sema::CheckBitwiseOperands(ExprResult &LHS, ExprResult &RHS, SourceLocation Loc, BinaryOperatorKind Opc) { Index: clang/lib/CodeGen/CGExprScalar.cpp =================================================================== --- clang/lib/CodeGen/CGExprScalar.cpp +++ clang/lib/CodeGen/CGExprScalar.cpp @@ -738,6 +738,22 @@ } } + if (Ops.Ty->isMatrixType()) { + llvm::MatrixBuilder<CGBuilderTy> MB(Builder); + // We need to check the types of the operands of the operator to get the + // correct matrix dimensions. + auto *BO = cast<BinaryOperator>(Ops.E); + auto *LHSMatTy = + dyn_cast<MatrixType>(BO->getLHS()->getType().getCanonicalType()); + auto *RHSMatTy = + dyn_cast<MatrixType>(BO->getRHS()->getType().getCanonicalType()); + if (LHSMatTy && RHSMatTy) + return MB.CreateMatrixMultiply(Ops.LHS, Ops.RHS, LHSMatTy->getNumRows(), + LHSMatTy->getNumColumns(), + RHSMatTy->getNumColumns()); + return MB.CreateScalarMultiply(Ops.LHS, Ops.RHS); + } + if (Ops.Ty->isUnsignedIntegerType() && CGF.SanOpts.has(SanitizerKind::UnsignedIntegerOverflow) && !CanElideOverflowCheck(CGF.getContext(), Ops)) Index: clang/include/clang/Sema/Sema.h =================================================================== --- clang/include/clang/Sema/Sema.h +++ clang/include/clang/Sema/Sema.h @@ -11143,6 +11143,9 @@ QualType CheckMatrixElementwiseOperands(ExprResult &LHS, ExprResult &RHS, SourceLocation Loc, bool IsCompAssign); + QualType CheckMatrixMultiplyOperands(ExprResult &LHS, ExprResult &RHS, + SourceLocation Loc, + bool IsCompAssign); bool areLaxCompatibleVectorTypes(QualType srcType, QualType destType); bool isLaxVectorConversion(QualType srcType, QualType destType);
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits