================
@@ -7992,6 +7926,420 @@ bool Compiler<Emitter>::emitBuiltinBitCast(const 
CastExpr *E) {
   return true;
 }
 
+/// Replicate a scalar value into every scalar element of an aggregate.
+/// The scalar is stored in a local at \p SrcOffset and a pointer to the
+/// destination must be on top of the interpreter stack. Each element receives
+/// the scalar, cast to its own type.
+template <class Emitter>
+bool Compiler<Emitter>::emitHLSLAggregateSplat(PrimType SrcT,
+                                               unsigned SrcOffset,
+                                               QualType DestType,
+                                               const Expr *E) {
+  // Vectors and matrices are treated as flat sequences of elements.
+  unsigned NumElems = 0;
+  QualType ElemType;
+  if (const auto *VT = DestType->getAs<VectorType>()) {
+    NumElems = VT->getNumElements();
+    ElemType = VT->getElementType();
+  } else if (const auto *MT = DestType->getAs<ConstantMatrixType>()) {
+    NumElems = MT->getNumElementsFlattened();
+    ElemType = MT->getElementType();
+  }
+  if (NumElems > 0) {
+    PrimType ElemT = classifyPrim(ElemType);
+    for (unsigned I = 0; I != NumElems; ++I) {
+      if (!this->emitGetLocal(SrcT, SrcOffset, E))
+        return false;
+      if (!this->emitPrimCast(SrcT, ElemT, ElemType, E))
+        return false;
+      if (!this->emitInitElem(ElemT, I, E))
+        return false;
+    }
+    return true;
+  }
+
+  // Arrays: primitive elements are filled directly; composite elements
+  // require recursion into each sub-aggregate.
+  if (const auto *AT = DestType->getAsArrayTypeUnsafe()) {
+    const auto *CAT = cast<ConstantArrayType>(AT);
+    QualType ArrElemType = CAT->getElementType();
+    unsigned ArrSize = CAT->getZExtSize();
+
+    if (OptPrimType ElemT = classify(ArrElemType)) {
+      for (unsigned I = 0; I != ArrSize; ++I) {
+        if (!this->emitGetLocal(SrcT, SrcOffset, E))
+          return false;
+        if (!this->emitPrimCast(SrcT, *ElemT, ArrElemType, E))
+          return false;
+        if (!this->emitInitElem(*ElemT, I, E))
+          return false;
+      }
+    } else {
+      for (unsigned I = 0; I != ArrSize; ++I) {
+        if (!this->emitConstUint32(I, E))
+          return false;
+        if (!this->emitArrayElemPtrUint32(E))
+          return false;
+        if (!emitHLSLAggregateSplat(SrcT, SrcOffset, ArrElemType, E))
+          return false;
+        if (!this->emitFinishInitPop(E))
+          return false;
+      }
+    }
+    return true;
+  }
+
+  // Records: fill base classes first, then named fields in declaration
+  // order.
+  if (DestType->isRecordType()) {
+    const Record *R = getRecord(DestType);
+    if (!R)
+      return false;
+
+    if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(R->getDecl())) {
+      for (const CXXBaseSpecifier &BS : CXXRD->bases()) {
+        const Record::Base *B = R->getBase(BS.getType());
+        assert(B);
+        if (!this->emitGetPtrBase(B->Offset, E))
+          return false;
+        if (!emitHLSLAggregateSplat(SrcT, SrcOffset, BS.getType(), E))
+          return false;
+        if (!this->emitFinishInitPop(E))
+          return false;
+      }
+    }
+
+    for (const Record::Field &F : R->fields()) {
+      if (F.isUnnamedBitField())
+        continue;
+
+      QualType FieldType = F.Decl->getType();
+      if (OptPrimType FieldT = classify(FieldType)) {
+        if (!this->emitGetLocal(SrcT, SrcOffset, E))
+          return false;
+        if (!this->emitPrimCast(SrcT, *FieldT, FieldType, E))
+          return false;
+        if (F.isBitField()) {
+          if (!this->emitInitBitField(*FieldT, F.Offset, F.bitWidth(), E))
+            return false;
+        } else {
+          if (!this->emitInitField(*FieldT, F.Offset, E))
+            return false;
+        }
+      } else {
+        if (!this->emitGetPtrField(F.Offset, E))
+          return false;
+        if (!emitHLSLAggregateSplat(SrcT, SrcOffset, FieldType, E))
+          return false;
+        if (!this->emitPopPtr(E))
+          return false;
+      }
+    }
+    return true;
+  }
+
+  return false;
+}
+
+/// Return the total number of scalar elements in a type. This is used
+/// to cap how many source elements are extracted during an elementwise cast,
+/// so we never flatten more than the destination can hold.
+template <class Emitter>
+unsigned Compiler<Emitter>::countHLSLFlatElements(QualType Ty) {
+  // Vector and matrix types are treated as flat sequences of elements.
+  if (const auto *VT = Ty->getAs<VectorType>())
+    return VT->getNumElements();
+  if (const auto *MT = Ty->getAs<ConstantMatrixType>())
+    return MT->getNumElementsFlattened();
+  // Arrays: total count is array size * scalar elements per element.
+  if (const auto *AT = Ty->getAsArrayTypeUnsafe()) {
+    const auto *CAT = cast<ConstantArrayType>(AT);
+    return CAT->getZExtSize() * countHLSLFlatElements(CAT->getElementType());
+  }
+  // Records: sum scalar element counts of base classes and named fields.
+  if (Ty->isRecordType()) {
+    const Record *R = getRecord(Ty);
+    if (!R)
+      return 0;
+    unsigned Count = 0;
+    if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(R->getDecl())) {
+      for (const CXXBaseSpecifier &BS : CXXRD->bases())
+        Count += countHLSLFlatElements(BS.getType());
+    }
+    for (const Record::Field &F : R->fields()) {
+      if (F.isUnnamedBitField())
+        continue;
+      Count += countHLSLFlatElements(F.Decl->getType());
+    }
+    return Count;
+  }
+  // Scalar primitive types contribute one element.
+  if (classify(Ty))
+    return 1;
+  return 0;
+}
+
+/// Walk a source aggregate and extract every scalar element into its own local
+/// variable. The results are appended to \p Elements in declaration order,
+/// stopping once \p MaxElements have been collected. A pointer to the
+/// source aggregate must be stored in the local at \p SrcOffset.
+template <class Emitter>
+bool Compiler<Emitter>::emitHLSLFlattenAggregate(
+    QualType SrcType, unsigned SrcOffset,
+    SmallVectorImpl<HLSLFlatElement> &Elements, unsigned MaxElements,
+    const Expr *E) {
+
+  // Save a scalar value from the stack into a new local and record it.
+  auto saveToLocal = [&](PrimType T) -> bool {
+    unsigned Offset = allocateLocalPrimitive(E, T, /*IsConst=*/true);
+    if (!this->emitSetLocal(T, Offset, E))
+      return false;
+    Elements.push_back({Offset, T});
+    return true;
+  };
+
+  // Save a pointer from the stack into a new local for later use.
+  auto savePtrToLocal = [&]() -> UnsignedOrNone {
+    unsigned Offset = allocateLocalPrimitive(E, PT_Ptr, /*IsConst=*/true);
+    if (!this->emitSetLocal(PT_Ptr, Offset, E))
+      return std::nullopt;
+    return Offset;
+  };
+
+  // Vectors and matrices are flat sequences of elements.
+  unsigned NumElems = 0;
+  QualType ElemType;
+  if (const auto *VT = SrcType->getAs<VectorType>()) {
+    NumElems = VT->getNumElements();
+    ElemType = VT->getElementType();
+  } else if (const auto *MT = SrcType->getAs<ConstantMatrixType>()) {
+    NumElems = MT->getNumElementsFlattened();
+    ElemType = MT->getElementType();
+  }
+  if (NumElems > 0) {
+    PrimType ElemT = classifyPrim(ElemType);
+    for (unsigned I = 0; I != NumElems && Elements.size() < MaxElements; ++I) {
+      if (!this->emitGetLocal(PT_Ptr, SrcOffset, E))
+        return false;
+      if (!this->emitArrayElemPop(ElemT, I, E))
+        return false;
+      if (!saveToLocal(ElemT))
+        return false;
+    }
+    return true;
+  }
+
+  // Arrays: primitive elements are extracted directly; composite elements
+  // require recursion into each sub-aggregate.
+  if (const auto *AT = SrcType->getAsArrayTypeUnsafe()) {
+    const auto *CAT = cast<ConstantArrayType>(AT);
+    QualType ArrElemType = CAT->getElementType();
+    unsigned ArrSize = CAT->getZExtSize();
+
+    if (OptPrimType ElemT = classify(ArrElemType)) {
+      for (unsigned I = 0; I != ArrSize && Elements.size() < MaxElements; ++I) 
{
+        if (!this->emitGetLocal(PT_Ptr, SrcOffset, E))
+          return false;
+        if (!this->emitArrayElemPop(*ElemT, I, E))
+          return false;
+        if (!saveToLocal(*ElemT))
+          return false;
+      }
+    } else {
+      for (unsigned I = 0; I != ArrSize && Elements.size() < MaxElements; ++I) 
{
+        if (!this->emitGetLocal(PT_Ptr, SrcOffset, E))
+          return false;
+        if (!this->emitConstUint32(I, E))
+          return false;
+        if (!this->emitArrayElemPtrPopUint32(E))
+          return false;
+        UnsignedOrNone ElemPtrOffset = savePtrToLocal();
+        if (!ElemPtrOffset)
+          return false;
+        if (!emitHLSLFlattenAggregate(ArrElemType, *ElemPtrOffset, Elements,
+                                      MaxElements, E))
+          return false;
+      }
+    }
+    return true;
+  }
+
+  // Records: base classes come first, then named fields in declaration
+  // order.
+  if (SrcType->isRecordType()) {
+    const Record *R = getRecord(SrcType);
+    if (!R)
+      return false;
+
+    if (const auto *CXXRD = dyn_cast<CXXRecordDecl>(R->getDecl())) {
+      for (const CXXBaseSpecifier &BS : CXXRD->bases()) {
+        if (Elements.size() >= MaxElements)
+          break;
+        const Record::Base *B = R->getBase(BS.getType());
+        assert(B);
+        if (!this->emitGetLocal(PT_Ptr, SrcOffset, E))
+          return false;
+        if (!this->emitGetPtrBasePop(B->Offset, /*NullOK=*/false, E))
+          return false;
+        UnsignedOrNone BasePtrOffset = savePtrToLocal();
+        if (!BasePtrOffset)
+          return false;
+        if (!emitHLSLFlattenAggregate(BS.getType(), *BasePtrOffset, Elements,
+                                      MaxElements, E))
+          return false;
+      }
+    }
+
+    for (const Record::Field &F : R->fields()) {
+      if (Elements.size() >= MaxElements)
+        break;
+      if (F.isUnnamedBitField())
----------------
tbaederr wrote:

Do unnamed bitfields exist in HLSL?

https://github.com/llvm/llvm-project/pull/189126
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to