================
@@ -883,99 +868,54 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *E) {
return false;
}
+ // The scalar to be splatted is stored in a local to be repeatedly loaded
+ // once for every scalar element of the destination.
PrimType SrcElemT = classifyPrim(SubExpr->getType());
unsigned SrcOffset =
- allocateLocalPrimitive(SubExpr, DestElemT, /*IsConst=*/true);
+ allocateLocalPrimitive(SubExpr, SrcElemT, /*IsConst=*/true);
if (!this->visit(SubExpr))
return false;
- if (SrcElemT != DestElemT) {
- if (!this->emitPrimCast(SrcElemT, DestElemT, DestElemType, E))
- return false;
- }
- if (!this->emitSetLocal(DestElemT, SrcOffset, E))
+ if (!this->emitSetLocal(SrcElemT, SrcOffset, E))
return false;
- for (unsigned I = 0; I != NumElems; ++I) {
- if (!this->emitGetLocal(DestElemT, SrcOffset, E))
- return false;
- if (!this->emitInitElem(DestElemT, I, E))
- return false;
- }
- return true;
+ // Recursively splat the scalar into every element of the destination.
+ return emitHLSLAggregateSplat(SrcElemT, SrcOffset, E->getType(), E);
}
case CK_HLSLElementwiseCast: {
- // Elementwise cast: flatten source elements of one aggregate type and
store
- // to a destination scalar or aggregate type of the same or fewer number of
- // elements, while inserting casts as necessary.
- // TODO: Elementwise cast to structs, nested arrays, and arrays of
composite
- // types
+ // Elementwise cast: flatten the elements of one aggregate source type and
+ // store to a destination scalar or aggregate type of the same or fewer
+ // number of elements. Casts are inserted element-wise to convert each
+ // source scalar element to its corresponding destination scalar element.
QualType SrcType = SubExpr->getType();
QualType DestType = E->getType();
- // Allowed SrcTypes
- const auto *SrcVT = SrcType->getAs<VectorType>();
- const auto *SrcMT = SrcType->getAs<ConstantMatrixType>();
- const auto *SrcAT = SrcType->getAsArrayTypeUnsafe();
- const auto *SrcCAT = SrcAT ? dyn_cast<ConstantArrayType>(SrcAT) : nullptr;
-
- // Allowed DestTypes
- const auto *DestVT = DestType->getAs<VectorType>();
- const auto *DestMT = DestType->getAs<ConstantMatrixType>();
- const auto *DestAT = DestType->getAsArrayTypeUnsafe();
- const auto *DestCAT =
- DestAT ? dyn_cast<ConstantArrayType>(DestAT) : nullptr;
- const OptPrimType DestPT = classify(DestType);
-
- if (!SrcVT && !SrcMT && !SrcCAT)
- return false;
- if (!DestVT && !DestMT && !DestCAT && !DestPT)
- return false;
-
- unsigned SrcNumElems;
- PrimType SrcElemT;
- if (SrcVT) {
- SrcNumElems = SrcVT->getNumElements();
- SrcElemT = classifyPrim(SrcVT->getElementType());
- } else if (SrcMT) {
- SrcNumElems = SrcMT->getNumElementsFlattened();
- SrcElemT = classifyPrim(SrcMT->getElementType());
- } else if (SrcCAT) {
- SrcNumElems = SrcCAT->getZExtSize();
- SrcElemT = classifyPrim(SrcCAT->getElementType());
- }
-
- if (DestPT) {
- // Scalar destination: extract element 0 and cast.
+ const OptPrimType DestT = classify(DestType);
----------------
tbaederr wrote:
```suggestion
OptPrimType DestT = classify(DestType);
```
https://github.com/llvm/llvm-project/pull/189126
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits