https://github.com/spall created https://github.com/llvm/llvm-project/pull/108121
Implement elementwise popcount to support HLSL function 'countbits'. Closes #99094 >From 365886deae6e35ee2761c2fae2a28caa0e214880 Mon Sep 17 00:00:00 2001 From: Sarah Spall <sp...@planetbauer.com> Date: Fri, 6 Sep 2024 21:03:05 +0000 Subject: [PATCH] implement elementwise popcount to implement countbits --- clang/docs/LanguageExtensions.rst | 1 + clang/include/clang/Basic/Builtins.td | 6 ++ clang/lib/CodeGen/CGBuiltin.cpp | 3 + clang/lib/Headers/hlsl/hlsl_intrinsics.h | 71 +++++++++++++++++++ clang/lib/Sema/SemaChecking.cpp | 2 +- clang/lib/Sema/SemaHLSL.cpp | 8 +++ .../test/CodeGen/builtins-elementwise-math.c | 37 ++++++++++ clang/test/Sema/builtins-elementwise-math.c | 21 ++++++ .../SemaCXX/builtins-elementwise-math.cpp | 8 +++ llvm/lib/Target/DirectX/DXIL.td | 11 +++ llvm/test/CodeGen/DirectX/countbits.ll | 31 ++++++++ llvm/test/CodeGen/DirectX/countbits_error.ll | 10 +++ .../SPIRV/hlsl-intrinsics/countbits.ll | 21 ++++++ 13 files changed, 229 insertions(+), 1 deletion(-) create mode 100644 llvm/test/CodeGen/DirectX/countbits.ll create mode 100644 llvm/test/CodeGen/DirectX/countbits_error.ll create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/countbits.ll diff --git a/clang/docs/LanguageExtensions.rst b/clang/docs/LanguageExtensions.rst index c08697282cbfe8..f62f90fb9650a9 100644 --- a/clang/docs/LanguageExtensions.rst +++ b/clang/docs/LanguageExtensions.rst @@ -667,6 +667,7 @@ Unless specified otherwise operation(±0) = ±0 and operation(±infinity) = ±in T __builtin_elementwise_log(T x) return the natural logarithm of x floating point types T __builtin_elementwise_log2(T x) return the base 2 logarithm of x floating point types T __builtin_elementwise_log10(T x) return the base 10 logarithm of x floating point types + T __builtin_elementwise_popcount(T x) return the number of 1 bits in x integer types T __builtin_elementwise_pow(T x, T y) return x raised to the power of y floating point types T __builtin_elementwise_bitreverse(T x) return the integer represented after reversing the bits of x integer types T __builtin_elementwise_exp(T x) returns the base-e exponential, e^x, of the specified value floating point types diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index 92118418d9d459..6281fa144bae35 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -1322,6 +1322,12 @@ def ElementwiseLog10 : Builtin { let Prototype = "void(...)"; } +def ElementwisePopcount : Builtin { + let Spellings = ["__builtin_elementwise_popcount"]; + let Attributes = [NoThrow, Const, CustomTypeChecking]; + let Prototype = "void(...)"; +} + def ElementwisePow : Builtin { let Spellings = ["__builtin_elementwise_pow"]; let Attributes = [NoThrow, Const, CustomTypeChecking]; diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index da7a1a55da5313..c5d50e57fa638c 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -3834,6 +3834,9 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID, case Builtin::BI__builtin_elementwise_floor: return RValue::get(emitBuiltinWithOneOverloadedType<1>( *this, E, llvm::Intrinsic::floor, "elt.floor")); + case Builtin::BI__builtin_elementwise_popcount: + return RValue::get(emitBuiltinWithOneOverloadedType<1>( + *this, E, llvm::Intrinsic::ctpop, "elt.ctpop")); case Builtin::BI__builtin_elementwise_roundeven: return RValue::get(emitBuiltinWithOneOverloadedType<1>( *this, E, llvm::Intrinsic::roundeven, "elt.roundeven")); diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h index 5c08a45a35377d..9d667bb61b74ae 100644 --- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h @@ -614,6 +614,77 @@ float3 cosh(float3); _HLSL_BUILTIN_ALIAS(__builtin_elementwise_cosh) float4 cosh(float4); +//===----------------------------------------------------------------------===// +// count bits builtins +//===----------------------------------------------------------------------===// + +/// \fn T countbits(T Val) +/// \brief Return the number of bits (per component) set in the input integer. +/// \param Val The input value. + +#ifdef __HLSL_ENABLE_16_BIT +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +int16_t countbits(int16_t); +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +int16_t2 countbits(int16_t2); +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +int16_t3 countbits(int16_t3); +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +int16_t4 countbits(int16_t4); +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +uint16_t countbits(uint16_t); +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +uint16_t2 countbits(uint16_t2); +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +uint16_t3 countbits(uint16_t3); +_HLSL_AVAILABILITY(shadermodel, 6.2) +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +uint16_t4 countbits(uint16_t4); +#endif + +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +int countbits(int); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +int2 countbits(int2); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +int3 countbits(int3); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +int4 countbits(int4); + +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +uint countbits(uint); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +uint2 countbits(uint2); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +uint3 countbits(uint3); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +uint4 countbits(uint4); + +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +int64_t countbits(int64_t); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +int64_t2 countbits(int64_t2); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +int64_t3 countbits(int64_t3); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +int64_t4 countbits(int64_t4); + +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +uint64_t countbits(uint64_t); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +uint64_t2 countbits(uint64_t2); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +uint64_t3 countbits(uint64_t3); +_HLSL_BUILTIN_ALIAS(__builtin_elementwise_popcount) +uint64_t4 countbits(uint64_t4); + //===----------------------------------------------------------------------===// // dot product builtins //===----------------------------------------------------------------------===// diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp index 99500daca295c9..d2570119c3432d 100644 --- a/clang/lib/Sema/SemaChecking.cpp +++ b/clang/lib/Sema/SemaChecking.cpp @@ -2795,7 +2795,7 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID, if (BuiltinElementwiseMath(TheCall)) return ExprError(); break; - + case Builtin::BI__builtin_elementwise_popcount: case Builtin::BI__builtin_elementwise_bitreverse: { if (PrepareBuiltinElementwiseMathOneArgCall(TheCall)) return ExprError(); diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 3b40769939f12f..dcefb3428a4afc 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -1522,6 +1522,14 @@ bool CheckUnsignedIntRepresentation(Sema *S, CallExpr *TheCall) { checkAllUnsignedTypes); } +bool CheckIntRepresentation(Sema *S, CallExpr *TheCall) { + auto checkAllIntTypes = [](clang::QualType PassedType) -> bool { + return !PassedType->hasIntegerRepresentation(); + }; + return CheckArgsTypesAreCorrect(S, TheCall, S->Context.IntTy, + checkAllIntTypes); +} + void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall, QualType ReturnType) { auto *VecTyA = TheCall->getArg(0)->getType()->getAs<VectorType>(); diff --git a/clang/test/CodeGen/builtins-elementwise-math.c b/clang/test/CodeGen/builtins-elementwise-math.c index 8fb52992c0fe68..7e094a52653ef0 100644 --- a/clang/test/CodeGen/builtins-elementwise-math.c +++ b/clang/test/CodeGen/builtins-elementwise-math.c @@ -570,6 +570,43 @@ void test_builtin_elementwise_log2(float f1, float f2, double d1, double d2, vf2 = __builtin_elementwise_log2(vf1); } +void test_builtin_elementwise_popcount(si8 vi1, si8 vi2, + long long int i1, long long int i2, short si, + _BitInt(31) bi1, _BitInt(31) bi2) { + + + // CHECK: [[I1:%.+]] = load i64, ptr %i1.addr, align 8 + // CHECK-NEXT: call i64 @llvm.ctpop.i64(i64 [[I1]]) + i2 = __builtin_elementwise_popcount(i1); + + // CHECK: [[VI1:%.+]] = load <8 x i16>, ptr %vi1.addr, align 16 + // CHECK-NEXT: call <8 x i16> @llvm.ctpop.v8i16(<8 x i16> [[VI1]]) + vi2 = __builtin_elementwise_popcount(vi1); + + // CHECK: [[CVI2:%.+]] = load <8 x i16>, ptr %cvi2, align 16 + // CHECK-NEXT: call <8 x i16> @llvm.ctpop.v8i16(<8 x i16> [[CVI2]]) + const si8 cvi2 = vi2; + vi2 = __builtin_elementwise_popcount(cvi2); + + // CHECK: [[BI1:%.+]] = load i32, ptr %bi1.addr, align 4 + // CHECK-NEXT: [[LOADEDV:%.+]] = trunc i32 [[BI1]] to i31 + // CHECK-NEXT: call i31 @llvm.ctpop.i31(i31 [[LOADEDV]]) + bi2 = __builtin_elementwise_popcount(bi1); + + // CHECK: [[IA1:%.+]] = load i32, ptr addrspace(1) @int_as_one, align 4 + // CHECK-NEXT: call i32 @llvm.ctpop.i32(i32 [[IA1]]) + b = __builtin_elementwise_popcount(int_as_one); + + // CHECK: call i32 @llvm.ctpop.i32(i32 -10) + b = __builtin_elementwise_popcount(-10); + + // CHECK: [[SI:%.+]] = load i16, ptr %si.addr, align 2 + // CHECK-NEXT: [[SI_EXT:%.+]] = sext i16 [[SI]] to i32 + // CHECK-NEXT: [[RES:%.+]] = call i32 @llvm.ctpop.i32(i32 [[SI_EXT]]) + // CHECK-NEXT: = trunc i32 [[RES]] to i16 + si = __builtin_elementwise_popcount(si); +} + void test_builtin_elementwise_pow(float f1, float f2, double d1, double d2, float4 vf1, float4 vf2) { diff --git a/clang/test/Sema/builtins-elementwise-math.c b/clang/test/Sema/builtins-elementwise-math.c index 2673f1f519af69..5ae86bf891b658 100644 --- a/clang/test/Sema/builtins-elementwise-math.c +++ b/clang/test/Sema/builtins-elementwise-math.c @@ -505,6 +505,27 @@ void test_builtin_elementwise_log2(int i, float f, double d, float4 v, int3 iv, // expected-error@-1 {{1st argument must be a floating point type (was 'unsigned4' (vector of 4 'unsigned int' values))}} } +void test_builtin_elementwise_popcount(int i, float f, double d, float4 v, int3 iv, unsigned u, unsigned4 uv) { + + struct Foo s = __builtin_elementwise_popcount(i); + // expected-error@-1 {{initializing 'struct Foo' with an expression of incompatible type 'int'}} + + i = __builtin_elementwise_popcount(); + // expected-error@-1 {{too few arguments to function call, expected 1, have 0}} + + i = __builtin_elementwise_popcount(f); + // expected-error@-1 {{1st argument must be a vector of integers (was 'float')}} + + i = __builtin_elementwise_popcount(f, f); + // expected-error@-1 {{too many arguments to function call, expected 1, have 2}} + + u = __builtin_elementwise_popcount(d); + // expected-error@-1 {{1st argument must be a vector of integers (was 'double')}} + + v = __builtin_elementwise_popcount(v); + // expected-error@-1 {{1st argument must be a vector of integers (was 'float4' (vector of 4 'float' values))}} +} + void test_builtin_elementwise_pow(int i, short s, double d, float4 v, int3 iv, unsigned3 uv, int *p) { i = __builtin_elementwise_pow(p, d); // expected-error@-1 {{arguments are of different types ('int *' vs 'double')}} diff --git a/clang/test/SemaCXX/builtins-elementwise-math.cpp b/clang/test/SemaCXX/builtins-elementwise-math.cpp index 898d869f4c81be..c3d8bc593c0bbc 100644 --- a/clang/test/SemaCXX/builtins-elementwise-math.cpp +++ b/clang/test/SemaCXX/builtins-elementwise-math.cpp @@ -269,3 +269,11 @@ void test_builtin_elementwise_bitreverse() { static_assert(!is_const<decltype(__builtin_elementwise_bitreverse(a))>::value); static_assert(!is_const<decltype(__builtin_elementwise_bitreverse(b))>::value); } + +void test_builtin_elementwise_popcount() { + const int a = 2; + int b = 1; + static_assert(!is_const<decltype(__builtin_elementwise_popcount(a))>::value); + static_assert(!is_const<decltype(__builtin_elementwise_popcount(b))>::value); +} + diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 4e3ecf4300d825..08978aaa8a8da6 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -550,6 +550,17 @@ def Rbits : DXILOp<30, unary> { let attributes = [Attributes<DXIL1_0, [ReadNone]>]; } +def CBits : DXILOp<31, unary> { + let Doc = "Returns the number of 1 bits in the specified value."; + let LLVMIntrinsic = int_ctpop; + let arguments = [OverloadTy]; + let result = OverloadTy; + let overloads = + [Overloads<DXIL1_0, [Int16Ty, Int32Ty, Int64Ty]>]; + let stages = [Stages<DXIL1_0, [all_stages]>]; + let attributes = [Attributes<DXIL1_0, [ReadNone]>]; +} + def FMax : DXILOp<35, binary> { let Doc = "Float maximum. FMax(a,b) = a > b ? a : b"; let LLVMIntrinsic = int_maxnum; diff --git a/llvm/test/CodeGen/DirectX/countbits.ll b/llvm/test/CodeGen/DirectX/countbits.ll new file mode 100644 index 00000000000000..9ebce58109e871 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/countbits.ll @@ -0,0 +1,31 @@ +; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s + +; Make sure dxil operation function calls for countbits are generated for all integer types. + +; Function Attrs: nounwind +define noundef i16 @test_countbits_short(i16 noundef %a) { +entry: +; CHECK:call i16 @dx.op.unary.i16(i32 31, i16 %{{.*}}) + %elt.ctpop = call i16 @llvm.ctpop.i16(i16 %a) + ret i16 %elt.ctpop +} + +; Function Attrs: nounwind +define noundef i32 @test_countbits_int(i32 noundef %a) { +entry: +; CHECK:call i32 @dx.op.unary.i32(i32 31, i32 %{{.*}}) + %elt.ctpop = call i32 @llvm.ctpop.i32(i32 %a) + ret i32 %elt.ctpop +} + +; Function Attrs: nounwind +define noundef i64 @test_countbits_long(i64 noundef %a) { +entry: +; CHECK:call i64 @dx.op.unary.i64(i32 31, i64 %{{.*}}) + %elt.ctpop = call i64 @llvm.ctpop.i64(i64 %a) + ret i64 %elt.ctpop +} + +declare i16 @llvm.ctpop.i16(i16) +declare i32 @llvm.ctpop.i32(i32) +declare i64 @llvm.ctpop.i64(i64) diff --git a/llvm/test/CodeGen/DirectX/countbits_error.ll b/llvm/test/CodeGen/DirectX/countbits_error.ll new file mode 100644 index 00000000000000..e7adb103eaae7c --- /dev/null +++ b/llvm/test/CodeGen/DirectX/countbits_error.ll @@ -0,0 +1,10 @@ +; RUN: not opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s 2>&1 | FileCheck %s + +; DXIL operation ctpop does not support double overload type +; CHECK: invalid intrinsic signature + +define noundef double @countbits_double(double noundef %a) { +entry: + %elt.ctpop = call double @llvm.ctpop.f64(double %a) + ret double %elt.ctpop +} diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/countbits.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/countbits.ll new file mode 100644 index 00000000000000..57ec0bda2e1890 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/countbits.ll @@ -0,0 +1,21 @@ +; RUN: llc -verify-machineinstrs -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %} + +; CHECK: OpMemoryModel Logical GLSL450 + +define noundef i32 @countbits_i32(i32 noundef %a) { +entry: +; CHECK: %[[#]] = OpBitCount %[[#]] %[[#]] + %elt.bitreverse = call i32 @llvm.ctpop.i32(i32 %a) + ret i32 %elt.bitreverse +} + +define noundef i16 @countbits_i16(i16 noundef %a) { +entry: +; CHECK: %[[#]] = OpBitCount %[[#]] %[[#]] + %elt.ctpop = call i16 @llvm.ctpop.i16(i16 %a) + ret i16 %elt.ctpop +} + +declare i16 @llvm.ctpop.i16(i16) +declare i32 @llvm.ctpop.i32(i32) _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits