================
@@ -883,99 +868,54 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *E) {
         return false;
     }
 
+    // The scalar to be splatted is stored in a local to be repeatedly loaded
+    // once for every scalar element of the destination.
     PrimType SrcElemT = classifyPrim(SubExpr->getType());
     unsigned SrcOffset =
-        allocateLocalPrimitive(SubExpr, DestElemT, /*IsConst=*/true);
+        allocateLocalPrimitive(SubExpr, SrcElemT, /*IsConst=*/true);
 
     if (!this->visit(SubExpr))
       return false;
-    if (SrcElemT != DestElemT) {
-      if (!this->emitPrimCast(SrcElemT, DestElemT, DestElemType, E))
-        return false;
-    }
-    if (!this->emitSetLocal(DestElemT, SrcOffset, E))
+    if (!this->emitSetLocal(SrcElemT, SrcOffset, E))
       return false;
 
-    for (unsigned I = 0; I != NumElems; ++I) {
-      if (!this->emitGetLocal(DestElemT, SrcOffset, E))
-        return false;
-      if (!this->emitInitElem(DestElemT, I, E))
-        return false;
-    }
-    return true;
+    // Recursively splat the scalar into every element of the destination.
+    return emitHLSLAggregateSplat(SrcElemT, SrcOffset, E->getType(), E);
   }
 
   case CK_HLSLElementwiseCast: {
-    // Elementwise cast: flatten source elements of one aggregate type and 
store
-    // to a destination scalar or aggregate type of the same or fewer number of
-    // elements, while inserting casts as necessary.
-    // TODO: Elementwise cast to structs, nested arrays, and arrays of 
composite
-    // types
+    // Elementwise cast: flatten the elements of one aggregate source type and
+    // store to a destination scalar or aggregate type of the same or fewer
+    // number of elements. Casts are inserted element-wise to convert each
+    // source scalar element to its corresponding destination scalar element.
     QualType SrcType = SubExpr->getType();
     QualType DestType = E->getType();
 
-    // Allowed SrcTypes
-    const auto *SrcVT = SrcType->getAs<VectorType>();
-    const auto *SrcMT = SrcType->getAs<ConstantMatrixType>();
-    const auto *SrcAT = SrcType->getAsArrayTypeUnsafe();
-    const auto *SrcCAT = SrcAT ? dyn_cast<ConstantArrayType>(SrcAT) : nullptr;
-
-    // Allowed DestTypes
-    const auto *DestVT = DestType->getAs<VectorType>();
-    const auto *DestMT = DestType->getAs<ConstantMatrixType>();
-    const auto *DestAT = DestType->getAsArrayTypeUnsafe();
-    const auto *DestCAT =
-        DestAT ? dyn_cast<ConstantArrayType>(DestAT) : nullptr;
-    const OptPrimType DestPT = classify(DestType);
-
-    if (!SrcVT && !SrcMT && !SrcCAT)
-      return false;
-    if (!DestVT && !DestMT && !DestCAT && !DestPT)
-      return false;
-
-    unsigned SrcNumElems;
-    PrimType SrcElemT;
-    if (SrcVT) {
-      SrcNumElems = SrcVT->getNumElements();
-      SrcElemT = classifyPrim(SrcVT->getElementType());
-    } else if (SrcMT) {
-      SrcNumElems = SrcMT->getNumElementsFlattened();
-      SrcElemT = classifyPrim(SrcMT->getElementType());
-    } else if (SrcCAT) {
-      SrcNumElems = SrcCAT->getZExtSize();
-      SrcElemT = classifyPrim(SrcCAT->getElementType());
-    }
-
-    if (DestPT) {
-      // Scalar destination: extract element 0 and cast.
+    const OptPrimType DestT = classify(DestType);
----------------
tbaederr wrote:

```suggestion
    OptPrimType DestT = classify(DestType);
```

https://github.com/llvm/llvm-project/pull/189126
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to