fhahn created this revision.
Herald added a subscriber: tschuett.
Herald added a project: clang.
Repository:
rG LLVM Github Monorepo
https://reviews.llvm.org/D72782
Files:
clang/include/clang/Basic/Builtins.def
clang/include/clang/Sema/Sema.h
clang/lib/CodeGen/CGBuiltin.cpp
clang/lib/Sema/SemaChecking.cpp
clang/test/CodeGen/builtin-matrix.c
Index: clang/test/CodeGen/builtin-matrix.c
===================================================================
--- clang/test/CodeGen/builtin-matrix.c
+++ clang/test/CodeGen/builtin-matrix.c
@@ -289,5 +289,24 @@
}
// CHECK: declare <25 x double> @llvm.matrix.columnwise.load.v25f64.p0f64(double*, i32, i32 immarg, i32 immarg) [[READONLY:#[0-9]]]
+void column_store1(dx5x5_t *a, double *b) {
+ __builtin_matrix_column_store(*a, b, 10);
+
+ // CHECK-LABEL: @column_store1(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8
+ // CHECK-NEXT: %b.addr = alloca double*, align 8
+ // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8
+ // CHECK-NEXT: store double* %b, double** %b.addr, align 8
+ // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %a.addr, align 8
+ // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>*
+ // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8
+ // CHECK-NEXT: %3 = load double*, double** %b.addr, align 8
+ // CHECK-NEXT: call void @llvm.matrix.columnwise.store.v25f64.p0f64(<25 x double> %2, double* %3, i32 10, i32 5, i32 5)
+ // CHECK-NEXT: ret void
+}
+// CHECK: declare void @llvm.matrix.columnwise.store.v25f64.p0f64(<25 x double>, double* writeonly, i32, i32 immarg, i32 immarg) [[WRITEONLY:#[0-9]]]
+
// CHECK: attributes [[READNONE]] = { nounwind readnone speculatable willreturn }
// CHECK: attributes [[READONLY]] = { nounwind readonly willreturn }
+// CHECK: attributes [[WRITEONLY]] = { nounwind willreturn }
Index: clang/lib/Sema/SemaChecking.cpp
===================================================================
--- clang/lib/Sema/SemaChecking.cpp
+++ clang/lib/Sema/SemaChecking.cpp
@@ -1620,6 +1620,7 @@
case Builtin::BI__builtin_matrix_multiply:
case Builtin::BI__builtin_matrix_transpose:
case Builtin::BI__builtin_matrix_column_load:
+ case Builtin::BI__builtin_matrix_column_store:
if (!getLangOpts().EnableMatrix) {
Diag(TheCall->getBeginLoc(), diag::err_builtin_matrix_disabled);
return ExprError();
@@ -1639,6 +1640,8 @@
return SemaBuiltinMatrixTransposeOverload(TheCall, TheCallResult);
case Builtin::BI__builtin_matrix_column_load:
return SemaBuiltinMatrixColumnLoadOverload(TheCall, TheCallResult);
+ case Builtin::BI__builtin_matrix_column_store:
+ return SemaBuiltinMatrixColumnStoreOverload(TheCall, TheCallResult);
default:
llvm_unreachable("All matrix builtins should be handled here!");
}
@@ -15651,3 +15654,81 @@
return CallResult;
}
+
+ExprResult Sema::SemaBuiltinMatrixColumnStoreOverload(CallExpr *TheCall,
+ ExprResult CallResult) {
+ // Must have
+ // 1: Matrix to store
+ // 2: Pointer to store to
+ // 3: Stride (unsigned)
+
+ if (checkArgCount(*this, TheCall, 3))
+ return ExprError();
+
+ Expr *MatrixExpr = TheCall->getArg(0);
+ Expr *DataExpr = TheCall->getArg(1);
+ Expr *StrideExpr = TheCall->getArg(2);
+
+ bool ArgError = false;
+ if (!MatrixExpr->getType()->isMatrixType()) {
+ Diag(MatrixExpr->getBeginLoc(), diag::err_builtin_matrix_arg) << 0;
+ ArgError = true;
+ }
+ if (!DataExpr->getType()->isPointerType()) {
+ Diag(DataExpr->getBeginLoc(), diag::err_builtin_matrix_pointer_arg)
+ << 1 << 0;
+ ArgError = true;
+ }
+ if (!StrideExpr->getType()->isIntegralType(Context)) {
+ Diag(StrideExpr->getBeginLoc(), diag::err_builtin_matrix_scalar_int_arg)
+ << 3 << 0;
+ ArgError = true;
+ }
+ if (ArgError)
+ return ExprError();
+
+ // TODO: Check element type compatibility, and possibly up/down cast element
+ // types
+
+ // Cast matrix to an rvalue
+ if (!MatrixExpr->isRValue()) {
+ ExprResult CastExprResult = ImplicitCastExpr::Create(
+ Context, MatrixExpr->getType(), CK_LValueToRValue, MatrixExpr, nullptr,
+ VK_RValue);
+ assert(!CastExprResult.isInvalid() && "Matrix cast to an R-value failed");
+ MatrixExpr = CastExprResult.get();
+ TheCall->setArg(0, MatrixExpr);
+ }
+
+ if (!DataExpr->isRValue()) {
+ ExprResult CastExprResult = ImplicitCastExpr::Create(
+ Context, DataExpr->getType(), CK_LValueToRValue, DataExpr, nullptr,
+ VK_RValue);
+ assert(!CastExprResult.isInvalid() && "Pointer cast to R-value failed");
+ DataExpr = CastExprResult.get();
+ TheCall->setArg(1, DataExpr);
+ }
+
+ llvm::SmallVector<QualType, 4> ParameterTypes = {
+ MatrixExpr->getType().withConst(), DataExpr->getType(),
+ StrideExpr->getType().withConst()};
+
+ Expr *Callee = TheCall->getCallee();
+ DeclRefExpr *DRE = cast<DeclRefExpr>(Callee->IgnoreParenCasts());
+ FunctionDecl *FDecl = cast<FunctionDecl>(DRE->getDecl());
+
+ // Create a new DeclRefExpr to refer to the new decl.
+ DeclRefExpr *NewDRE = DeclRefExpr::Create(
+ Context, DRE->getQualifierLoc(), SourceLocation(), FDecl,
+ /*enclosing*/ false, DRE->getLocation(), Context.BuiltinFnTy,
+ DRE->getValueKind(), nullptr, nullptr, DRE->isNonOdrUse());
+
+ // Set the callee in the CallExpr.
+ // FIXME: This loses syntactic information.
+ QualType CalleePtrTy = Context.getPointerType(FDecl->getType());
+ ExprResult PromotedCall = ImpCastExprToType(NewDRE, CalleePtrTy,
+ CK_BuiltinFnToFnPtr);
+ TheCall->setCallee(PromotedCall.get());
+
+ return CallResult;
+}
Index: clang/lib/CodeGen/CGBuiltin.cpp
===================================================================
--- clang/lib/CodeGen/CGBuiltin.cpp
+++ clang/lib/CodeGen/CGBuiltin.cpp
@@ -2373,6 +2373,21 @@
}
return RValue::get(Result);
}
+ case Builtin::BI__builtin_matrix_column_store: {
+ MatrixBuilder<CGBuilderTy> MB(Builder);
+ Value *Matrix = EmitScalarExpr(E->getArg(0));
+ const MatrixType *MatrixTy = getMatrixTy(E->getArg(0)->getType());
+ Address Dst = EmitPointerWithAlignment(E->getArg(1));
+ EmitNonNullArgCheck(RValue::get(Dst.getPointer()), E->getArg(1)->getType(),
+ E->getArg(1)->getExprLoc(), FD, 1);
+ Value *Stride = EmitScalarExpr(E->getArg(2));
+
+ // TODO: Pass Dst alignment to intrinsic
+ MB.CreateMatrixColumnwiseStore(Matrix, Dst.getPointer(), Stride,
+ MatrixTy->getNumRows(),
+ MatrixTy->getNumColumns());
+ return RValue::get(Dst.getPointer());
+ }
case Builtin::BI__builtin_matrix_insert: {
MatrixBuilder<CGBuilderTy> MB(Builder);
Value *MatValue = EmitScalarExpr(E->getArg(0));
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -11625,6 +11625,8 @@
ExprResult SemaBuiltinMatrixColumnLoadOverload(CallExpr *TheCall,
ExprResult CallResult);
+ ExprResult SemaBuiltinMatrixColumnStoreOverload(CallExpr *TheCall,
+ ExprResult CallResult);
public:
enum FormatStringType {
Index: clang/include/clang/Basic/Builtins.def
===================================================================
--- clang/include/clang/Basic/Builtins.def
+++ clang/include/clang/Basic/Builtins.def
@@ -580,6 +580,7 @@
BUILTIN(__builtin_matrix_multiply, "v.", "nt")
BUILTIN(__builtin_matrix_transpose, "v.", "nFt")
BUILTIN(__builtin_matrix_column_load, "v.", "nFt")
+BUILTIN(__builtin_matrix_column_store, "v.", "nFt")
// "Overloaded" Atomic operator builtins. These are overloaded to support data
// types of i8, i16, i32, i64, and i128. The front-end sees calls to the
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits