https://github.com/joaosaffran updated https://github.com/llvm/llvm-project/pull/109331
>From 50d21754119ac10c2ee2376ed8f79d12f73cd137 Mon Sep 17 00:00:00 2001 From: Joao Saffran <jdereze...@microsoft.com> Date: Thu, 19 Sep 2024 00:13:51 +0000 Subject: [PATCH 1/3] Codegen builtin --- clang/include/clang/Basic/Builtins.td | 6 ++ clang/lib/CodeGen/CGBuiltin.cpp | 38 ++++++++++++ clang/lib/CodeGen/CGCall.cpp | 5 ++ clang/lib/CodeGen/CGExpr.cpp | 15 ++++- clang/lib/CodeGen/CodeGenFunction.h | 10 +++- clang/lib/Headers/hlsl/hlsl_intrinsics.h | 20 +++++++ clang/lib/Sema/SemaHLSL.cpp | 58 ++++++++++++++++--- .../builtins/asuint-splitdouble.hlsl | 10 ++++ llvm/include/llvm/IR/IntrinsicsDirectX.td | 5 ++ llvm/lib/Target/DirectX/DXIL.td | 1 + .../Target/DirectX/DXILIntrinsicExpansion.cpp | 13 +++++ 11 files changed, 167 insertions(+), 14 deletions(-) create mode 100644 clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index 8c5d7ad763bf97..b38957f6e3f15d 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -4788,6 +4788,12 @@ def HLSLStep: LangBuiltin<"HLSL_LANG"> { let Prototype = "void(...)"; } +def HLSLAsUintSplitDouble: LangBuiltin<"HLSL_LANG"> { + let Spellings = ["__builtin_hlsl_asuint_splitdouble"]; + let Attributes = [NoThrow, Const]; + let Prototype = "void(...)"; +} + // Builtins for XRay. def XRayCustomEvent : Builtin { let Spellings = ["__xray_customevent"]; diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index 249aead33ad73d..f7695b8693f3dc 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -18843,6 +18843,44 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: { retType, CGM.getHLSLRuntime().getSignIntrinsic(), ArrayRef<Value *>{Op0}, nullptr, "hlsl.sign"); } + // This should only be called when targeting DXIL + case Builtin::BI__builtin_hlsl_asuint_splitdouble: { + + assert((E->getArg(0)->getType()->hasFloatingRepresentation() && + E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() && + E->getArg(2)->getType()->hasUnsignedIntegerRepresentation()) && + "asuint operands types mismatch"); + + Value *Op0 = EmitScalarExpr(E->getArg(0)); + const HLSLOutArgExpr *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1)); + const HLSLOutArgExpr *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2)); + + CallArgList Args; + LValue Op1TmpLValue = EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType()); + LValue Op2TmpLValue = EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType()); + + llvm::Type *retType = llvm::StructType::get(Int32Ty, Int32Ty); + if (Op0->getType()->isVectorTy()) { + auto *XVecTy = E->getArg(0)->getType()->getAs<VectorType>(); + + llvm::VectorType *i32VecTy = llvm::VectorType::get( + Int32Ty, ElementCount::getFixed(XVecTy->getNumElements())); + + retType = llvm::StructType::get(i32VecTy, i32VecTy); + } + + CallInst *CI = + Builder.CreateIntrinsic(retType, llvm::Intrinsic::dx_asuint_splitdouble, + {Op0}, nullptr, "hlsl.asuint"); + + Value *arg0 = Builder.CreateExtractValue(CI, 0); + Value *arg1 = Builder.CreateExtractValue(CI, 1); + + Builder.CreateStore(arg0, Op1TmpLValue.getAddress()); + auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress()); + EmitWritebacks(*this, Args); + return s; + } } return nullptr; } diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp index 4ae981e4013e9c..096bbafa4cc694 100644 --- a/clang/lib/CodeGen/CGCall.cpp +++ b/clang/lib/CodeGen/CGCall.cpp @@ -4681,6 +4681,11 @@ void CallArg::copyInto(CodeGenFunction &CGF, Address Addr) const { IsUsed = true; } +void CodeGenFunction::EmitWritebacks(CodeGenFunction &CGF, + const CallArgList &args) { + emitWritebacks(CGF, args); +} + void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E, QualType type) { DisableDebugLocationUpdates Dis(*this, E); diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index 9166db4c74128c..1c299c4a932ca0 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -19,6 +19,7 @@ #include "CGObjCRuntime.h" #include "CGOpenMPRuntime.h" #include "CGRecordLayout.h" +#include "CGValue.h" #include "CodeGenFunction.h" #include "CodeGenModule.h" #include "ConstantEmitter.h" @@ -28,6 +29,7 @@ #include "clang/AST/DeclObjC.h" #include "clang/AST/NSAPI.h" #include "clang/AST/StmtVisitor.h" +#include "clang/AST/Type.h" #include "clang/Basic/Builtins.h" #include "clang/Basic/CodeGenOptions.h" #include "clang/Basic/SourceManager.h" @@ -5458,9 +5460,8 @@ LValue CodeGenFunction::EmitOpaqueValueLValue(const OpaqueValueExpr *e) { return getOrCreateOpaqueLValueMapping(e); } -void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, - CallArgList &Args, QualType Ty) { - +std::pair<LValue, LValue> +CodeGenFunction::EmitHLSLOutArgLValues(const HLSLOutArgExpr *E, QualType Ty) { // Emitting the casted temporary through an opaque value. LValue BaseLV = EmitLValue(E->getArgLValue()); OpaqueValueMappingData::bind(*this, E->getOpaqueArgLValue(), BaseLV); @@ -5474,6 +5475,13 @@ void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, TempLV); OpaqueValueMappingData::bind(*this, E->getCastedTemporary(), TempLV); + return std::make_pair(BaseLV, TempLV); +} + +LValue CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, + CallArgList &Args, QualType Ty) { + + auto [BaseLV, TempLV] = EmitHLSLOutArgLValues(E, Ty); llvm::Value *Addr = TempLV.getAddress().getBasePointer(); llvm::Type *ElTy = ConvertTypeForMem(TempLV.getType()); @@ -5486,6 +5494,7 @@ void CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, Args.addWriteback(BaseLV, TmpAddr, nullptr, E->getWritebackCast(), LifetimeSize); Args.add(RValue::get(TmpAddr, *this), Ty); + return TempLV; } LValue diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index 3e2abbd9bc1094..ad7c2635500d93 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -4296,8 +4296,11 @@ class CodeGenFunction : public CodeGenTypeCache { LValue EmitCastLValue(const CastExpr *E); LValue EmitMaterializeTemporaryExpr(const MaterializeTemporaryExpr *E); LValue EmitOpaqueValueLValue(const OpaqueValueExpr *e); - void EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args, - QualType Ty); + + std::pair<LValue, LValue> EmitHLSLOutArgLValues(const HLSLOutArgExpr *E, + QualType Ty); + LValue EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args, + QualType Ty); Address EmitExtVectorElementLValue(LValue V); @@ -5147,6 +5150,9 @@ class CodeGenFunction : public CodeGenTypeCache { SourceLocation ArgLoc, AbstractCallee AC, unsigned ParmNum); + /// EmitWriteback - Emit callbacks for function. + void EmitWritebacks(CodeGenFunction &CGF, const CallArgList &args); + /// EmitCallArg - Emit a single call argument. void EmitCallArg(CallArgList &args, const Expr *E, QualType ArgType); diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h index b139f9eb7d999b..e8a1e97f344559 100644 --- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h @@ -422,6 +422,26 @@ template <typename T> constexpr uint asuint(T F) { return __detail::bit_cast<uint, T>(F); } +//===----------------------------------------------------------------------===// +// asuint splitdouble builtins +//===----------------------------------------------------------------------===// + +/// \fn void asuint(double D, out uint lowbits, out int highbits) +/// \brief Split and interprets the lowbits and highbits of double D into uints. +/// \param D The input double. +/// \param lowbits The output lowbits of D. +/// \param highbits The highbits lowbits D. +#if __is_target_arch(dxil) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble) +void asuint(double, out uint, out uint); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble) +void asuint(double2, out uint2, out uint2); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble) +void asuint(double3, out uint3, out uint3); +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble) +void asuint(double4, out uint4, out uint4); +#endif + //===----------------------------------------------------------------------===// // atan builtins //===----------------------------------------------------------------------===// diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index ebe76185cbb2d5..20878442c92338 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -1467,18 +1467,27 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) { return true; } -static bool CheckArgsTypesAreCorrect( +bool CheckArgTypeIsCorrect( + Sema *S, Expr *Arg, QualType ExpectedType, + llvm::function_ref<bool(clang::QualType PassedType)> Check) { + QualType PassedType = Arg->getType(); + if (Check(PassedType)) { + if (auto *VecTyA = PassedType->getAs<VectorType>()) + ExpectedType = S->Context.getVectorType( + ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind()); + S->Diag(Arg->getBeginLoc(), diag::err_typecheck_convert_incompatible) + << PassedType << ExpectedType << 1 << 0 << 0; + return true; + } + return false; +} + +bool CheckArgsTypesAreCorrect( Sema *S, CallExpr *TheCall, QualType ExpectedType, llvm::function_ref<bool(clang::QualType PassedType)> Check) { for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) { - QualType PassedType = TheCall->getArg(i)->getType(); - if (Check(PassedType)) { - if (auto *VecTyA = PassedType->getAs<VectorType>()) - ExpectedType = S->Context.getVectorType( - ExpectedType, VecTyA->getNumElements(), VecTyA->getVectorKind()); - S->Diag(TheCall->getArg(0)->getBeginLoc(), - diag::err_typecheck_convert_incompatible) - << PassedType << ExpectedType << 1 << 0 << 0; + Expr *Arg = TheCall->getArg(i); + if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) { return true; } } @@ -1762,6 +1771,37 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { return true; break; } + case Builtin::BI__builtin_hlsl_asuint_splitdouble: { + if (SemaRef.checkArgCount(TheCall, 3)) + return true; + + // Expr *Op0 = TheCall->getArg(0); + + // auto CheckIsNotDouble = [](clang::QualType PassedType) -> bool { + // return !PassedType->isDoubleType(); + // }; + + // if (CheckArgTypeIsCorrect(&SemaRef, Op0, SemaRef.Context.DoubleTy, + // CheckIsNotDouble)) { + // return true; + // } + + // Expr *Op1 = TheCall->getArg(1); + // Expr *Op2 = TheCall->getArg(2); + + // auto CheckIsNotUint = [](clang::QualType PassedType) -> bool { + // return !PassedType->isUnsignedIntegerType(); + // }; + + // if (CheckArgTypeIsCorrect(&SemaRef, Op1, SemaRef.Context.UnsignedIntTy, + // CheckIsNotUint) || + // CheckArgTypeIsCorrect(&SemaRef, Op2, SemaRef.Context.UnsignedIntTy, + // CheckIsNotUint)) { + // return true; + // } + + break; + } case Builtin::BI__builtin_elementwise_acos: case Builtin::BI__builtin_elementwise_asin: case Builtin::BI__builtin_elementwise_atan: diff --git a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl new file mode 100644 index 00000000000000..e359354dc3a6df --- /dev/null +++ b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl @@ -0,0 +1,10 @@ +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O0 -o - | FileCheck %s + +// CHECK: define {{.*}}test_scalar{{.*}}(double {{.*}} [[VAL1:%.*]], i32 {{.*}} [[VAL2:%.*]], i32 {{.*}} [[VAL3:%.*]]){{.*}} +// CHECK: [[VALD:%.*]] = load double, ptr [[VAL1]].addr{{.*}} +// CHECK: call { i32, i32 } @llvm.dx.asuint.splitdouble.{{.*}}(double [[VALD]]) +float fn(double D) { + uint A, B; + asuint(D, A, B); + return A + B; +} diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index 3ce7b8b987ef86..d8092397881550 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -88,4 +88,9 @@ def int_dx_rsqrt : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>] def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrConvergent]>; def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>; def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>; + +def int_dx_asuint_splitdouble : DefaultAttrsIntrinsic< + [llvm_anyint_ty, LLVMMatchType<0>], + [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], + [IntrNoMem, IntrWillReturn]>; } diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 9aa0af3e3a6b17..06c52da5fc07c8 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -778,6 +778,7 @@ def FlattenedThreadIdInGroup : DXILOp<96, flattenedThreadIdInGroup> { let stages = [Stages<DXIL1_0, [compute, mesh, amplification, node]>]; let attributes = [Attributes<DXIL1_0, [ReadNone]>]; } +// def AnnotateHandle : DXILOp<217, annotateHandle> { let Doc = "annotate handle with resource properties"; diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp index 926cbe97f24fda..09e87d5035093b 100644 --- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp +++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp @@ -12,6 +12,7 @@ #include "DXILIntrinsicExpansion.h" #include "DirectX.h" +#include "llvm-c/Core.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/DXILResource.h" @@ -395,6 +396,15 @@ static Value *expandStepIntrinsic(CallInst *Orig) { return Builder.CreateSelect(Cond, Zero, One); } +// static Value *expandSplitdoubleIntrinsic(CallInst *Orig) { +// Value *X = Orig->getOperand(0); +// Type *Ty = X->getType(); +// IRBuilder<> Builder(Orig); + +// Builder.CreateIntrinsic() + +// } + static Intrinsic::ID getMaxForClamp(Type *ElemTy, Intrinsic::ID ClampIntrinsic) { if (ClampIntrinsic == Intrinsic::dx_uclamp) @@ -511,6 +521,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) { break; case Intrinsic::dx_step: Result = expandStepIntrinsic(Orig); + break; + // case Intrinsic::dx_asuint_splitdouble: + // Result = expandSplitdoubleIntrinsic(Orig); } if (Result) { Orig->replaceAllUsesWith(Result); >From 9a094f4fec017dd2e990b41caf343c1d5081cada Mon Sep 17 00:00:00 2001 From: Joao Saffran <jdereze...@microsoft.com> Date: Mon, 23 Sep 2024 21:19:12 +0000 Subject: [PATCH 2/3] adding vector case for splitdouble --- clang/lib/CodeGen/CGBuiltin.cpp | 62 ++++++++++++++----- clang/lib/CodeGen/CGExpr.cpp | 8 ++- clang/lib/CodeGen/CodeGenFunction.h | 4 +- .../builtins/asuint-splitdouble.hlsl | 4 +- 4 files changed, 57 insertions(+), 21 deletions(-) diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index f7695b8693f3dc..e0c97270f6d1f1 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -34,12 +34,14 @@ #include "clang/Frontend/FrontendDiagnostic.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/FloatingPointMode.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/InlineAsm.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsAArch64.h" #include "llvm/IR/IntrinsicsAMDGPU.h" @@ -67,6 +69,7 @@ #include "llvm/TargetParser/X86TargetParser.h" #include <optional> #include <sstream> +#include <utility> using namespace clang; using namespace CodeGen; @@ -18855,29 +18858,60 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: { const HLSLOutArgExpr *OutArg1 = dyn_cast<HLSLOutArgExpr>(E->getArg(1)); const HLSLOutArgExpr *OutArg2 = dyn_cast<HLSLOutArgExpr>(E->getArg(2)); + auto emitSplitDouble = + [](CGBuilderTy *Builder, llvm::Value *arg, + llvm::Type *retType) -> std::pair<Value *, Value *> { + CallInst *CI = Builder->CreateIntrinsic( + retType, llvm::Intrinsic::dx_asuint_splitdouble, {arg}, nullptr, + "hlsl.asuint"); + + Value *arg0 = Builder->CreateExtractValue(CI, 0); + Value *arg1 = Builder->CreateExtractValue(CI, 1); + + return std::make_pair(arg0, arg1); + }; + CallArgList Args; - LValue Op1TmpLValue = EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType()); - LValue Op2TmpLValue = EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType()); + auto [Op1BaseLValue, Op1TmpLValue] = + EmitHLSLOutArgExpr(OutArg1, Args, OutArg1->getType()); + auto [Op2BaseLValue, Op2TmpLValue] = + EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType()); llvm::Type *retType = llvm::StructType::get(Int32Ty, Int32Ty); - if (Op0->getType()->isVectorTy()) { - auto *XVecTy = E->getArg(0)->getType()->getAs<VectorType>(); - llvm::VectorType *i32VecTy = llvm::VectorType::get( - Int32Ty, ElementCount::getFixed(XVecTy->getNumElements())); + if (!Op0->getType()->isVectorTy()) { + auto [arg0, arg1] = emitSplitDouble(&Builder, Op0, retType); + + Builder.CreateStore(arg0, Op1TmpLValue.getAddress()); + auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress()); - retType = llvm::StructType::get(i32VecTy, i32VecTy); + EmitWritebacks(*this, Args); + return s; } - CallInst *CI = - Builder.CreateIntrinsic(retType, llvm::Intrinsic::dx_asuint_splitdouble, - {Op0}, nullptr, "hlsl.asuint"); + auto *Op0VecTy = E->getArg(0)->getType()->getAs<VectorType>(); + + llvm::VectorType *i32VecTy = llvm::VectorType::get( + Int32Ty, ElementCount::getFixed(Op0VecTy->getNumElements())); - Value *arg0 = Builder.CreateExtractValue(CI, 0); - Value *arg1 = Builder.CreateExtractValue(CI, 1); + std::pair<Value *, Value *> inserts = std::make_pair(nullptr, nullptr); + + for (uint64_t idx = 0; idx < Op0VecTy->getNumElements(); idx++) { + Value *op = Builder.CreateExtractElement(Op0, idx); + + auto [arg0, arg1] = emitSplitDouble(&Builder, op, retType); + + if (idx == 0) { + inserts.first = Builder.CreateInsertElement(i32VecTy, arg0, idx); + inserts.second = Builder.CreateInsertElement(i32VecTy, arg1, idx); + } else { + inserts.first = Builder.CreateInsertElement(inserts.first, arg0, idx); + inserts.second = Builder.CreateInsertElement(inserts.second, arg0, idx); + } + } - Builder.CreateStore(arg0, Op1TmpLValue.getAddress()); - auto *s = Builder.CreateStore(arg1, Op2TmpLValue.getAddress()); + Builder.CreateStore(inserts.first, Op1TmpLValue.getAddress()); + auto *s = Builder.CreateStore(inserts.second, Op2TmpLValue.getAddress()); EmitWritebacks(*this, Args); return s; } diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp index 1c299c4a932ca0..53b60ad477a68b 100644 --- a/clang/lib/CodeGen/CGExpr.cpp +++ b/clang/lib/CodeGen/CGExpr.cpp @@ -54,6 +54,7 @@ #include <optional> #include <string> +#include <utility> using namespace clang; using namespace CodeGen; @@ -5478,8 +5479,9 @@ CodeGenFunction::EmitHLSLOutArgLValues(const HLSLOutArgExpr *E, QualType Ty) { return std::make_pair(BaseLV, TempLV); } -LValue CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, - CallArgList &Args, QualType Ty) { +std::pair<LValue, LValue> +CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args, + QualType Ty) { auto [BaseLV, TempLV] = EmitHLSLOutArgLValues(E, Ty); @@ -5494,7 +5496,7 @@ LValue CodeGenFunction::EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, Args.addWriteback(BaseLV, TmpAddr, nullptr, E->getWritebackCast(), LifetimeSize); Args.add(RValue::get(TmpAddr, *this), Ty); - return TempLV; + return std::make_pair(BaseLV, TempLV); } LValue diff --git a/clang/lib/CodeGen/CodeGenFunction.h b/clang/lib/CodeGen/CodeGenFunction.h index ad7c2635500d93..7372faa5656121 100644 --- a/clang/lib/CodeGen/CodeGenFunction.h +++ b/clang/lib/CodeGen/CodeGenFunction.h @@ -4299,8 +4299,8 @@ class CodeGenFunction : public CodeGenTypeCache { std::pair<LValue, LValue> EmitHLSLOutArgLValues(const HLSLOutArgExpr *E, QualType Ty); - LValue EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, CallArgList &Args, - QualType Ty); + std::pair<LValue, LValue> EmitHLSLOutArgExpr(const HLSLOutArgExpr *E, + CallArgList &Args, QualType Ty); Address EmitExtVectorElementLValue(LValue V); diff --git a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl index e359354dc3a6df..4326612db96b0f 100644 --- a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl +++ b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl @@ -3,8 +3,8 @@ // CHECK: define {{.*}}test_scalar{{.*}}(double {{.*}} [[VAL1:%.*]], i32 {{.*}} [[VAL2:%.*]], i32 {{.*}} [[VAL3:%.*]]){{.*}} // CHECK: [[VALD:%.*]] = load double, ptr [[VAL1]].addr{{.*}} // CHECK: call { i32, i32 } @llvm.dx.asuint.splitdouble.{{.*}}(double [[VALD]]) -float fn(double D) { - uint A, B; +float2 fn(double2 D) { + uint2 A, B; asuint(D, A, B); return A + B; } >From 37fb42a2dc8a358c34f11052ae1b6fce3a7797a4 Mon Sep 17 00:00:00 2001 From: Joao Saffran <jdereze...@microsoft.com> Date: Tue, 24 Sep 2024 00:50:10 +0000 Subject: [PATCH 3/3] adding lowering to dxil --- clang/include/clang/Basic/Builtins.td | 4 +- clang/lib/CodeGen/CGBuiltin.cpp | 15 +++--- clang/lib/Headers/hlsl/hlsl_intrinsics.h | 8 +-- clang/lib/Sema/SemaHLSL.cpp | 44 ++++++++-------- .../builtins/asuint-splitdouble.hlsl | 25 +++++++--- .../test/SemaHLSL/BuiltIns/asuint-errors.hlsl | 4 ++ llvm/include/llvm/IR/IntrinsicsDirectX.td | 2 +- llvm/lib/Target/DirectX/DXIL.td | 11 +++- .../Target/DirectX/DXILIntrinsicExpansion.cpp | 13 ----- llvm/lib/Target/DirectX/DXILOpBuilder.cpp | 13 +++++ llvm/lib/Target/DirectX/DXILOpBuilder.h | 4 ++ llvm/lib/Target/DirectX/DXILOpLowering.cpp | 50 +++++++++++++++++++ 12 files changed, 135 insertions(+), 58 deletions(-) diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index b38957f6e3f15d..4e0566615b5fef 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -4788,8 +4788,8 @@ def HLSLStep: LangBuiltin<"HLSL_LANG"> { let Prototype = "void(...)"; } -def HLSLAsUintSplitDouble: LangBuiltin<"HLSL_LANG"> { - let Spellings = ["__builtin_hlsl_asuint_splitdouble"]; +def HLSLSplitDouble: LangBuiltin<"HLSL_LANG"> { + let Spellings = ["__builtin_hlsl_splitdouble"]; let Attributes = [NoThrow, Const]; let Prototype = "void(...)"; } diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index e0c97270f6d1f1..e9c44be58289af 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -34,14 +34,12 @@ #include "clang/Frontend/FrontendDiagnostic.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" -#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/FloatingPointMode.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/InlineAsm.h" -#include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsAArch64.h" #include "llvm/IR/IntrinsicsAMDGPU.h" @@ -69,7 +67,6 @@ #include "llvm/TargetParser/X86TargetParser.h" #include <optional> #include <sstream> -#include <utility> using namespace clang; using namespace CodeGen; @@ -18847,7 +18844,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: { ArrayRef<Value *>{Op0}, nullptr, "hlsl.sign"); } // This should only be called when targeting DXIL - case Builtin::BI__builtin_hlsl_asuint_splitdouble: { + case Builtin::BI__builtin_hlsl_splitdouble: { assert((E->getArg(0)->getType()->hasFloatingRepresentation() && E->getArg(1)->getType()->hasUnsignedIntegerRepresentation() && @@ -18861,9 +18858,9 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: { auto emitSplitDouble = [](CGBuilderTy *Builder, llvm::Value *arg, llvm::Type *retType) -> std::pair<Value *, Value *> { - CallInst *CI = Builder->CreateIntrinsic( - retType, llvm::Intrinsic::dx_asuint_splitdouble, {arg}, nullptr, - "hlsl.asuint"); + CallInst *CI = + Builder->CreateIntrinsic(retType, llvm::Intrinsic::dx_splitdouble, + {arg}, nullptr, "hlsl.asuint"); Value *arg0 = Builder->CreateExtractValue(CI, 0); Value *arg1 = Builder->CreateExtractValue(CI, 1); @@ -18877,7 +18874,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: { auto [Op2BaseLValue, Op2TmpLValue] = EmitHLSLOutArgExpr(OutArg2, Args, OutArg2->getType()); - llvm::Type *retType = llvm::StructType::get(Int32Ty, Int32Ty); + llvm::StructType *retType = llvm::StructType::get(Int32Ty, Int32Ty); if (!Op0->getType()->isVectorTy()) { auto [arg0, arg1] = emitSplitDouble(&Builder, Op0, retType); @@ -18906,7 +18903,7 @@ case Builtin::BI__builtin_hlsl_elementwise_isinf: { inserts.second = Builder.CreateInsertElement(i32VecTy, arg1, idx); } else { inserts.first = Builder.CreateInsertElement(inserts.first, arg0, idx); - inserts.second = Builder.CreateInsertElement(inserts.second, arg0, idx); + inserts.second = Builder.CreateInsertElement(inserts.second, arg1, idx); } } diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h index e8a1e97f344559..dede9583d1bc58 100644 --- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h @@ -432,13 +432,13 @@ template <typename T> constexpr uint asuint(T F) { /// \param lowbits The output lowbits of D. /// \param highbits The highbits lowbits D. #if __is_target_arch(dxil) -_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble) void asuint(double, out uint, out uint); -_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble) void asuint(double2, out uint2, out uint2); -_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble) void asuint(double3, out uint3, out uint3); -_HLSL_BUILTIN_ALIAS(__builtin_hlsl_asuint_splitdouble) +_HLSL_BUILTIN_ALIAS(__builtin_hlsl_splitdouble) void asuint(double4, out uint4, out uint4); #endif diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 20878442c92338..c0bbe00fe47008 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -1467,7 +1467,7 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) { return true; } -bool CheckArgTypeIsCorrect( +bool CheckArgTypeIsIncorrect( Sema *S, Expr *Arg, QualType ExpectedType, llvm::function_ref<bool(clang::QualType PassedType)> Check) { QualType PassedType = Arg->getType(); @@ -1487,7 +1487,7 @@ bool CheckArgsTypesAreCorrect( llvm::function_ref<bool(clang::QualType PassedType)> Check) { for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) { Expr *Arg = TheCall->getArg(i); - if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) { + if (CheckArgTypeIsIncorrect(S, Arg, ExpectedType, Check)) { return true; } } @@ -1771,34 +1771,34 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { return true; break; } - case Builtin::BI__builtin_hlsl_asuint_splitdouble: { + case Builtin::BI__builtin_hlsl_splitdouble: { if (SemaRef.checkArgCount(TheCall, 3)) return true; - // Expr *Op0 = TheCall->getArg(0); + Expr *Op0 = TheCall->getArg(0); - // auto CheckIsNotDouble = [](clang::QualType PassedType) -> bool { - // return !PassedType->isDoubleType(); - // }; + auto CheckIsNotDouble = [](clang::QualType PassedType) -> bool { + return !PassedType->hasFloatingRepresentation(); + }; - // if (CheckArgTypeIsCorrect(&SemaRef, Op0, SemaRef.Context.DoubleTy, - // CheckIsNotDouble)) { - // return true; - // } + if (CheckArgTypeIsIncorrect(&SemaRef, Op0, SemaRef.Context.DoubleTy, + CheckIsNotDouble)) { + return true; + } - // Expr *Op1 = TheCall->getArg(1); - // Expr *Op2 = TheCall->getArg(2); + Expr *Op1 = TheCall->getArg(1); + Expr *Op2 = TheCall->getArg(2); - // auto CheckIsNotUint = [](clang::QualType PassedType) -> bool { - // return !PassedType->isUnsignedIntegerType(); - // }; + auto CheckIsNotUint = [](clang::QualType PassedType) -> bool { + return !PassedType->hasUnsignedIntegerRepresentation(); + }; - // if (CheckArgTypeIsCorrect(&SemaRef, Op1, SemaRef.Context.UnsignedIntTy, - // CheckIsNotUint) || - // CheckArgTypeIsCorrect(&SemaRef, Op2, SemaRef.Context.UnsignedIntTy, - // CheckIsNotUint)) { - // return true; - // } + if (CheckArgTypeIsIncorrect(&SemaRef, Op1, SemaRef.Context.UnsignedIntTy, + CheckIsNotUint) || + CheckArgTypeIsIncorrect(&SemaRef, Op2, SemaRef.Context.UnsignedIntTy, + CheckIsNotUint)) { + return true; + } break; } diff --git a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl index 4326612db96b0f..1711c344792aee 100644 --- a/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl +++ b/clang/test/CodeGenHLSL/builtins/asuint-splitdouble.hlsl @@ -1,10 +1,23 @@ -// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O0 -o - | FileCheck %s +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -O1 -o - | FileCheck %s -// CHECK: define {{.*}}test_scalar{{.*}}(double {{.*}} [[VAL1:%.*]], i32 {{.*}} [[VAL2:%.*]], i32 {{.*}} [[VAL3:%.*]]){{.*}} -// CHECK: [[VALD:%.*]] = load double, ptr [[VAL1]].addr{{.*}} -// CHECK: call { i32, i32 } @llvm.dx.asuint.splitdouble.{{.*}}(double [[VALD]]) -float2 fn(double2 D) { - uint2 A, B; + +// CHECK: define {{.*}} float {{.*}}test_scalar{{.*}}(double {{.*}} [[VALD:%.*]]) +// CHECK: [[VALRET:%.*]] = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALD]]) +// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0 +// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1 +float test_scalar(double D) { + uint A, B; + asuint(D, A, B); + return A + B; +} + +// CHECK: define {{.*}} <3 x float> {{.*}}test_vector{{.*}}(<3 x double> {{.*}} [[VALD:%.*]]) +// CHECK-COUNT-3: [[VALREG:%.*]] = extractelement <3 x double> [[VALD]], i64 [[VALIDX:[0-3]]] +// CHECK-NEXT: [[VALRET:%.*]] = tail call { i32, i32 } @llvm.dx.splitdouble.i32(double [[VALREG]]) +// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 0 +// CHECK-NEXT: extractvalue { i32, i32 } [[VALRET]], 1 +float3 test_vector(double3 D) { + uint3 A, B; asuint(D, A, B); return A + B; } diff --git a/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl index 8c56fdddb1c24c..b9a920f9f1b4d0 100644 --- a/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl +++ b/clang/test/SemaHLSL/BuiltIns/asuint-errors.hlsl @@ -6,6 +6,10 @@ uint4 test_asuint_too_many_arg(float p0, float p1) { // expected-error@-1 {{no matching function for call to 'asuint'}} // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires single argument 'V', but 2 arguments were provided}} // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function template not viable: requires single argument 'F', but 2 arguments were provided}} + // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 2 were provided}} + // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 2 were provided}} + // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 2 were provided}} + // expected-note@hlsl/hlsl_intrinsics.h:* {{candidate function not viable: requires 3 arguments, but 2 were provided}} } uint test_asuint_double(double p1) { diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index d8092397881550..04dd26ea54ca80 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -89,7 +89,7 @@ def int_dx_wave_is_first_lane : DefaultAttrsIntrinsic<[llvm_i1_ty], [], [IntrCon def int_dx_sign : DefaultAttrsIntrinsic<[LLVMScalarOrSameVectorWidth<0, llvm_i32_ty>], [llvm_any_ty], [IntrNoMem]>; def int_dx_step : DefaultAttrsIntrinsic<[LLVMMatchType<0>], [llvm_anyfloat_ty, LLVMMatchType<0>], [IntrNoMem]>; -def int_dx_asuint_splitdouble : DefaultAttrsIntrinsic< +def int_dx_splitdouble : DefaultAttrsIntrinsic< [llvm_anyint_ty, LLVMMatchType<0>], [LLVMScalarOrSameVectorWidth<0, llvm_double_ty>], [IntrNoMem, IntrWillReturn]>; diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 06c52da5fc07c8..912d385fe285a2 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -47,6 +47,7 @@ def ResRetInt32Ty : DXILOpParamType; def HandleTy : DXILOpParamType; def ResBindTy : DXILOpParamType; def ResPropsTy : DXILOpParamType; +def ResSplitDoubleTy : DXILOpParamType; class DXILOpClass; @@ -778,7 +779,15 @@ def FlattenedThreadIdInGroup : DXILOp<96, flattenedThreadIdInGroup> { let stages = [Stages<DXIL1_0, [compute, mesh, amplification, node]>]; let attributes = [Attributes<DXIL1_0, [ReadNone]>]; } -// + +def SplitDouble : DXILOp<102, splitDouble> { + let Doc = "Splits a double into 2 uints"; + let arguments = [OverloadTy]; + let result = ResSplitDoubleTy; + let overloads = [Overloads<DXIL1_0, [DoubleTy]>]; + let stages = [Stages<DXIL1_0, [all_stages]>]; + let attributes = [Attributes<DXIL1_0, [ReadNone]>]; +} def AnnotateHandle : DXILOp<217, annotateHandle> { let Doc = "annotate handle with resource properties"; diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp index 09e87d5035093b..926cbe97f24fda 100644 --- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp +++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp @@ -12,7 +12,6 @@ #include "DXILIntrinsicExpansion.h" #include "DirectX.h" -#include "llvm-c/Core.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/DXILResource.h" @@ -396,15 +395,6 @@ static Value *expandStepIntrinsic(CallInst *Orig) { return Builder.CreateSelect(Cond, Zero, One); } -// static Value *expandSplitdoubleIntrinsic(CallInst *Orig) { -// Value *X = Orig->getOperand(0); -// Type *Ty = X->getType(); -// IRBuilder<> Builder(Orig); - -// Builder.CreateIntrinsic() - -// } - static Intrinsic::ID getMaxForClamp(Type *ElemTy, Intrinsic::ID ClampIntrinsic) { if (ClampIntrinsic == Intrinsic::dx_uclamp) @@ -521,9 +511,6 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) { break; case Intrinsic::dx_step: Result = expandStepIntrinsic(Orig); - break; - // case Intrinsic::dx_asuint_splitdouble: - // Result = expandSplitdoubleIntrinsic(Orig); } if (Result) { Orig->replaceAllUsesWith(Result); diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp index 7719d6b1079110..982d7849d9bb8b 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp @@ -229,6 +229,13 @@ static StructType *getResPropsType(LLVMContext &Context) { return StructType::create({Int32Ty, Int32Ty}, "dx.types.ResourceProperties"); } +static StructType *getResSplitDoubleType(LLVMContext &Context) { + if (auto *ST = StructType::getTypeByName(Context, "dx.types.splitdouble")) + return ST; + Type *Int32Ty = Type::getInt32Ty(Context); + return StructType::create({Int32Ty, Int32Ty}, "dx.types.splitdouble"); +} + static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx, Type *OverloadTy) { switch (Kind) { @@ -266,6 +273,8 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx, return getResBindType(Ctx); case OpParamType::ResPropsTy: return getResPropsType(Ctx); + case OpParamType::ResSplitDoubleTy: + return getResSplitDoubleType(Ctx); } llvm_unreachable("Invalid parameter kind"); return nullptr; @@ -467,6 +476,10 @@ StructType *DXILOpBuilder::getResRetType(Type *ElementTy) { return ::getResRetType(ElementTy); } +StructType *DXILOpBuilder::getResSplitDoubleType(LLVMContext &Context) { + return ::getResSplitDoubleType(Context); +} + StructType *DXILOpBuilder::getHandleType() { return ::getHandleType(IRB.getContext()); } diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h index 037ae3822cfb90..8b1e87c283146c 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.h +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h @@ -49,6 +49,10 @@ class DXILOpBuilder { /// Get a `%dx.types.ResRet` type with the given element type. StructType *getResRetType(Type *ElementTy); + + /// Get the `%dx.types.splitdouble` type. + StructType *getResSplitDoubleType(LLVMContext &Context); + /// Get the `%dx.types.Handle` type. StructType *getHandleType(); diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index 3ee3ee05563c24..83c6b7f6d503dc 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -17,6 +17,7 @@ #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/IntrinsicsDirectX.h" #include "llvm/IR/Module.h" @@ -264,6 +265,31 @@ class OpLowerer { return lowerToBindAndAnnotateHandle(F); } + Error replaceSplitDoubleCallUsages(CallInst *Intrin, CallInst *Op) { + IRBuilder<> &IRB = OpBuilder.getIRB(); + + for (Use &U : Intrin->uses()) { + if (auto *EVI = dyn_cast<ExtractValueInst>(U.getUser())) { + + assert(EVI->getNumIndices() == 1 && + "splitdouble result should be indexed individually."); + if (EVI->getNumIndices() != 1) + return make_error<StringError>( + "splitdouble result should be indexed individually.", + inconvertibleErrorCode()); + + unsigned int IndexVal = EVI->getIndices()[0]; + + auto *OpEVI = IRB.CreateExtractValue(Op, IndexVal); + EVI->replaceAllUsesWith(OpEVI); + EVI->eraseFromParent(); + } + } + Intrin->eraseFromParent(); + + return Error::success(); + } + /// Replace uses of \c Intrin with the values in the `dx.ResRet` of \c Op. /// Since we expect to be post-scalarization, make an effort to avoid vectors. Error replaceResRetUses(CallInst *Intrin, CallInst *Op, bool HasCheckBit) { @@ -461,6 +487,27 @@ class OpLowerer { }); } + [[nodiscard]] bool lowerSplitDouble(Function &F) { + IRBuilder<> &IRB = OpBuilder.getIRB(); + return replaceFunction(F, [&](CallInst *CI) -> Error { + IRB.SetInsertPoint(CI); + + Value *Arg0 = CI->getArgOperand(0); + + Type *NewRetTy = OpBuilder.getResSplitDoubleType(M.getContext()); + + std::array<Value *, 1> Args{Arg0}; + Expected<CallInst *> OpCall = OpBuilder.tryCreateOp( + OpCode::SplitDouble, Args, CI->getName(), NewRetTy); + if (Error E = OpCall.takeError()) + return E; + if (Error E = replaceSplitDoubleCallUsages(CI, *OpCall)) + return E; + + return Error::success(); + }); + } + bool lowerIntrinsics() { bool Updated = false; bool HasErrors = false; @@ -489,6 +536,9 @@ class OpLowerer { case Intrinsic::dx_typedBufferStore: HasErrors |= lowerTypedBufferStore(F); break; + case Intrinsic::dx_splitdouble: + HasErrors |= lowerSplitDouble(F); + break; } Updated = true; } _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits