arsenm created this revision. arsenm added reviewers: fhahn, junaire, bob80905, python3kgae, RKSimon, aaron.ballman, erichkeane, scanon. Herald added a project: All. arsenm requested review of this revision. Herald added a subscriber: wdng.
I didn't understand why the other builtins have promotion logic, or how it would apply for a ternary operation. Implicit conversions are evil to begin with, and even more so when the purpose is to get an exact IR intrinsic. This checks all the arguments have the same type. https://reviews.llvm.org/D140992 Files: clang/docs/LanguageExtensions.rst clang/include/clang/Basic/Builtins.def clang/include/clang/Sema/Sema.h clang/lib/CodeGen/CGBuiltin.cpp clang/lib/Sema/SemaChecking.cpp clang/test/CodeGen/builtins-elementwise-math.c clang/test/Sema/builtins-elementwise-math.c
Index: clang/test/Sema/builtins-elementwise-math.c =================================================================== --- clang/test/Sema/builtins-elementwise-math.c +++ clang/test/Sema/builtins-elementwise-math.c @@ -4,6 +4,8 @@ typedef double double4 __attribute__((ext_vector_type(4))); typedef float float2 __attribute__((ext_vector_type(2))); typedef float float4 __attribute__((ext_vector_type(4))); + +typedef int int2 __attribute__((ext_vector_type(2))); typedef int int3 __attribute__((ext_vector_type(3))); typedef unsigned unsigned3 __attribute__((ext_vector_type(3))); typedef unsigned unsigned4 __attribute__((ext_vector_type(4))); @@ -509,3 +511,84 @@ float2 tmp9 = __builtin_elementwise_copysign(v4f32, v4f32); // expected-error@-1 {{initializing 'float2' (vector of 2 'float' values) with an expression of incompatible type 'float4' (vector of 4 'float' values)}} } + +void test_builtin_elementwise_fma(int i32, int2 v2i32, short i16, + double f64, double2 v2f64, double2 v3f64, + float f32, float2 v2f32, float v3f32, float4 v4f32, + const float4 c_v4f32, + int3 v3i32, int *ptr) { + + f32 = __builtin_elementwise_fma(); + // expected-error@-1 {{too few arguments to function call, expected 3, have 0}} + + f32 = __builtin_elementwise_fma(f32); + // expected-error@-1 {{too few arguments to function call, expected 3, have 1}} + + f32 = __builtin_elementwise_fma(f32, f32); + // expected-error@-1 {{too few arguments to function call, expected 3, have 2}} + + f32 = __builtin_elementwise_fma(f32, f32, f32, f32); + // expected-error@-1 {{too many arguments to function call, expected 3, have 4}} + + f32 = __builtin_elementwise_fma(f64, f32, f32); + // expected-error@-1 {{arguments are of different types ('double' vs 'float')}} + + f32 = __builtin_elementwise_fma(f32, f64, f32); + // expected-error@-1 {{arguments are of different types ('float' vs 'double')}} + + f32 = __builtin_elementwise_fma(f32, f32, f64); + // expected-error@-1 {{arguments are of different types ('float' vs 'double')}} + + f32 = __builtin_elementwise_fma(f32, f32, f64); + // expected-error@-1 {{arguments are of different types ('float' vs 'double')}} + + f64 = __builtin_elementwise_fma(f64, f32, f32); + // expected-error@-1 {{arguments are of different types ('double' vs 'float')}} + + f64 = __builtin_elementwise_fma(f64, f64, f32); + // expected-error@-1 {{arguments are of different types ('double' vs 'float')}} + + f64 = __builtin_elementwise_fma(f64, f32, f64); + // expected-error@-1 {{arguments are of different types ('double' vs 'float')}} + + v2f64 = __builtin_elementwise_fma(v2f32, f64, f64); + // expected-error@-1 {{arguments are of different types ('float2' (vector of 2 'float' values) vs 'double'}} + + v2f64 = __builtin_elementwise_fma(v2f32, v2f64, f64); + // expected-error@-1 {{arguments are of different types ('float2' (vector of 2 'float' values) vs 'double2' (vector of 2 'double' values)}} + + v2f64 = __builtin_elementwise_fma(v2f32, f64, v2f64); + // expected-error@-1 {{arguments are of different types ('float2' (vector of 2 'float' values) vs 'double'}} + + v2f64 = __builtin_elementwise_fma(f64, v2f32, v2f64); + // expected-error@-1 {{arguments are of different types ('double' vs 'float2' (vector of 2 'float' values)}} + + v2f64 = __builtin_elementwise_fma(f64, v2f64, v2f64); + // expected-error@-1 {{arguments are of different types ('double' vs 'double2' (vector of 2 'double' values)}} + + i32 = __builtin_elementwise_fma(i32, i32, i32); + // expected-error@-1 {{1st argument must be a floating point type (was 'int')}} + + v2i32 = __builtin_elementwise_fma(v2i32, v2i32, v2i32); + // expected-error@-1 {{1st argument must be a floating point type (was 'int2' (vector of 2 'int' values))}} + + f32 = __builtin_elementwise_fma(f32, f32, i32); + // expected-error@-1 {{3rd argument must be a floating point type (was 'int')}} + + f32 = __builtin_elementwise_fma(f32, i32, f32); + // expected-error@-1 {{2nd argument must be a floating point type (was 'int')}} + + f32 = __builtin_elementwise_fma(f32, f32, i32); + // expected-error@-1 {{3rd argument must be a floating point type (was 'int')}} + + + _Complex float c1, c2, c3; + c1 = __builtin_elementwise_fma(c1, f32, f32); + // expected-error@-1 {{1st argument must be a floating point type (was '_Complex float')}} + + c2 = __builtin_elementwise_fma(f32, c2, f32); + // expected-error@-1 {{2nd argument must be a floating point type (was '_Complex float')}} + + c3 = __builtin_elementwise_fma(f32, f32, c3); + // expected-error@-1 {{3rd argument must be a floating point type (was '_Complex float')}} +} Index: clang/test/CodeGen/builtins-elementwise-math.c =================================================================== --- clang/test/CodeGen/builtins-elementwise-math.c +++ clang/test/CodeGen/builtins-elementwise-math.c @@ -1,5 +1,9 @@ // RUN: %clang_cc1 -triple x86_64-apple-darwin %s -emit-llvm -disable-llvm-passes -o - | FileCheck %s +typedef _Float16 half; + +typedef half half2 __attribute__((ext_vector_type(2))); +typedef float float2 __attribute__((ext_vector_type(2))); typedef float float4 __attribute__((ext_vector_type(4))); typedef short int si8 __attribute__((ext_vector_type(8))); typedef unsigned int u4 __attribute__((ext_vector_type(4))); @@ -477,3 +481,77 @@ // CHECK-NEXT: call <2 x double> @llvm.copysign.v2f64(<2 x double> <double 1.000000e+00, double 1.000000e+00>, <2 x double> [[V2F64]]) v2f64 = __builtin_elementwise_copysign((double2)1.0, v2f64); } + +void test_builtin_elementwise_fma(float f32, double f64, + float2 v2f32, float4 v4f32, + double2 v2f64, double3 v3f64, + const float4 c_v4f32, + half f16, half2 v2f16) { + // CHECK-LABEL: define void @test_builtin_elementwise_fma( + // CHECK: [[F32_0:%.+]] = load float, ptr %f32.addr + // CHECK-NEXT: [[F32_1:%.+]] = load float, ptr %f32.addr + // CHECK-NEXT: [[F32_2:%.+]] = load float, ptr %f32.addr + // CHECK-NEXT: call float @llvm.fma.f32(float [[F32_0]], float [[F32_1]], float [[F32_2]]) + float f2 = __builtin_elementwise_fma(f32, f32, f32); + + // CHECK: [[F64_0:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: [[F64_1:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: [[F64_2:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: call double @llvm.fma.f64(double [[F64_0]], double [[F64_1]], double [[F64_2]]) + double d2 = __builtin_elementwise_fma(f64, f64, f64); + + // CHECK: [[V4F32_0:%.+]] = load <4 x float>, ptr %v4f32.addr + // CHECK-NEXT: [[V4F32_1:%.+]] = load <4 x float>, ptr %v4f32.addr + // CHECK-NEXT: [[V4F32_2:%.+]] = load <4 x float>, ptr %v4f32.addr + // CHECK-NEXT: call <4 x float> @llvm.fma.v4f32(<4 x float> [[V4F32_0]], <4 x float> [[V4F32_1]], <4 x float> [[V4F32_2]]) + float4 tmp_v4f32 = __builtin_elementwise_fma(v4f32, v4f32, v4f32); + + + // FIXME: Are we really still doing the 3 vector load workaround + // CHECK: [[V3F64_LOAD_0:%.+]] = load <4 x double>, ptr %v3f64.addr + // CHECK-NEXT: [[V3F64_0:%.+]] = shufflevector + // CHECK-NEXT: [[V3F64_LOAD_1:%.+]] = load <4 x double>, ptr %v3f64.addr + // CHECK-NEXT: [[V3F64_1:%.+]] = shufflevector + // CHECK-NEXT: [[V3F64_LOAD_2:%.+]] = load <4 x double>, ptr %v3f64.addr + // CHECK-NEXT: [[V3F64_2:%.+]] = shufflevector + // CHECK-NEXT: call <3 x double> @llvm.fma.v3f64(<3 x double> [[V3F64_0]], <3 x double> [[V3F64_1]], <3 x double> [[V3F64_2]]) + v3f64 = __builtin_elementwise_fma(v3f64, v3f64, v3f64); + + // CHECK: [[F64_0:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: [[F64_1:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: [[F64_2:%.+]] = load double, ptr %f64.addr + // CHECK-NEXT: call double @llvm.fma.f64(double [[F64_0]], double [[F64_1]], double [[F64_2]]) + v2f64 = __builtin_elementwise_fma(f64, f64, f64); + + // CHECK: [[V4F32_0:%.+]] = load <4 x float>, ptr %c_v4f32.addr + // CHECK-NEXT: [[V4F32_1:%.+]] = load <4 x float>, ptr %c_v4f32.addr + // CHECK-NEXT: [[V4F32_2:%.+]] = load <4 x float>, ptr %c_v4f32.addr + // CHECK-NEXT: call <4 x float> @llvm.fma.v4f32(<4 x float> [[V4F32_0]], <4 x float> [[V4F32_1]], <4 x float> [[V4F32_2]]) + v4f32 = __builtin_elementwise_fma(c_v4f32, c_v4f32, c_v4f32); + + // CHECK: [[F16_0:%.+]] = load half, ptr %f16.addr + // CHECK-NEXT: [[F16_1:%.+]] = load half, ptr %f16.addr + // CHECK-NEXT: [[F16_2:%.+]] = load half, ptr %f16.addr + // CHECK-NEXT: call half @llvm.fma.f16(half [[F16_0]], half [[F16_1]], half [[F16_2]]) + half tmp_f16 = __builtin_elementwise_fma(f16, f16, f16); + + // CHECK: [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: [[V2F16_2:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> [[V2F16_2]]) + half2 tmp0_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, v2f16); + + // CHECK: [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: [[F16_2:%.+]] = load half, ptr %f16.addr + // CHECK-NEXT: [[V2F16_2_INSERT:%.+]] = insertelement + // CHECK-NEXT: [[V2F16_2:%.+]] = shufflevector <2 x half> [[V2F16_2_INSERT]], <2 x half> poison, <2 x i32> zeroinitializer + // CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> [[V2F16_2]]) + half2 tmp1_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, (half2)f16); + + // CHECK: [[V2F16_0:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: [[V2F16_1:%.+]] = load <2 x half>, ptr %v2f16.addr + // CHECK-NEXT: call <2 x half> @llvm.fma.v2f16(<2 x half> [[V2F16_0]], <2 x half> [[V2F16_1]], <2 x half> <half 0xH4400, half 0xH4400>) + half2 tmp2_v2f16 = __builtin_elementwise_fma(v2f16, v2f16, (half2)4.0); + +} Index: clang/lib/Sema/SemaChecking.cpp =================================================================== --- clang/lib/Sema/SemaChecking.cpp +++ clang/lib/Sema/SemaChecking.cpp @@ -2612,20 +2612,16 @@ return ExprError(); QualType ArgTy = TheCall->getArg(0)->getType(); - QualType EltTy = ArgTy; - - if (auto *VecTy = EltTy->getAs<VectorType>()) - EltTy = VecTy->getElementType(); - if (!EltTy->isFloatingType()) { - Diag(TheCall->getArg(0)->getBeginLoc(), - diag::err_builtin_invalid_arg_type) - << 1 << /* float ty*/ 5 << ArgTy; - + if (checkFPMathBuiltinElementType(*this, TheCall->getArg(0)->getBeginLoc(), + ArgTy, 1)) + return ExprError(); + break; + } + case Builtin::BI__builtin_elementwise_fma: { + if (SemaBuiltinElementwiseTernaryMath(TheCall)) return ExprError(); - } break; } - // These builtins restrict the element type to integer // types only. case Builtin::BI__builtin_elementwise_add_sat: @@ -17757,6 +17753,41 @@ return false; } +bool Sema::SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall) { + if (checkArgCount(*this, TheCall, 3)) + return true; + + Expr *Args[3]; + for (int I = 0; I < 3; ++I) { + ExprResult Converted = UsualUnaryConversions(TheCall->getArg(I)); + if (Converted.isInvalid()) + return true; + Args[I] = Converted.get(); + } + + int ArgOrdinal = 1; + for (Expr *Arg : Args) { + if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(), + ArgOrdinal++)) + return true; + } + + for (int I = 1; I < 3; ++I) { + if (Args[0]->getType().getCanonicalType() != + Args[I]->getType().getCanonicalType()) { + return Diag(Args[0]->getBeginLoc(), + diag::err_typecheck_call_different_arg_types) + << Args[0]->getType() << Args[I]->getType(); + } + } + + for (int I = 0; I < 3; ++I) + TheCall->setArg(I, Args[I]); + + TheCall->setType(Args[0]->getType()); + return false; +} + bool Sema::PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall) { if (checkArgCount(*this, TheCall, 1)) return true; Index: clang/lib/CodeGen/CGBuiltin.cpp =================================================================== --- clang/lib/CodeGen/CGBuiltin.cpp +++ clang/lib/CodeGen/CGBuiltin.cpp @@ -3097,6 +3097,8 @@ emitUnaryBuiltin(*this, E, llvm::Intrinsic::canonicalize, "elt.trunc")); case Builtin::BI__builtin_elementwise_copysign: return RValue::get(emitBinaryBuiltin(*this, E, llvm::Intrinsic::copysign)); + case Builtin::BI__builtin_elementwise_fma: + return RValue::get(emitTernaryBuiltin(*this, E, llvm::Intrinsic::fma)); case Builtin::BI__builtin_elementwise_add_sat: case Builtin::BI__builtin_elementwise_sub_sat: { Value *Op0 = EmitScalarExpr(E->getArg(0)); Index: clang/include/clang/Sema/Sema.h =================================================================== --- clang/include/clang/Sema/Sema.h +++ clang/include/clang/Sema/Sema.h @@ -13444,6 +13444,7 @@ bool CheckPPCMMAType(QualType Type, SourceLocation TypeLoc); bool SemaBuiltinElementwiseMath(CallExpr *TheCall); + bool SemaBuiltinElementwiseTernaryMath(CallExpr *TheCall); bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall); bool PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall); Index: clang/include/clang/Basic/Builtins.def =================================================================== --- clang/include/clang/Basic/Builtins.def +++ clang/include/clang/Basic/Builtins.def @@ -667,6 +667,7 @@ BUILTIN(__builtin_elementwise_trunc, "v.", "nct") BUILTIN(__builtin_elementwise_canonicalize, "v.", "nct") BUILTIN(__builtin_elementwise_copysign, "v.", "nct") +BUILTIN(__builtin_elementwise_fma, "v.", "nct") BUILTIN(__builtin_elementwise_add_sat, "v.", "nct") BUILTIN(__builtin_elementwise_sub_sat, "v.", "nct") BUILTIN(__builtin_reduce_max, "v.", "nct") Index: clang/docs/LanguageExtensions.rst =================================================================== --- clang/docs/LanguageExtensions.rst +++ clang/docs/LanguageExtensions.rst @@ -631,6 +631,7 @@ =========================================== ================================================================ ========================================= T __builtin_elementwise_abs(T x) return the absolute value of a number x; the absolute value of signed integer and floating point types the most negative integer remains the most negative integer + T __builtin_elementwise_fma(T x, T y, T z) fused multiply add. floating point types T __builtin_elementwise_ceil(T x) return the smallest integral value greater than or equal to x floating point types T __builtin_elementwise_sin(T x) return the sine of x interpreted as an angle in radians floating point types T __builtin_elementwise_cos(T x) return the cosine of x interpreted as an angle in radians floating point types
_______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits