fhahn updated this revision to Diff 262188.
fhahn marked 9 inline comments as done.
fhahn added a comment.
Thanks for the extensive comments! They should be addressed: refactor dependent
type construction, template argument deduction, adjust mangling.
In D72281#2019417 <https://reviews.llvm.org/D72281#2019417>, @rjmccall wrote:
> The test cases for function template partial specialization would look
> something like this:
>
> template <class T, size_t R, size_t C>
> using matrix = T __attribute__((matrix_type(R, C)));
>
> template <int N> struct selector {};
>
> template <class T, size_t R, size_t C>
> selector<0> use_matrix(matrix<T, R, C> m) {}
>
> template <class T, size_t R>
> selector<1> use_matrix(matrix<T, R, 10> m) {}
>
> template <class T>
> selector<2> use_matrix(matrix<T, 10, 10> m) {}
>
> void test() {
> selector<2> x = use_matrix(matrix<int, 10, 10>());
> selector<1> y = use_matrix(matrix<int, 12, 10>());
> selector<0> z = use_matrix(matrix<int, 12, 12>());
> }
>
That's a great example that highlighted a few other issues (e.g.
BuildMatrixType not supporting dependent element types).
The latest version of the patch manages to compile each `use_matrix` definition
individually (if there is only a single template definition of `use_matrix`),
but there still is a disambiguation failure in the snippet below, if all 3
`use_matrix` definitions are available.
matrix<int, 10, 10> m1;
selector<2> x = use_matrix(m1);
The type of `m1` matches the matrix argument in all 3 definitions of
`use_matrix` and for some reason the return type is not used to disambiguate
the definitions:
llvm-project/clang/test/CodeGenCXX/matrix-type.cpp:310:19: error: call to
'use_matrix' is ambiguous
selector<2> x = use_matrix(m1);
^~~~~~~~~~
llvm-project/clang/test/CodeGenCXX/matrix-type.cpp:300:13: note: candidate
function [with T = int, R = 10, C = 10]
selector<0> use_matrix(matrix<T, R, C> &m) {}
^
llvm-project/clang/test/CodeGenCXX/matrix-type.cpp:303:13: note: candidate
function [with T = int, R = 10]
selector<1> use_matrix(matrix<T, R, 10> &m) {}
^
llvm-project/clang/test/CodeGenCXX/matrix-type.cpp:306:13: note: candidate
function [with T = int]
selector<2> use_matrix(matrix<T, 10, 10> &m) {}
I am not sure where things are going wrong unfortunately. The matrix argument
deduction should mirror the code for types like DependentSizedArrayType. Do you
have any idea what could be missing?
> But you should include some weirder kinds of template, expressions that
> aren't deducible, and so on.
Will do, once the issue above is sorted out :)
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D72281/new/
https://reviews.llvm.org/D72281
Files:
clang/include/clang/AST/ASTContext.h
clang/include/clang/AST/RecursiveASTVisitor.h
clang/include/clang/AST/Type.h
clang/include/clang/AST/TypeLoc.h
clang/include/clang/AST/TypeProperties.td
clang/include/clang/Basic/Attr.td
clang/include/clang/Basic/DiagnosticSemaKinds.td
clang/include/clang/Basic/Features.def
clang/include/clang/Basic/LangOptions.def
clang/include/clang/Basic/TypeNodes.td
clang/include/clang/Driver/Options.td
clang/include/clang/Sema/Sema.h
clang/include/clang/Serialization/TypeBitCodes.def
clang/lib/AST/ASTContext.cpp
clang/lib/AST/ASTStructuralEquivalence.cpp
clang/lib/AST/ExprConstant.cpp
clang/lib/AST/ItaniumMangle.cpp
clang/lib/AST/MicrosoftMangle.cpp
clang/lib/AST/Type.cpp
clang/lib/AST/TypePrinter.cpp
clang/lib/CodeGen/CGDebugInfo.cpp
clang/lib/CodeGen/CGDebugInfo.h
clang/lib/CodeGen/CGExpr.cpp
clang/lib/CodeGen/CodeGenFunction.cpp
clang/lib/CodeGen/CodeGenTypes.cpp
clang/lib/CodeGen/ItaniumCXXABI.cpp
clang/lib/Driver/ToolChains/Clang.cpp
clang/lib/Frontend/CompilerInvocation.cpp
clang/lib/Sema/SemaExpr.cpp
clang/lib/Sema/SemaLookup.cpp
clang/lib/Sema/SemaTemplate.cpp
clang/lib/Sema/SemaTemplateDeduction.cpp
clang/lib/Sema/SemaType.cpp
clang/lib/Sema/TreeTransform.h
clang/lib/Serialization/ASTReader.cpp
clang/lib/Serialization/ASTWriter.cpp
clang/test/CodeGen/debug-info-matrix-types.c
clang/test/CodeGen/matrix-type.c
clang/test/CodeGenCXX/matrix-type.cpp
clang/test/Parser/matrix-type-disabled.c
clang/test/SemaCXX/matrix-type.cpp
clang/tools/libclang/CIndex.cpp
Index: clang/tools/libclang/CIndex.cpp
===================================================================
--- clang/tools/libclang/CIndex.cpp
+++ clang/tools/libclang/CIndex.cpp
@@ -1795,6 +1795,8 @@
DEFAULT_TYPELOC_IMPL(DependentSizedExtVector, Type)
DEFAULT_TYPELOC_IMPL(Vector, Type)
DEFAULT_TYPELOC_IMPL(ExtVector, VectorType)
+DEFAULT_TYPELOC_IMPL(ConstantMatrix, MatrixType)
+DEFAULT_TYPELOC_IMPL(DependentSizedMatrix, MatrixType)
DEFAULT_TYPELOC_IMPL(FunctionProto, FunctionType)
DEFAULT_TYPELOC_IMPL(FunctionNoProto, FunctionType)
DEFAULT_TYPELOC_IMPL(Record, TagType)
Index: clang/test/SemaCXX/matrix-type.cpp
===================================================================
--- /dev/null
+++ clang/test/SemaCXX/matrix-type.cpp
@@ -0,0 +1,61 @@
+// RUN: %clang_cc1 -fsyntax-only -pedantic -fenable-matrix -std=c++11 -verify -triple x86_64-apple-darwin %s
+
+using matrix_double_t = double __attribute__((matrix_type(6, 6)));
+using matrix_float_t = float __attribute__((matrix_type(6, 6)));
+using matrix_int_t = int __attribute__((matrix_type(6, 6)));
+
+void matrix_var_dimensions(int Rows, unsigned Columns, char C) {
+ using matrix1_t = int __attribute__((matrix_type(Rows, 1))); // expected-error{{matrix_type attribute requires an integer constant}}
+ using matrix2_t = int __attribute__((matrix_type(1, Columns))); // expected-error{{matrix_type attribute requires an integer constant}}
+ using matrix3_t = int __attribute__((matrix_type(C, C))); // expected-error{{matrix_type attribute requires an integer constant}}
+ using matrix4_t = int __attribute__((matrix_type(-1, 1))); // expected-error{{matrix row size too large}}
+ using matrix5_t = int __attribute__((matrix_type(1, -1))); // expected-error{{matrix column size too large}}
+ using matrix6_t = int __attribute__((matrix_type(0, 1))); // expected-error{{zero matrix size}}
+ using matrix7_t = int __attribute__((matrix_type(1, 0))); // expected-error{{zero matrix size}}
+ using matrix7_t = int __attribute__((matrix_type(char, 0))); // expected-error{{expected '(' for function-style cast or type construction}}
+ using matrix8_t = int __attribute__((matrix_type(1048576, 1))); // expected-error{{matrix row size too large}}
+}
+
+struct S1 {};
+
+enum TestEnum {
+ A,
+ B
+};
+
+void matrix_unsupported_element_type() {
+ using matrix1_t = char *__attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'char *'}}
+ using matrix2_t = S1 __attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'S1'}}
+ using matrix3_t = bool __attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'bool'}}
+ using matrix4_t = TestEnum __attribute__((matrix_type(1, 1))); // expected-error{{invalid matrix element type 'TestEnum'}}
+}
+
+template <typename T> // expected-note{{declared here}}
+void matrix_template_1() {
+ using matrix1_t = float __attribute__((matrix_type(T, T))); // expected-error{{'T' does not refer to a value}}
+}
+
+template <class C> // expected-note{{declared here}}
+void matrix_template_2() {
+ using matrix1_t = float __attribute__((matrix_type(C, C))); // expected-error{{'C' does not refer to a value}}
+}
+
+template <unsigned Rows, unsigned Cols>
+void matrix_template_3() {
+ using matrix1_t = float __attribute__((matrix_type(Rows, Cols))); // expected-error{{zero matrix size}}
+}
+
+void instantiate_template_3() {
+ matrix_template_3<1, 10>();
+ matrix_template_3<0, 10>(); // expected-note{{in instantiation of function template specialization 'matrix_template_3<0, 10>' requested here}}
+}
+
+template <int Rows, unsigned Cols>
+void matrix_template_4() {
+ using matrix1_t = float __attribute__((matrix_type(Rows, Cols))); // expected-error{{matrix row size too large}}
+}
+
+void instantiate_template_4() {
+ matrix_template_4<2, 10>();
+ matrix_template_4<-3, 10>(); // expected-note{{in instantiation of function template specialization 'matrix_template_4<-3, 10>' requested here}}
+}
Index: clang/test/Parser/matrix-type-disabled.c
===================================================================
--- /dev/null
+++ clang/test/Parser/matrix-type-disabled.c
@@ -0,0 +1,14 @@
+// RUN: %clang_cc1 %s -triple i686-apple-darwin -verify -fsyntax-only
+
+// Matrix types are disabled by default.
+
+#if __has_extension(matrix_types)
+#error Expected extension 'matrix_types' to be disabled
+#endif
+
+typedef double dx5x5_t __attribute__((matrix_type(5, 5)));
+// expected-error@-1 {{matrix types extension is disabled. Pass -fenable-matrix to enable it}}
+
+void load_store_double(dx5x5_t *a, dx5x5_t *b) {
+ *a = *b;
+}
Index: clang/test/CodeGenCXX/matrix-type.cpp
===================================================================
--- /dev/null
+++ clang/test/CodeGenCXX/matrix-type.cpp
@@ -0,0 +1,315 @@
+// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+typedef double dx5x5_t __attribute__((matrix_type(5, 5)));
+typedef float fx3x4_t __attribute__((matrix_type(3, 4)));
+
+// CHECK: %struct.Matrix = type { i8, [12 x float], float }
+
+void load_store(dx5x5_t *a, dx5x5_t *b) {
+ // CHECK-LABEL: define void @_Z10load_storePU11matrix_typeLm5ELm5EdS0_(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8
+ // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8
+ // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8
+ // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8
+ // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8
+ // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>*
+ // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8
+ // CHECK-NEXT: %3 = load [25 x double]*, [25 x double]** %a.addr, align 8
+ // CHECK-NEXT: %4 = bitcast [25 x double]* %3 to <25 x double>*
+ // CHECK-NEXT: store <25 x double> %2, <25 x double>* %4, align 8
+ // CHECK-NEXT: ret void
+
+ *a = *b;
+}
+
+typedef float fx3x3_t __attribute__((matrix_type(3, 3)));
+
+void parameter_passing(fx3x3_t a, fx3x3_t *b) {
+ // CHECK-LABEL: define void @_Z17parameter_passingU11matrix_typeLm3ELm3EfPS_(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca [9 x float], align 4
+ // CHECK-NEXT: %b.addr = alloca [9 x float]*, align 8
+ // CHECK-NEXT: %0 = bitcast [9 x float]* %a.addr to <9 x float>*
+ // CHECK-NEXT: store <9 x float> %a, <9 x float>* %0, align 4
+ // CHECK-NEXT: store [9 x float]* %b, [9 x float]** %b.addr, align 8
+ // CHECK-NEXT: %1 = load <9 x float>, <9 x float>* %0, align 4
+ // CHECK-NEXT: %2 = load [9 x float]*, [9 x float]** %b.addr, align 8
+ // CHECK-NEXT: %3 = bitcast [9 x float]* %2 to <9 x float>*
+ // CHECK-NEXT: store <9 x float> %1, <9 x float>* %3, align 4
+ // CHECK-NEXT: ret void
+ *b = a;
+}
+
+fx3x3_t return_matrix(fx3x3_t *a) {
+ // CHECK-LABEL: define <9 x float> @_Z13return_matrixPU11matrix_typeLm3ELm3Ef(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca [9 x float]*, align 8
+ // CHECK-NEXT: store [9 x float]* %a, [9 x float]** %a.addr, align 8
+ // CHECK-NEXT: %0 = load [9 x float]*, [9 x float]** %a.addr, align 8
+ // CHECK-NEXT: %1 = bitcast [9 x float]* %0 to <9 x float>*
+ // CHECK-NEXT: %2 = load <9 x float>, <9 x float>* %1, align 4
+ // CHECK-NEXT: ret <9 x float> %2
+ return *a;
+}
+
+struct Matrix {
+ char Tmp1;
+ fx3x4_t Data;
+ float Tmp2;
+};
+
+void matrix_struct_pointers(Matrix *a, Matrix *b) {
+ // CHECK-LABEL: define void @_Z22matrix_struct_pointersP6MatrixS0_(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca %struct.Matrix*, align 8
+ // CHECK-NEXT: %b.addr = alloca %struct.Matrix*, align 8
+ // CHECK-NEXT: store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8
+ // CHECK-NEXT: store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8
+ // CHECK-NEXT: %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8
+ // CHECK-NEXT: %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1
+ // CHECK-NEXT: %1 = bitcast [12 x float]* %Data to <12 x float>*
+ // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4
+ // CHECK-NEXT: %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8
+ // CHECK-NEXT: %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1
+ // CHECK-NEXT: %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+ // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4
+ // CHECK-NEXT: ret void
+ b->Data = a->Data;
+}
+
+void matrix_struct_reference(Matrix &a, Matrix &b) {
+ // CHECK-LABEL: define void @_Z23matrix_struct_referenceR6MatrixS0_(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca %struct.Matrix*, align 8
+ // CHECK-NEXT: %b.addr = alloca %struct.Matrix*, align 8
+ // CHECK-NEXT: store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8
+ // CHECK-NEXT: store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8
+ // CHECK-NEXT: %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8
+ // CHECK-NEXT: %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1
+ // CHECK-NEXT: %1 = bitcast [12 x float]* %Data to <12 x float>*
+ // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4
+ // CHECK-NEXT: %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8
+ // CHECK-NEXT: %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1
+ // CHECK-NEXT: %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+ // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4
+ // CHECK-NEXT: ret void
+ b.Data = a.Data;
+}
+
+class MatrixClass {
+public:
+ int Tmp1;
+ fx3x4_t Data;
+ long Tmp2;
+};
+
+void matrix_class_reference(MatrixClass &a, MatrixClass &b) {
+ // CHECK-LABEL: define void @_Z22matrix_class_referenceR11MatrixClassS0_(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca %class.MatrixClass*, align 8
+ // CHECK-NEXT: %b.addr = alloca %class.MatrixClass*, align 8
+ // CHECK-NEXT: store %class.MatrixClass* %a, %class.MatrixClass** %a.addr, align 8
+ // CHECK-NEXT: store %class.MatrixClass* %b, %class.MatrixClass** %b.addr, align 8
+ // CHECK-NEXT: %0 = load %class.MatrixClass*, %class.MatrixClass** %a.addr, align 8
+ // CHECK-NEXT: %Data = getelementptr inbounds %class.MatrixClass, %class.MatrixClass* %0, i32 0, i32 1
+ // CHECK-NEXT: %1 = bitcast [12 x float]* %Data to <12 x float>*
+ // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4
+ // CHECK-NEXT: %3 = load %class.MatrixClass*, %class.MatrixClass** %b.addr, align 8
+ // CHECK-NEXT: %Data1 = getelementptr inbounds %class.MatrixClass, %class.MatrixClass* %3, i32 0, i32 1
+ // CHECK-NEXT: %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+ // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4
+ // CHECK-NEXT: ret void
+ b.Data = a.Data;
+}
+
+template <typename Ty, unsigned Rows, unsigned Cols>
+class MatrixClassTemplate {
+public:
+ using MatrixTy = Ty __attribute__((matrix_type(Rows, Cols)));
+ int Tmp1;
+ MatrixTy Data;
+ long Tmp2;
+};
+
+template <typename Ty, unsigned Rows, unsigned Cols>
+void matrix_template_reference(MatrixClassTemplate<Ty, Rows, Cols> &a, MatrixClassTemplate<Ty, Rows, Cols> &b) {
+ b.Data = a.Data;
+}
+
+MatrixClassTemplate<float, 10, 15> matrix_template_reference_caller(float *Data) {
+ // CHECK-LABEL: define void @_Z32matrix_template_reference_callerPf(%class.MatrixClassTemplate* noalias sret align 8 %agg.result, float* %Data
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %Data.addr = alloca float*, align 8
+ // CHECK-NEXT: %Arg = alloca %class.MatrixClassTemplate, align 8
+ // CHECK-NEXT: store float* %Data, float** %Data.addr, align 8
+ // CHECK-NEXT: %0 = load float*, float** %Data.addr, align 8
+ // CHECK-NEXT: %1 = bitcast float* %0 to [150 x float]*
+ // CHECK-NEXT: %2 = bitcast [150 x float]* %1 to <150 x float>*
+ // CHECK-NEXT: %3 = load <150 x float>, <150 x float>* %2, align 4
+ // CHECK-NEXT: %Data1 = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %Arg, i32 0, i32 1
+ // CHECK-NEXT: %4 = bitcast [150 x float]* %Data1 to <150 x float>*
+ // CHECK-NEXT: store <150 x float> %3, <150 x float>* %4, align 4
+ // CHECK-NEXT: call void @_Z25matrix_template_referenceIfLj10ELj15EEvR19MatrixClassTemplateIT_XT0_EXT1_EES3_(%class.MatrixClassTemplate* dereferenceable(616) %Arg, %class.MatrixClassTemplate* dereferenceable(616) %agg.result)
+ // CHECK-NEXT: ret void
+
+ // CHECK-LABEL: define linkonce_odr void @_Z25matrix_template_referenceIfLj10ELj15EEvR19MatrixClassTemplateIT_XT0_EXT1_EES3_(%class.MatrixClassTemplate* dereferenceable(616) %a, %class.MatrixClassTemplate* dereferenceable(616) %b)
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca %class.MatrixClassTemplate*, align 8
+ // CHECK-NEXT: %b.addr = alloca %class.MatrixClassTemplate*, align 8
+ // CHECK-NEXT: store %class.MatrixClassTemplate* %a, %class.MatrixClassTemplate** %a.addr, align 8
+ // CHECK-NEXT: store %class.MatrixClassTemplate* %b, %class.MatrixClassTemplate** %b.addr, align 8
+ // CHECK-NEXT: %0 = load %class.MatrixClassTemplate*, %class.MatrixClassTemplate** %a.addr, align 8
+ // CHECK-NEXT: %Data = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %0, i32 0, i32 1
+ // CHECK-NEXT: %1 = bitcast [150 x float]* %Data to <150 x float>*
+ // CHECK-NEXT: %2 = load <150 x float>, <150 x float>* %1, align 4
+ // CHECK-NEXT: %3 = load %class.MatrixClassTemplate*, %class.MatrixClassTemplate** %b.addr, align 8
+ // CHECK-NEXT: %Data1 = getelementptr inbounds %class.MatrixClassTemplate, %class.MatrixClassTemplate* %3, i32 0, i32 1
+ // CHECK-NEXT: %4 = bitcast [150 x float]* %Data1 to <150 x float>*
+ // CHECK-NEXT: store <150 x float> %2, <150 x float>* %4, align 4
+ // CHECK-NEXT: ret void
+
+ MatrixClassTemplate<float, 10, 15> Result, Arg;
+ Arg.Data = *((MatrixClassTemplate<float, 10, 15>::MatrixTy *)Data);
+ matrix_template_reference(Arg, Result);
+ return Result;
+}
+
+template <class T, unsigned R, unsigned C>
+using matrix = T __attribute__((matrix_type(R, C)));
+
+template <class T, unsigned R, unsigned C>
+matrix<T, R, C> &use_matrix_1(matrix<T, R, C> &m) { return m; }
+
+void test_use_matrix_1() {
+ // CHECK-LABEL: define void @_Z17test_use_matrix_1v()
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %m = alloca [144 x i32], align 4
+ // CHECK-NEXT: %m2 = alloca [144 x i32]*, align 8
+ // CHECK-NEXT: %call = call dereferenceable(576) [144 x i32]* @_Z12use_matrix_1IiLj12ELj12EERU11matrix_typeXT0_EXT1_ET_S2_([144 x i32]* dereferenceable(576) %m)
+ // CHECK-NEXT: store [144 x i32]* %call, [144 x i32]** %m2, align 8
+ // CHECK-NEXT: ret void
+
+ // CHECK-LABEL: define linkonce_odr dereferenceable(576) [144 x i32]* @_Z12use_matrix_1IiLj12ELj12EERU11matrix_typeXT0_EXT1_ET_S2_([144 x i32]* dereferenceable(576) %m)
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %m.addr = alloca [144 x i32]*, align 8
+ // CHECK-NEXT: store [144 x i32]* %m, [144 x i32]** %m.addr, align 8
+ // CHECK-NEXT: %0 = load [144 x i32]*, [144 x i32]** %m.addr, align 8
+ // CHECK-NEXT: ret [144 x i32]* %0
+
+ matrix<int, 12, 12> m;
+ auto &m2 = use_matrix_1(m);
+}
+
+template <class T, unsigned R>
+matrix<T, R, 10> &use_matrix_2(matrix<T, R, 10> &m) { return m; }
+
+void test_use_matrix_2() {
+ // CHECK-LABEL: define void @_Z17test_use_matrix_2v()
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %m = alloca [120 x i32], align 4
+ // CHECK-NEXT: %m2 = alloca [120 x i32]*, align 8
+ // CHECK-NEXT: %call = call dereferenceable(480) [120 x i32]* @_Z12use_matrix_2IiLj12EERU11matrix_typeXT0_EXLj10EET_S2_([120 x i32]* dereferenceable(480) %m)
+ // CHECK-NEXT: store [120 x i32]* %call, [120 x i32]** %m2, align 8
+ // CHECK-NEXT: ret void
+
+ // CHECK-LABEL: define linkonce_odr dereferenceable(480) [120 x i32]* @_Z12use_matrix_2IiLj12EERU11matrix_typeXT0_EXLj10EET_S2_([120 x i32]* dereferenceable(480) %m)
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %m.addr = alloca [120 x i32]*, align 8
+ // CHECK-NEXT: store [120 x i32]* %m, [120 x i32]** %m.addr, align 8
+ // CHECK-NEXT: %0 = load [120 x i32]*, [120 x i32]** %m.addr, align 8
+ // CHECK-NEXT: ret [120 x i32]* %0
+
+ matrix<int, 12, 10> m;
+ auto &m2 = use_matrix_2(m);
+}
+
+template <class T, unsigned C>
+matrix<T, 10, C> &use_matrix_3(matrix<T, 10, C> &m) { return m; }
+
+void test_use_matrix_3() {
+ // CHECK-LABEL: define void @_Z17test_use_matrix_3v()
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %m = alloca [120 x i32], align 4
+ // CHECK-NEXT: %m2 = alloca [120 x i32]*, align 8
+ // CHECK-NEXT: %call = call dereferenceable(480) [120 x i32]* @_Z12use_matrix_3IiLj12EERU11matrix_typeXLj10EEXT0_ET_S2_([120 x i32]* dereferenceable(480) %m)
+ // CHECK-NEXT: store [120 x i32]* %call, [120 x i32]** %m2, align 8
+ // CHECK-NEXT: ret void
+
+ // CHECK-LABEL: define linkonce_odr dereferenceable(480) [120 x i32]* @_Z12use_matrix_3IiLj12EERU11matrix_typeXLj10EEXT0_ET_S2_([120 x i32]* dereferenceable(480) %m)
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %m.addr = alloca [120 x i32]*, align 8
+ // CHECK-NEXT: store [120 x i32]* %m, [120 x i32]** %m.addr, align 8
+ // CHECK-NEXT: %0 = load [120 x i32]*, [120 x i32]** %m.addr, align 8
+ // CHECK-NEXT: ret [120 x i32]* %0
+
+ matrix<int, 10, 12> m;
+ auto &m2 = use_matrix_3(m);
+}
+
+template <class T, int C>
+matrix<T, 10, C> &use_matrix_int(matrix<T, 10, C> &m) { return m; }
+
+void test_use_matrix_int() {
+ // CHECK-LABEL: define void @_Z19test_use_matrix_intv()
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %m = alloca [120 x i32], align 4
+ // CHECK-NEXT: %m2 = alloca [120 x i32]*, align 8
+ // CHECK-NEXT: %call = call dereferenceable(480) [120 x i32]* @_Z14use_matrix_intIiLi12EERU11matrix_typeXLj10EEXT0_ET_S2_([120 x i32]* dereferenceable(480) %m)
+ // CHECK-NEXT: store [120 x i32]* %call, [120 x i32]** %m2, align 8
+ // CHECK-NEXT: ret void
+
+ // CHECK-LABEL: define linkonce_odr dereferenceable(480) [120 x i32]* @_Z14use_matrix_intIiLi12EERU11matrix_typeXLj10EEXT0_ET_S2_([120 x i32]* dereferenceable(480) %m)
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %m.addr = alloca [120 x i32]*, align 8
+ // CHECK-NEXT: store [120 x i32]* %m, [120 x i32]** %m.addr, align 8
+ // CHECK-NEXT: %0 = load [120 x i32]*, [120 x i32]** %m.addr, align 8
+ // CHECK-NEXT: ret [120 x i32]* %0
+
+ matrix<int, 10, 12> m;
+ auto &m2 = use_matrix_int(m);
+}
+
+template <class T, unsigned long long C>
+matrix<T, 10, C> &use_matrix_ull(matrix<T, 10, C> &m) { return m; }
+
+void test_use_matrix_ull() {
+ // CHECK-LABEL: define void @_Z19test_use_matrix_ullv()
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %m = alloca [120 x i32], align 4
+ // CHECK-NEXT: %m2 = alloca [120 x i32]*, align 8
+ // CHECK-NEXT: %call = call dereferenceable(480) [120 x i32]* @_Z14use_matrix_ullIiLy12EERU11matrix_typeXLj10EEXT0_ET_S2_([120 x i32]* dereferenceable(480) %m)
+ // CHECK-NEXT: store [120 x i32]* %call, [120 x i32]** %m2, align 8
+ // CHECK-NEXT: ret void
+
+ // CHECK-LABEL: define linkonce_odr dereferenceable(480) [120 x i32]* @_Z14use_matrix_ullIiLy12EERU11matrix_typeXLj10EEXT0_ET_S2_([120 x i32]* dereferenceable(480) %m)
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %m.addr = alloca [120 x i32]*, align 8
+ // CHECK-NEXT: store [120 x i32]* %m, [120 x i32]** %m.addr, align 8
+ // CHECK-NEXT: %0 = load [120 x i32]*, [120 x i32]** %m.addr, align 8
+ // CHECK-NEXT: ret [120 x i32]* %0
+
+ matrix<int, 10, 12> m;
+ auto &m2 = use_matrix_ull(m);
+}
+
+template <int N>
+struct selector {};
+
+template <class T, unsigned R, unsigned C>
+selector<0> use_matrix(matrix<T, R, C> &m) {}
+
+template <class T, unsigned R>
+selector<1> use_matrix(matrix<T, R, 10> &m) {}
+
+template <class T>
+selector<2> use_matrix(matrix<T, 10, 10> &m) {}
+
+void test() {
+ /* matrix<int, 10, 10> m1;*/
+ /*selector<2> x = use_matrix(m1);*/
+ /* matrix<int, 12, 10> m2;*/
+ /*selector<1> y = use_matrix(m2);*/
+ matrix<int, 12, 12> m3;
+ selector<0> z = use_matrix(m3);
+}
Index: clang/test/CodeGen/matrix-type.c
===================================================================
--- /dev/null
+++ clang/test/CodeGen/matrix-type.c
@@ -0,0 +1,158 @@
+// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+#if !__has_extension(matrix_types)
+#error Expected extension 'matrix_types' to be enabled
+#endif
+
+typedef double dx5x5_t __attribute__((matrix_type(5, 5)));
+
+// CHECK: %struct.Matrix = type { i8, [12 x float], float }
+
+void load_store_double(dx5x5_t *a, dx5x5_t *b) {
+ // CHECK-LABEL: define void @load_store_double(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca [25 x double]*, align 8
+ // CHECK-NEXT: %b.addr = alloca [25 x double]*, align 8
+ // CHECK-NEXT: store [25 x double]* %a, [25 x double]** %a.addr, align 8
+ // CHECK-NEXT: store [25 x double]* %b, [25 x double]** %b.addr, align 8
+ // CHECK-NEXT: %0 = load [25 x double]*, [25 x double]** %b.addr, align 8
+ // CHECK-NEXT: %1 = bitcast [25 x double]* %0 to <25 x double>*
+ // CHECK-NEXT: %2 = load <25 x double>, <25 x double>* %1, align 8
+ // CHECK-NEXT: %3 = load [25 x double]*, [25 x double]** %a.addr, align 8
+ // CHECK-NEXT: %4 = bitcast [25 x double]* %3 to <25 x double>*
+ // CHECK-NEXT: store <25 x double> %2, <25 x double>* %4, align 8
+ // CHECK-NEXT: ret void
+
+ *a = *b;
+}
+
+typedef float fx3x4_t __attribute__((matrix_type(3, 4)));
+void load_store_float(fx3x4_t *a, fx3x4_t *b) {
+ // CHECK-LABEL: define void @load_store_float(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca [12 x float]*, align 8
+ // CHECK-NEXT: %b.addr = alloca [12 x float]*, align 8
+ // CHECK-NEXT: store [12 x float]* %a, [12 x float]** %a.addr, align 8
+ // CHECK-NEXT: store [12 x float]* %b, [12 x float]** %b.addr, align 8
+ // CHECK-NEXT: %0 = load [12 x float]*, [12 x float]** %b.addr, align 8
+ // CHECK-NEXT: %1 = bitcast [12 x float]* %0 to <12 x float>*
+ // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4
+ // CHECK-NEXT: %3 = load [12 x float]*, [12 x float]** %a.addr, align 8
+ // CHECK-NEXT: %4 = bitcast [12 x float]* %3 to <12 x float>*
+ // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4
+ // CHECK-NEXT: ret void
+
+ *a = *b;
+}
+
+typedef int ix3x4_t __attribute__((matrix_type(4, 3)));
+void load_store_int(ix3x4_t *a, ix3x4_t *b) {
+ // CHECK-LABEL: define void @load_store_int(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca [12 x i32]*, align 8
+ // CHECK-NEXT: %b.addr = alloca [12 x i32]*, align 8
+ // CHECK-NEXT: store [12 x i32]* %a, [12 x i32]** %a.addr, align 8
+ // CHECK-NEXT: store [12 x i32]* %b, [12 x i32]** %b.addr, align 8
+ // CHECK-NEXT: %0 = load [12 x i32]*, [12 x i32]** %b.addr, align 8
+ // CHECK-NEXT: %1 = bitcast [12 x i32]* %0 to <12 x i32>*
+ // CHECK-NEXT: %2 = load <12 x i32>, <12 x i32>* %1, align 4
+ // CHECK-NEXT: %3 = load [12 x i32]*, [12 x i32]** %a.addr, align 8
+ // CHECK-NEXT: %4 = bitcast [12 x i32]* %3 to <12 x i32>*
+ // CHECK-NEXT: store <12 x i32> %2, <12 x i32>* %4, align 4
+ // CHECK-NEXT: ret void
+
+ *a = *b;
+}
+
+typedef unsigned long long ullx3x4_t __attribute__((matrix_type(4, 3)));
+void load_store_ull(ullx3x4_t *a, ullx3x4_t *b) {
+ // CHECK-LABEL: define void @load_store_ull(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca [12 x i64]*, align 8
+ // CHECK-NEXT: %b.addr = alloca [12 x i64]*, align 8
+ // CHECK-NEXT: store [12 x i64]* %a, [12 x i64]** %a.addr, align 8
+ // CHECK-NEXT: store [12 x i64]* %b, [12 x i64]** %b.addr, align 8
+ // CHECK-NEXT: %0 = load [12 x i64]*, [12 x i64]** %b.addr, align 8
+ // CHECK-NEXT: %1 = bitcast [12 x i64]* %0 to <12 x i64>*
+ // CHECK-NEXT: %2 = load <12 x i64>, <12 x i64>* %1, align 8
+ // CHECK-NEXT: %3 = load [12 x i64]*, [12 x i64]** %a.addr, align 8
+ // CHECK-NEXT: %4 = bitcast [12 x i64]* %3 to <12 x i64>*
+ // CHECK-NEXT: store <12 x i64> %2, <12 x i64>* %4, align 8
+ // CHECK-NEXT: ret void
+
+ *a = *b;
+}
+
+typedef __fp16 fp16x3x4_t __attribute__((matrix_type(4, 3)));
+void load_store_fp16(fp16x3x4_t *a, fp16x3x4_t *b) {
+ // CHECK-LABEL: define void @load_store_fp16(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca [12 x half]*, align 8
+ // CHECK-NEXT: %b.addr = alloca [12 x half]*, align 8
+ // CHECK-NEXT: store [12 x half]* %a, [12 x half]** %a.addr, align 8
+ // CHECK-NEXT: store [12 x half]* %b, [12 x half]** %b.addr, align 8
+ // CHECK-NEXT: %0 = load [12 x half]*, [12 x half]** %b.addr, align 8
+ // CHECK-NEXT: %1 = bitcast [12 x half]* %0 to <12 x half>*
+ // CHECK-NEXT: %2 = load <12 x half>, <12 x half>* %1, align 2
+ // CHECK-NEXT: %3 = load [12 x half]*, [12 x half]** %a.addr, align 8
+ // CHECK-NEXT: %4 = bitcast [12 x half]* %3 to <12 x half>*
+ // CHECK-NEXT: store <12 x half> %2, <12 x half>* %4, align 2
+ // CHECK-NEXT: ret void
+
+ *a = *b;
+}
+
+typedef float fx3x3_t __attribute__((matrix_type(3, 3)));
+
+void parameter_passing(fx3x3_t a, fx3x3_t *b) {
+ // CHECK-LABEL: define void @parameter_passing(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca [9 x float], align 4
+ // CHECK-NEXT: %b.addr = alloca [9 x float]*, align 8
+ // CHECK-NEXT: %0 = bitcast [9 x float]* %a.addr to <9 x float>*
+ // CHECK-NEXT: store <9 x float> %a, <9 x float>* %0, align 4
+ // CHECK-NEXT: store [9 x float]* %b, [9 x float]** %b.addr, align 8
+ // CHECK-NEXT: %1 = load <9 x float>, <9 x float>* %0, align 4
+ // CHECK-NEXT: %2 = load [9 x float]*, [9 x float]** %b.addr, align 8
+ // CHECK-NEXT: %3 = bitcast [9 x float]* %2 to <9 x float>*
+ // CHECK-NEXT: store <9 x float> %1, <9 x float>* %3, align 4
+ // CHECK-NEXT: ret void
+ *b = a;
+}
+
+fx3x3_t return_matrix(fx3x3_t *a) {
+ // CHECK-LABEL: define <9 x float> @return_matrix
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca [9 x float]*, align 8
+ // CHECK-NEXT: store [9 x float]* %a, [9 x float]** %a.addr, align 8
+ // CHECK-NEXT: %0 = load [9 x float]*, [9 x float]** %a.addr, align 8
+ // CHECK-NEXT: %1 = bitcast [9 x float]* %0 to <9 x float>*
+ // CHECK-NEXT: %2 = load <9 x float>, <9 x float>* %1, align 4
+ // CHECK-NEXT: ret <9 x float> %2
+ return *a;
+}
+
+typedef struct {
+ char Tmp1;
+ fx3x4_t Data;
+ float Tmp2;
+} Matrix;
+
+void matrix_struct(Matrix *a, Matrix *b) {
+ // CHECK-LABEL: define void @matrix_struct(
+ // CHECK-NEXT: entry:
+ // CHECK-NEXT: %a.addr = alloca %struct.Matrix*, align 8
+ // CHECK-NEXT: %b.addr = alloca %struct.Matrix*, align 8
+ // CHECK-NEXT: store %struct.Matrix* %a, %struct.Matrix** %a.addr, align 8
+ // CHECK-NEXT: store %struct.Matrix* %b, %struct.Matrix** %b.addr, align 8
+ // CHECK-NEXT: %0 = load %struct.Matrix*, %struct.Matrix** %a.addr, align 8
+ // CHECK-NEXT: %Data = getelementptr inbounds %struct.Matrix, %struct.Matrix* %0, i32 0, i32 1
+ // CHECK-NEXT: %1 = bitcast [12 x float]* %Data to <12 x float>*
+ // CHECK-NEXT: %2 = load <12 x float>, <12 x float>* %1, align 4
+ // CHECK-NEXT: %3 = load %struct.Matrix*, %struct.Matrix** %b.addr, align 8
+ // CHECK-NEXT: %Data1 = getelementptr inbounds %struct.Matrix, %struct.Matrix* %3, i32 0, i32 1
+ // CHECK-NEXT: %4 = bitcast [12 x float]* %Data1 to <12 x float>*
+ // CHECK-NEXT: store <12 x float> %2, <12 x float>* %4, align 4
+ // CHECK-NEXT: ret void
+ b->Data = a->Data;
+}
Index: clang/test/CodeGen/debug-info-matrix-types.c
===================================================================
--- /dev/null
+++ clang/test/CodeGen/debug-info-matrix-types.c
@@ -0,0 +1,19 @@
+// RUN: %clang_cc1 -fenable-matrix -triple x86_64-apple-darwin %s -debug-info-kind=limited -emit-llvm -disable-llvm-passes -o - | FileCheck %s
+
+typedef double dx2x3_t __attribute__((matrix_type(2, 3)));
+
+void load_store_double(dx2x3_t *a, dx2x3_t *b) {
+ // CHECK-DAG: @llvm.dbg.declare(metadata [6 x double]** %a.addr, metadata [[EXPR_A:![0-9]+]]
+ // CHECK-DAG: @llvm.dbg.declare(metadata [6 x double]** %b.addr, metadata [[EXPR_B:![0-9]+]]
+ // CHECK: [[PTR_TY:![0-9]+]] = !DIDerivedType(tag: DW_TAG_pointer_type, baseType: [[TYPEDEF:![0-9]+]], size: 64)
+ // CHECK: [[TYPEDEF]] = !DIDerivedType(tag: DW_TAG_typedef, name: "dx2x3_t", {{.+}} baseType: [[MATRIX_TY:![0-9]+]])
+ // CHECK: [[MATRIX_TY]] = !DICompositeType(tag: DW_TAG_array_type, baseType: [[ELT_TY:![0-9]+]], size: 384, elements: [[ELEMENTS:![0-9]+]])
+ // CHECK: [[ELT_TY]] = !DIBasicType(name: "double", size: 64, encoding: DW_ATE_float)
+ // CHECK: [[ELEMENTS]] = !{[[COLS:![0-9]+]], [[ROWS:![0-9]+]]}
+ // CHECK: [[COLS]] = !DISubrange(count: 3)
+ // CHECK: [[ROWS]] = !DISubrange(count: 2)
+ // CHECK: [[EXPR_A]] = !DILocalVariable(name: "a", arg: 1, {{.+}} type: [[PTR_TY]])
+ // CHECK: [[EXPR_B]] = !DILocalVariable(name: "b", arg: 2, {{.+}} type: [[PTR_TY]])
+
+ *a = *b;
+}
Index: clang/lib/Serialization/ASTWriter.cpp
===================================================================
--- clang/lib/Serialization/ASTWriter.cpp
+++ clang/lib/Serialization/ASTWriter.cpp
@@ -288,6 +288,25 @@
Record.AddSourceLocation(TL.getNameLoc());
}
+void TypeLocWriter::VisitConstantMatrixTypeLoc(ConstantMatrixTypeLoc TL) {
+ Record.AddSourceLocation(TL.getAttrNameLoc());
+ SourceRange range = TL.getAttrOperandParensRange();
+ Record.AddSourceLocation(range.getBegin());
+ Record.AddSourceLocation(range.getEnd());
+ Record.AddStmt(TL.getAttrRowOperand());
+ Record.AddStmt(TL.getAttrColumnOperand());
+}
+
+void TypeLocWriter::VisitDependentSizedMatrixTypeLoc(
+ DependentSizedMatrixTypeLoc TL) {
+ Record.AddSourceLocation(TL.getAttrNameLoc());
+ SourceRange range = TL.getAttrOperandParensRange();
+ Record.AddSourceLocation(range.getBegin());
+ Record.AddSourceLocation(range.getEnd());
+ Record.AddStmt(TL.getAttrRowOperand());
+ Record.AddStmt(TL.getAttrColumnOperand());
+}
+
void TypeLocWriter::VisitFunctionTypeLoc(FunctionTypeLoc TL) {
Record.AddSourceLocation(TL.getLocalRangeBegin());
Record.AddSourceLocation(TL.getLParenLoc());
Index: clang/lib/Serialization/ASTReader.cpp
===================================================================
--- clang/lib/Serialization/ASTReader.cpp
+++ clang/lib/Serialization/ASTReader.cpp
@@ -6554,6 +6554,21 @@
TL.setNameLoc(readSourceLocation());
}
+void TypeLocReader::VisitConstantMatrixTypeLoc(ConstantMatrixTypeLoc TL) {
+ TL.setAttrNameLoc(readSourceLocation());
+ TL.setAttrOperandParensRange(Reader.readSourceRange());
+ TL.setAttrRowOperand(Reader.readExpr());
+ TL.setAttrColumnOperand(Reader.readExpr());
+}
+
+void TypeLocReader::VisitDependentSizedMatrixTypeLoc(
+ DependentSizedMatrixTypeLoc TL) {
+ TL.setAttrNameLoc(readSourceLocation());
+ TL.setAttrOperandParensRange(Reader.readSourceRange());
+ TL.setAttrRowOperand(Reader.readExpr());
+ TL.setAttrColumnOperand(Reader.readExpr());
+}
+
void TypeLocReader::VisitFunctionTypeLoc(FunctionTypeLoc TL) {
TL.setLocalRangeBegin(readSourceLocation());
TL.setLParenLoc(readSourceLocation());
Index: clang/lib/Sema/TreeTransform.h
===================================================================
--- clang/lib/Sema/TreeTransform.h
+++ clang/lib/Sema/TreeTransform.h
@@ -894,6 +894,16 @@
Expr *SizeExpr,
SourceLocation AttributeLoc);
+ /// Build a new matrix type given the element type and dimensions.
+ QualType RebuildConstantMatrixType(QualType ElementType, unsigned NumRows,
+ unsigned NumColumns);
+
+ /// Build a new matrix type given the type and dependently-defined
+ /// dimensions.
+ QualType RebuildDependentSizedMatrixType(QualType ElementType, Expr *RowExpr,
+ Expr *ColumnExpr,
+ SourceLocation AttributeLoc);
+
/// Build a new DependentAddressSpaceType or return the pointee
/// type variable with the correct address space (retrieved from
/// AddrSpaceExpr) applied to it. The former will be returned in cases
@@ -5179,6 +5189,75 @@
return Result;
}
+template <typename Derived>
+QualType
+TreeTransform<Derived>::TransformConstantMatrixType(TypeLocBuilder &TLB,
+ ConstantMatrixTypeLoc TL) {
+ const ConstantMatrixType *T = TL.getTypePtr();
+ QualType ElementType = getDerived().TransformType(T->getElementType());
+ if (ElementType.isNull())
+ return QualType();
+
+ QualType Result = TL.getType();
+ if (getDerived().AlwaysRebuild() || ElementType != T->getElementType()) {
+ Result = getDerived().RebuildConstantMatrixType(
+ ElementType, T->getNumRows(), T->getNumColumns());
+ if (Result.isNull())
+ return QualType();
+ }
+
+ ConstantMatrixTypeLoc NewTL = TLB.push<ConstantMatrixTypeLoc>(Result);
+ NewTL.setAttrNameLoc(TL.getAttrNameLoc());
+ NewTL.setAttrOperandParensRange(TL.getAttrOperandParensRange());
+ NewTL.setAttrRowOperand(TL.getAttrRowOperand());
+ NewTL.setAttrColumnOperand(TL.getAttrColumnOperand());
+
+ return Result;
+}
+
+template <typename Derived>
+QualType TreeTransform<Derived>::TransformDependentSizedMatrixType(
+ TypeLocBuilder &TLB, DependentSizedMatrixTypeLoc TL) {
+ const DependentSizedMatrixType *T = TL.getTypePtr();
+
+ QualType ElementType = getDerived().TransformType(T->getElementType());
+ if (ElementType.isNull()) {
+ return QualType();
+ }
+
+ EnterExpressionEvaluationContext Unevaluated(
+ SemaRef, Sema::ExpressionEvaluationContext::ConstantEvaluated);
+ ExprResult Rows = getDerived().TransformExpr(T->getRowExpr());
+ ExprResult Cols = getDerived().TransformExpr(T->getColumnExpr());
+
+ QualType Result = TL.getType();
+ // TODO: Finish this
+ if (getDerived().AlwaysRebuild() || ElementType != T->getElementType() ||
+ Rows.get() != T->getRowExpr() || Cols.get() != T->getColumnExpr()) {
+ Result = getDerived().RebuildDependentSizedMatrixType(
+ ElementType, Rows.get(), Cols.get(), T->getAttributeLoc());
+
+ if (Result.isNull())
+ return QualType();
+ }
+
+ if (auto *ResultMTy = dyn_cast<DependentSizedMatrixType>(Result)) {
+ DependentSizedMatrixTypeLoc NewTL =
+ TLB.push<DependentSizedMatrixTypeLoc>(Result);
+ NewTL.setAttrNameLoc(TL.getAttrNameLoc());
+ NewTL.setAttrOperandParensRange(TL.getAttrOperandParensRange());
+ NewTL.setAttrRowOperand(ResultMTy->getRowExpr());
+ NewTL.setAttrColumnOperand(ResultMTy->getColumnExpr());
+ } else {
+ MatrixTypeLoc NewTL = TLB.push<MatrixTypeLoc>(Result);
+ NewTL.setAttrNameLoc(TL.getAttrNameLoc());
+ NewTL.setAttrOperandParensRange(TL.getAttrOperandParensRange());
+ NewTL.setAttrRowOperand(TL.getAttrRowOperand());
+ NewTL.setAttrColumnOperand(TL.getAttrColumnOperand());
+ }
+ return Result;
+}
+
template <typename Derived>
QualType TreeTransform<Derived>::TransformDependentAddressSpaceType(
TypeLocBuilder &TLB, DependentAddressSpaceTypeLoc TL) {
@@ -13750,6 +13829,21 @@
return SemaRef.BuildExtVectorType(ElementType, SizeExpr, AttributeLoc);
}
+template <typename Derived>
+QualType TreeTransform<Derived>::RebuildConstantMatrixType(
+ QualType ElementType, unsigned NumRows, unsigned NumColumns) {
+ return SemaRef.Context.getConstantMatrixType(ElementType, NumRows,
+ NumColumns);
+}
+
+template <typename Derived>
+QualType TreeTransform<Derived>::RebuildDependentSizedMatrixType(
+ QualType ElementType, Expr *RowExpr, Expr *ColumnExpr,
+ SourceLocation AttributeLoc) {
+ return SemaRef.BuildMatrixType(ElementType, RowExpr, ColumnExpr,
+ AttributeLoc);
+}
+
template<typename Derived>
QualType TreeTransform<Derived>::RebuildFunctionProtoType(
QualType T,
Index: clang/lib/Sema/SemaType.cpp
===================================================================
--- clang/lib/Sema/SemaType.cpp
+++ clang/lib/Sema/SemaType.cpp
@@ -2492,14 +2492,15 @@
if (!VecSize.isIntN(61)) {
// Bit size will overflow uint64.
Diag(AttrLoc, diag::err_attribute_size_too_large)
- << SizeExpr->getSourceRange();
+ << SizeExpr->getSourceRange() << "vector";
return QualType();
}
uint64_t VectorSizeBits = VecSize.getZExtValue() * 8;
unsigned TypeSize = static_cast<unsigned>(Context.getTypeSize(CurType));
if (VectorSizeBits == 0) {
- Diag(AttrLoc, diag::err_attribute_zero_size) << SizeExpr->getSourceRange();
+ Diag(AttrLoc, diag::err_attribute_zero_size)
+ << SizeExpr->getSourceRange() << "vector";
return QualType();
}
@@ -2511,7 +2512,7 @@
if (VectorSizeBits / TypeSize > std::numeric_limits<uint32_t>::max()) {
Diag(AttrLoc, diag::err_attribute_size_too_large)
- << SizeExpr->getSourceRange();
+ << SizeExpr->getSourceRange() << "vector";
return QualType();
}
@@ -2549,7 +2550,7 @@
if (!vecSize.isIntN(32)) {
Diag(AttrLoc, diag::err_attribute_size_too_large)
- << ArraySize->getSourceRange();
+ << ArraySize->getSourceRange() << "vector";
return QualType();
}
// Unlike gcc's vector_size attribute, the size is specified as the
@@ -2558,7 +2559,7 @@
if (vectorSize == 0) {
Diag(AttrLoc, diag::err_attribute_zero_size)
- << ArraySize->getSourceRange();
+ << ArraySize->getSourceRange() << "vector";
return QualType();
}
@@ -2568,6 +2569,83 @@
return Context.getDependentSizedExtVectorType(T, ArraySize, AttrLoc);
}
+QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
+ SourceLocation AttrLoc) {
+ assert(Context.getLangOpts().MatrixTypes &&
+ "Should never build a matrix type when it is disabled");
+
+ if (ElementTy->isDependentType() || NumRows->isTypeDependent() ||
+ NumCols->isTypeDependent() || NumRows->isValueDependent() ||
+ NumCols->isValueDependent())
+ return Context.getDependentSizedMatrixType(ElementTy, NumRows, NumCols,
+ AttrLoc);
+
+ if (!MatrixType::isValidElementType(ElementTy)) {
+ Diag(AttrLoc, diag::err_attribute_invalid_matrix_type) << ElementTy;
+ return QualType();
+ }
+
+ // Both row and column values can only be 20 bit wide currently.
+ llvm::APSInt ValueRows(32), ValueColumns(32);
+
+ bool const RowsIsInteger = NumRows->isIntegerConstantExpr(ValueRows, Context);
+ bool const ColumnsIsInteger =
+ NumCols->isIntegerConstantExpr(ValueColumns, Context);
+
+ auto const RowRange = NumRows->getSourceRange();
+ auto const ColRange = NumCols->getSourceRange();
+
+ // Both are row and column expressions are invalid.
+ if (!RowsIsInteger && !ColumnsIsInteger) {
+ Diag(AttrLoc, diag::err_attribute_argument_type)
+ << "matrix_type" << AANT_ArgumentIntegerConstant << RowRange
+ << ColRange;
+ return QualType();
+ }
+
+ // Only the row expression is invalid.
+ if (!RowsIsInteger) {
+ Diag(AttrLoc, diag::err_attribute_argument_type)
+ << "matrix_type" << AANT_ArgumentIntegerConstant << RowRange;
+ return QualType();
+ }
+
+ // Only the column expression is invalid.
+ if (!ColumnsIsInteger) {
+ Diag(AttrLoc, diag::err_attribute_argument_type)
+ << "matrix_type" << AANT_ArgumentIntegerConstant << ColRange;
+ return QualType();
+ }
+
+ // Check the matrix dimensions.
+ unsigned MatrixRows = static_cast<unsigned>(ValueRows.getZExtValue());
+ unsigned MatrixColumns = static_cast<unsigned>(ValueColumns.getZExtValue());
+ if (MatrixRows == 0 && MatrixColumns == 0) {
+ Diag(AttrLoc, diag::err_attribute_zero_size)
+ << "matrix" << RowRange << ColRange;
+ return QualType();
+ }
+ if (MatrixRows == 0) {
+ Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << RowRange;
+ return QualType();
+ }
+ if (MatrixColumns == 0) {
+ Diag(AttrLoc, diag::err_attribute_zero_size) << "matrix" << ColRange;
+ return QualType();
+ }
+ if (!ConstantMatrixType::isDimensionValid(MatrixRows)) {
+ Diag(AttrLoc, diag::err_attribute_size_too_large)
+ << RowRange << "matrix row";
+ return QualType();
+ }
+ if (!ConstantMatrixType::isDimensionValid(MatrixColumns)) {
+ Diag(AttrLoc, diag::err_attribute_size_too_large)
+ << ColRange << "matrix column";
+ return QualType();
+ }
+ return Context.getConstantMatrixType(ElementTy, MatrixRows, MatrixColumns);
+}
+
bool Sema::CheckFunctionReturnType(QualType T, SourceLocation Loc) {
if (T->isArrayType() || T->isFunctionType()) {
Diag(Loc, diag::err_func_returning_array_function)
@@ -6013,6 +6091,21 @@
"no address_space attribute found at the expected location!");
}
+static void fillMatrixTypeLoc(MatrixTypeLoc MTL,
+ const ParsedAttributesView &Attrs) {
+ for (const ParsedAttr &AL : Attrs) {
+ if (AL.getKind() == ParsedAttr::AT_MatrixType) {
+ MTL.setAttrNameLoc(AL.getLoc());
+ MTL.setAttrRowOperand(AL.getArgAsExpr(0));
+ MTL.setAttrColumnOperand(AL.getArgAsExpr(1));
+ MTL.setAttrOperandParensRange(SourceRange());
+ return;
+ }
+ }
+
+ llvm_unreachable("no matrix_type attribute found at the expected location!");
+}
+
/// Create and instantiate a TypeSourceInfo with type source information.
///
/// \param T QualType referring to the type as written in source code.
@@ -6061,6 +6154,9 @@
CurrTL = TL.getPointeeTypeLoc().getUnqualifiedLoc();
}
+ if (MatrixTypeLoc TL = CurrTL.getAs<MatrixTypeLoc>())
+ fillMatrixTypeLoc(TL, D.getTypeObject(i).getAttrs());
+
// FIXME: Ordering here?
while (AdjustedTypeLoc TL = CurrTL.getAs<AdjustedTypeLoc>())
CurrTL = TL.getNextTypeLoc().getUnqualifiedLoc();
@@ -7706,6 +7802,68 @@
}
}
+/// HandleMatrixTypeAttr - "matrix_type" attribute, like ext_vector_type
+static void HandleMatrixTypeAttr(QualType &CurType, const ParsedAttr &Attr,
+ Sema &S) {
+ if (!S.getLangOpts().MatrixTypes) {
+ S.Diag(Attr.getLoc(), diag::err_builtin_matrix_disabled);
+ return;
+ }
+
+ if (Attr.getNumArgs() != 2) {
+ S.Diag(Attr.getLoc(), diag::err_attribute_wrong_number_arguments)
+ << Attr << 2;
+ return;
+ }
+
+ Expr *RowsExpr = nullptr;
+ Expr *ColsExpr = nullptr;
+
+ // TODO: Refactor parameter extraction into separate function
+ // Get the number of rows
+ if (Attr.isArgIdent(0)) {
+ CXXScopeSpec SS;
+ SourceLocation TemplateKeywordLoc;
+ UnqualifiedId id;
+ id.setIdentifier(Attr.getArgAsIdent(0)->Ident, Attr.getLoc());
+ ExprResult Rows = S.ActOnIdExpression(S.getCurScope(), SS,
+ TemplateKeywordLoc, id, false, false);
+
+ if (Rows.isInvalid())
+ // TODO: maybe a good error message would be nice here
+ return;
+ RowsExpr = Rows.get();
+ } else {
+ assert(Attr.isArgExpr(0) &&
+ "Argument to should either be an identity or expression");
+ RowsExpr = Attr.getArgAsExpr(0);
+ }
+
+ // Get the number of columns
+ if (Attr.isArgIdent(1)) {
+ CXXScopeSpec SS;
+ SourceLocation TemplateKeywordLoc;
+ UnqualifiedId id;
+ id.setIdentifier(Attr.getArgAsIdent(1)->Ident, Attr.getLoc());
+ ExprResult Columns = S.ActOnIdExpression(
+ S.getCurScope(), SS, TemplateKeywordLoc, id, false, false);
+
+ if (Columns.isInvalid())
+ // TODO: a good error message would be nice here
+ return;
+ RowsExpr = Columns.get();
+ } else {
+ assert(Attr.isArgExpr(1) &&
+ "Argument to should either be an identity or expression");
+ ColsExpr = Attr.getArgAsExpr(1);
+ }
+
+ // Create the matrix type.
+ QualType T = S.BuildMatrixType(CurType, RowsExpr, ColsExpr, Attr.getLoc());
+ if (!T.isNull())
+ CurType = T;
+}
+
static void HandleLifetimeBoundAttr(TypeProcessingState &State,
QualType &CurType,
ParsedAttr &Attr) {
@@ -7857,6 +8015,11 @@
break;
}
+ case ParsedAttr::AT_MatrixType:
+ HandleMatrixTypeAttr(type, attr, state.getSema());
+ attr.setUsedAsTypeAttr();
+ break;
+
MS_TYPE_ATTRS_CASELIST:
if (!handleMSPointerTypeQualifierAttr(state, attr, type))
attr.setUsedAsTypeAttr();
Index: clang/lib/Sema/SemaTemplateDeduction.cpp
===================================================================
--- clang/lib/Sema/SemaTemplateDeduction.cpp
+++ clang/lib/Sema/SemaTemplateDeduction.cpp
@@ -2055,6 +2055,98 @@
return Sema::TDK_NonDeducedMismatch;
}
+ // (clang extension)
+ //
+ // T __attribute__((matrix_type(<integral constant>,
+ // <integral constant>)))
+ case Type::ConstantMatrix: {
+ const ConstantMatrixType *MatrixArg = dyn_cast<ConstantMatrixType>(Arg);
+ if (!MatrixArg)
+ return Sema::TDK_NonDeducedMismatch;
+
+ const ConstantMatrixType *MatrixParam = cast<ConstantMatrixType>(Param);
+ // Check that the dimensions are the same
+ if (MatrixParam->getNumRows() != MatrixArg->getNumRows() ||
+ MatrixParam->getNumColumns() != MatrixArg->getNumColumns()) {
+ return Sema::TDK_NonDeducedMismatch;
+ }
+ // Perform deduction on element types.
+ unsigned SubTDF = TDF & TDF_IgnoreQualifiers;
+ return DeduceTemplateArgumentsByTypeMatch(
+ S, TemplateParams, MatrixParam->getElementType(),
+ MatrixArg->getElementType(), Info, Deduced, SubTDF);
+ }
+
+ case Type::DependentSizedMatrix: {
+ const MatrixType *MatrixArg = dyn_cast<MatrixType>(Arg);
+ if (!MatrixArg)
+ return Sema::TDK_NonDeducedMismatch;
+
+ unsigned SubTDF = TDF & TDF_IgnoreQualifiers;
+
+ // Check the element type of the matrixes.
+ const DependentSizedMatrixType *MatrixParam =
+ cast<DependentSizedMatrixType>(Param);
+ if (Sema::TemplateDeductionResult Result =
+ DeduceTemplateArgumentsByTypeMatch(
+ S, TemplateParams, MatrixParam->getElementType(),
+ MatrixArg->getElementType(), Info, Deduced, SubTDF))
+ return Result;
+
+ // Determine if the number of rows and columns is something we can deduce.
+ NonTypeTemplateParmDecl *RowNTTP =
+ getDeducedParameterFromExpr(Info, MatrixParam->getRowExpr());
+ NonTypeTemplateParmDecl *ColumnNTTP =
+ getDeducedParameterFromExpr(Info, MatrixParam->getColumnExpr());
+ if (!RowNTTP && !ColumnNTTP)
+ return Sema::TDK_Success;
+
+ // Otherwise perform template argument deduction for the given non-type
+ // template parameters.
+ auto DeduceMatrixArg =
+ [&](NonTypeTemplateParmDecl *NTTP,
+ std::function<unsigned(const ConstantMatrixType *)> GetDimension,
+ std::function<Expr *(const DependentSizedMatrixType *)>
+ GetDimensionExpr) {
+ if (!NTTP)
+ return Sema::TDK_Success;
+ auto Result = Sema::TDK_NonDeducedMismatch;
+ assert(NTTP->getDepth() == Info.getDeducedDepth() &&
+ "saw non-type template parameter with wrong depth");
+ if (const ConstantMatrixType *ConstantMatrixArg =
+ dyn_cast<ConstantMatrixType>(MatrixArg)) {
+ llvm::APSInt NumAPSInt(S.Context.getTypeSize(NTTP->getType()));
+ NumAPSInt = GetDimension(ConstantMatrixArg);
+ Result = DeduceNonTypeTemplateArgument(
+ S, TemplateParams, NTTP,
+ NumAPSInt, // S.Context.getSizeType(),
+ NTTP->getType(),
+ /*ArrayBound=*/false, Info, Deduced);
+
+ } else if (const DependentSizedMatrixType *DepMatrixArg =
+ dyn_cast<DependentSizedMatrixType>(MatrixArg))
+ Result = DeduceNonTypeTemplateArgument(
+ S, TemplateParams, NTTP, GetDimensionExpr(DepMatrixArg), Info,
+ Deduced);
+
+ return Result;
+ };
+
+ auto Result = DeduceMatrixArg(
+ RowNTTP,
+ [](const ConstantMatrixType *MT) { return MT->getNumRows(); },
+ [](const DependentSizedMatrixType *MT) { return MT->getRowExpr(); });
+ if (Result != Sema::TDK_Success)
+ return Result;
+
+ return DeduceMatrixArg(
+ ColumnNTTP,
+ [](const ConstantMatrixType *MT) { return MT->getNumColumns(); },
+ [](const DependentSizedMatrixType *MT) {
+ return MT->getColumnExpr();
+ });
+ }
+
// (clang extension)
//
// T __attribute__(((address_space(N))))
@@ -5723,6 +5815,24 @@
break;
}
+ case Type::ConstantMatrix: {
+ const ConstantMatrixType *MatType = cast<ConstantMatrixType>(T);
+ MarkUsedTemplateParameters(Ctx, MatType->getElementType(), OnlyDeduced,
+ Depth, Used);
+ break;
+ }
+
+ case Type::DependentSizedMatrix: {
+ const DependentSizedMatrixType *MatType = cast<DependentSizedMatrixType>(T);
+ MarkUsedTemplateParameters(Ctx, MatType->getElementType(), OnlyDeduced,
+ Depth, Used);
+ MarkUsedTemplateParameters(Ctx, MatType->getRowExpr(), OnlyDeduced, Depth,
+ Used);
+ MarkUsedTemplateParameters(Ctx, MatType->getColumnExpr(), OnlyDeduced,
+ Depth, Used);
+ break;
+ }
+
case Type::FunctionProto: {
const FunctionProtoType *Proto = cast<FunctionProtoType>(T);
MarkUsedTemplateParameters(Ctx, Proto->getReturnType(), OnlyDeduced, Depth,
Index: clang/lib/Sema/SemaTemplate.cpp
===================================================================
--- clang/lib/Sema/SemaTemplate.cpp
+++ clang/lib/Sema/SemaTemplate.cpp
@@ -5867,6 +5867,11 @@
return Visit(T->getElementType());
}
+bool UnnamedLocalNoLinkageFinder::VisitDependentSizedMatrixType(
+ const DependentSizedMatrixType *T) {
+ return Visit(T->getElementType());
+}
+
bool UnnamedLocalNoLinkageFinder::VisitDependentAddressSpaceType(
const DependentAddressSpaceType *T) {
return Visit(T->getPointeeType());
@@ -5885,6 +5890,11 @@
return Visit(T->getElementType());
}
+bool UnnamedLocalNoLinkageFinder::VisitConstantMatrixType(
+ const ConstantMatrixType *T) {
+ return Visit(T->getElementType());
+}
+
bool UnnamedLocalNoLinkageFinder::VisitFunctionProtoType(
const FunctionProtoType* T) {
for (const auto &A : T->param_types()) {
Index: clang/lib/Sema/SemaLookup.cpp
===================================================================
--- clang/lib/Sema/SemaLookup.cpp
+++ clang/lib/Sema/SemaLookup.cpp
@@ -2966,6 +2966,7 @@
// These are fundamental types.
case Type::Vector:
case Type::ExtVector:
+ case Type::ConstantMatrix:
case Type::Complex:
case Type::ExtInt:
break;
Index: clang/lib/Sema/SemaExpr.cpp
===================================================================
--- clang/lib/Sema/SemaExpr.cpp
+++ clang/lib/Sema/SemaExpr.cpp
@@ -4257,6 +4257,7 @@
case Type::Complex:
case Type::Vector:
case Type::ExtVector:
+ case Type::ConstantMatrix:
case Type::Record:
case Type::Enum:
case Type::Elaborated:
Index: clang/lib/Frontend/CompilerInvocation.cpp
===================================================================
--- clang/lib/Frontend/CompilerInvocation.cpp
+++ clang/lib/Frontend/CompilerInvocation.cpp
@@ -3336,6 +3336,8 @@
Opts.CompleteMemberPointers = Args.hasArg(OPT_fcomplete_member_pointers);
Opts.BuildingPCHWithObjectFile = Args.hasArg(OPT_building_pch_with_obj);
+ Opts.MatrixTypes = Args.hasArg(OPT_fenable_matrix);
+
Opts.MaxTokens = getLastArgIntValue(Args, OPT_fmax_tokens_EQ, 0, Diags);
if (Arg *A = Args.getLastArg(OPT_msign_return_address_EQ)) {
Index: clang/lib/Driver/ToolChains/Clang.cpp
===================================================================
--- clang/lib/Driver/ToolChains/Clang.cpp
+++ clang/lib/Driver/ToolChains/Clang.cpp
@@ -4565,6 +4565,13 @@
if (Args.hasFlag(options::OPT_mrtd, options::OPT_mno_rtd, false))
CmdArgs.push_back("-fdefault-calling-conv=stdcall");
+ if (Args.hasArg(options::OPT_fenable_matrix)) {
+ // enable-matrix is needed by both the LangOpts and by LLVM.
+ CmdArgs.push_back("-fenable-matrix");
+ CmdArgs.push_back("-mllvm");
+ CmdArgs.push_back("-enable-matrix");
+ }
+
CodeGenOptions::FramePointerKind FPKeepKind =
getFramePointerKind(Args, RawTriple);
const char *FPKeepKindStr = nullptr;
Index: clang/lib/CodeGen/ItaniumCXXABI.cpp
===================================================================
--- clang/lib/CodeGen/ItaniumCXXABI.cpp
+++ clang/lib/CodeGen/ItaniumCXXABI.cpp
@@ -3223,6 +3223,7 @@
// GCC treats vector and complex types as fundamental types.
case Type::Vector:
case Type::ExtVector:
+ case Type::ConstantMatrix:
case Type::Complex:
case Type::Atomic:
// FIXME: GCC treats block pointers as fundamental types?!
@@ -3458,6 +3459,7 @@
case Type::Builtin:
case Type::Vector:
case Type::ExtVector:
+ case Type::ConstantMatrix:
case Type::Complex:
case Type::BlockPointer:
// Itanium C++ ABI 2.9.5p4:
Index: clang/lib/CodeGen/CodeGenTypes.cpp
===================================================================
--- clang/lib/CodeGen/CodeGenTypes.cpp
+++ clang/lib/CodeGen/CodeGenTypes.cpp
@@ -82,6 +82,13 @@
/// a type. For example, the scalar representation for _Bool is i1, but the
/// memory representation is usually i8 or i32, depending on the target.
llvm::Type *CodeGenTypes::ConvertTypeForMem(QualType T, bool ForBitField) {
+ if (T->isConstantMatrixType()) {
+ const Type *Ty = Context.getCanonicalType(T).getTypePtr();
+ const ConstantMatrixType *MT = cast<ConstantMatrixType>(Ty);
+ return llvm::ArrayType::get(ConvertType(MT->getElementType()),
+ MT->getNumRows() * MT->getNumColumns());
+ }
+
llvm::Type *R = ConvertType(T);
// If this is a bool type, or an ExtIntType in a bitfield representation,
@@ -646,6 +653,12 @@
VT->getNumElements());
break;
}
+ case Type::ConstantMatrix: {
+ const ConstantMatrixType *MT = cast<ConstantMatrixType>(Ty);
+ ResultType = llvm::VectorType::get(ConvertType(MT->getElementType()),
+ MT->getNumRows() * MT->getNumColumns());
+ break;
+ }
case Type::FunctionNoProto:
case Type::FunctionProto:
ResultType = ConvertFunctionTypeInternal(T);
Index: clang/lib/CodeGen/CodeGenFunction.cpp
===================================================================
--- clang/lib/CodeGen/CodeGenFunction.cpp
+++ clang/lib/CodeGen/CodeGenFunction.cpp
@@ -247,6 +247,7 @@
case Type::MemberPointer:
case Type::Vector:
case Type::ExtVector:
+ case Type::ConstantMatrix:
case Type::FunctionProto:
case Type::FunctionNoProto:
case Type::Enum:
@@ -2000,6 +2001,7 @@
case Type::Complex:
case Type::Vector:
case Type::ExtVector:
+ case Type::ConstantMatrix:
case Type::Record:
case Type::Enum:
case Type::Elaborated:
Index: clang/lib/CodeGen/CGExpr.cpp
===================================================================
--- clang/lib/CodeGen/CGExpr.cpp
+++ clang/lib/CodeGen/CGExpr.cpp
@@ -145,8 +145,19 @@
Address CodeGenFunction::CreateMemTemp(QualType Ty, CharUnits Align,
const Twine &Name, Address *Alloca) {
- return CreateTempAlloca(ConvertTypeForMem(Ty), Align, Name,
- /*ArraySize=*/nullptr, Alloca);
+ Address Result = CreateTempAlloca(ConvertTypeForMem(Ty), Align, Name,
+ /*ArraySize=*/nullptr, Alloca);
+
+ if (Ty->isConstantMatrixType()) {
+ auto *ArrayTy = cast<llvm::ArrayType>(Result.getType()->getElementType());
+ auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(),
+ ArrayTy->getNumElements());
+
+ Result = Address(
+ Builder.CreateBitCast(Result.getPointer(), VectorTy->getPointerTo()),
+ Result.getAlignment());
+ }
+ return Result;
}
Address CodeGenFunction::CreateMemTempWithoutCast(QualType Ty, CharUnits Align,
@@ -1732,6 +1743,42 @@
return Value;
}
+// Convert the pointer of \p Addr to a pointer to a vector (the value type of
+// MatrixType), if it points to a array (the memory type of MatrixType).
+static Address MaybeConvertMatrixAddress(Address Addr, CodeGenFunction &CGF,
+ bool IsVector = true) {
+ auto *ArrayTy = dyn_cast<llvm::ArrayType>(
+ cast<llvm::PointerType>(Addr.getPointer()->getType())->getElementType());
+ if (ArrayTy && IsVector) {
+ auto *VectorTy = llvm::VectorType::get(ArrayTy->getElementType(),
+ ArrayTy->getNumElements());
+
+ return Address(CGF.Builder.CreateElementBitCast(Addr, VectorTy));
+ }
+ auto *VectorTy = dyn_cast<llvm::VectorType>(
+ cast<llvm::PointerType>(Addr.getPointer()->getType())->getElementType());
+ if (VectorTy && !IsVector) {
+ auto *ArrayTy = llvm::ArrayType::get(VectorTy->getElementType(),
+ VectorTy->getNumElements());
+
+ return Address(CGF.Builder.CreateElementBitCast(Addr, ArrayTy));
+ }
+
+ return Addr;
+}
+
+// Emit a store of a matrix LValue. This may require casting the original
+// pointer to memory address (ArrayType) to a pointer to the value type
+// (VectorType).
+static void EmitStoreOfMatrixScalar(llvm::Value *value, LValue lvalue,
+ bool isInit, CodeGenFunction &CGF) {
+ Address Addr = MaybeConvertMatrixAddress(lvalue.getAddress(CGF), CGF,
+ value->getType()->isVectorTy());
+ CGF.EmitStoreOfScalar(value, Addr, lvalue.isVolatile(), lvalue.getType(),
+ lvalue.getBaseInfo(), lvalue.getTBAAInfo(), isInit,
+ lvalue.isNontemporal());
+}
+
void CodeGenFunction::EmitStoreOfScalar(llvm::Value *Value, Address Addr,
bool Volatile, QualType Ty,
LValueBaseInfo BaseInfo,
@@ -1779,11 +1826,26 @@
void CodeGenFunction::EmitStoreOfScalar(llvm::Value *value, LValue lvalue,
bool isInit) {
+ if (lvalue.getType()->isConstantMatrixType()) {
+ EmitStoreOfMatrixScalar(value, lvalue, isInit, *this);
+ return;
+ }
+
EmitStoreOfScalar(value, lvalue.getAddress(*this), lvalue.isVolatile(),
lvalue.getType(), lvalue.getBaseInfo(),
lvalue.getTBAAInfo(), isInit, lvalue.isNontemporal());
}
+// Emit a load of a LValue of matrix type. This may require casting the pointer
+// to memory address (ArrayType) to a pointer to the value type (VectorType).
+static RValue EmitLoadOfMatrixLValue(LValue LV, SourceLocation Loc,
+ CodeGenFunction &CGF) {
+ assert(LV.getType()->isConstantMatrixType());
+ Address Addr = MaybeConvertMatrixAddress(LV.getAddress(CGF), CGF);
+ LV.setAddress(Addr);
+ return RValue::get(CGF.EmitLoadOfScalar(LV, Loc));
+}
+
/// EmitLoadOfLValue - Given an expression that represents a value lvalue, this
/// method emits the address of the lvalue, then loads the result as an rvalue,
/// returning the rvalue.
@@ -1809,6 +1871,9 @@
if (LV.isSimple()) {
assert(!LV.getType()->isFunctionType());
+ if (LV.getType()->isConstantMatrixType())
+ return EmitLoadOfMatrixLValue(LV, Loc, *this);
+
// Everything needs a load.
return RValue::get(EmitLoadOfScalar(LV, Loc));
}
Index: clang/lib/CodeGen/CGDebugInfo.h
===================================================================
--- clang/lib/CodeGen/CGDebugInfo.h
+++ clang/lib/CodeGen/CGDebugInfo.h
@@ -192,6 +192,7 @@
llvm::DIType *CreateType(const ObjCTypeParamType *Ty, llvm::DIFile *Unit);
llvm::DIType *CreateType(const VectorType *Ty, llvm::DIFile *F);
+ llvm::DIType *CreateType(const ConstantMatrixType *Ty, llvm::DIFile *F);
llvm::DIType *CreateType(const ArrayType *Ty, llvm::DIFile *F);
llvm::DIType *CreateType(const LValueReferenceType *Ty, llvm::DIFile *F);
llvm::DIType *CreateType(const RValueReferenceType *Ty, llvm::DIFile *Unit);
Index: clang/lib/CodeGen/CGDebugInfo.cpp
===================================================================
--- clang/lib/CodeGen/CGDebugInfo.cpp
+++ clang/lib/CodeGen/CGDebugInfo.cpp
@@ -2736,6 +2736,23 @@
return DBuilder.createVectorType(Size, Align, ElementTy, SubscriptArray);
}
+llvm::DIType *CGDebugInfo::CreateType(const ConstantMatrixType *Ty,
+ llvm::DIFile *Unit) {
+ // FIXME: Create another debug type for matrices
+ // For the time being, it treats it like a nested ArrayType.
+
+ llvm::DIType *ElementTy = getOrCreateType(Ty->getElementType(), Unit);
+ uint64_t Size = CGM.getContext().getTypeSize(Ty);
+ uint32_t Align = getTypeAlignIfRequired(Ty, CGM.getContext());
+
+ // Create ranges for both dimensions.
+ llvm::SmallVector<llvm::Metadata *, 2> Subscripts;
+ Subscripts.push_back(DBuilder.getOrCreateSubrange(0, Ty->getNumColumns()));
+ Subscripts.push_back(DBuilder.getOrCreateSubrange(0, Ty->getNumRows()));
+ llvm::DINodeArray SubscriptArray = DBuilder.getOrCreateArray(Subscripts);
+ return DBuilder.createArrayType(Size, Align, ElementTy, SubscriptArray);
+}
+
llvm::DIType *CGDebugInfo::CreateType(const ArrayType *Ty, llvm::DIFile *Unit) {
uint64_t Size;
uint32_t Align;
@@ -3129,6 +3146,8 @@
case Type::ExtVector:
case Type::Vector:
return CreateType(cast<VectorType>(Ty), Unit);
+ case Type::ConstantMatrix:
+ return CreateType(cast<ConstantMatrixType>(Ty), Unit);
case Type::ObjCObjectPointer:
return CreateType(cast<ObjCObjectPointerType>(Ty), Unit);
case Type::ObjCObject:
Index: clang/lib/AST/TypePrinter.cpp
===================================================================
--- clang/lib/AST/TypePrinter.cpp
+++ clang/lib/AST/TypePrinter.cpp
@@ -256,6 +256,8 @@
case Type::DependentSizedExtVector:
case Type::Vector:
case Type::ExtVector:
+ case Type::ConstantMatrix:
+ case Type::DependentSizedMatrix:
case Type::FunctionProto:
case Type::FunctionNoProto:
case Type::Paren:
@@ -720,6 +722,38 @@
OS << ")))";
}
+void TypePrinter::printConstantMatrixBefore(const ConstantMatrixType *T,
+ raw_ostream &OS) {
+ printBefore(T->getElementType(), OS);
+ OS << " __attribute__((matrix_type(";
+ OS << T->getNumRows() << ", " << T->getNumColumns();
+ OS << ")))";
+}
+
+void TypePrinter::printConstantMatrixAfter(const ConstantMatrixType *T,
+ raw_ostream &OS) {
+ printAfter(T->getElementType(), OS);
+}
+
+void TypePrinter::printDependentSizedMatrixBefore(
+ const DependentSizedMatrixType *T, raw_ostream &OS) {
+ printBefore(T->getElementType(), OS);
+ OS << " __attribute__((matrix_type(";
+ if (T->getRowExpr()) {
+ T->getRowExpr()->printPretty(OS, nullptr, Policy);
+ }
+ OS << ", ";
+ if (T->getColumnExpr()) {
+ T->getColumnExpr()->printPretty(OS, nullptr, Policy);
+ }
+ OS << ")))";
+}
+
+void TypePrinter::printDependentSizedMatrixAfter(
+ const DependentSizedMatrixType *T, raw_ostream &OS) {
+ printAfter(T->getElementType(), OS);
+}
+
void
FunctionProtoType::printExceptionSpecification(raw_ostream &OS,
const PrintingPolicy &Policy)
Index: clang/lib/AST/Type.cpp
===================================================================
--- clang/lib/AST/Type.cpp
+++ clang/lib/AST/Type.cpp
@@ -282,6 +282,53 @@
AddrSpaceExpr->Profile(ID, Context, true);
}
+MatrixType::MatrixType(TypeClass tc, QualType matrixType, QualType canonType,
+ const Expr *RowExpr, const Expr *ColumnExpr)
+ : Type(tc, canonType,
+ (RowExpr
+ ? (TypeDependence::Dependent | TypeDependence::Instantiation |
+ (matrixType->isVariablyModifiedType()
+ ? TypeDependence::VariablyModified
+ : TypeDependence::None) |
+ (matrixType->containsUnexpandedParameterPack() ||
+ (RowExpr &&
+ RowExpr->containsUnexpandedParameterPack()) ||
+ (ColumnExpr &&
+ ColumnExpr->containsUnexpandedParameterPack())
+ ? TypeDependence::UnexpandedPack
+ : TypeDependence::None))
+ : matrixType->getDependence())),
+ ElementType(matrixType) {}
+
+ConstantMatrixType::ConstantMatrixType(QualType matrixType, unsigned nRows,
+ unsigned nColumns, QualType canonType)
+ : ConstantMatrixType(ConstantMatrix, matrixType, nRows, nColumns,
+ canonType) {}
+
+ConstantMatrixType::ConstantMatrixType(TypeClass tc, QualType matrixType,
+ unsigned nRows, unsigned nColumns,
+ QualType canonType)
+ : MatrixType(tc, matrixType, canonType) {
+ ConstantMatrixTypeBits.NumRows = nRows;
+ ConstantMatrixTypeBits.NumColumns = nColumns;
+}
+
+DependentSizedMatrixType::DependentSizedMatrixType(
+ const ASTContext &CTX, QualType ElementType, QualType CanonicalType,
+ Expr *RowExpr, Expr *ColumnExpr, SourceLocation loc)
+ : MatrixType(DependentSizedMatrix, ElementType, CanonicalType, RowExpr,
+ ColumnExpr),
+ Context(CTX), RowExpr(RowExpr), ColumnExpr(ColumnExpr), loc(loc) {}
+
+void DependentSizedMatrixType::Profile(llvm::FoldingSetNodeID &ID,
+ const ASTContext &CTX,
+ QualType ElementType, Expr *RowExpr,
+ Expr *ColumnExpr) {
+ ID.AddPointer(ElementType.getAsOpaquePtr());
+ RowExpr->Profile(ID, CTX, true);
+ ColumnExpr->Profile(ID, CTX, true);
+}
+
VectorType::VectorType(QualType vecType, unsigned nElements, QualType canonType,
VectorKind vecKind)
: VectorType(Vector, vecType, nElements, canonType, vecKind) {}
@@ -971,6 +1018,17 @@
return Ctx.getExtVectorType(elementType, T->getNumElements());
}
+ QualType VisitConstantMatrixType(const ConstantMatrixType *T) {
+ QualType elementType = recurse(T->getElementType());
+ if (elementType.isNull())
+ return {};
+ if (elementType.getAsOpaquePtr() == T->getElementType().getAsOpaquePtr())
+ return QualType(T, 0);
+
+ return Ctx.getConstantMatrixType(elementType, T->getNumRows(),
+ T->getNumColumns());
+ }
+
QualType VisitFunctionNoProtoType(const FunctionNoProtoType *T) {
QualType returnType = recurse(T->getReturnType());
if (returnType.isNull())
@@ -1790,6 +1848,14 @@
return Visit(T->getElementType());
}
+ Type *VisitDependentSizedMatrixType(const DependentSizedMatrixType *T) {
+ return Visit(T->getElementType());
+ }
+
+ Type *VisitConstantMatrixType(const ConstantMatrixType *T) {
+ return Visit(T->getElementType());
+ }
+
Type *VisitFunctionProtoType(const FunctionProtoType *T) {
if (Syntactic && T->hasTrailingReturn())
return const_cast<FunctionProtoType*>(T);
@@ -3744,6 +3810,8 @@
case Type::Vector:
case Type::ExtVector:
return Cache::get(cast<VectorType>(T)->getElementType());
+ case Type::ConstantMatrix:
+ return Cache::get(cast<ConstantMatrixType>(T)->getElementType());
case Type::FunctionNoProto:
return Cache::get(cast<FunctionType>(T)->getReturnType());
case Type::FunctionProto: {
@@ -3830,6 +3898,9 @@
case Type::Vector:
case Type::ExtVector:
return computeTypeLinkageInfo(cast<VectorType>(T)->getElementType());
+ case Type::ConstantMatrix:
+ return computeTypeLinkageInfo(
+ cast<ConstantMatrixType>(T)->getElementType());
case Type::FunctionNoProto:
return computeTypeLinkageInfo(cast<FunctionType>(T)->getReturnType());
case Type::FunctionProto: {
@@ -3993,6 +4064,8 @@
case Type::DependentSizedExtVector:
case Type::Vector:
case Type::ExtVector:
+ case Type::ConstantMatrix:
+ case Type::DependentSizedMatrix:
case Type::DependentAddressSpace:
case Type::FunctionProto:
case Type::FunctionNoProto:
Index: clang/lib/AST/MicrosoftMangle.cpp
===================================================================
--- clang/lib/AST/MicrosoftMangle.cpp
+++ clang/lib/AST/MicrosoftMangle.cpp
@@ -2730,6 +2730,23 @@
<< Range;
}
+void MicrosoftCXXNameMangler::mangleType(const ConstantMatrixType *T,
+ Qualifiers quals, SourceRange Range) {
+ DiagnosticsEngine &Diags = Context.getDiags();
+ unsigned DiagID = Diags.getCustomDiagID(DiagnosticsEngine::Error,
+ "Cannot mangle this matrix type yet");
+ Diags.Report(Range.getBegin(), DiagID) << Range;
+}
+
+void MicrosoftCXXNameMangler::mangleType(const DependentSizedMatrixType *T,
+ Qualifiers quals, SourceRange Range) {
+ DiagnosticsEngine &Diags = Context.getDiags();
+ unsigned DiagID = Diags.getCustomDiagID(
+ DiagnosticsEngine::Error,
+ "Cannot mangle this dependent-sized matrix type yet");
+ Diags.Report(Range.getBegin(), DiagID) << Range;
+}
+
void MicrosoftCXXNameMangler::mangleType(const DependentAddressSpaceType *T,
Qualifiers, SourceRange Range) {
DiagnosticsEngine &Diags = Context.getDiags();
Index: clang/lib/AST/ItaniumMangle.cpp
===================================================================
--- clang/lib/AST/ItaniumMangle.cpp
+++ clang/lib/AST/ItaniumMangle.cpp
@@ -2079,6 +2079,8 @@
case Type::DependentSizedExtVector:
case Type::Vector:
case Type::ExtVector:
+ case Type::ConstantMatrix:
+ case Type::DependentSizedMatrix:
case Type::FunctionProto:
case Type::FunctionNoProto:
case Type::Paren:
@@ -3343,6 +3345,31 @@
mangleType(T->getElementType());
}
+void CXXNameMangler::mangleType(const ConstantMatrixType *T) {
+ // Mangle matrix types using a vendor extended type qualifier:
+ // U<Len>matrix_type<Rows><Columns><element type>
+ std::string VendorQualifier = "matrix_type";
+ Out << "U" << VendorQualifier.size() << VendorQualifier;
+ auto &ASTCtx = getASTContext();
+ unsigned BitWidth = ASTCtx.getTypeSize(ASTCtx.getSizeType());
+ llvm::APSInt Rows(BitWidth);
+ Rows = T->getNumRows();
+ mangleIntegerLiteral(ASTCtx.getSizeType(), Rows);
+ llvm::APSInt Columns(BitWidth);
+ Columns = T->getNumColumns();
+ mangleIntegerLiteral(ASTCtx.getSizeType(), Columns);
+ mangleType(T->getElementType());
+}
+
+void CXXNameMangler::mangleType(const DependentSizedMatrixType *T) {
+ // U<Len>matrix_type<row expr><column expr><element type>
+ std::string VendorQualifier = "matrix_type";
+ Out << "U" << VendorQualifier.size() << VendorQualifier;
+ mangleTemplateArg(T->getRowExpr());
+ mangleTemplateArg(T->getColumnExpr());
+ mangleType(T->getElementType());
+}
+
void CXXNameMangler::mangleType(const DependentAddressSpaceType *T) {
SplitQualType split = T->getPointeeType().split();
mangleQualifiers(split.Quals, T);
Index: clang/lib/AST/ExprConstant.cpp
===================================================================
--- clang/lib/AST/ExprConstant.cpp
+++ clang/lib/AST/ExprConstant.cpp
@@ -10350,6 +10350,7 @@
case Type::BlockPointer:
case Type::Vector:
case Type::ExtVector:
+ case Type::ConstantMatrix:
case Type::ObjCObject:
case Type::ObjCInterface:
case Type::ObjCObjectPointer:
Index: clang/lib/AST/ASTStructuralEquivalence.cpp
===================================================================
--- clang/lib/AST/ASTStructuralEquivalence.cpp
+++ clang/lib/AST/ASTStructuralEquivalence.cpp
@@ -617,6 +617,34 @@
break;
}
+ case Type::DependentSizedMatrix: {
+ const DependentSizedMatrixType *Mat1 = cast<DependentSizedMatrixType>(T1);
+ const DependentSizedMatrixType *Mat2 = cast<DependentSizedMatrixType>(T2);
+ // The element types, row and column expressions must be structurally
+ // equivalent.
+ if (!IsStructurallyEquivalent(Context, Mat1->getRowExpr(),
+ Mat2->getRowExpr()) ||
+ !IsStructurallyEquivalent(Context, Mat1->getColumnExpr(),
+ Mat2->getColumnExpr()) ||
+ !IsStructurallyEquivalent(Context, Mat1->getElementType(),
+ Mat2->getElementType()))
+ return false;
+ break;
+ }
+
+ case Type::ConstantMatrix: {
+ const ConstantMatrixType *Mat1 = cast<ConstantMatrixType>(T1);
+ const ConstantMatrixType *Mat2 = cast<ConstantMatrixType>(T2);
+ // The element types must be structurally equivalent and the number of rows
+ // and columns must match.
+ if (!IsStructurallyEquivalent(Context, Mat1->getElementType(),
+ Mat2->getElementType()) ||
+ Mat1->getNumRows() != Mat2->getNumRows() ||
+ Mat1->getNumColumns() != Mat2->getNumColumns())
+ return false;
+ break;
+ }
+
case Type::FunctionProto: {
const auto *Proto1 = cast<FunctionProtoType>(T1);
const auto *Proto2 = cast<FunctionProtoType>(T2);
Index: clang/lib/AST/ASTContext.cpp
===================================================================
--- clang/lib/AST/ASTContext.cpp
+++ clang/lib/AST/ASTContext.cpp
@@ -1932,6 +1932,17 @@
break;
}
+ case Type::ConstantMatrix: {
+ const auto *MT = cast<ConstantMatrixType>(T);
+ TypeInfo ElementInfo = getTypeInfo(MT->getElementType());
+ // The internal layout of a matrix value is implementation defined.
+ // Initially be ABI compatible with arrays with respect to alignment and
+ // size.
+ Width = ElementInfo.Width * MT->getNumRows() * MT->getNumColumns();
+ Align = ElementInfo.Align;
+ break;
+ }
+
case Type::Builtin:
switch (cast<BuiltinType>(T)->getKind()) {
default: llvm_unreachable("Unknown builtin type!");
@@ -3362,6 +3373,8 @@
case Type::DependentVector:
case Type::ExtVector:
case Type::DependentSizedExtVector:
+ case Type::ConstantMatrix:
+ case Type::DependentSizedMatrix:
case Type::DependentAddressSpace:
case Type::ObjCObject:
case Type::ObjCInterface:
@@ -3775,6 +3788,78 @@
return QualType(New, 0);
}
+QualType ASTContext::getConstantMatrixType(QualType ElementTy, unsigned NumRows,
+ unsigned NumColumns) const {
+ llvm::FoldingSetNodeID ID;
+ ConstantMatrixType::Profile(ID, ElementTy, NumRows, NumColumns,
+ Type::ConstantMatrix);
+
+ assert(MatrixType::isValidElementType(ElementTy) &&
+ "need a valid element type");
+ assert(ConstantMatrixType::isDimensionValid(NumRows) &&
+ ConstantMatrixType::isDimensionValid(NumColumns) &&
+ "need valid matrix dimensions");
+ void *InsertPos = nullptr;
+ if (ConstantMatrixType *MTP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos))
+ return QualType(MTP, 0);
+
+ QualType Canonical;
+ if (!ElementTy.isCanonical()) {
+ Canonical =
+ getConstantMatrixType(getCanonicalType(ElementTy), NumRows, NumColumns);
+
+ ConstantMatrixType *NewIP = MatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+ assert(!NewIP && "Matrix type shouldn't already exist in the map");
+ (void)NewIP;
+ }
+
+ auto *New = new (*this, TypeAlignment)
+ ConstantMatrixType(ElementTy, NumRows, NumColumns, Canonical);
+ MatrixTypes.InsertNode(New, InsertPos);
+ Types.push_back(New);
+ return QualType(New, 0);
+}
+
+QualType ASTContext::getDependentSizedMatrixType(QualType ElementTy,
+ Expr *RowExpr,
+ Expr *ColumnExpr,
+ SourceLocation AttrLoc) const {
+ QualType CanonElementTy = getCanonicalType(ElementTy);
+ llvm::FoldingSetNodeID ID;
+ DependentSizedMatrixType::Profile(ID, *this, CanonElementTy, RowExpr,
+ ColumnExpr);
+
+ void *InsertPos = nullptr;
+ DependentSizedMatrixType *Canon =
+ DependentSizedMatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+
+ if (!Canon) {
+ Canon = new (*this, TypeAlignment) DependentSizedMatrixType(
+ *this, CanonElementTy, QualType(), RowExpr, ColumnExpr, AttrLoc);
+#ifndef NDEBUG
+ DependentSizedMatrixType *CanonCheck =
+ DependentSizedMatrixTypes.FindNodeOrInsertPos(ID, InsertPos);
+ assert(!CanonCheck && "Dependent-sized matrix canonical type broken");
+#endif
+ DependentSizedMatrixTypes.InsertNode(Canon, InsertPos);
+ Types.push_back(Canon);
+ }
+
+ // Already have a canonical version of the matrix type
+ //
+ // If it exactly matches the requested type, use it directly.
+ if (Canon->getElementType() == ElementTy && Canon->getRowExpr() == RowExpr &&
+ Canon->getRowExpr() == ColumnExpr)
+ return QualType(Canon, 0);
+
+ // Use Canon as the canonical type for newly-built type.
+ DependentSizedMatrixType *New = new (*this, TypeAlignment)
+ DependentSizedMatrixType(*this, ElementTy, QualType(Canon, 0), RowExpr,
+ ColumnExpr, AttrLoc);
+ Types.push_back(New);
+ return QualType(New, 0);
+}
+
QualType ASTContext::getDependentAddressSpaceType(QualType PointeeType,
Expr *AddrSpaceExpr,
SourceLocation AttrLoc) const {
@@ -7338,6 +7423,11 @@
*NotEncodedT = T;
return;
+ case Type::ConstantMatrix:
+ if (NotEncodedT)
+ *NotEncodedT = T;
+ return;
+
// We could see an undeduced auto type here during error recovery.
// Just ignore it.
case Type::Auto:
@@ -8217,6 +8307,16 @@
LHS->getNumElements() == RHS->getNumElements();
}
+/// areCompatMatrixTypes - Return true if the two specified matrix types are
+/// compatible.
+static bool areCompatMatrixTypes(const ConstantMatrixType *LHS,
+ const ConstantMatrixType *RHS) {
+ assert(LHS->isCanonicalUnqualified() && RHS->isCanonicalUnqualified());
+ return LHS->getElementType() == RHS->getElementType() &&
+ LHS->getNumRows() == RHS->getNumRows() &&
+ LHS->getNumColumns() == RHS->getNumColumns();
+}
+
bool ASTContext::areCompatibleVectorTypes(QualType FirstVec,
QualType SecondVec) {
assert(FirstVec->isVectorType() && "FirstVec should be a vector type");
@@ -9414,6 +9514,11 @@
RHSCan->castAs<VectorType>()))
return LHS;
return {};
+ case Type::ConstantMatrix:
+ if (areCompatMatrixTypes(LHSCan->castAs<ConstantMatrixType>(),
+ RHSCan->castAs<ConstantMatrixType>()))
+ return LHS;
+ return {};
case Type::ObjCObject: {
// Check if the types are assignment compatible.
// FIXME: This should be type compatibility, e.g. whether
Index: clang/include/clang/Serialization/TypeBitCodes.def
===================================================================
--- clang/include/clang/Serialization/TypeBitCodes.def
+++ clang/include/clang/Serialization/TypeBitCodes.def
@@ -60,5 +60,7 @@
TYPE_BIT_CODE(MacroQualified, MACRO_QUALIFIED, 49)
TYPE_BIT_CODE(ExtInt, EXT_INT, 50)
TYPE_BIT_CODE(DependentExtInt, DEPENDENT_EXT_INT, 51)
+TYPE_BIT_CODE(ConstantMatrix, CONSTANT_MATRIX, 52)
+TYPE_BIT_CODE(DependentSizedMatrix, DEPENDENT_SIZE_MATRIX, 53)
#undef TYPE_BIT_CODE
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -1627,6 +1627,9 @@
QualType BuildVectorType(QualType T, Expr *VecSize, SourceLocation AttrLoc);
QualType BuildExtVectorType(QualType T, Expr *ArraySize,
SourceLocation AttrLoc);
+ QualType BuildMatrixType(QualType T, Expr *NumRows, Expr *NumColumns,
+ SourceLocation AttrLoc);
+
QualType BuildAddressSpaceAttr(QualType &T, LangAS ASIdx, Expr *AddrSpace,
SourceLocation AttrLoc);
Index: clang/include/clang/Driver/Options.td
===================================================================
--- clang/include/clang/Driver/Options.td
+++ clang/include/clang/Driver/Options.td
@@ -2007,6 +2007,10 @@
def fno_strict_return : Flag<["-"], "fno-strict-return">, Group<f_Group>,
Flags<[CC1Option]>;
+def fenable_matrix : Flag<["-"], "fenable-matrix">, Group<f_Group>,
+ Flags<[CC1Option]>,
+ HelpText<"Enable matrix data type and related builtin functions">;
+
def fallow_editor_placeholders : Flag<["-"], "fallow-editor-placeholders">,
Group<f_Group>, Flags<[CC1Option]>,
HelpText<"Treat editor placeholders as valid source code">;
Index: clang/include/clang/Basic/TypeNodes.td
===================================================================
--- clang/include/clang/Basic/TypeNodes.td
+++ clang/include/clang/Basic/TypeNodes.td
@@ -69,6 +69,9 @@
def VectorType : TypeNode<Type>;
def DependentVectorType : TypeNode<Type>, AlwaysDependent;
def ExtVectorType : TypeNode<VectorType>;
+def MatrixType : TypeNode<Type, 1>;
+def ConstantMatrixType : TypeNode<MatrixType>;
+def DependentSizedMatrixType : TypeNode<MatrixType>, AlwaysDependent;
def FunctionType : TypeNode<Type, 1>;
def FunctionProtoType : TypeNode<FunctionType>;
def FunctionNoProtoType : TypeNode<FunctionType>;
Index: clang/include/clang/Basic/LangOptions.def
===================================================================
--- clang/include/clang/Basic/LangOptions.def
+++ clang/include/clang/Basic/LangOptions.def
@@ -357,6 +357,8 @@
LANGOPT(RegisterStaticDestructors, 1, 1, "Register C++ static destructors")
+LANGOPT(MatrixTypes, 1, 0, "Enable or disable the builtin matrix type")
+
COMPATIBLE_VALUE_LANGOPT(MaxTokens, 32, 0, "Max number of tokens per TU or 0")
ENUM_LANGOPT(SignReturnAddressScope, SignReturnAddressScopeKind, 2, SignReturnAddressScopeKind::None,
Index: clang/include/clang/Basic/Features.def
===================================================================
--- clang/include/clang/Basic/Features.def
+++ clang/include/clang/Basic/Features.def
@@ -253,6 +253,7 @@
EXTENSION(pragma_clang_attribute_external_declaration, true)
EXTENSION(gnu_asm, LangOpts.GNUAsm)
EXTENSION(gnu_asm_goto_with_outputs, LangOpts.GNUAsm)
+EXTENSION(matrix_types, LangOpts.MatrixTypes)
#undef EXTENSION
#undef FEATURE
Index: clang/include/clang/Basic/DiagnosticSemaKinds.td
===================================================================
--- clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -2776,6 +2776,7 @@
def err_attribute_too_few_arguments : Error<
"%0 attribute takes at least %1 argument%s1">;
def err_attribute_invalid_vector_type : Error<"invalid vector element type %0">;
+def err_attribute_invalid_matrix_type : Error<"invalid matrix element type %0">;
def err_attribute_bad_neon_vector_size : Error<
"Neon vector size must be 64 or 128 bits">;
def err_attribute_requires_positive_integer : Error<
@@ -2879,8 +2880,8 @@
"init methods must return an object pointer type, not %0">;
def err_attribute_invalid_size : Error<
"vector size not an integral multiple of component size">;
-def err_attribute_zero_size : Error<"zero vector size">;
-def err_attribute_size_too_large : Error<"vector size too large">;
+def err_attribute_zero_size : Error<"zero %0 size">;
+def err_attribute_size_too_large : Error<"%0 size too large">;
def err_typecheck_vector_not_convertable_implict_truncation : Error<
"cannot convert between %select{scalar|vector}0 type %1 and vector type"
" %2 as implicit conversion would cause truncation">;
@@ -10722,6 +10723,9 @@
"%select{non-pointer|function pointer|void pointer}0 argument to "
"'__builtin_launder' is not allowed">;
+def err_builtin_matrix_disabled: Error<
+ "matrix types extension is disabled. Pass -fenable-matrix to enable it">;
+
def err_preserve_field_info_not_field : Error<
"__builtin_preserve_field_info argument %0 not a field access">;
def err_preserve_field_info_not_const: Error<
Index: clang/include/clang/Basic/Attr.td
===================================================================
--- clang/include/clang/Basic/Attr.td
+++ clang/include/clang/Basic/Attr.td
@@ -2460,6 +2460,15 @@
let Documentation = [Undocumented];
}
+def MatrixType : TypeAttr {
+ let Spellings = [Clang<"matrix_type">];
+ let Subjects = SubjectList<[TypedefName], ErrorDiag>;
+ let Args = [ExprArgument<"NumRows">, ExprArgument<"NumColumns">];
+ let Documentation = [Undocumented];
+ let ASTNode = 0;
+ let PragmaAttributeSupport = 0;
+}
+
def Visibility : InheritableAttr {
let Clone = 0;
let Spellings = [GCC<"visibility">];
Index: clang/include/clang/AST/TypeProperties.td
===================================================================
--- clang/include/clang/AST/TypeProperties.td
+++ clang/include/clang/AST/TypeProperties.td
@@ -224,6 +224,41 @@
}]>;
}
+let Class = MatrixType in {
+ def : Property<"elementType", QualType> {
+ let Read = [{ node->getElementType() }];
+ }
+}
+
+let Class = ConstantMatrixType in {
+ def : Property<"numRows", UInt32> {
+ let Read = [{ node->getNumRows() }];
+ }
+ def : Property<"numColumns", UInt32> {
+ let Read = [{ node->getNumColumns() }];
+ }
+
+ def : Creator<[{
+ return ctx.getConstantMatrixType(elementType, numRows, numColumns);
+ }]>;
+}
+
+let Class = DependentSizedMatrixType in {
+ def : Property<"rows", ExprRef> {
+ let Read = [{ node->getRowExpr() }];
+ }
+ def : Property<"columns", ExprRef> {
+ let Read = [{ node->getColumnExpr() }];
+ }
+ def : Property<"attributeLoc", SourceLocation> {
+ let Read = [{ node->getAttributeLoc() }];
+ }
+
+ def : Creator<[{
+ return ctx.getDependentSizedMatrixType(elementType, rows, columns, attributeLoc);
+ }]>;
+}
+
let Class = FunctionType in {
def : Property<"returnType", QualType> {
let Read = [{ node->getReturnType() }];
Index: clang/include/clang/AST/TypeLoc.h
===================================================================
--- clang/include/clang/AST/TypeLoc.h
+++ clang/include/clang/AST/TypeLoc.h
@@ -1735,6 +1735,7 @@
void initializeLocal(ASTContext &Context, SourceLocation loc) {
setAttrNameLoc(loc);
+ setAttrOperandParensRange(loc);
setAttrOperandParensRange(SourceRange(loc));
setAttrExprOperand(getTypePtr()->getAddrSpaceExpr());
}
@@ -1774,6 +1775,68 @@
DependentSizedExtVectorType> {
};
+struct MatrixTypeLocInfo {
+ SourceLocation AttrLoc;
+ SourceRange OperandParens;
+ Expr *RowOperand;
+ Expr *ColumnOperand;
+};
+
+class MatrixTypeLoc : public ConcreteTypeLoc<UnqualTypeLoc, MatrixTypeLoc,
+ MatrixType, MatrixTypeLocInfo> {
+public:
+ /// The location of the attribute name, i.e.
+ /// float __attribute__((matrix_type(4, 2)))
+ /// ^~~~~~~~~~~~~~~~~
+ SourceLocation getAttrNameLoc() const { return getLocalData()->AttrLoc; }
+ void setAttrNameLoc(SourceLocation loc) { getLocalData()->AttrLoc = loc; }
+
+ /// The attribute's row operand, if it has one.
+ /// float __attribute__((matrix_type(4, 2)))
+ /// ^
+ Expr *getAttrRowOperand() const { return getLocalData()->RowOperand; }
+ void setAttrRowOperand(Expr *e) { getLocalData()->RowOperand = e; }
+
+ /// The attribute's column operand, if it has one.
+ /// float __attribute__((matrix_type(4, 2)))
+ /// ^
+ Expr *getAttrColumnOperand() const { return getLocalData()->ColumnOperand; }
+ void setAttrColumnOperand(Expr *e) { getLocalData()->ColumnOperand = e; }
+
+ /// The location of the parentheses around the operand, if there is
+ /// an operand.
+ /// float __attribute__((matrix_type(4, 2)))
+ /// ^ ^
+ SourceRange getAttrOperandParensRange() const {
+ return getLocalData()->OperandParens;
+ }
+ void setAttrOperandParensRange(SourceRange range) {
+ getLocalData()->OperandParens = range;
+ }
+
+ SourceRange getLocalSourceRange() const {
+ SourceRange range(getAttrNameLoc());
+ range.setEnd(getAttrOperandParensRange().getEnd());
+ return range;
+ }
+
+ void initializeLocal(ASTContext &Context, SourceLocation loc) {
+ setAttrNameLoc(loc);
+ setAttrOperandParensRange(loc);
+ setAttrRowOperand(nullptr);
+ setAttrColumnOperand(nullptr);
+ }
+};
+
+class ConstantMatrixTypeLoc
+ : public InheritingConcreteTypeLoc<MatrixTypeLoc, ConstantMatrixTypeLoc,
+ ConstantMatrixType> {};
+
+class DependentSizedMatrixTypeLoc
+ : public InheritingConcreteTypeLoc<MatrixTypeLoc,
+ DependentSizedMatrixTypeLoc,
+ DependentSizedMatrixType> {};
+
// FIXME: location of the '_Complex' keyword.
class ComplexTypeLoc : public InheritingConcreteTypeLoc<TypeSpecTypeLoc,
ComplexTypeLoc,
Index: clang/include/clang/AST/Type.h
===================================================================
--- clang/include/clang/AST/Type.h
+++ clang/include/clang/AST/Type.h
@@ -1654,6 +1654,19 @@
uint32_t NumElements;
};
+ class ConstantMatrixTypeBitfields {
+ friend class ConstantMatrixType;
+
+ unsigned : NumTypeBits;
+
+ /// Number of rows and columns. Using 20 bits allows supporting very large
+ /// matrixes, while keeping 24 bits to accommodate NumTypeBits.
+ unsigned NumRows : 20;
+ unsigned NumColumns : 20;
+
+ static constexpr uint32_t MaxElementsPerDimension = (1 << 20) - 1;
+ };
+
class AttributedTypeBitfields {
friend class AttributedType;
@@ -1763,6 +1776,7 @@
TypeWithKeywordBitfields TypeWithKeywordBits;
ElaboratedTypeBitfields ElaboratedTypeBits;
VectorTypeBitfields VectorTypeBits;
+ ConstantMatrixTypeBitfields ConstantMatrixTypeBits;
SubstTemplateTypeParmPackTypeBitfields SubstTemplateTypeParmPackTypeBits;
TemplateSpecializationTypeBitfields TemplateSpecializationTypeBits;
DependentTemplateSpecializationTypeBitfields
@@ -2021,6 +2035,7 @@
bool isComplexIntegerType() const; // GCC _Complex integer type.
bool isVectorType() const; // GCC vector type.
bool isExtVectorType() const; // Extended vector type.
+ bool isConstantMatrixType() const; // Matrix type.
bool isDependentAddressSpaceType() const; // value-dependent address space qualifier
bool isObjCObjectPointerType() const; // pointer to ObjC object
bool isObjCRetainableType() const; // ObjC object or block pointer
@@ -3390,6 +3405,130 @@
}
};
+/// Represents a matrix type, as defined in the Matrix Types clang extensions.
+/// __attribute__((matrix_type(rows, columns))), where "rows" specifies
+/// number of rows and "columns" specifies the number of columns.
+class MatrixType : public Type, public llvm::FoldingSetNode {
+protected:
+ friend class ASTContext;
+
+ /// The element type of the matrix.
+ QualType ElementType;
+
+ MatrixType(QualType ElementTy, QualType CanonElementTy);
+
+ MatrixType(TypeClass TypeClass, QualType ElementTy, QualType CanonElementTy,
+ const Expr *RowExpr = nullptr, const Expr *ColumnExpr = nullptr);
+
+public:
+ /// Returns type of the elements being stored in the matrix
+ QualType getElementType() const { return ElementType; }
+
+ /// Valid elements types are the following:
+ /// * an integer type (as in C2x 6.2.5p19), but excluding enumerated types
+ /// and _Bool
+ /// * the standard floating types float or double
+ /// * a half-precision floating point type, if one is supported on the target
+ static bool isValidElementType(QualType T) {
+ return T->isRealType() && !T->isBooleanType() && !T->isEnumeralType();
+ }
+
+ bool isSugared() const { return false; }
+ QualType desugar() const { return QualType(this, 0); }
+
+ static bool classof(const Type *T) {
+ return T->getTypeClass() == ConstantMatrix ||
+ T->getTypeClass() == DependentSizedMatrix;
+ }
+};
+
+/// Represents a concrete matrix type with constant number of rows and columns
+class ConstantMatrixType final : public MatrixType {
+protected:
+ friend class ASTContext;
+
+ /// The element type of the matrix.
+ QualType ElementType;
+
+ ConstantMatrixType(QualType MatrixElementType, unsigned NRows,
+ unsigned NColumns, QualType CanonElementType);
+
+ ConstantMatrixType(TypeClass typeClass, QualType MatrixType, unsigned NRows,
+ unsigned NColumns, QualType CanonElementType);
+
+public:
+ /// Returns the number of rows in the matrix.
+ unsigned getNumRows() const { return ConstantMatrixTypeBits.NumRows; }
+
+ /// Returns the number of columns in the matrix.
+ unsigned getNumColumns() const { return ConstantMatrixTypeBits.NumColumns; }
+
+ /// Returns the number of elements required to embed the matrix into a vector.
+ unsigned getNumElementsFlattened() const {
+ return ConstantMatrixTypeBits.NumRows * ConstantMatrixTypeBits.NumColumns;
+ }
+
+ /// Returns true if \p NumElements is a valid matrix dimension.
+ static bool isDimensionValid(uint64_t NumElements) {
+ return NumElements > 0 &&
+ NumElements <= ConstantMatrixTypeBitfields::MaxElementsPerDimension;
+ }
+
+ void Profile(llvm::FoldingSetNodeID &ID) {
+ Profile(ID, getElementType(), getNumRows(), getNumColumns(),
+ getTypeClass());
+ }
+
+ static void Profile(llvm::FoldingSetNodeID &ID, QualType ElementType,
+ unsigned NumRows, unsigned NumColumns,
+ TypeClass TypeClass) {
+ ID.AddPointer(ElementType.getAsOpaquePtr());
+ ID.AddInteger(NumRows);
+ ID.AddInteger(NumColumns);
+ ID.AddInteger(TypeClass);
+ }
+
+ static bool classof(const Type *T) {
+ return T->getTypeClass() == ConstantMatrix;
+ }
+};
+
+/// Represents a matrix type where the type and the number of rows and columns
+/// is dependent on a template.
+class DependentSizedMatrixType final : public MatrixType {
+ friend class ASTContext;
+
+ const ASTContext &Context;
+ Expr *RowExpr;
+ Expr *ColumnExpr;
+
+ SourceLocation loc;
+
+ DependentSizedMatrixType(const ASTContext &Context, QualType ElementType,
+ QualType CanonicalType, Expr *RowExpr,
+ Expr *ColumnExpr, SourceLocation loc);
+
+public:
+ QualType getElementType() const { return ElementType; }
+ Expr *getRowExpr() const { return RowExpr; }
+ Expr *getColumnExpr() const { return ColumnExpr; }
+ SourceLocation getAttributeLoc() const { return loc; }
+
+ bool isSugared() const { return false; }
+ QualType desugar() const { return QualType(this, 0); }
+
+ static bool classof(const Type *T) {
+ return T->getTypeClass() == DependentSizedMatrix;
+ }
+
+ void Profile(llvm::FoldingSetNodeID &ID) {
+ Profile(ID, Context, getElementType(), getRowExpr(), getColumnExpr());
+ }
+
+ static void Profile(llvm::FoldingSetNodeID &ID, const ASTContext &Context,
+ QualType ElementType, Expr *RowExpr, Expr *ColumnExpr);
+};
+
/// FunctionType - C99 6.7.5.3 - Function Declarators. This is the common base
/// class of FunctionNoProtoType and FunctionProtoType.
class FunctionType : public Type {
@@ -6605,6 +6744,10 @@
return isa<ExtVectorType>(CanonicalType);
}
+inline bool Type::isConstantMatrixType() const {
+ return isa<ConstantMatrixType>(CanonicalType);
+}
+
inline bool Type::isDependentAddressSpaceType() const {
return isa<DependentAddressSpaceType>(CanonicalType);
}
Index: clang/include/clang/AST/RecursiveASTVisitor.h
===================================================================
--- clang/include/clang/AST/RecursiveASTVisitor.h
+++ clang/include/clang/AST/RecursiveASTVisitor.h
@@ -1006,6 +1006,17 @@
DEF_TRAVERSE_TYPE(ExtVectorType, { TRY_TO(TraverseType(T->getElementType())); })
+DEF_TRAVERSE_TYPE(ConstantMatrixType,
+ { TRY_TO(TraverseType(T->getElementType())); })
+
+DEF_TRAVERSE_TYPE(DependentSizedMatrixType, {
+ if (T->getRowExpr())
+ TRY_TO(TraverseStmt(T->getRowExpr()));
+ if (T->getColumnExpr())
+ TRY_TO(TraverseStmt(T->getColumnExpr()));
+ TRY_TO(TraverseType(T->getElementType()));
+})
+
DEF_TRAVERSE_TYPE(FunctionNoProtoType,
{ TRY_TO(TraverseType(T->getReturnType())); })
@@ -1258,6 +1269,18 @@
TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
})
+DEF_TRAVERSE_TYPELOC(ConstantMatrixType, {
+ TRY_TO(TraverseStmt(TL.getAttrRowOperand()));
+ TRY_TO(TraverseStmt(TL.getAttrColumnOperand()));
+ TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
+})
+
+DEF_TRAVERSE_TYPELOC(DependentSizedMatrixType, {
+ TRY_TO(TraverseStmt(TL.getAttrRowOperand()));
+ TRY_TO(TraverseStmt(TL.getAttrColumnOperand()));
+ TRY_TO(TraverseType(TL.getTypePtr()->getElementType()));
+})
+
DEF_TRAVERSE_TYPELOC(FunctionNoProtoType,
{ TRY_TO(TraverseTypeLoc(TL.getReturnLoc())); })
Index: clang/include/clang/AST/ASTContext.h
===================================================================
--- clang/include/clang/AST/ASTContext.h
+++ clang/include/clang/AST/ASTContext.h
@@ -194,6 +194,8 @@
DependentAddressSpaceTypes;
mutable llvm::FoldingSet<VectorType> VectorTypes;
mutable llvm::FoldingSet<DependentVectorType> DependentVectorTypes;
+ mutable llvm::FoldingSet<ConstantMatrixType> MatrixTypes;
+ mutable llvm::FoldingSet<DependentSizedMatrixType> DependentSizedMatrixTypes;
mutable llvm::FoldingSet<FunctionNoProtoType> FunctionNoProtoTypes;
mutable llvm::ContextualFoldingSet<FunctionProtoType, ASTContext&>
FunctionProtoTypes;
@@ -1326,6 +1328,20 @@
Expr *SizeExpr,
SourceLocation AttrLoc) const;
+ /// Return the unique reference to the matrix type of the specified element
+ /// type and size
+ ///
+ /// \pre \p ElementType must be a valid matrix element type (see
+ /// MatrixType::isValidElementType).
+ QualType getConstantMatrixType(QualType ElementType, unsigned NumRows,
+ unsigned NumColumns) const;
+
+ /// Return the unique reference to the matrix type of the specified element
+ /// type and size
+ QualType getDependentSizedMatrixType(QualType ElementType, Expr *RowExpr,
+ Expr *ColumnExpr,
+ SourceLocation AttrLoc) const;
+
QualType getDependentAddressSpaceType(QualType PointeeType,
Expr *AddrSpaceExpr,
SourceLocation AttrLoc) const;
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits