llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-backend-directx Author: Justin Bogner (bogner) <details> <summary>Changes</summary> The `@<!-- -->llvm.dx.typedBufferLoad` intrinsic is lowered to `@<!-- -->dx.op.bufferLoad`. There's some complexity here due to translating from a vector return type to a named struct and trying to avoid excessive IR coming out of that. Note that this change includes a bit of a hack in how it deals with `getOverloadKind` for the `dx.ResRet` types - we need to adjust how we deal with operation overloads to generate a table directly rather than proxy through the OverloadKind enum, but that's left for a later change here. --- Full diff: https://github.com/llvm/llvm-project/pull/104252.diff 7 Files Affected: - (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+4) - (modified) llvm/lib/Target/DirectX/DXIL.td (+15-1) - (modified) llvm/lib/Target/DirectX/DXILOpBuilder.cpp (+25-6) - (modified) llvm/lib/Target/DirectX/DXILOpBuilder.h (+2) - (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+57) - (added) llvm/test/CodeGen/DirectX/BufferLoad.ll (+102) - (modified) llvm/utils/TableGen/DXILEmitter.cpp (+5-1) ``````````diff diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index ca3682fa47767..d817b610fa71a 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -30,6 +30,10 @@ def int_dx_handle_fromBinding [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty], [IntrNoMem]>; +def int_dx_typedBufferLoad + : DefaultAttrsIntrinsic<[llvm_anyvector_ty], + [llvm_any_ty, llvm_i32_ty]>; + // Cast between target extension handle types and dxil-style opaque handles def int_dx_cast_handle : Intrinsic<[llvm_any_ty], [llvm_any_ty]>; diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 31fee04d82158..b114148f84e84 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -40,7 +40,10 @@ def Int64Ty : DXILOpParamType; def HalfTy : DXILOpParamType; def FloatTy : DXILOpParamType; def DoubleTy : DXILOpParamType; -def ResRetTy : DXILOpParamType; +def ResRetHalfTy : DXILOpParamType; +def ResRetFloatTy : DXILOpParamType; +def ResRetInt16Ty : DXILOpParamType; +def ResRetInt32Ty : DXILOpParamType; def HandleTy : DXILOpParamType; def ResBindTy : DXILOpParamType; def ResPropsTy : DXILOpParamType; @@ -683,6 +686,17 @@ def CreateHandle : DXILOp<57, createHandle> { let stages = [Stages<DXIL1_0, [all_stages]>]; } +def BufferLoad : DXILOp<68, bufferLoad> { + let Doc = "reads from a TypedBuffer"; + // Handle, Coord0, Coord1 + let arguments = [HandleTy, Int32Ty, Int32Ty]; + let result = OverloadTy; + let overloads = + [Overloads<DXIL1_0, + [ResRetHalfTy, ResRetFloatTy, ResRetInt16Ty, ResRetInt32Ty]>]; + let stages = [Stages<DXIL1_0, [all_stages]>]; +} + def ThreadId : DXILOp<93, threadId> { let Doc = "Reads the thread ID"; let LLVMIntrinsic = int_dx_thread_id; diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp index 692af1b359ced..246e32c264dc9 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.cpp +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.cpp @@ -120,8 +120,15 @@ static OverloadKind getOverloadKind(Type *Ty) { } case Type::PointerTyID: return OverloadKind::UserDefineType; - case Type::StructTyID: + case Type::StructTyID: { + // TODO: This is a hack. As described in DXILEmitter.cpp, we need to rework + // how we're handling overloads and remove the `OverloadKind` proxy enum. + StructType *ST = cast<StructType>(Ty); + if (ST->hasName() && ST->getName().starts_with("dx.types.ResRet")) + return getOverloadKind(ST->getElementType(0)); + return OverloadKind::ObjectType; + } default: llvm_unreachable("invalid overload type"); return OverloadKind::VOID; @@ -195,10 +202,11 @@ static StructType *getOrCreateStructType(StringRef Name, return StructType::create(Ctx, EltTys, Name); } -static StructType *getResRetType(Type *OverloadTy, LLVMContext &Ctx) { - OverloadKind Kind = getOverloadKind(OverloadTy); +static StructType *getResRetType(Type *ElementTy) { + LLVMContext &Ctx = ElementTy->getContext(); + OverloadKind Kind = getOverloadKind(ElementTy); std::string TypeName = constructOverloadTypeName(Kind, "dx.types.ResRet."); - Type *FieldTypes[5] = {OverloadTy, OverloadTy, OverloadTy, OverloadTy, + Type *FieldTypes[5] = {ElementTy, ElementTy, ElementTy, ElementTy, Type::getInt32Ty(Ctx)}; return getOrCreateStructType(TypeName, FieldTypes, Ctx); } @@ -248,8 +256,14 @@ static Type *getTypeFromOpParamType(OpParamType Kind, LLVMContext &Ctx, return Type::getInt64Ty(Ctx); case OpParamType::OverloadTy: return OverloadTy; - case OpParamType::ResRetTy: - return getResRetType(OverloadTy, Ctx); + case OpParamType::ResRetHalfTy: + return getResRetType(Type::getHalfTy(Ctx)); + case OpParamType::ResRetFloatTy: + return getResRetType(Type::getFloatTy(Ctx)); + case OpParamType::ResRetInt16Ty: + return getResRetType(Type::getInt16Ty(Ctx)); + case OpParamType::ResRetInt32Ty: + return getResRetType(Type::getInt32Ty(Ctx)); case OpParamType::HandleTy: return getHandleType(Ctx); case OpParamType::ResBindTy: @@ -391,6 +405,7 @@ Expected<CallInst *> DXILOpBuilder::tryCreateOp(dxil::OpCode OpCode, return makeOpError(OpCode, "Wrong number of arguments"); OverloadTy = Args[ArgIndex]->getType(); } + FunctionType *DXILOpFT = getDXILOpFunctionType(OpCode, M.getContext(), OverloadTy); @@ -451,6 +466,10 @@ CallInst *DXILOpBuilder::createOp(dxil::OpCode OpCode, ArrayRef<Value *> Args, return *Result; } +StructType *DXILOpBuilder::getResRetType(Type *ElementTy) { + return ::getResRetType(ElementTy); +} + StructType *DXILOpBuilder::getHandleType() { return ::getHandleType(IRB.getContext()); } diff --git a/llvm/lib/Target/DirectX/DXILOpBuilder.h b/llvm/lib/Target/DirectX/DXILOpBuilder.h index 4a55a8ac9eadb..a68f0c43f67af 100644 --- a/llvm/lib/Target/DirectX/DXILOpBuilder.h +++ b/llvm/lib/Target/DirectX/DXILOpBuilder.h @@ -46,6 +46,8 @@ class DXILOpBuilder { Expected<CallInst *> tryCreateOp(dxil::OpCode Op, ArrayRef<Value *> Args, Type *RetTy = nullptr); + /// Get a `%dx.types.ResRet` type with the given element type. + StructType *getResRetType(Type *ElementTy); /// Get the `%dx.types.Handle` type. StructType *getHandleType(); diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index ab18c57efa307..46dfc905b5875 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -236,6 +236,59 @@ class OpLowerer { lowerToBindAndAnnotateHandle(F); } + void lowerTypedBufferLoad(Function &F) { + IRBuilder<> &IRB = OpBuilder.getIRB(); + Type *Int32Ty = IRB.getInt32Ty(); + + replaceFunction(F, [&](CallInst *CI) -> Error { + IRB.SetInsertPoint(CI); + + Value *Handle = + createTmpHandleCast(CI->getArgOperand(0), OpBuilder.getHandleType()); + Value *Index0 = CI->getArgOperand(1); + Value *Index1 = UndefValue::get(Int32Ty); + Type *RetTy = OpBuilder.getResRetType(CI->getType()->getScalarType()); + + std::array<Value *, 3> Args{Handle, Index0, Index1}; + Expected<CallInst *> OpCall = + OpBuilder.tryCreateOp(OpCode::BufferLoad, Args, RetTy); + if (Error E = OpCall.takeError()) + return E; + + std::array<Value *, 4> Extracts = {}; + + // We've switched the return type from a vector to a struct, but at this + // point most vectors have probably already been scalarized. Try to + // forward arguments directly rather than inserting into and immediately + // extracting from a vector. + for (Use &U : make_early_inc_range(CI->uses())) + if (auto *EEI = dyn_cast<ExtractElementInst>(U.getUser())) + if (auto *Index = dyn_cast<ConstantInt>(EEI->getIndexOperand())) { + size_t IndexVal = Index->getZExtValue(); + assert(IndexVal < 4 && "Index into buffer load out of range"); + if (!Extracts[IndexVal]) + Extracts[IndexVal] = IRB.CreateExtractValue(*OpCall, IndexVal); + EEI->replaceAllUsesWith(Extracts[IndexVal]); + EEI->eraseFromParent(); + } + + // If there are still uses then we need to create a vector. + if (!CI->use_empty()) { + for (int I = 0, E = 4; I != E; ++I) + if (!Extracts[I]) + Extracts[I] = IRB.CreateExtractValue(*OpCall, I); + + Value *Vec = UndefValue::get(CI->getType()); + for (int I = 0, E = 4; I != E; ++I) + Vec = IRB.CreateInsertElement(Vec, Extracts[I], I); + CI->replaceAllUsesWith(Vec); + } + + CI->eraseFromParent(); + return Error::success(); + }); + } + bool lowerIntrinsics() { bool Updated = false; @@ -253,6 +306,10 @@ class OpLowerer { #include "DXILOperation.inc" case Intrinsic::dx_handle_fromBinding: lowerHandleFromBinding(F); + break; + case Intrinsic::dx_typedBufferLoad: + lowerTypedBufferLoad(F); + break; } Updated = true; } diff --git a/llvm/test/CodeGen/DirectX/BufferLoad.ll b/llvm/test/CodeGen/DirectX/BufferLoad.ll new file mode 100644 index 0000000000000..c3bb96dbdf909 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/BufferLoad.ll @@ -0,0 +1,102 @@ +; RUN: opt -S -dxil-op-lower %s | FileCheck %s + +target triple = "dxil-pc-shadermodel6.6-compute" + +declare void @scalar_user(float) +declare void @vector_user(<4 x float>) + +define void @loadfloats() { + ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding + ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]] + %buffer = call target("dx.TypedBuffer", <4 x float>, 0, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_0_0_0( + i32 0, i32 0, i32 1, i32 0, i1 false) + + ; The temporary casts should all have been cleaned up + ; CHECK-NOT: %dx.cast_handle + + ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) + %data0 = call <4 x float> @llvm.dx.typedBufferLoad( + target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 0) + + ; The extract order depends on the users, so don't enforce that here. + ; CHECK-DAG: extractvalue %dx.types.ResRet.f32 [[DATA0]], 0 + %data0_0 = extractelement <4 x float> %data0, i32 0 + ; CHECK-DAG: extractvalue %dx.types.ResRet.f32 [[DATA0]], 2 + %data0_2 = extractelement <4 x float> %data0, i32 2 + + ; If all of the uses are extracts, we skip creating a vector + ; CHECK-NOT: insertelement + call void @scalar_user(float %data0_0) + call void @scalar_user(float %data0_2) + + ; CHECK: [[DATA4:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 4, i32 undef) + %data4 = call <4 x float> @llvm.dx.typedBufferLoad( + target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 4) + + ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 0 + ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 1 + ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 2 + ; CHECK: extractvalue %dx.types.ResRet.f32 [[DATA4]], 3 + ; CHECK: insertelement <4 x float> undef + ; CHECK: insertelement <4 x float> + ; CHECK: insertelement <4 x float> + ; CHECK: insertelement <4 x float> + call void @vector_user(<4 x float> %data4) + + ; CHECK: [[DATA12:%.*]] = call %dx.types.ResRet.f32 @dx.op.bufferLoad.f32(i32 68, %dx.types.Handle [[HANDLE]], i32 12, i32 undef) + %data12 = call <4 x float> @llvm.dx.typedBufferLoad( + target("dx.TypedBuffer", <4 x float>, 0, 0, 0) %buffer, i32 12) + + ; CHECK: [[DATA12_3:%.*]] = extractvalue %dx.types.ResRet.f32 [[DATA12]], 3 + %data12_3 = extractelement <4 x float> %data12, i32 3 + + ; If there are a mix of users we need the vector, but extracts are direct + ; CHECK: call void @scalar_user(float [[DATA12_3]]) + call void @scalar_user(float %data12_3) + call void @vector_user(<4 x float> %data12) + + ret void +} + +define void @loadint() { + ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding + ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]] + %buffer = call target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i32_0_0_0( + i32 0, i32 0, i32 1, i32 0, i1 false) + + ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i32 @dx.op.bufferLoad.i32(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) + %data0 = call <4 x i32> @llvm.dx.typedBufferLoad( + target("dx.TypedBuffer", <4 x i32>, 0, 0, 0) %buffer, i32 0) + + ret void +} + +define void @loadhalf() { + ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding + ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]] + %buffer = call target("dx.TypedBuffer", <4 x half>, 0, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f16_0_0_0( + i32 0, i32 0, i32 1, i32 0, i1 false) + + ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.f16 @dx.op.bufferLoad.f16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) + %data0 = call <4 x half> @llvm.dx.typedBufferLoad( + target("dx.TypedBuffer", <4 x half>, 0, 0, 0) %buffer, i32 0) + + ret void +} + +define void @loadi16() { + ; CHECK: [[BIND:%.*]] = call %dx.types.Handle @dx.op.createHandleFromBinding + ; CHECK: [[HANDLE:%.*]] = call %dx.types.Handle @dx.op.annotateHandle(i32 217, %dx.types.Handle [[BIND]] + %buffer = call target("dx.TypedBuffer", <4 x i16>, 0, 0, 0) + @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4i16_0_0_0( + i32 0, i32 0, i32 1, i32 0, i1 false) + + ; CHECK: [[DATA0:%.*]] = call %dx.types.ResRet.i16 @dx.op.bufferLoad.i16(i32 68, %dx.types.Handle [[HANDLE]], i32 0, i32 undef) + %data0 = call <4 x i16> @llvm.dx.typedBufferLoad( + target("dx.TypedBuffer", <4 x i16>, 0, 0, 0) %buffer, i32 0) + + ret void +} diff --git a/llvm/utils/TableGen/DXILEmitter.cpp b/llvm/utils/TableGen/DXILEmitter.cpp index 9cc1b5ccb8acb..332706f7e3e57 100644 --- a/llvm/utils/TableGen/DXILEmitter.cpp +++ b/llvm/utils/TableGen/DXILEmitter.cpp @@ -187,7 +187,11 @@ static StringRef getOverloadKindStr(const Record *R) { .Case("Int8Ty", "OverloadKind::I8") .Case("Int16Ty", "OverloadKind::I16") .Case("Int32Ty", "OverloadKind::I32") - .Case("Int64Ty", "OverloadKind::I64"); + .Case("Int64Ty", "OverloadKind::I64") + .Case("ResRetHalfTy", "OverloadKind::HALF") + .Case("ResRetFloatTy", "OverloadKind::FLOAT") + .Case("ResRetInt16Ty", "OverloadKind::I16") + .Case("ResRetInt32Ty", "OverloadKind::I32"); } /// Return a string representation of valid overload information denoted `````````` </details> https://github.com/llvm/llvm-project/pull/104252 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits