Author: Deric C.
Date: 2026-01-14T08:44:55-08:00
New Revision: 58a9dc01be8a5529c2d9676b1a32d9ed09a4cd70

URL: 
https://github.com/llvm/llvm-project/commit/58a9dc01be8a5529c2d9676b1a32d9ed09a4cd70
DIFF: 
https://github.com/llvm/llvm-project/commit/58a9dc01be8a5529c2d9676b1a32d9ed09a4cd70.diff

LOG: [HLSL][Matrix] Add type conversions to support bool matrix single 
subscript operators (#175633)

Fixes #172711

Fixes the type mismatch issues preventing single matrix subscript
getters and setters from working with boolean matrices.

The changes from this PR also happens to make matrix splats work for
boolean matrices, but adding tests for that and (re)introducing
boolean-matrix-specific sema will be relegated to its own PR.

Added: 
    

Modified: 
    clang/lib/CodeGen/CGExpr.cpp
    clang/lib/CodeGen/CGExprScalar.cpp
    clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptGetter.hlsl
    clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptSetter.hlsl

Removed: 
    


################################################################################
diff  --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 999726340aaed..91407b233e890 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -2744,6 +2744,20 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, 
LValue Dst,
 
       llvm::Value *Row = Dst.getMatrixRowIdx();
       llvm::Value *RowVal = Src.getScalarVal(); // <NumCols x T>
+
+      if (RowVal->getType()->isIntOrIntVectorTy(1)) {
+        // NOTE: If matrix single subscripting becomes a feature in languages
+        // other than HLSL, the following assert should be removed and the
+        // assert condition should be made part of the enclosing if-statement
+        // condition as is the case for similar logic for Dst.isMatrixElt()
+        assert(getLangOpts().HLSL);
+        auto *RowValVecTy = cast<llvm::FixedVectorType>(RowVal->getType());
+        llvm::Type *StorageElmTy =
+            llvm::FixedVectorType::get(MatrixVec->getType()->getScalarType(),
+                                       RowValVecTy->getNumElements());
+        RowVal = Builder.CreateZExt(RowVal, StorageElmTy);
+      }
+
       llvm::MatrixBuilder MB(Builder);
 
       llvm::Constant *ColConstsIndices = nullptr;

diff  --git a/clang/lib/CodeGen/CGExprScalar.cpp 
b/clang/lib/CodeGen/CGExprScalar.cpp
index e48d316d337b0..35e2c65a8e112 100644
--- a/clang/lib/CodeGen/CGExprScalar.cpp
+++ b/clang/lib/CodeGen/CGExprScalar.cpp
@@ -2130,7 +2130,7 @@ Value *ScalarExprEmitter::VisitMatrixSingleSubscriptExpr(
     MB.CreateIndexAssumption(RowIdx, NumRows);
 
   Value *FlatMatrix = Visit(E->getBase());
-  llvm::Type *ElemTy = CGF.ConvertType(MatrixTy->getElementType());
+  llvm::Type *ElemTy = CGF.ConvertTypeForMem(MatrixTy->getElementType());
   auto *ResultTy = llvm::FixedVectorType::get(ElemTy, NumColumns);
   Value *RowVec = llvm::PoisonValue::get(ResultTy);
 
@@ -2146,7 +2146,7 @@ Value *ScalarExprEmitter::VisitMatrixSingleSubscriptExpr(
     RowVec = Builder.CreateInsertElement(RowVec, Elt, Lane, "matrix_row_ins");
   }
 
-  return RowVec;
+  return CGF.EmitFromMemory(RowVec, E->getType());
 }
 
 Value *ScalarExprEmitter::VisitMatrixSubscriptExpr(MatrixSubscriptExpr *E) {

diff  --git 
a/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptGetter.hlsl 
b/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptGetter.hlsl
index 341a5bbaf0147..df724d217fe6b 100644
--- a/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptGetter.hlsl
+++ b/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptGetter.hlsl
@@ -203,3 +203,68 @@ float4 AddFloatMatrixConstant(float4x4 M) {
 int4 AddIntMatrixConstant(int4x4 M) {
    return M[0] + M[1] + M[2] + M[3];
 }
+
+// CHECK-LABEL: define hidden noundef <3 x i1> 
@_Z23getBoolVecMatrixDynamicu11matrix_typeILm2ELm3EbEi(
+// CHECK-SAME: <6 x i1> noundef [[M:%.*]], i32 noundef [[INDEX:%.*]]) 
#[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[M_ADDR:%.*]] = alloca [6 x i32], align 4
+// CHECK-NEXT:    [[INDEX_ADDR:%.*]] = alloca i32, align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = zext <6 x i1> [[M]] to <6 x i32>
+// CHECK-NEXT:    store <6 x i32> [[TMP0]], ptr [[M_ADDR]], align 4
+// CHECK-NEXT:    store i32 [[INDEX]], ptr [[INDEX_ADDR]], align 4
+// CHECK-NEXT:    [[TMP1:%.*]] = load i32, ptr [[INDEX_ADDR]], align 4
+// CHECK-NEXT:    [[TMP2:%.*]] = load <6 x i32>, ptr [[M_ADDR]], align 4
+// CHECK-NEXT:    [[TMP3:%.*]] = add i32 0, [[TMP1]]
+// CHECK-NEXT:    [[MATRIX_ELEM:%.*]] = extractelement <6 x i32> [[TMP2]], i32 
[[TMP3]]
+// CHECK-NEXT:    [[MATRIX_ROW_INS:%.*]] = insertelement <3 x i32> poison, i32 
[[MATRIX_ELEM]], i32 0
+// CHECK-NEXT:    [[TMP4:%.*]] = add i32 2, [[TMP1]]
+// CHECK-NEXT:    [[MATRIX_ELEM1:%.*]] = extractelement <6 x i32> [[TMP2]], 
i32 [[TMP4]]
+// CHECK-NEXT:    [[MATRIX_ROW_INS2:%.*]] = insertelement <3 x i32> 
[[MATRIX_ROW_INS]], i32 [[MATRIX_ELEM1]], i32 1
+// CHECK-NEXT:    [[TMP5:%.*]] = add i32 4, [[TMP1]]
+// CHECK-NEXT:    [[MATRIX_ELEM3:%.*]] = extractelement <6 x i32> [[TMP2]], 
i32 [[TMP5]]
+// CHECK-NEXT:    [[MATRIX_ROW_INS4:%.*]] = insertelement <3 x i32> 
[[MATRIX_ROW_INS2]], i32 [[MATRIX_ELEM3]], i32 2
+// CHECK-NEXT:    [[LOADEDV:%.*]] = trunc <3 x i32> [[MATRIX_ROW_INS4]] to <3 
x i1>
+// CHECK-NEXT:    ret <3 x i1> [[LOADEDV]]
+//
+bool3 getBoolVecMatrixDynamic(bool2x3 M, int index) {
+    return M[index];
+}
+
+// CHECK-LABEL: define hidden noundef <4 x i1> 
@_Z24getBoolVecMatrixConstantu11matrix_typeILm4ELm4EbE(
+// CHECK-SAME: <16 x i1> noundef [[M:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[M_ADDR:%.*]] = alloca [16 x i32], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = zext <16 x i1> [[M]] to <16 x i32>
+// CHECK-NEXT:    store <16 x i32> [[TMP0]], ptr [[M_ADDR]], align 4
+// CHECK-NEXT:    [[TMP1:%.*]] = load <16 x i32>, ptr [[M_ADDR]], align 4
+// CHECK-NEXT:    [[MATRIX_ELEM:%.*]] = extractelement <16 x i32> [[TMP1]], 
i32 0
+// CHECK-NEXT:    [[MATRIX_ROW_INS:%.*]] = insertelement <4 x i32> poison, i32 
[[MATRIX_ELEM]], i32 0
+// CHECK-NEXT:    [[MATRIX_ELEM1:%.*]] = extractelement <16 x i32> [[TMP1]], 
i32 4
+// CHECK-NEXT:    [[MATRIX_ROW_INS2:%.*]] = insertelement <4 x i32> 
[[MATRIX_ROW_INS]], i32 [[MATRIX_ELEM1]], i32 1
+// CHECK-NEXT:    [[MATRIX_ELEM3:%.*]] = extractelement <16 x i32> [[TMP1]], 
i32 8
+// CHECK-NEXT:    [[MATRIX_ROW_INS4:%.*]] = insertelement <4 x i32> 
[[MATRIX_ROW_INS2]], i32 [[MATRIX_ELEM3]], i32 2
+// CHECK-NEXT:    [[MATRIX_ELEM5:%.*]] = extractelement <16 x i32> [[TMP1]], 
i32 12
+// CHECK-NEXT:    [[MATRIX_ROW_INS6:%.*]] = insertelement <4 x i32> 
[[MATRIX_ROW_INS4]], i32 [[MATRIX_ELEM5]], i32 3
+// CHECK-NEXT:    [[LOADEDV:%.*]] = trunc <4 x i32> [[MATRIX_ROW_INS6]] to <4 
x i1>
+// CHECK-NEXT:    ret <4 x i1> [[LOADEDV]]
+//
+bool4 getBoolVecMatrixConstant(bool4x4 M) {
+    return M[0];
+}
+
+// CHECK-LABEL: define hidden noundef i1 
@_Z27getBoolScalarMatrixConstantu11matrix_typeILm3ELm1EbE(
+// CHECK-SAME: <3 x i1> noundef [[M:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[M_ADDR:%.*]] = alloca [3 x i32], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = zext <3 x i1> [[M]] to <3 x i32>
+// CHECK-NEXT:    store <3 x i32> [[TMP0]], ptr [[M_ADDR]], align 4
+// CHECK-NEXT:    [[TMP1:%.*]] = load <3 x i32>, ptr [[M_ADDR]], align 4
+// CHECK-NEXT:    [[MATRIX_ELEM:%.*]] = extractelement <3 x i32> [[TMP1]], i32 
1
+// CHECK-NEXT:    [[MATRIX_ROW_INS:%.*]] = insertelement <1 x i32> poison, i32 
[[MATRIX_ELEM]], i32 0
+// CHECK-NEXT:    [[LOADEDV:%.*]] = trunc <1 x i32> [[MATRIX_ROW_INS]] to <1 x 
i1>
+// CHECK-NEXT:    [[CAST_VTRUNC:%.*]] = extractelement <1 x i1> [[LOADEDV]], 
i32 0
+// CHECK-NEXT:    ret i1 [[CAST_VTRUNC]]
+//
+bool getBoolScalarMatrixConstant(bool3x1 M) {
+    return M[1];
+}

diff  --git 
a/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptSetter.hlsl 
b/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptSetter.hlsl
index 49746531ddccc..d314f3a87d619 100644
--- a/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptSetter.hlsl
+++ b/clang/test/CodeGenHLSL/BasicFeatures/MatrixSingleSubscriptSetter.hlsl
@@ -58,6 +58,69 @@ void setMatrixScalar(out float2x1 M, int index, float S) {
     M[index] = S;
 }
 
+// CHECK-LABEL: define hidden void 
@_Z13setBoolMatrixRu11matrix_typeILm4ELm4EbEiDv4_b(
+// CHECK-SAME: ptr noalias noundef nonnull align 4 dereferenceable(64) 
[[M:%.*]], i32 noundef [[INDEX:%.*]], <4 x i1> noundef [[V:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[M_ADDR:%.*]] = alloca ptr, align 4
+// CHECK-NEXT:    [[INDEX_ADDR:%.*]] = alloca i32, align 4
+// CHECK-NEXT:    [[V_ADDR:%.*]] = alloca <4 x i32>, align 16
+// CHECK-NEXT:    store ptr [[M]], ptr [[M_ADDR]], align 4
+// CHECK-NEXT:    store i32 [[INDEX]], ptr [[INDEX_ADDR]], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = zext <4 x i1> [[V]] to <4 x i32>
+// CHECK-NEXT:    store <4 x i32> [[TMP0]], ptr [[V_ADDR]], align 16
+// CHECK-NEXT:    [[TMP1:%.*]] = load <4 x i32>, ptr [[V_ADDR]], align 16
+// CHECK-NEXT:    [[LOADEDV:%.*]] = trunc <4 x i32> [[TMP1]] to <4 x i1>
+// CHECK-NEXT:    [[TMP2:%.*]] = load ptr, ptr [[M_ADDR]], align 4, !nonnull 
[[META3]], !align [[META4]]
+// CHECK-NEXT:    [[TMP3:%.*]] = load i32, ptr [[INDEX_ADDR]], align 4
+// CHECK-NEXT:    [[MATRIX_LOAD:%.*]] = load <16 x i32>, ptr [[TMP2]], align 4
+// CHECK-NEXT:    [[TMP4:%.*]] = zext <4 x i1> [[LOADEDV]] to <4 x i32>
+// CHECK-NEXT:    [[TMP5:%.*]] = add i32 0, [[TMP3]]
+// CHECK-NEXT:    [[TMP6:%.*]] = extractelement <4 x i32> [[TMP4]], i32 0
+// CHECK-NEXT:    [[TMP7:%.*]] = insertelement <16 x i32> [[MATRIX_LOAD]], i32 
[[TMP6]], i32 [[TMP5]]
+// CHECK-NEXT:    [[TMP8:%.*]] = add i32 4, [[TMP3]]
+// CHECK-NEXT:    [[TMP9:%.*]] = extractelement <4 x i32> [[TMP4]], i32 1
+// CHECK-NEXT:    [[TMP10:%.*]] = insertelement <16 x i32> [[TMP7]], i32 
[[TMP9]], i32 [[TMP8]]
+// CHECK-NEXT:    [[TMP11:%.*]] = add i32 8, [[TMP3]]
+// CHECK-NEXT:    [[TMP12:%.*]] = extractelement <4 x i32> [[TMP4]], i32 2
+// CHECK-NEXT:    [[TMP13:%.*]] = insertelement <16 x i32> [[TMP10]], i32 
[[TMP12]], i32 [[TMP11]]
+// CHECK-NEXT:    [[TMP14:%.*]] = add i32 12, [[TMP3]]
+// CHECK-NEXT:    [[TMP15:%.*]] = extractelement <4 x i32> [[TMP4]], i32 3
+// CHECK-NEXT:    [[TMP16:%.*]] = insertelement <16 x i32> [[TMP13]], i32 
[[TMP15]], i32 [[TMP14]]
+// CHECK-NEXT:    store <16 x i32> [[TMP16]], ptr [[TMP2]], align 4
+// CHECK-NEXT:    ret void
+//
+void setBoolMatrix(out bool4x4 M, int index, bool4 V) {
+    M[index] = V;
+}
+
+// CHECK-LABEL: define hidden void 
@_Z19setBoolMatrixScalarRu11matrix_typeILm2ELm1EbEib(
+// CHECK-SAME: ptr noalias noundef nonnull align 4 dereferenceable(8) 
[[M:%.*]], i32 noundef [[INDEX:%.*]], i1 noundef [[S:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[M_ADDR:%.*]] = alloca ptr, align 4
+// CHECK-NEXT:    [[INDEX_ADDR:%.*]] = alloca i32, align 4
+// CHECK-NEXT:    [[S_ADDR:%.*]] = alloca i32, align 4
+// CHECK-NEXT:    store ptr [[M]], ptr [[M_ADDR]], align 4
+// CHECK-NEXT:    store i32 [[INDEX]], ptr [[INDEX_ADDR]], align 4
+// CHECK-NEXT:    [[STOREDV:%.*]] = zext i1 [[S]] to i32
+// CHECK-NEXT:    store i32 [[STOREDV]], ptr [[S_ADDR]], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = load i32, ptr [[S_ADDR]], align 4
+// CHECK-NEXT:    [[LOADEDV:%.*]] = trunc i32 [[TMP0]] to i1
+// CHECK-NEXT:    [[SPLAT_SPLATINSERT:%.*]] = insertelement <1 x i1> poison, 
i1 [[LOADEDV]], i64 0
+// CHECK-NEXT:    [[SPLAT_SPLAT:%.*]] = shufflevector <1 x i1> 
[[SPLAT_SPLATINSERT]], <1 x i1> poison, <1 x i32> zeroinitializer
+// CHECK-NEXT:    [[TMP1:%.*]] = load ptr, ptr [[M_ADDR]], align 4, !nonnull 
[[META3]], !align [[META4]]
+// CHECK-NEXT:    [[TMP2:%.*]] = load i32, ptr [[INDEX_ADDR]], align 4
+// CHECK-NEXT:    [[MATRIX_LOAD:%.*]] = load <2 x i32>, ptr [[TMP1]], align 4
+// CHECK-NEXT:    [[TMP3:%.*]] = zext <1 x i1> [[SPLAT_SPLAT]] to <1 x i32>
+// CHECK-NEXT:    [[TMP4:%.*]] = add i32 0, [[TMP2]]
+// CHECK-NEXT:    [[TMP5:%.*]] = extractelement <1 x i32> [[TMP3]], i32 0
+// CHECK-NEXT:    [[TMP6:%.*]] = insertelement <2 x i32> [[MATRIX_LOAD]], i32 
[[TMP5]], i32 [[TMP4]]
+// CHECK-NEXT:    store <2 x i32> [[TMP6]], ptr [[TMP1]], align 4
+// CHECK-NEXT:    ret void
+//
+void setBoolMatrixScalar(out bool2x1 M, int index, bool S) {
+    M[index] = S;
+}
+
 // CHECK-LABEL: define hidden void 
@_Z19setMatrixConstIndexRu11matrix_typeILm4ELm4EiES_(
 // CHECK-SAME: ptr noalias noundef nonnull align 4 dereferenceable(64) 
[[M:%.*]], <16 x i32> noundef [[N:%.*]]) #[[ATTR0]] {
 // CHECK-NEXT:  [[ENTRY:.*:]]


        
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to