uriel.k created this revision.
uriel.k added reviewers: craig.topper, igorb.

seperates the non-round version from the round version of sqrt builtins
and catching them in CGBuiltin.cpp to replace builtin with IR.


https://reviews.llvm.org/D41168

Files:
  include/clang/Basic/BuiltinsX86.def
  lib/CodeGen/CGBuiltin.cpp
  lib/Headers/avx512fintrin.h
  test/CodeGen/avx512f-builtins.c

Index: test/CodeGen/avx512f-builtins.c
===================================================================
--- test/CodeGen/avx512f-builtins.c
+++ test/CodeGen/avx512f-builtins.c
@@ -5,21 +5,25 @@
 __m512d test_mm512_sqrt_pd(__m512d a)
 {
   // CHECK-LABEL: @test_mm512_sqrt_pd
-  // CHECK: @llvm.x86.avx512.mask.sqrt.pd.512
+  // CHECK: @llvm.sqrt.v8f64
   return _mm512_sqrt_pd(a);
 }
 
 __m512d test_mm512_mask_sqrt_pd (__m512d __W, __mmask8 __U, __m512d __A)
 {
   // CHECK-LABEL: @test_mm512_mask_sqrt_pd 
-  // CHECK: @llvm.x86.avx512.mask.sqrt.pd.512
+  // CHECK: @llvm.sqrt.v8f64
+  // CHECK: bitcast
+  // CHECK: select
   return _mm512_mask_sqrt_pd (__W,__U,__A);
 }
 
 __m512d test_mm512_maskz_sqrt_pd (__mmask8 __U, __m512d __A)
 {
   // CHECK-LABEL: @test_mm512_maskz_sqrt_pd 
-  // CHECK: @llvm.x86.avx512.mask.sqrt.pd.512
+  // CHECK: @llvm.sqrt.v8f64
+  // CHECK: bitcast
+  // CHECK: select
   return _mm512_maskz_sqrt_pd (__U,__A);
 }
 
@@ -47,21 +51,25 @@
 __m512 test_mm512_sqrt_ps(__m512 a)
 {
   // CHECK-LABEL: @test_mm512_sqrt_ps
-  // CHECK: @llvm.x86.avx512.mask.sqrt.ps.512
+  // CHECK: @llvm.sqrt.v16f32
   return _mm512_sqrt_ps(a);
 }
 
 __m512 test_mm512_mask_sqrt_ps(__m512 __W, __mmask16 __U, __m512 __A)
 {
   // CHECK-LABEL: @test_mm512_mask_sqrt_ps
-  // CHECK: @llvm.x86.avx512.mask.sqrt.ps.512
+  // CHECK: @llvm.sqrt.v16f32
+  // CHECK: bitcast
+  // CHECK: select
   return _mm512_mask_sqrt_ps( __W, __U, __A);
 }
 
 __m512 test_mm512_maskz_sqrt_ps( __mmask16 __U, __m512 __A)
 {
   // CHECK-LABEL: @test_mm512_maskz_sqrt_ps
-  // CHECK: @llvm.x86.avx512.mask.sqrt.ps.512
+  // CHECK: @llvm.sqrt.v16f32
+  // CHECK: bitcast
+  // CHECK: select
   return _mm512_maskz_sqrt_ps(__U ,__A);
 }
 
@@ -4618,7 +4626,13 @@
 }
 
 __m128d test_mm_mask_sqrt_sd(__m128d __W, __mmask8 __U, __m128d __A, __m128d __B){
-  // CHECK: @llvm.x86.avx512.mask.sqrt.sd
+  // CHECK: extractelement
+  // CHECK: extractelement
+  // CHECK: bitcast
+  // CHECK: extractelement
+  // CHECK: llvm.sqrt.f64
+  // CHECK: select
+  // CHECK: insertelement
     return _mm_mask_sqrt_sd(__W,__U,__A,__B);
 }
 
@@ -4628,7 +4642,12 @@
 }
 
 __m128d test_mm_maskz_sqrt_sd(__mmask8 __U, __m128d __A, __m128d __B){
-  // CHECK: @llvm.x86.avx512.mask.sqrt.sd
+  // CHECK: extractelement
+  // CHECK: bitcast
+  // CHECK: extractelement
+  // CHECK: llvm.sqrt.f64
+  // CHECK: select
+  // CHECK: insertelement
     return _mm_maskz_sqrt_sd(__U,__A,__B);
 }
 
@@ -4644,7 +4663,13 @@
 }
 
 __m128 test_mm_mask_sqrt_ss(__m128 __W, __mmask8 __U, __m128 __A, __m128 __B){
-  // CHECK: @llvm.x86.avx512.mask.sqrt.ss
+  // CHECK: extractelement
+  // CHECK: extractelement
+  // CHECK: bitcast
+  // CHECK: extractelement
+  // CHECK: llvm.sqrt.f32
+  // CHECK: select
+  // CHECK: insertelement
     return _mm_mask_sqrt_ss(__W,__U,__A,__B);
 }
 
@@ -4654,7 +4679,12 @@
 }
 
 __m128 test_mm_maskz_sqrt_ss(__mmask8 __U, __m128 __A, __m128 __B){
-  // CHECK: @llvm.x86.avx512.mask.sqrt.ss
+  // CHECK: extractelement
+  // CHECK: bitcast
+  // CHECK: extractelement
+  // CHECK: llvm.sqrt.f32
+  // CHECK: select
+  // CHECK: insertelement
     return _mm_maskz_sqrt_ss(__U,__A,__B);
 }
 
Index: lib/Headers/avx512fintrin.h
===================================================================
--- lib/Headers/avx512fintrin.h
+++ lib/Headers/avx512fintrin.h
@@ -1601,29 +1601,26 @@
 static  __inline__ __m512d __DEFAULT_FN_ATTRS
 _mm512_sqrt_pd(__m512d __a)
 {
-  return (__m512d)__builtin_ia32_sqrtpd512_mask((__v8df)__a,
+  return (__m512d)__builtin_ia32_sqrtpd512_mask_nr ((__v8df)__a,
                                                 (__v8df) _mm512_setzero_pd (),
-                                                (__mmask8) -1,
-                                                _MM_FROUND_CUR_DIRECTION);
+                                                (__mmask8) -1);
 }
 
 static __inline__ __m512d __DEFAULT_FN_ATTRS
 _mm512_mask_sqrt_pd (__m512d __W, __mmask8 __U, __m512d __A)
 {
-  return (__m512d) __builtin_ia32_sqrtpd512_mask ((__v8df) __A,
+  return (__m512d) __builtin_ia32_sqrtpd512_mask_nr ((__v8df) __A,
                    (__v8df) __W,
-                   (__mmask8) __U,
-                   _MM_FROUND_CUR_DIRECTION);
+                   (__mmask8) __U);
 }
 
 static __inline__ __m512d __DEFAULT_FN_ATTRS
 _mm512_maskz_sqrt_pd (__mmask8 __U, __m512d __A)
 {
-  return (__m512d) __builtin_ia32_sqrtpd512_mask ((__v8df) __A,
+  return (__m512d) __builtin_ia32_sqrtpd512_mask_nr ((__v8df) __A,
                    (__v8df)
                    _mm512_setzero_pd (),
-                   (__mmask8) __U,
-                   _MM_FROUND_CUR_DIRECTION);
+                   (__mmask8) __U);
 }
 
 #define _mm512_mask_sqrt_round_ps(W, U, A, R) __extension__ ({ \
@@ -1644,28 +1641,25 @@
 static  __inline__ __m512 __DEFAULT_FN_ATTRS
 _mm512_sqrt_ps(__m512 __a)
 {
-  return (__m512)__builtin_ia32_sqrtps512_mask((__v16sf)__a,
+  return (__m512)__builtin_ia32_sqrtps512_mask_nr ((__v16sf)__a,
                                                (__v16sf) _mm512_setzero_ps (),
-                                               (__mmask16) -1,
-                                               _MM_FROUND_CUR_DIRECTION);
+                                               (__mmask16) -1);
 }
 
 static  __inline__ __m512 __DEFAULT_FN_ATTRS
 _mm512_mask_sqrt_ps(__m512 __W, __mmask16 __U, __m512 __A)
 {
-  return (__m512)__builtin_ia32_sqrtps512_mask((__v16sf)__A,
+  return (__m512)__builtin_ia32_sqrtps512_mask_nr ((__v16sf)__A,
                                                (__v16sf) __W,
-                                               (__mmask16) __U,
-                                               _MM_FROUND_CUR_DIRECTION);
+                                               (__mmask16) __U);
 }
 
 static  __inline__ __m512 __DEFAULT_FN_ATTRS
 _mm512_maskz_sqrt_ps( __mmask16 __U, __m512 __A)
 {
-  return (__m512)__builtin_ia32_sqrtps512_mask((__v16sf)__A,
+  return (__m512)__builtin_ia32_sqrtps512_mask_nr ((__v16sf)__A,
                                                (__v16sf) _mm512_setzero_ps (),
-                                               (__mmask16) __U,
-                                               _MM_FROUND_CUR_DIRECTION);
+                                               (__mmask16) __U);
 }
 
 static  __inline__ __m512d __DEFAULT_FN_ATTRS
@@ -7102,11 +7096,10 @@
 static __inline__ __m128d __DEFAULT_FN_ATTRS
 _mm_mask_sqrt_sd (__m128d __W, __mmask8 __U, __m128d __A, __m128d __B)
 {
- return (__m128d) __builtin_ia32_sqrtsd_round_mask ( (__v2df) __A,
+ return (__m128d) __builtin_ia32_sqrtsd_mask ( (__v2df) __A,
                  (__v2df) __B,
                 (__v2df) __W,
-                (__mmask8) __U,
-                _MM_FROUND_CUR_DIRECTION);
+                (__mmask8) __U);
 }
 
 #define _mm_mask_sqrt_round_sd(W, U, A, B, R) __extension__ ({ \
@@ -7118,11 +7111,10 @@
 static __inline__ __m128d __DEFAULT_FN_ATTRS
 _mm_maskz_sqrt_sd (__mmask8 __U, __m128d __A, __m128d __B)
 {
- return (__m128d) __builtin_ia32_sqrtsd_round_mask ( (__v2df) __A,
+ return (__m128d) __builtin_ia32_sqrtsd_mask ( (__v2df) __A,
                  (__v2df) __B,
                 (__v2df) _mm_setzero_pd (),
-                (__mmask8) __U,
-                _MM_FROUND_CUR_DIRECTION);
+                (__mmask8) __U);
 }
 
 #define _mm_maskz_sqrt_round_sd(U, A, B, R) __extension__ ({ \
@@ -7140,11 +7132,10 @@
 static __inline__ __m128 __DEFAULT_FN_ATTRS
 _mm_mask_sqrt_ss (__m128 __W, __mmask8 __U, __m128 __A, __m128 __B)
 {
- return (__m128) __builtin_ia32_sqrtss_round_mask ( (__v4sf) __A,
+ return (__m128) __builtin_ia32_sqrtss_mask ( (__v4sf) __A,
                  (__v4sf) __B,
                 (__v4sf) __W,
-                (__mmask8) __U,
-                _MM_FROUND_CUR_DIRECTION);
+                (__mmask8) __U);
 }
 
 #define _mm_mask_sqrt_round_ss(W, U, A, B, R) __extension__ ({ \
@@ -7156,11 +7147,10 @@
 static __inline__ __m128 __DEFAULT_FN_ATTRS
 _mm_maskz_sqrt_ss (__mmask8 __U, __m128 __A, __m128 __B)
 {
- return (__m128) __builtin_ia32_sqrtss_round_mask ( (__v4sf) __A,
+ return (__m128) __builtin_ia32_sqrtss_mask ( (__v4sf) __A,
                  (__v4sf) __B,
                 (__v4sf) _mm_setzero_ps (),
-                (__mmask8) __U,
-                _MM_FROUND_CUR_DIRECTION);
+                (__mmask8) __U);
 }
 
 #define _mm_maskz_sqrt_round_ss(U, A, B, R) __extension__ ({ \
Index: lib/CodeGen/CGBuiltin.cpp
===================================================================
--- lib/CodeGen/CGBuiltin.cpp
+++ lib/CodeGen/CGBuiltin.cpp
@@ -8128,6 +8128,26 @@
                          Ops[1]);
   }
 
+  case X86::BI__builtin_ia32_sqrtsd_mask:
+  case X86::BI__builtin_ia32_sqrtss_mask: {
+    llvm::Value *C0 = llvm::ConstantInt::get(SizeTy, 0);
+    Value *A = Builder.CreateExtractElement(Ops[0], C0, "extract");
+    Function *F = CGM.getIntrinsic(Intrinsic::sqrt, A->getType());
+    Value *Src = Builder.CreateExtractElement(Ops[2], C0, "extract");
+    int MaskSize = Ops[3]->getType()->getScalarSizeInBits();
+    llvm::Type *MaskTy = llvm::VectorType::get(Builder.getInt1Ty(), MaskSize);
+    Value *Mask = Builder.CreateBitCast(Ops[3], MaskTy);
+    Mask = Builder.CreateExtractElement(Mask, C0, "extract");
+    A = Builder.CreateSelect(Mask, Builder.CreateCall(F, {A}), Src);
+    return Builder.CreateInsertElement(Ops[1], A, C0);
+  }
+  case X86::BI__builtin_ia32_sqrtpd512_mask_nr:
+  case X86::BI__builtin_ia32_sqrtps512_mask_nr: {
+    Function *F = CGM.getIntrinsic(Intrinsic::sqrt, Ops[0]->getType());
+    return EmitX86Select(*this, Ops[2], Builder.CreateCall(F, {Ops[0]}),
+                         Ops[1]);
+  }
+
   case X86::BI__builtin_ia32_pabsb128:
   case X86::BI__builtin_ia32_pabsw128:
   case X86::BI__builtin_ia32_pabsd128:
Index: include/clang/Basic/BuiltinsX86.def
===================================================================
--- include/clang/Basic/BuiltinsX86.def
+++ include/clang/Basic/BuiltinsX86.def
@@ -875,7 +875,9 @@
 
 // AVX-512
 TARGET_BUILTIN(__builtin_ia32_sqrtpd512_mask, "V8dV8dV8dUcIi", "", "avx512f")
+TARGET_BUILTIN(__builtin_ia32_sqrtpd512_mask_nr, "V8dV8dV8dUc", "", "avx512f")
 TARGET_BUILTIN(__builtin_ia32_sqrtps512_mask, "V16fV16fV16fUsIi", "", "avx512f")
+TARGET_BUILTIN(__builtin_ia32_sqrtps512_mask_nr, "V16fV16fV16fUs", "", "avx512f")
 TARGET_BUILTIN(__builtin_ia32_rsqrt14sd_mask, "V2dV2dV2dV2dUc", "", "avx512f")
 TARGET_BUILTIN(__builtin_ia32_rsqrt14ss_mask, "V4fV4fV4fV4fUc", "", "avx512f")
 TARGET_BUILTIN(__builtin_ia32_rsqrt14pd512_mask, "V8dV8dV8dUc", "", "avx512f")
@@ -1494,6 +1496,8 @@
 TARGET_BUILTIN(__builtin_ia32_shuf_i64x2_256_mask, "V4LLiV4LLiV4LLiIiV4LLiUc","","avx512vl")
 TARGET_BUILTIN(__builtin_ia32_sqrtsd_round_mask, "V2dV2dV2dV2dUcIi","","avx512f")
 TARGET_BUILTIN(__builtin_ia32_sqrtss_round_mask, "V4fV4fV4fV4fUcIi","","avx512f")
+TARGET_BUILTIN(__builtin_ia32_sqrtsd_mask, "V2dV2dV2dV2dUc","","avx512f")
+TARGET_BUILTIN(__builtin_ia32_sqrtss_mask, "V4fV4fV4fV4fUc","","avx512f")
 TARGET_BUILTIN(__builtin_ia32_rsqrt14pd128_mask, "V2dV2dV2dUc","","avx512vl")
 TARGET_BUILTIN(__builtin_ia32_rsqrt14pd256_mask, "V4dV4dV4dUc","","avx512vl")
 TARGET_BUILTIN(__builtin_ia32_rsqrt14ps128_mask, "V4fV4fV4fUc","","avx512vl")
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to