================
@@ -7992,6 +7926,420 @@ bool Compiler<Emitter>::emitBuiltinBitCast(const
CastExpr *E) {
return true;
}
+/// Replicate a scalar value into every scalar element of an aggregate.
+/// The scalar is stored in a local at \p SrcOffset and a pointer to the
+/// destination must be on top of the interpreter stack. Each element receives
+/// the scalar, cast to its own type.
+template <class Emitter>
+bool Compiler<Emitter>::emitHLSLAggregateSplat(PrimType SrcT,
+ unsigned SrcOffset,
+ QualType DestType,
+ const Expr *E) {
+ // Vectors and matrices are treated as flat sequences of elements.
+ unsigned NumElems = 0;
+ QualType ElemType;
+ if (const auto *VT = DestType->getAs<VectorType>()) {
+ NumElems = VT->getNumElements();
+ ElemType = VT->getElementType();
+ } else if (const auto *MT = DestType->getAs<ConstantMatrixType>()) {
+ NumElems = MT->getNumElementsFlattened();
+ ElemType = MT->getElementType();
+ }
+ if (NumElems > 0) {
+ PrimType ElemT = classifyPrim(ElemType);
+ for (unsigned I = 0; I != NumElems; ++I) {
+ if (!this->emitGetLocal(SrcT, SrcOffset, E))
+ return false;
+ if (!this->emitPrimCast(SrcT, ElemT, ElemType, E))
+ return false;
+ if (!this->emitInitElem(ElemT, I, E))
+ return false;
+ }
+ return true;
+ }
+
+ // Arrays: primitive elements are filled directly; composite elements
+ // require recursion into each sub-aggregate.
+ if (const auto *AT = DestType->getAsArrayTypeUnsafe()) {
+ const auto *CAT = cast<ConstantArrayType>(AT);
+ QualType ArrElemType = CAT->getElementType();
+ unsigned ArrSize = CAT->getZExtSize();
+
+ if (OptPrimType ElemT = classify(ArrElemType)) {
+ for (unsigned I = 0; I != ArrSize; ++I) {
+ if (!this->emitGetLocal(SrcT, SrcOffset, E))
+ return false;
+ if (!this->emitPrimCast(SrcT, *ElemT, ArrElemType, E))
+ return false;
+ if (!this->emitInitElem(*ElemT, I, E))
+ return false;
+ }
+ } else {
+ for (unsigned I = 0; I != ArrSize; ++I) {
+ if (!this->emitConstUint32(I, E))
+ return false;
+ if (!this->emitArrayElemPtrUint32(E))
+ return false;
+ if (!emitHLSLAggregateSplat(SrcT, SrcOffset, ArrElemType, E))
+ return false;
+ if (!this->emitFinishInitPop(E))
+ return false;
+ }
+ }
+ return true;
+ }
+
+ // Records: fill base classes first, then named fields in declaration
+ // order.
+ if (DestType->isRecordType()) {
+ const Record *R = getRecord(DestType);
+ if (!R)
+ return false;
+
+ if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(R->getDecl())) {
+ for (const CXXBaseSpecifier &BS : CXXRD->bases()) {
+ const Record::Base *B = R->getBase(BS.getType());
+ assert(B);
+ if (!this->emitGetPtrBase(B->Offset, E))
+ return false;
+ if (!emitHLSLAggregateSplat(SrcT, SrcOffset, BS.getType(), E))
+ return false;
+ if (!this->emitFinishInitPop(E))
+ return false;
+ }
+ }
+
+ for (const Record::Field &F : R->fields()) {
+ if (F.isUnnamedBitField())
+ continue;
+
+ QualType FieldType = F.Decl->getType();
+ if (OptPrimType FieldT = classify(FieldType)) {
+ if (!this->emitGetLocal(SrcT, SrcOffset, E))
+ return false;
+ if (!this->emitPrimCast(SrcT, *FieldT, FieldType, E))
+ return false;
+ if (F.isBitField()) {
+ if (!this->emitInitBitField(*FieldT, F.Offset, F.bitWidth(), E))
+ return false;
+ } else {
+ if (!this->emitInitField(*FieldT, F.Offset, E))
+ return false;
+ }
+ } else {
+ if (!this->emitGetPtrField(F.Offset, E))
+ return false;
+ if (!emitHLSLAggregateSplat(SrcT, SrcOffset, FieldType, E))
+ return false;
+ if (!this->emitPopPtr(E))
+ return false;
+ }
+ }
+ return true;
+ }
+
+ return false;
+}
+
+/// Return the total number of scalar elements in a type. This is used
+/// to cap how many source elements are extracted during an elementwise cast,
+/// so we never flatten more than the destination can hold.
+template <class Emitter>
+unsigned Compiler<Emitter>::countHLSLFlatElements(QualType Ty) {
+ // Vector and matrix types are treated as flat sequences of elements.
+ if (const auto *VT = Ty->getAs<VectorType>())
+ return VT->getNumElements();
+ if (const auto *MT = Ty->getAs<ConstantMatrixType>())
+ return MT->getNumElementsFlattened();
+ // Arrays: total count is array size * scalar elements per element.
+ if (const auto *AT = Ty->getAsArrayTypeUnsafe()) {
+ const auto *CAT = cast<ConstantArrayType>(AT);
+ return CAT->getZExtSize() * countHLSLFlatElements(CAT->getElementType());
+ }
+ // Records: sum scalar element counts of base classes and named fields.
+ if (Ty->isRecordType()) {
+ const Record *R = getRecord(Ty);
+ if (!R)
+ return 0;
+ unsigned Count = 0;
+ if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(R->getDecl())) {
+ for (const CXXBaseSpecifier &BS : CXXRD->bases())
+ Count += countHLSLFlatElements(BS.getType());
+ }
+ for (const Record::Field &F : R->fields()) {
+ if (F.isUnnamedBitField())
+ continue;
+ Count += countHLSLFlatElements(F.Decl->getType());
+ }
+ return Count;
+ }
+ // Scalar primitive types contribute one element.
+ if (classify(Ty))
----------------
tbaederr wrote:
```suggestion
if (canClassify(Ty))
```
https://github.com/llvm/llvm-project/pull/189126
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits