https://github.com/bogner updated 
https://github.com/llvm/llvm-project/pull/137877

>From a6c359a02712277ba2feab9e2f1be1caf8fa650e Mon Sep 17 00:00:00 2001
From: Justin Bogner <m...@justinbogner.com>
Date: Tue, 29 Apr 2025 11:59:37 -0700
Subject: [PATCH] [HLSL] Overloads for `lerp` with a scalar weight

This adds overloads for the `lerp` function that accept a scalar for the weight
parameter by splatting it into the appropriate vector.

Fixes #137827
---
 .../lib/Headers/hlsl/hlsl_compat_overloads.h  |  6 +++++
 clang/lib/Sema/SemaHLSL.cpp                   |  3 ++-
 .../CodeGenHLSL/builtins/lerp-overloads.hlsl  | 23 +++++++++++++++++--
 clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl | 22 +++++++++---------
 4 files changed, 40 insertions(+), 14 deletions(-)

diff --git a/clang/lib/Headers/hlsl/hlsl_compat_overloads.h 
b/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
index 47ae34adfe541..4874206d349c0 100644
--- a/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
+++ b/clang/lib/Headers/hlsl/hlsl_compat_overloads.h
@@ -277,6 +277,12 @@ constexpr bool4 isinf(double4 V) { return 
isinf((float4)V); }
 // lerp builtins overloads
 
//===----------------------------------------------------------------------===//
 
+template <typename T, uint N>
+constexpr __detail::enable_if_t<(N > 1 && N <= 4), vector<T, N>>
+lerp(vector<T, N> x, vector<T, N> y, T s) {
+  return lerp(x, y, (vector<T, N>)s);
+}
+
 _DXC_COMPAT_TERNARY_DOUBLE_OVERLOADS(lerp)
 _DXC_COMPAT_TERNARY_INTEGER_OVERLOADS(lerp)
 
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index f51bde4827ad1..70aacaa2aadbe 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2587,7 +2587,8 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned 
BuiltinID, CallExpr *TheCall) {
   case Builtin::BI__builtin_hlsl_lerp: {
     if (SemaRef.checkArgCount(TheCall, 3))
       return true;
-    if (CheckVectorElementCallArgs(&SemaRef, TheCall))
+    if (CheckAnyScalarOrVector(&SemaRef, TheCall, 0) ||
+        CheckAllArgsHaveSameType(&SemaRef, TheCall))
       return true;
     if (SemaRef.BuiltinElementwiseTernaryMath(TheCall))
       return true;
diff --git a/clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl 
b/clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl
index e80bdb4734487..3cb14f8555cab 100644
--- a/clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl
+++ b/clang/test/CodeGenHLSL/builtins/lerp-overloads.hlsl
@@ -1,5 +1,7 @@
-// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -x hlsl -triple  
dxil-pc-shadermodel6.3-library %s -emit-llvm -o - | FileCheck %s 
--check-prefixes=CHECK -DFNATTRS="noundef nofpclass(nan inf)" -DTARGET=dx
-// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -x hlsl -triple 
spirv-unknown-vulkan-compute %s -emit-llvm -o - | FileCheck %s 
--check-prefixes=CHECK -DFNATTRS="spir_func noundef nofpclass(nan inf)" 
-DTARGET=spv
+// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -x hlsl -triple  
dxil-pc-shadermodel6.3-library %s -fnative-half-type -emit-llvm -o - | 
FileCheck %s --check-prefixes=CHECK,NATIVE_HALF -DFNATTRS="noundef 
nofpclass(nan inf)" -DTARGET=dx
+// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -x hlsl -triple  
dxil-pc-shadermodel6.3-library %s -emit-llvm -o - | FileCheck %s 
--check-prefixes=CHECK,NO_HALF -DFNATTRS="noundef nofpclass(nan inf)" 
-DTARGET=dx
+// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -x hlsl -triple 
spirv-unknown-vulkan-compute %s -fnative-half-type -emit-llvm -o - | FileCheck 
%s --check-prefixes=CHECK,NATIVE_HALF -DFNATTRS="spir_func noundef 
nofpclass(nan inf)" -DTARGET=spv
+// RUN: %clang_cc1 -std=hlsl202x -finclude-default-header -x hlsl -triple 
spirv-unknown-vulkan-compute %s -emit-llvm -o - | FileCheck %s 
--check-prefixes=CHECK,NO_HALF -DFNATTRS="spir_func noundef nofpclass(nan inf)" 
-DTARGET=spv
 
 // CHECK: define [[FNATTRS]] float @_Z16test_lerp_doubled(
 // CHECK:    [[CONV0:%.*]] = fptrunc {{.*}} double %{{.*}} to float
@@ -160,3 +162,20 @@ float3 test_lerp_uint64_t3(uint64_t3 p0) { return lerp(p0, 
p0, p0); }
 // CHECK:    [[LERP:%.*]] = call {{.*}} <4 x float> 
@llvm.[[TARGET]].lerp.v4f32(<4 x float> [[CONV0]], <4 x float> [[CONV1]], <4 x 
float> [[CONV2]])
 // CHECK:    ret <4 x float> [[LERP]]
 float4 test_lerp_uint64_t4(uint64_t4 p0) { return lerp(p0, p0, p0); }
+
+// NATIVE_HALF: define [[FNATTRS]] <3 x [[TY:half]]> 
@_Z21test_lerp_half_scalarDv3_DhS_Dh{{.*}}(
+// NO_HALF: define [[FNATTRS]] <3 x [[TY:float]]> 
@_Z21test_lerp_half_scalarDv3_DhS_Dh(
+// CHECK:    [[SPLATINSERT:%.*]] = insertelement <3 x [[TY]]> poison, [[TY]] 
%{{.*}}, i64 0
+// CHECK:    [[SPLAT:%.*]] = shufflevector <3 x [[TY]]> [[SPLATINSERT]], <3 x 
[[TY]]> poison, <3 x i32> zeroinitializer
+// CHECK:    [[LERP:%.*]] = call {{.*}} <3 x [[TY]]> 
@llvm.[[TARGET]].lerp.{{.*}}(<3 x [[TY]]> {{.*}}, <3 x [[TY]]> {{.*}}, <3 x 
[[TY]]> [[SPLAT]])
+// CHECK:    ret <3 x [[TY]]> [[LERP]]
+half3 test_lerp_half_scalar(half3 x, half3 y, half s) { return lerp(x, y, s); }
+
+// CHECK: define [[FNATTRS]] <3 x float> @_Z22test_lerp_float_scalarDv3_fS_f(
+// CHECK:    [[SPLATINSERT:%.*]] = insertelement <3 x float> poison, float 
%{{.*}}, i64 0
+// CHECK:    [[SPLAT:%.*]] = shufflevector <3 x float> [[SPLATINSERT]], <3 x 
float> poison, <3 x i32> zeroinitializer
+// CHECK:    [[LERP:%.*]] = call {{.*}} <3 x float> 
@llvm.[[TARGET]].lerp.v3f32(<3 x float> {{.*}}, <3 x float> {{.*}}, <3 x float> 
[[SPLAT]])
+// CHECK:    ret <3 x float> [[LERP]]
+float3 test_lerp_float_scalar(float3 x, float3 y, float s) {
+  return lerp(x, y, s);
+}
diff --git a/clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl 
b/clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl
index 398d3c7f938c1..b4734a985f31c 100644
--- a/clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl
+++ b/clang/test/SemaHLSL/BuiltIns/lerp-errors.hlsl
@@ -62,42 +62,42 @@ float2 test_lerp_element_type_mismatch(half2 p0, float2 p1) 
{
 
 float2 test_builtin_lerp_float2_splat(float p0, float2 p1) {
   return __builtin_hlsl_lerp(p0, p1, p1);
-  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be 
vectors}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the 
same type}}
 }
 
 float2 test_builtin_lerp_float2_splat2(double p0, double2 p1) {
   return __builtin_hlsl_lerp(p1, p0, p1);
-  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be 
vectors}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the 
same type}}
 }
 
 float2 test_builtin_lerp_float2_splat3(double p0, double2 p1) {
   return __builtin_hlsl_lerp(p1, p1, p0);
-  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be 
vectors}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the 
same type}}
 }
 
 float3 test_builtin_lerp_float3_splat(float p0, float3 p1) {
   return __builtin_hlsl_lerp(p0, p1, p1);
-  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be 
vectors}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the 
same type}}
 }
 
 float4 test_builtin_lerp_float4_splat(float p0, float4 p1) {
   return __builtin_hlsl_lerp(p0, p1, p1);
-  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be 
vectors}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the 
same type}}
 }
 
 float2 test_lerp_float2_int_splat(float2 p0, int p1) {
   return __builtin_hlsl_lerp(p0, p1, p1);
-  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be 
vectors}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the 
same type}}
 }
 
 float3 test_lerp_float3_int_splat(float3 p0, int p1) {
   return __builtin_hlsl_lerp(p0, p1, p1);
-  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be 
vectors}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the 
same type}}
 }
 
 float2 test_builtin_lerp_int_vect_to_float_vec_promotion(int2 p0, float p1) {
   return __builtin_hlsl_lerp(p0, p1, p1);
-  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must be 
vectors}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the 
same type}}
 }
 
 float test_builtin_lerp_bool_type_promotion(bool p0) {
@@ -107,17 +107,17 @@ float test_builtin_lerp_bool_type_promotion(bool p0) {
 
 float builtin_bool_to_float_type_promotion(float p0, bool p1) {
   return __builtin_hlsl_lerp(p0, p0, p1);
-  // expected-error@-1 {{3rd argument must be a scalar or vector of 
floating-point types (was 'bool')}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the 
same type}}
 }
 
 float builtin_bool_to_float_type_promotion2(bool p0, float p1) {
   return __builtin_hlsl_lerp(p1, p0, p1);
-  // expected-error@-1 {{2nd argument must be a scalar or vector of 
floating-point types (was 'bool')}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the 
same type}}
 }
 
 float builtin_lerp_int_to_float_promotion(float p0, int p1) {
   return __builtin_hlsl_lerp(p0, p0, p1);
-  // expected-error@-1 {{3rd argument must be a scalar or vector of 
floating-point types (was 'int')}}
+  // expected-error@-1 {{all arguments to '__builtin_hlsl_lerp' must have the 
same type}}
 }
 
 float4 test_lerp_int4(int4 p0, int4 p1, int4 p2) {

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to