kushanam updated this revision to Diff 525896.
kushanam added a comment.
removing commented td entry
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D144911/new/
https://reviews.llvm.org/D144911
Files:
clang/include/clang/Basic/BuiltinsNVPTX.def
llvm/include/llvm/IR/IntrinsicsNVVM.td
llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp
llvm/lib/Target/NVPTX/NVPTXMCExpr.h
llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
llvm/lib/Target/NVPTX/NVPTXSubtarget.h
llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
llvm/test/CodeGen/NVPTX/bf16-instructions.ll
Index: llvm/test/CodeGen/NVPTX/bf16-instructions.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/NVPTX/bf16-instructions.ll
@@ -0,0 +1,88 @@
+; RUN: llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | FileCheck %s
+; RUN: %if ptxas-11.0 %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=+ptx70 | %ptxas-verify -arch=sm_80 %}
+
+
+; CHECK-LABEL: test_fadd(
+; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fadd_param_0];
+; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fadd_param_1];
+; CHECK-NEXT: add.rn.bf16 [[R:%f[0-9]+]], [[A]], [[B]];
+; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+
+define bfloat @test_fadd(bfloat %0, bfloat %1) {
+ %3 = fadd bfloat %0, %1
+ ret bfloat %3
+}
+
+; CHECK-LABEL: test_fsub(
+; CHECK-DAG: ld.param.b16 [[A:%h[0-9]+]], [test_fsub_param_0];
+; CHECK-DAG: ld.param.b16 [[B:%h[0-9]+]], [test_fsub_param_1];
+; CHECK-NEXT: sub.rn.bf16 [[R:%f[0-9]+]], [[A]], [[B]];
+; CHECK-NEXT: st.param.b16 [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+
+define bfloat @test_fsub(bfloat %0, bfloat %1) {
+ %3 = fsub bfloat %0, %1
+ ret bfloat %3
+}
+
+; CHECK-LABEL: test_faddx2(
+; CHECK-DAG: ld.param.b32 [[A:%hh[0-9]+]], [test_faddx2_param_0];
+; CHECK-DAG: ld.param.b32 [[B:%hh[0-9]+]], [test_faddx2_param_1];
+; CHECK-NEXT: add.rn.bf16x2 [[R:%f[0-9]+]], [[A]], [[B]];
+
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+
+define <2 x bfloat> @test_faddx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+ %r = fadd <2 x bfloat> %a, %b
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_fsubx2(
+; CHECK-DAG: ld.param.b32 [[A:%hh[0-9]+]], [test_fsubx2_param_0];
+; CHECK-DAG: ld.param.b32 [[B:%hh[0-9]+]], [test_fsubx2_param_1];
+; CHECK-NEXT: sub.rn.bf16x2 [[R:%f[0-9]+]], [[A]], [[B]];
+
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+
+define <2 x bfloat> @test_fsubx2(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+ %r = fsub <2 x bfloat> %a, %b
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_fmulx2(
+; CHECK-DAG: ld.param.b32 [[A:%hh[0-9]+]], [test_fmulx2_param_0];
+; CHECK-DAG: ld.param.b32 [[B:%hh[0-9]+]], [test_fmulx2_param_1];
+; CHECK-NEXT: mul.rn.bf16x2 [[R:%f[0-9]+]], [[A]], [[B]];
+
+; CHECK: st.param.b32 [func_retval0+0], [[R]];
+; CHECK: ret;
+
+define <2 x bfloat> @test_fmul(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+ %r = fmul <2 x bfloat> %a, %b
+ ret <2 x bfloat> %r
+}
+
+; CHECK-LABEL: test_fdiv(
+; CHECK-DAG: ld.param.b32 [[A:%hh[0-9]+]], [test_fdiv_param_0];
+; CHECK-DAG: ld.param.b32 [[B:%hh[0-9]+]], [test_fdiv_param_1];
+; CHECK-DAG: mov.b32 {[[A0:%h[0-9]+]], [[A1:%h[0-9]+]]}, [[A]]
+; CHECK-DAG: mov.b32 {[[B0:%h[0-9]+]], [[B1:%h[0-9]+]]}, [[B]]
+; CHECK-DAG: cvt.f32.bf16 [[FA0:%f[0-9]+]], [[A0]];
+; CHECK-DAG: cvt.f32.bf16 [[FA1:%f[0-9]+]], [[A1]];
+; CHECK-DAG: cvt.f32.bf16 [[FB0:%f[0-9]+]], [[B0]];
+; CHECK-DAG: cvt.f32.bf16 [[FB1:%f[0-9]+]], [[B1]];
+; CHECK-DAG: div.rn.f32 [[FR0:%f[0-9]+]], [[FA0]], [[FB0]];
+; CHECK-DAG: div.rn.f32 [[FR1:%f[0-9]+]], [[FA1]], [[FB1]];
+; CHECK-DAG: cvt.rn.bf16.f32 [[R0:%h[0-9]+]], [[FR0]];
+; CHECK-DAG: cvt.rn.bf16.f32 [[R1:%h[0-9]+]], [[FR1]];
+; CHECK-NEXT: mov.b32 [[R:%hh[0-9]+]], {[[R0]], [[R1]]}
+; CHECK-NEXT: st.param.b32 [func_retval0+0], [[R]];
+; CHECK-NEXT: ret;
+
+define <2 x bfloat> @test_fdiv(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
+ %r = fdiv <2 x bfloat> %a, %b
+ ret <2 x bfloat> %r
+}
Index: llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -204,6 +204,14 @@
return {Intrinsic::fma, FTZ_MustBeOff, true};
case Intrinsic::nvvm_fma_rn_ftz_f16x2:
return {Intrinsic::fma, FTZ_MustBeOn, true};
+ case Intrinsic::nvvm_fma_rn_bf16:
+ return {Intrinsic::fma, FTZ_MustBeOff, true};
+ case Intrinsic::nvvm_fma_rn_ftz_bf16:
+ return {Intrinsic::fma, FTZ_MustBeOn, true};
+ case Intrinsic::nvvm_fma_rn_bf16x2:
+ return {Intrinsic::fma, FTZ_MustBeOff, true};
+ case Intrinsic::nvvm_fma_rn_ftz_bf16x2:
+ return {Intrinsic::fma, FTZ_MustBeOn, true};
case Intrinsic::nvvm_fmax_d:
return {Intrinsic::maxnum, FTZ_Any};
case Intrinsic::nvvm_fmax_f:
Index: llvm/lib/Target/NVPTX/NVPTXSubtarget.h
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXSubtarget.h
+++ llvm/lib/Target/NVPTX/NVPTXSubtarget.h
@@ -76,7 +76,9 @@
inline bool hasHWROT32() const { return SmVersion >= 32; }
bool hasImageHandles() const;
bool hasFP16Math() const { return SmVersion >= 53; }
+ bool hasBF16Math() const { return SmVersion >= 80; }
bool allowFP16Math() const;
+ bool allowBF16Math() const;
bool hasMaskOperator() const { return PTXVersion >= 71; }
bool hasNoReturn() const { return SmVersion >= 30 && PTXVersion >= 64; }
unsigned int getSmVersion() const { return SmVersion; }
Index: llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
+++ llvm/lib/Target/NVPTX/NVPTXSubtarget.cpp
@@ -26,7 +26,10 @@
NoF16Math("nvptx-no-f16-math", cl::Hidden,
cl::desc("NVPTX Specific: Disable generation of f16 math ops."),
cl::init(false));
-
+static cl::opt<bool>
+ NoBF16Math("nvptx-no-bf16-math", cl::Hidden,
+ cl::desc("NVPTX Specific: Disable generation of bf16 math ops."),
+ cl::init(false));
// Pin the vtable to this file.
void NVPTXSubtarget::anchor() {}
@@ -65,3 +68,7 @@
bool NVPTXSubtarget::allowFP16Math() const {
return hasFP16Math() && NoF16Math == false;
}
+
+bool NVPTXSubtarget::allowBF16Math() const {
+ return hasBF16Math() && NoBF16Math == false;
+}
\ No newline at end of file
Index: llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
+++ llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td
@@ -60,8 +60,10 @@
def Int16Regs : NVPTXRegClass<[i16], 16, (add (sequence "RS%u", 0, 4))>;
def Int32Regs : NVPTXRegClass<[i32], 32, (add (sequence "R%u", 0, 4), VRFrame32, VRFrameLocal32)>;
def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
-def Float16Regs : NVPTXRegClass<[f16,bf16], 16, (add (sequence "H%u", 0, 4))>;
-def Float16x2Regs : NVPTXRegClass<[v2f16,v2bf16], 32, (add (sequence "HH%u", 0, 4))>;
+def Float16Regs : NVPTXRegClass<[f16], 16, (add (sequence "H%u", 0, 4))>;
+def Float16x2Regs : NVPTXRegClass<[v2f16], 32, (add (sequence "HH%u", 0, 4))>;
+def BFloat16Regs : NVPTXRegClass<[bf16], 16, (add (sequence "H%u", 0, 4))>;
+def BFloat16x2Regs : NVPTXRegClass<[v2bf16], 32, (add (sequence "HH%u", 0, 4))>;
def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>;
def Float64Regs : NVPTXRegClass<[f64], 64, (add (sequence "FL%u", 0, 4))>;
def Int32ArgRegs : NVPTXRegClass<[i32], 32, (add (sequence "ia%u", 0, 4))>;
Index: llvm/lib/Target/NVPTX/NVPTXMCExpr.h
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXMCExpr.h
+++ llvm/lib/Target/NVPTX/NVPTXMCExpr.h
@@ -21,6 +21,7 @@
public:
enum VariantKind {
VK_NVPTX_None,
+ VK_NVPTX_BFLOAT_PREC_FLOAT, // FP constant in bfloat-precision
VK_NVPTX_HALF_PREC_FLOAT, // FP constant in half-precision
VK_NVPTX_SINGLE_PREC_FLOAT, // FP constant in single-precision
VK_NVPTX_DOUBLE_PREC_FLOAT // FP constant in double-precision
@@ -40,6 +41,11 @@
static const NVPTXFloatMCExpr *create(VariantKind Kind, const APFloat &Flt,
MCContext &Ctx);
+ static const NVPTXFloatMCExpr *createConstantBFPHalf(const APFloat &Flt,
+ MCContext &Ctx) {
+ return create(VK_NVPTX_BFLOAT_PREC_FLOAT, Flt, Ctx);
+ }
+
static const NVPTXFloatMCExpr *createConstantFPHalf(const APFloat &Flt,
MCContext &Ctx) {
return create(VK_NVPTX_HALF_PREC_FLOAT, Flt, Ctx);
Index: llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp
+++ llvm/lib/Target/NVPTX/NVPTXMCExpr.cpp
@@ -34,6 +34,11 @@
NumHex = 4;
APF.convert(APFloat::IEEEhalf(), APFloat::rmNearestTiesToEven, &Ignored);
break;
+ case VK_NVPTX_BFLOAT_PREC_FLOAT:
+ OS << "0x";
+ NumHex = 4;
+ APF.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, &Ignored);
+ break;
case VK_NVPTX_SINGLE_PREC_FLOAT:
OS << "0f";
NumHex = 8;
Index: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -973,6 +973,18 @@
FMA_TUPLE<"_rn_ftz_relu_f16", int_nvvm_fma_rn_ftz_relu_f16, Float16Regs,
[hasPTX70, hasSM80]>,
+ FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, BFloat16Regs, [hasPTX70, hasSM80]>,
+ FMA_TUPLE<"_rn_ftz_bf16", int_nvvm_fma_rn_ftz_bf16, BFloat16Regs,
+ [hasPTX70, hasSM80]>,
+ FMA_TUPLE<"_rn_sat_bf16", int_nvvm_fma_rn_sat_bf16, BFloat16Regs,
+ [hasPTX70, hasSM80]>,
+ FMA_TUPLE<"_rn_ftz_sat_bf16", int_nvvm_fma_rn_ftz_sat_bf16, BFloat16Regs,
+ [hasPTX70, hasSM80]>,
+ FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, BFloat16Regs,
+ [hasPTX70, hasSM80]>,
+ FMA_TUPLE<"_rn_ftz_relu_bf16", int_nvvm_fma_rn_ftz_relu_bf16, BFloat16Regs,
+ [hasPTX70, hasSM80]>,
+
FMA_TUPLE<"_rn_f16x2", int_nvvm_fma_rn_f16x2, Float16x2Regs,
[hasPTX42, hasSM53]>,
FMA_TUPLE<"_rn_ftz_f16x2", int_nvvm_fma_rn_ftz_f16x2, Float16x2Regs,
@@ -986,13 +998,9 @@
FMA_TUPLE<"_rn_ftz_relu_f16x2", int_nvvm_fma_rn_ftz_relu_f16x2,
Float16x2Regs, [hasPTX70, hasSM80]>,
- FMA_TUPLE<"_rn_bf16", int_nvvm_fma_rn_bf16, Int16Regs, [hasPTX70, hasSM80]>,
- FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, Int16Regs,
- [hasPTX70, hasSM80]>,
-
- FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, Int32Regs,
+ FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, BFloat16x2Regs,
[hasPTX70, hasSM80]>,
- FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, Int32Regs,
+ FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, BFloat16x2Regs,
[hasPTX70, hasSM80]>
] in {
def P.Variant :
@@ -1243,24 +1251,6 @@
def : Pat<(int_nvvm_ff2bf16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
(CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>;
-def : Pat<(int_nvvm_ff2f16x2_rn Float32Regs:$a, Float32Regs:$b),
- (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
-def : Pat<(int_nvvm_ff2f16x2_rn_relu Float32Regs:$a, Float32Regs:$b),
- (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
-def : Pat<(int_nvvm_ff2f16x2_rz Float32Regs:$a, Float32Regs:$b),
- (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>;
-def : Pat<(int_nvvm_ff2f16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
- (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>;
-
-def : Pat<(int_nvvm_f2bf16_rn Float32Regs:$a),
- (CVT_bf16_f32 Float32Regs:$a, CvtRN)>;
-def : Pat<(int_nvvm_f2bf16_rn_relu Float32Regs:$a),
- (CVT_bf16_f32 Float32Regs:$a, CvtRN_RELU)>;
-def : Pat<(int_nvvm_f2bf16_rz Float32Regs:$a),
- (CVT_bf16_f32 Float32Regs:$a, CvtRZ)>;
-def : Pat<(int_nvvm_f2bf16_rz_relu Float32Regs:$a),
- (CVT_bf16_f32 Float32Regs:$a, CvtRZ_RELU)>;
-
def CVT_tf32_f32 :
NVPTXInst<(outs Int32Regs:$dest), (ins Float32Regs:$a),
"cvt.rna.tf32.f32 \t$dest, $a;",
@@ -2136,6 +2126,8 @@
defm INT_PTX_LDU_GLOBAL_i64 : LDU_G<"u64 \t$result, [$src];", Int64Regs>;
defm INT_PTX_LDU_GLOBAL_f16 : LDU_G<"b16 \t$result, [$src];", Float16Regs>;
defm INT_PTX_LDU_GLOBAL_f16x2 : LDU_G<"b32 \t$result, [$src];", Float16x2Regs>;
+defm INT_PTX_LDU_GLOBAL_bf16 : LDU_G<"b16 \t$result, [$src];", BFloat16Regs>;
+defm INT_PTX_LDU_GLOBAL_bf16x2 : LDU_G<"b32 \t$result, [$src];", BFloat16x2Regs>;
defm INT_PTX_LDU_GLOBAL_f32 : LDU_G<"f32 \t$result, [$src];", Float32Regs>;
defm INT_PTX_LDU_GLOBAL_f64 : LDU_G<"f64 \t$result, [$src];", Float64Regs>;
defm INT_PTX_LDU_GLOBAL_p32 : LDU_G<"u32 \t$result, [$src];", Int32Regs>;
@@ -2190,6 +2182,10 @@
: VLDU_G_ELE_V2<"v2.b16 \t{{$dst1, $dst2}}, [$src];", Float16Regs>;
defm INT_PTX_LDU_G_v2f16x2_ELE
: VLDU_G_ELE_V2<"v2.b32 \t{{$dst1, $dst2}}, [$src];", Float16x2Regs>;
+defm INT_PTX_LDU_G_v2bf16_ELE
+ : VLDU_G_ELE_V2<"v2.b16 \t{{$dst1, $dst2}}, [$src];", BFloat16Regs>;
+defm INT_PTX_LDU_G_v2bf16x2_ELE
+ : VLDU_G_ELE_V2<"v2.b32 \t{{$dst1, $dst2}}, [$src];", BFloat16x2Regs>;
defm INT_PTX_LDU_G_v2f32_ELE
: VLDU_G_ELE_V2<"v2.f32 \t{{$dst1, $dst2}}, [$src];", Float32Regs>;
defm INT_PTX_LDU_G_v2i64_ELE
@@ -2253,6 +2249,10 @@
: LDG_G<"b16 \t$result, [$src];", Float16Regs>;
defm INT_PTX_LDG_GLOBAL_f16x2
: LDG_G<"b32 \t$result, [$src];", Float16x2Regs>;
+defm INT_PTX_LDG_GLOBAL_bf16
+ : LDG_G<"b16 \t$result, [$src];", BFloat16Regs>;
+defm INT_PTX_LDG_GLOBAL_bf16x2
+ : LDG_G<"b32 \t$result, [$src];", BFloat16x2Regs>;
defm INT_PTX_LDG_GLOBAL_f32
: LDG_G<"f32 \t$result, [$src];", Float32Regs>;
defm INT_PTX_LDG_GLOBAL_f64
Index: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -19,6 +19,8 @@
let OperandType = "OPERAND_IMMEDIATE" in {
def f16imm : Operand<f16>;
+ def bf16imm : Operand<bf16>;
+
}
// List of vector specific properties
@@ -172,6 +174,7 @@
def useShortPtr : Predicate<"useShortPointers()">;
def useFP16Math: Predicate<"Subtarget->allowFP16Math()">;
+def useBFP16Math: Predicate<"Subtarget->allowBF16Math()">;
// Helper class to aid conversion between ValueType and a matching RegisterClass.
@@ -184,8 +187,8 @@
!eq(name, "i64"): Int64Regs,
!eq(name, "f16"): Float16Regs,
!eq(name, "v2f16"): Float16x2Regs,
- !eq(name, "bf16"): Float16Regs,
- !eq(name, "v2bf16"): Float16x2Regs,
+ !eq(name, "bf16"): BFloat16Regs,
+ !eq(name, "v2bf16"): BFloat16x2Regs,
!eq(name, "f32"): Float32Regs,
!eq(name, "f64"): Float64Regs,
!eq(name, "ai32"): Int32ArgRegs,
@@ -322,6 +325,31 @@
!strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
[(set Float16x2Regs:$dst, (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>,
Requires<[useFP16Math]>;
+ def bf16rr_ftz :
+ NVPTXInst<(outs BFloat16Regs:$dst),
+ (ins BFloat16Regs:$a, BFloat16Regs:$b),
+ !strconcat(OpcStr, ".ftz.bf16 \t$dst, $a, $b;"),
+ [(set BFloat16Regs:$dst, (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>,
+ Requires<[useBFP16Math, doF32FTZ]>;
+ def bf16rr :
+ NVPTXInst<(outs BFloat16Regs:$dst),
+ (ins BFloat16Regs:$a, BFloat16Regs:$b),
+ !strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"),
+ [(set BFloat16Regs:$dst, (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>,
+ Requires<[useBFP16Math]>;
+
+ def bf16x2rr_ftz :
+ NVPTXInst<(outs BFloat16x2Regs:$dst),
+ (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b),
+ !strconcat(OpcStr, ".ftz.bf16x2 \t$dst, $a, $b;"),
+ [(set BFloat16x2Regs:$dst, (OpNode (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>,
+ Requires<[useBFP16Math, doF32FTZ]>;
+ def bf16x2rr :
+ NVPTXInst<(outs BFloat16x2Regs:$dst),
+ (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b),
+ !strconcat(OpcStr, ".bf16x2 \t$dst, $a, $b;"),
+ [(set BFloat16x2Regs:$dst, (OpNode (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>,
+ Requires<[useBFP16Math]>;
}
// Template for instructions which take three FP args. The
@@ -396,7 +424,31 @@
!strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
[(set Float16x2Regs:$dst, (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>,
Requires<[useFP16Math, allowFMA]>;
-
+ def bf16rr_ftz :
+ NVPTXInst<(outs BFloat16Regs:$dst),
+ (ins BFloat16Regs:$a, BFloat16Regs:$b),
+ !strconcat(OpcStr, ".ftz.bf16 \t$dst, $a, $b;"),
+ [(set BFloat16Regs:$dst, (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>,
+ Requires<[useBFP16Math, allowFMA, doF32FTZ]>;
+ def bf16rr :
+ NVPTXInst<(outs BFloat16Regs:$dst),
+ (ins BFloat16Regs:$a, BFloat16Regs:$b),
+ !strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"),
+ [(set BFloat16Regs:$dst, (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>,
+ Requires<[useBFP16Math, allowFMA]>;
+
+ def bf16x2rr_ftz :
+ NVPTXInst<(outs BFloat16x2Regs:$dst),
+ (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b),
+ !strconcat(OpcStr, ".ftz.bf16x2 \t$dst, $a, $b;"),
+ [(set (v2bf16 BFloat16x2Regs:$dst), (OpNode (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>,
+ Requires<[useBFP16Math, allowFMA, doF32FTZ]>;
+ def bf16x2rr :
+ NVPTXInst<(outs BFloat16x2Regs:$dst),
+ (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b),
+ !strconcat(OpcStr, ".bf16x2 \t$dst, $a, $b;"),
+ [(set BFloat16x2Regs:$dst, (OpNode (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>,
+ Requires<[useBFP16Math, allowFMA]>;
// These have strange names so we don't perturb existing mir tests.
def _rnf64rr :
NVPTXInst<(outs Float64Regs:$dst),
@@ -458,6 +510,30 @@
!strconcat(OpcStr, ".rn.f16x2 \t$dst, $a, $b;"),
[(set Float16x2Regs:$dst, (OpNode (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>,
Requires<[useFP16Math, noFMA]>;
+ def _rnbf16rr_ftz :
+ NVPTXInst<(outs BFloat16Regs:$dst),
+ (ins BFloat16Regs:$a, BFloat16Regs:$b),
+ !strconcat(OpcStr, ".rn.ftz.bf16 \t$dst, $a, $b;"),
+ [(set BFloat16Regs:$dst, (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>,
+ Requires<[useBFP16Math, noFMA, doF32FTZ]>;
+ def _rnbf16rr :
+ NVPTXInst<(outs BFloat16Regs:$dst),
+ (ins BFloat16Regs:$a, BFloat16Regs:$b),
+ !strconcat(OpcStr, ".rn.bf16 \t$dst, $a, $b;"),
+ [(set BFloat16Regs:$dst, (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>,
+ Requires<[useBFP16Math, noFMA]>;
+ def _rnbf16x2rr_ftz :
+ NVPTXInst<(outs BFloat16x2Regs:$dst),
+ (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b),
+ !strconcat(OpcStr, ".rn.ftz.bf16x2 \t$dst, $a, $b;"),
+ [(set BFloat16x2Regs:$dst, (OpNode (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>,
+ Requires<[useBFP16Math, noFMA, doF32FTZ]>;
+ def _rnbf16x2rr :
+ NVPTXInst<(outs BFloat16x2Regs:$dst),
+ (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b),
+ !strconcat(OpcStr, ".rn.bf16x2 \t$dst, $a, $b;"),
+ [(set BFloat16x2Regs:$dst, (OpNode (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>,
+ Requires<[useBFP16Math, noFMA]>;
}
// Template for operations which take two f32 or f64 operands. Provides three
@@ -534,6 +610,11 @@
(ins Float16Regs:$src, CvtMode:$mode),
!strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
FromName, ".f16 \t$dst, $src;"), []>;
+ def _bf16 :
+ NVPTXInst<(outs RC:$dst),
+ (ins BFloat16Regs:$src, CvtMode:$mode),
+ !strconcat("cvt${mode:base}${mode:ftz}${mode:sat}.",
+ FromName, ".bf16 \t$dst, $src;"), []>;
def _f32 :
NVPTXInst<(outs RC:$dst),
(ins Float32Regs:$src, CvtMode:$mode),
@@ -556,6 +637,7 @@
defm CVT_s64 : CVT_FROM_ALL<"s64", Int64Regs>;
defm CVT_u64 : CVT_FROM_ALL<"u64", Int64Regs>;
defm CVT_f16 : CVT_FROM_ALL<"f16", Float16Regs>;
+ defm CVT_bf16 : CVT_FROM_ALL<"bf16", BFloat16Regs>;
defm CVT_f32 : CVT_FROM_ALL<"f32", Float32Regs>;
defm CVT_f64 : CVT_FROM_ALL<"f64", Float64Regs>;
@@ -574,18 +656,7 @@
def CVT_INREG_s64_s32 : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src),
"cvt.s64.s32 \t$dst, $src;", []>;
-multiclass CVT_FROM_FLOAT_SM80<string FromName, RegisterClass RC> {
- def _f32 :
- NVPTXInst<(outs RC:$dst),
- (ins Float32Regs:$src, CvtMode:$mode),
- !strconcat("cvt${mode:base}${mode:relu}.",
- FromName, ".f32 \t$dst, $src;"), []>,
- Requires<[hasPTX70, hasSM80]>;
- }
-
- defm CVT_bf16 : CVT_FROM_FLOAT_SM80<"bf16", Int16Regs>;
-
- multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> {
+ multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> {
def _f32 :
NVPTXInst<(outs RC:$dst),
(ins Float32Regs:$src1, Float32Regs:$src2, CvtMode:$mode),
@@ -594,7 +665,7 @@
Requires<[hasPTX70, hasSM80]>;
}
- defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", Float16x2Regs>;
+ defm CVT_f16x2 : CVT_FROM_FLOAT_V2_SM80<"f16x2", BFloat16x2Regs>;
defm CVT_bf16x2 : CVT_FROM_FLOAT_V2_SM80<"bf16x2", Int32Regs>;
}
@@ -659,7 +730,7 @@
defm SELP_s64 : SELP<"s64", Int64Regs, i64imm>;
defm SELP_u64 : SELP<"u64", Int64Regs, i64imm>;
defm SELP_f16 : SELP_PATTERN<"b16", f16, Float16Regs, f16imm, fpimm>;
-
+defm SELP_bf16 : SELP_PATTERN<"b16", bf16, BFloat16Regs, bf16imm, fpimm>;
defm SELP_f32 : SELP_PATTERN<"f32", f32, Float32Regs, f32imm, fpimm>;
defm SELP_f64 : SELP_PATTERN<"f64", f64, Float64Regs, f64imm, fpimm>;
@@ -1023,7 +1094,9 @@
def LOAD_CONST_F16 :
NVPTXInst<(outs Float16Regs:$dst), (ins f16imm:$a),
"mov.b16 \t$dst, $a;", []>;
-
+def LOAD_CONST_BF16 :
+ NVPTXInst<(outs BFloat16Regs:$dst), (ins bf16imm:$a),
+ "mov.b16 \t$dst, $a;", []>;
defm FADD : F3_fma_component<"add", fadd>;
defm FSUB : F3_fma_component<"sub", fsub>;
defm FMUL : F3_fma_component<"mul", fmul>;
@@ -1051,6 +1124,20 @@
def FNEG16x2_ftz : FNEG_F16_F16X2<"neg.ftz.f16x2", v2f16, Float16x2Regs, doF32FTZ>;
def FNEG16x2 : FNEG_F16_F16X2<"neg.f16x2", v2f16, Float16x2Regs, True>;
+//
+// BF16 NEG
+//
+
+class FNEG_BF16_F16X2<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> :
+ NVPTXInst<(outs RC:$dst), (ins RC:$src),
+ !strconcat(OpcStr, " \t$dst, $src;"),
+ [(set RC:$dst, (fneg (T RC:$src)))]>,
+ Requires<[useFP16Math, hasPTX70, hasSM80, Pred]>;
+def BFNEG16_ftz : FNEG_BF16_F16X2<"neg.ftz.bf16", bf16, BFloat16Regs, doF32FTZ>;
+def BFNEG16 : FNEG_BF16_F16X2<"neg.bf16", bf16, BFloat16Regs, True>;
+def BFNEG16x2_ftz : FNEG_BF16_F16X2<"neg.ftz.bf16x2", v2bf16, BFloat16x2Regs, doF32FTZ>;
+def BFNEG16x2 : FNEG_BF16_F16X2<"neg.bf16x2", v2bf16, BFloat16x2Regs, True>;
+
//
// F64 division
//
@@ -1229,10 +1316,21 @@
Requires<[useFP16Math, Pred]>;
}
+multiclass FMA_BF16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> {
+ def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
+ !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
+ [(set RC:$dst, (fma (T RC:$a), (T RC:$b), (T RC:$c)))]>,
+ Requires<[useBFP16Math, Pred]>;
+}
+
defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Float16Regs, doF32FTZ>;
defm FMA16 : FMA_F16<"fma.rn.f16", f16, Float16Regs, True>;
defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Float16x2Regs, doF32FTZ>;
defm FMA16x2 : FMA_F16<"fma.rn.f16x2", v2f16, Float16x2Regs, True>;
+defm BFMA16_ftz : FMA_BF16<"fma.rn.ftz.bf16", bf16, BFloat16Regs, doF32FTZ>;
+defm BFMA16 : FMA_BF16<"fma.rn.bf16", bf16, BFloat16Regs, True>;
+defm BFMA16x2_ftz : FMA_BF16<"fma.rn.ftz.bf16x2", v2bf16, BFloat16x2Regs, doF32FTZ>;
+defm BFMA16x2 : FMA_BF16<"fma.rn.bf16x2", v2bf16, BFloat16x2Regs, True>;
defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True>;
@@ -1679,6 +1777,18 @@
"setp${cmp:base}${cmp:ftz}.f16x2 \t$p|$q, $a, $b;",
[]>,
Requires<[useFP16Math]>;
+def SETP_bf16rr :
+ NVPTXInst<(outs Int1Regs:$dst),
+ (ins BFloat16Regs:$a, BFloat16Regs:$b, CmpMode:$cmp),
+ "setp${cmp:base}${cmp:ftz}.bf16 \t$dst, $a, $b;",
+ []>, Requires<[useBFP16Math]>;
+
+def SETP_bf16x2rr :
+ NVPTXInst<(outs Int1Regs:$p, Int1Regs:$q),
+ (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b, CmpMode:$cmp),
+ "setp${cmp:base}${cmp:ftz}.bf16x2 \t$p|$q, $a, $b;",
+ []>,
+ Requires<[useBFP16Math]>;
// FIXME: This doesn't appear to be correct. The "set" mnemonic has the form
@@ -1709,6 +1819,7 @@
defm SET_s64 : SET<"s64", Int64Regs, i64imm>;
defm SET_u64 : SET<"u64", Int64Regs, i64imm>;
defm SET_f16 : SET<"f16", Float16Regs, f16imm>;
+defm SET_bf16 : SET<"bf16", BFloat16Regs, bf16imm>;
defm SET_f32 : SET<"f32", Float32Regs, f32imm>;
defm SET_f64 : SET<"f64", Float64Regs, f64imm>;
@@ -1781,6 +1892,8 @@
def FMOV16rr : NVPTXInst<(outs Float16Regs:$dst), (ins Float16Regs:$src),
// We have to use .b16 here as there's no mov.f16.
"mov.b16 \t$dst, $src;", []>;
+ def BFMOV16rr : NVPTXInst<(outs BFloat16Regs:$dst), (ins BFloat16Regs:$src),
+ "mov.b16 \t$dst, $src;", []>;
def FMOV32rr : NVPTXInst<(outs Float32Regs:$dst), (ins Float32Regs:$src),
"mov.f32 \t$dst, $src;", []>;
def FMOV64rr : NVPTXInst<(outs Float64Regs:$dst), (ins Float64Regs:$src),
@@ -1963,7 +2076,27 @@
(SETP_f16rr (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, Mode)>,
Requires<[useFP16Math]>;
- // f32 -> pred
+ // bf16 -> pred
+ def : Pat<(i1 (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b))),
+ (SETP_bf16rr BFloat16Regs:$a, BFloat16Regs:$b, ModeFTZ)>,
+ Requires<[useBFP16Math,doF32FTZ]>;
+ def : Pat<(i1 (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b))),
+ (SETP_bf16rr BFloat16Regs:$a, BFloat16Regs:$b, Mode)>,
+ Requires<[useBFP16Math]>;
+ def : Pat<(i1 (OpNode (bf16 BFloat16Regs:$a), fpimm:$b)),
+ (SETP_bf16rr BFloat16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), ModeFTZ)>,
+ Requires<[useBFP16Math,doF32FTZ]>;
+ def : Pat<(i1 (OpNode (bf16 BFloat16Regs:$a), fpimm:$b)),
+ (SETP_bf16rr BFloat16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), Mode)>,
+ Requires<[useBFP16Math]>;
+ def : Pat<(i1 (OpNode fpimm:$a, (bf16 BFloat16Regs:$b))),
+ (SETP_bf16rr (LOAD_CONST_BF16 fpimm:$a), BFloat16Regs:$b, ModeFTZ)>,
+ Requires<[useBFP16Math,doF32FTZ]>;
+ def : Pat<(i1 (OpNode fpimm:$a, (bf16 BFloat16Regs:$b))),
+ (SETP_bf16rr (LOAD_CONST_BF16 fpimm:$a), BFloat16Regs:$b, Mode)>,
+ Requires<[useBFP16Math]>;
+
+ //f32 -> pred
def : Pat<(i1 (OpNode Float32Regs:$a, Float32Regs:$b)),
(SETP_f32rr Float32Regs:$a, Float32Regs:$b, ModeFTZ)>,
Requires<[doF32FTZ]>;
@@ -2007,6 +2140,26 @@
def : Pat<(i32 (OpNode fpimm:$a, (f16 Float16Regs:$b))),
(SET_f16ir (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, Mode)>,
Requires<[useFP16Math]>;
+
+ // bf16 -> i32
+ def : Pat<(i32 (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b))),
+ (SET_bf16rr BFloat16Regs:$a, BFloat16Regs:$b, ModeFTZ)>,
+ Requires<[useBFP16Math, doF32FTZ]>;
+ def : Pat<(i32 (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b))),
+ (SET_bf16rr BFloat16Regs:$a, BFloat16Regs:$b, Mode)>,
+ Requires<[useBFP16Math]>;
+ def : Pat<(i32 (OpNode (bf16 BFloat16Regs:$a), fpimm:$b)),
+ (SET_bf16rr BFloat16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), ModeFTZ)>,
+ Requires<[useBFP16Math, doF32FTZ]>;
+ def : Pat<(i32 (OpNode (bf16 BFloat16Regs:$a), fpimm:$b)),
+ (SET_bf16rr BFloat16Regs:$a, (LOAD_CONST_BF16 fpimm:$b), Mode)>,
+ Requires<[useBFP16Math]>;
+ def : Pat<(i32 (OpNode fpimm:$a, (bf16 BFloat16Regs:$b))),
+ (SET_bf16ir (LOAD_CONST_BF16 fpimm:$a), BFloat16Regs:$b, ModeFTZ)>,
+ Requires<[useBFP16Math, doF32FTZ]>;
+ def : Pat<(i32 (OpNode fpimm:$a, (bf16 BFloat16Regs:$b))),
+ (SET_bf16ir (LOAD_CONST_BF16 fpimm:$a), BFloat16Regs:$b, Mode)>,
+ Requires<[useBFP16Math]>;
// f32 -> i32
def : Pat<(i32 (OpNode Float32Regs:$a, Float32Regs:$b)),
@@ -2296,10 +2449,14 @@
def LoadParamMemV4I8 : LoadParamV4MemInst<Int16Regs, ".b8">;
def LoadParamMemF16 : LoadParamMemInst<Float16Regs, ".b16">;
def LoadParamMemF16x2 : LoadParamMemInst<Float16x2Regs, ".b32">;
+def LoadParamMemBF16 : LoadParamMemInst<BFloat16Regs, ".b16">;
+def LoadParamMemBF16x2 : LoadParamMemInst<BFloat16x2Regs, ".b32">;
def LoadParamMemF32 : LoadParamMemInst<Float32Regs, ".f32">;
def LoadParamMemF64 : LoadParamMemInst<Float64Regs, ".f64">;
def LoadParamMemV2F16 : LoadParamV2MemInst<Float16Regs, ".b16">;
def LoadParamMemV2F16x2: LoadParamV2MemInst<Float16x2Regs, ".b32">;
+def LoadParamMemV2BF16 : LoadParamV2MemInst<BFloat16Regs, ".b16">;
+def LoadParamMemV2BF16x2: LoadParamV2MemInst<BFloat16x2Regs, ".b32">;
def LoadParamMemV2F32 : LoadParamV2MemInst<Float32Regs, ".f32">;
def LoadParamMemV2F64 : LoadParamV2MemInst<Float64Regs, ".f64">;
def LoadParamMemV4F16 : LoadParamV4MemInst<Float16Regs, ".b16">;
@@ -2322,6 +2479,10 @@
def StoreParamF16 : StoreParamInst<Float16Regs, ".b16">;
def StoreParamF16x2 : StoreParamInst<Float16x2Regs, ".b32">;
+
+def StoreParamBF16 : StoreParamInst<BFloat16Regs, ".b16">;
+def StoreParamBF16x2 : StoreParamInst<BFloat16x2Regs, ".b32">;
+
def StoreParamF32 : StoreParamInst<Float32Regs, ".f32">;
def StoreParamF64 : StoreParamInst<Float64Regs, ".f64">;
def StoreParamV2F16 : StoreParamV2Inst<Float16Regs, ".b16">;
@@ -2348,6 +2509,8 @@
def StoreRetvalF32 : StoreRetvalInst<Float32Regs, ".f32">;
def StoreRetvalF16 : StoreRetvalInst<Float16Regs, ".b16">;
def StoreRetvalF16x2 : StoreRetvalInst<Float16x2Regs, ".b32">;
+def StoreRetvalBF16 : StoreRetvalInst<BFloat16Regs, ".b16">;
+def StoreRetvalBF16x2 : StoreRetvalInst<BFloat16x2Regs, ".b32">;
def StoreRetvalV2F64 : StoreRetvalV2Inst<Float64Regs, ".f64">;
def StoreRetvalV2F32 : StoreRetvalV2Inst<Float32Regs, ".f32">;
def StoreRetvalV2F16 : StoreRetvalV2Inst<Float16Regs, ".b16">;
@@ -2450,6 +2613,7 @@
def MoveParamF64 : MoveParamInst<f64, Float64Regs, ".f64">;
def MoveParamF32 : MoveParamInst<f32, Float32Regs, ".f32">;
def MoveParamF16 : MoveParamInst<f16, Float16Regs, ".f16">;
+def MoveParamBF16 : MoveParamInst<bf16, BFloat16Regs, ".bf16">;
class PseudoUseParamInst<NVPTXRegClass regclass> :
NVPTXInst<(outs), (ins regclass:$src),
@@ -2473,11 +2637,11 @@
def ProxyRegI32 : ProxyRegInst<"b32", i32, Int32Regs>;
def ProxyRegI64 : ProxyRegInst<"b64", i64, Int64Regs>;
def ProxyRegF16 : ProxyRegInst<"b16", f16, Float16Regs>;
- def ProxyRegBF16 : ProxyRegInst<"b16", bf16, Float16Regs>;
+ def ProxyRegBF16 : ProxyRegInst<"b16", bf16, BFloat16Regs>;
def ProxyRegF32 : ProxyRegInst<"f32", f32, Float32Regs>;
def ProxyRegF64 : ProxyRegInst<"f64", f64, Float64Regs>;
def ProxyRegF16x2 : ProxyRegInst<"b32", v2f16, Float16x2Regs>;
- def ProxyRegBF16x2 : ProxyRegInst<"b32", v2bf16, Float16x2Regs>;
+ def ProxyRegBF16x2 : ProxyRegInst<"b32", v2bf16, BFloat16x2Regs>;
}
//
@@ -2578,7 +2742,9 @@
defm ST_i32 : ST<Int32Regs>;
defm ST_i64 : ST<Int64Regs>;
defm ST_f16 : ST<Float16Regs>;
+ defm ST_bf16 : ST<BFloat16Regs>;
defm ST_f16x2 : ST<Float16x2Regs>;
+ defm ST_bf16x2 : ST<BFloat16x2Regs>;
defm ST_f32 : ST<Float32Regs>;
defm ST_f64 : ST<Float64Regs>;
}
@@ -2667,6 +2833,8 @@
defm LDV_i64 : LD_VEC<Int64Regs>;
defm LDV_f16 : LD_VEC<Float16Regs>;
defm LDV_f16x2 : LD_VEC<Float16x2Regs>;
+ defm LDV_bf16 : LD_VEC<BFloat16Regs>;
+ defm LDV_bf16x2 : LD_VEC<BFloat16x2Regs>;
defm LDV_f32 : LD_VEC<Float32Regs>;
defm LDV_f64 : LD_VEC<Float64Regs>;
}
@@ -2762,6 +2930,8 @@
defm STV_i64 : ST_VEC<Int64Regs>;
defm STV_f16 : ST_VEC<Float16Regs>;
defm STV_f16x2 : ST_VEC<Float16x2Regs>;
+ defm STV_bf16 : ST_VEC<BFloat16Regs>;
+ defm STV_bf16x2 : ST_VEC<BFloat16x2Regs>;
defm STV_f32 : ST_VEC<Float32Regs>;
defm STV_f64 : ST_VEC<Float64Regs>;
}
@@ -2816,6 +2986,26 @@
def : Pat<(f16 (uint_to_fp Int64Regs:$a)),
(CVT_f16_u64 Int64Regs:$a, CvtRN)>;
+// sint -> bf16
+def : Pat<(bf16 (sint_to_fp Int1Regs:$a)),
+ (CVT_bf16_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
+def : Pat<(bf16 (sint_to_fp Int16Regs:$a)),
+ (CVT_bf16_s16 Int16Regs:$a, CvtRN)>;
+def : Pat<(bf16 (sint_to_fp Int32Regs:$a)),
+ (CVT_bf16_s32 Int32Regs:$a, CvtRN)>;
+def : Pat<(bf16 (sint_to_fp Int64Regs:$a)),
+ (CVT_bf16_s64 Int64Regs:$a, CvtRN)>;
+
+// uint -> bf16
+def : Pat<(bf16 (uint_to_fp Int1Regs:$a)),
+ (CVT_bf16_u32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
+def : Pat<(bf16 (uint_to_fp Int16Regs:$a)),
+ (CVT_bf16_u16 Int16Regs:$a, CvtRN)>;
+def : Pat<(bf16 (uint_to_fp Int32Regs:$a)),
+ (CVT_bf16_u32 Int32Regs:$a, CvtRN)>;
+def : Pat<(bf16 (uint_to_fp Int64Regs:$a)),
+ (CVT_bf16_u64 Int64Regs:$a, CvtRN)>;
+
// sint -> f32
def : Pat<(f32 (sint_to_fp Int1Regs:$a)),
(CVT_f32_s32 (SELP_u32ii 1, 0, Int1Regs:$a), CvtRN)>;
@@ -2877,6 +3067,25 @@
def : Pat<(i64 (fp_to_uint (f16 Float16Regs:$a))),
(CVT_u64_f16 Float16Regs:$a, CvtRZI)>;
+// bf16 -> sint
+def : Pat<(i1 (fp_to_sint (bf16 BFloat16Regs:$a))),
+ (SETP_b16ri (BITCONVERT_16_BF2I BFloat16Regs:$a), 0, CmpEQ)>;
+def : Pat<(i16 (fp_to_sint (bf16 BFloat16Regs:$a))),
+ (CVT_s16_bf16 (bf16 BFloat16Regs:$a), CvtRZI)>;
+def : Pat<(i32 (fp_to_sint (bf16 BFloat16Regs:$a))),
+ (CVT_s32_bf16 (bf16 BFloat16Regs:$a), CvtRZI)>;
+def : Pat<(i64 (fp_to_sint (bf16 BFloat16Regs:$a))),
+ (CVT_s64_bf16 BFloat16Regs:$a, CvtRZI)>;
+
+// bf16 -> uint
+def : Pat<(i1 (fp_to_uint (bf16 BFloat16Regs:$a))),
+ (SETP_b16ri (BITCONVERT_16_BF2I BFloat16Regs:$a), 0, CmpEQ)>;
+def : Pat<(i16 (fp_to_uint (bf16 BFloat16Regs:$a))),
+ (CVT_u16_bf16 BFloat16Regs:$a, CvtRZI)>;
+def : Pat<(i32 (fp_to_uint (bf16 BFloat16Regs:$a))),
+ (CVT_u32_bf16 BFloat16Regs:$a, CvtRZI)>;
+def : Pat<(i64 (fp_to_uint (bf16 BFloat16Regs:$a))),
+ (CVT_u64_bf16 BFloat16Regs:$a, CvtRZI)>;
// f32 -> sint
def : Pat<(i1 (fp_to_sint Float32Regs:$a)),
(SETP_b32ri (BITCONVERT_32_F2I Float32Regs:$a), 0, CmpEQ)>;
@@ -3024,6 +3233,9 @@
def : Pat<(select Int32Regs:$pred, (f16 Float16Regs:$a), (f16 Float16Regs:$b)),
(SELP_f16rr Float16Regs:$a, Float16Regs:$b,
(SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>;
+def : Pat<(select Int32Regs:$pred, (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)),
+ (SELP_bf16rr BFloat16Regs:$a, BFloat16Regs:$b,
+ (SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>;
def : Pat<(select Int32Regs:$pred, Float32Regs:$a, Float32Regs:$b),
(SELP_f32rr Float32Regs:$a, Float32Regs:$b,
(SETP_b32ri (ANDb32ri Int32Regs:$pred, 1), 1, CmpEQ))>;
@@ -3124,6 +3336,18 @@
(ins Int32Regs:$src),
"mov.b32 \t{{$lo, $hi}}, $src;",
[]>;
+ def BF16x2toBF16_0 : NVPTXInst<(outs BFloat16Regs:$dst),
+ (ins BFloat16x2Regs:$src),
+ "{{ .reg .b16 \t%tmp_hi;\n\t"
+ " mov.b32 \t{$dst, %tmp_hi}, $src; }}",
+ [(set BFloat16Regs:$dst,
+ (extractelt (v2bf16 BFloat16x2Regs:$src), 0))]>;
+ def BF16x2toBF16_1 : NVPTXInst<(outs BFloat16Regs:$dst),
+ (ins BFloat16x2Regs:$src),
+ "{{ .reg .b16 \t%tmp_lo;\n\t"
+ " mov.b32 \t{%tmp_lo, $dst}, $src; }}",
+ [(set BFloat16Regs:$dst,
+ (extractelt (v2bf16 BFloat16x2Regs:$src), 1))]>;
}
// Count leading zeros
@@ -3193,10 +3417,17 @@
def : Pat<(f16 (fpround Float32Regs:$a)),
(CVT_f16_f32 Float32Regs:$a, CvtRN)>;
+// fpround f32 -> bf16
+def : Pat<(bf16 (fpround Float32Regs:$a)),
+ (CVT_bf16_f32 Float32Regs:$a, CvtRN)>;
+
// fpround f64 -> f16
def : Pat<(f16 (fpround Float64Regs:$a)),
(CVT_f16_f64 Float64Regs:$a, CvtRN)>;
+// fpround f64 -> bf16
+def : Pat<(bf16 (fpround Float64Regs:$a)),
+ (CVT_bf16_f64 Float64Regs:$a, CvtRN)>;
// fpround f64 -> f32
def : Pat<(f32 (fpround Float64Regs:$a)),
(CVT_f32_f64 Float64Regs:$a, CvtRN_FTZ)>, Requires<[doF32FTZ]>;
@@ -3208,11 +3439,20 @@
(CVT_f32_f16 Float16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>;
def : Pat<(f32 (fpextend (f16 Float16Regs:$a))),
(CVT_f32_f16 Float16Regs:$a, CvtNONE)>;
+// fpextend bf16 -> f32
+def : Pat<(f32 (fpextend (bf16 BFloat16Regs:$a))),
+ (CVT_f32_bf16 BFloat16Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>;
+def : Pat<(f32 (fpextend (bf16 BFloat16Regs:$a))),
+ (CVT_f32_bf16 BFloat16Regs:$a, CvtNONE)>;
// fpextend f16 -> f64
def : Pat<(f64 (fpextend (f16 Float16Regs:$a))),
(CVT_f64_f16 Float16Regs:$a, CvtNONE)>;
+// fpextend bf16 -> f64
+def : Pat<(f64 (fpextend (bf16 BFloat16Regs:$a))),
+ (CVT_f64_bf16 BFloat16Regs:$a, CvtNONE)>;
+
// fpextend f32 -> f64
def : Pat<(f64 (fpextend Float32Regs:$a)),
(CVT_f64_f32 Float32Regs:$a, CvtNONE_FTZ)>, Requires<[doF32FTZ]>;
@@ -3227,6 +3467,8 @@
multiclass CVT_ROUND<SDNode OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
def : Pat<(OpNode (f16 Float16Regs:$a)),
(CVT_f16_f16 Float16Regs:$a, Mode)>;
+ def : Pat<(OpNode (bf16 BFloat16Regs:$a)),
+ (CVT_bf16_bf16 BFloat16Regs:$a, Mode)>;
def : Pat<(OpNode Float32Regs:$a),
(CVT_f32_f32 Float32Regs:$a, ModeFTZ)>, Requires<[doF32FTZ]>;
def : Pat<(OpNode Float32Regs:$a),
Index: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -143,6 +143,26 @@
}
}
+static bool Isv2f16Orv2bf16Type(MVT VT) {
+ switch (VT.SimpleTy) {
+ default:
+ return false;
+ case MVT::v2f16:
+ case MVT::v2bf16:
+ return true;
+ }
+}
+
+static bool Isf16Orbf16Type(MVT VT) {
+ switch (VT.SimpleTy) {
+ default:
+ return false;
+ case MVT::f16:
+ case MVT::bf16:
+ return true;
+ }
+}
+
/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
/// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors
/// into their primitive components.
@@ -193,7 +213,7 @@
// Vectors with an even number of f16 elements will be passed to
// us as an array of v2f16/v2bf16 elements. We must match this so we
// stay in sync with Ins/Outs.
- if ((EltVT == MVT::f16 || EltVT == MVT::bf16) && NumElts % 2 == 0) {
+ if ((Isf16Orbf16Type(EltVT.getSimpleVT())) && NumElts % 2 == 0) {
EltVT = EltVT == MVT::f16 ? MVT::v2f16 : MVT::v2bf16;
NumElts /= 2;
}
@@ -398,6 +418,11 @@
setOperationAction(Op, VT, STI.allowFP16Math() ? Action : NoF16Action);
};
+ auto setBF16OperationAction = [&](unsigned Op, MVT VT, LegalizeAction Action,
+ LegalizeAction NoBF16Action) {
+ setOperationAction(Op, VT, STI.allowBF16Math() ? Action : NoBF16Action);
+ };
+
addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
@@ -406,8 +431,6 @@
addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
addRegisterClass(MVT::f16, &NVPTX::Float16RegsRegClass);
addRegisterClass(MVT::v2f16, &NVPTX::Float16x2RegsRegClass);
- addRegisterClass(MVT::bf16, &NVPTX::Float16RegsRegClass);
- addRegisterClass(MVT::v2bf16, &NVPTX::Float16x2RegsRegClass);
// Conversion to/from FP16/FP16x2 is always legal.
setOperationAction(ISD::SINT_TO_FP, MVT::f16, Legal);
@@ -420,6 +443,16 @@
setFP16OperationAction(ISD::SETCC, MVT::f16, Legal, Promote);
setFP16OperationAction(ISD::SETCC, MVT::v2f16, Legal, Expand);
+ // Conversion to/from BFP16/BFP16x2 is always legal.
+ setOperationAction(ISD::SINT_TO_FP, MVT::bf16, Legal);
+ setOperationAction(ISD::FP_TO_SINT, MVT::bf16, Legal);
+ setOperationAction(ISD::BUILD_VECTOR, MVT::v2bf16, Custom);
+ setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2bf16, Custom);
+ setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2bf16, Expand);
+ setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2bf16, Expand);
+
+ setBF16OperationAction(ISD::SETCC, MVT::bf16, Legal, Promote);
+ setBF16OperationAction(ISD::SETCC, MVT::v2bf16, Legal, Expand);
// Operations not directly supported by NVPTX.
for (MVT VT : {MVT::f16, MVT::v2f16, MVT::f32, MVT::f64, MVT::i1, MVT::i8,
MVT::i16, MVT::i32, MVT::i64}) {
@@ -476,17 +509,25 @@
// Turn FP extload into load/fpextend
setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
+ setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
+ setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand);
+ setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand);
+ setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand);
+ setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand);
+ setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand);
setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand);
// Turn FP truncstore into trunc + store.
// FIXME: vector types should also be expanded
setTruncStoreAction(MVT::f32, MVT::f16, Expand);
setTruncStoreAction(MVT::f64, MVT::f16, Expand);
+ setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
+ setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
setTruncStoreAction(MVT::f64, MVT::f32, Expand);
// PTX does not support load / store predicate registers
@@ -563,9 +604,9 @@
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL,
ISD::SREM, ISD::UREM});
- // setcc for f16x2 needs special handling to prevent legalizer's
- // attempt to scalarize it due to v2i1 not being legal.
- if (STI.allowFP16Math())
+ // setcc for f16x2 and bf16x2 needs special handling to prevent
+ // legalizer's attempt to scalarize it due to v2i1 not being legal.
+ if (STI.allowFP16Math() || STI.allowBF16Math())
setTargetDAGCombine(ISD::SETCC);
// Promote fp16 arithmetic if fp16 hardware isn't available or the
@@ -579,6 +620,11 @@
setFP16OperationAction(Op, MVT::v2f16, Legal, Expand);
}
+ for (const auto &Op : {ISD::FADD, ISD::FMUL, ISD::FSUB, ISD::FMA}) {
+ setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
+ setBF16OperationAction(Op, MVT::v2bf16, Legal, Expand);
+ }
+
// f16/f16x2 neg was introduced in PTX 60, SM_53.
const bool IsFP16FP16x2NegAvailable = STI.getSmVersion() >= 53 &&
STI.getPTXVersion() >= 60 &&
@@ -587,19 +633,29 @@
setOperationAction(ISD::FNEG, VT,
IsFP16FP16x2NegAvailable ? Legal : Expand);
+ const bool IsBFP16FP16x2NegAvailable = STI.getSmVersion() >= 80 &&
+ STI.getPTXVersion() >= 70 &&
+ STI.allowBF16Math();
+ for (const auto &VT : {MVT::bf16, MVT::v2bf16})
+ setOperationAction(ISD::FNEG, VT,
+ IsBFP16FP16x2NegAvailable ? Legal : Expand);
// (would be) Library functions.
// These map to conversion instructions for scalar FP types.
for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
ISD::FROUNDEVEN, ISD::FTRUNC}) {
+ setOperationAction(Op, MVT::bf16, Legal);
setOperationAction(Op, MVT::f16, Legal);
setOperationAction(Op, MVT::f32, Legal);
setOperationAction(Op, MVT::f64, Legal);
setOperationAction(Op, MVT::v2f16, Expand);
+ setOperationAction(Op, MVT::v2bf16, Expand);
}
setOperationAction(ISD::FROUND, MVT::f16, Promote);
setOperationAction(ISD::FROUND, MVT::v2f16, Expand);
+ setOperationAction(ISD::FROUND, MVT::bf16, Promote);
+ setOperationAction(ISD::FROUND, MVT::v2bf16, Expand);
setOperationAction(ISD::FROUND, MVT::f32, Custom);
setOperationAction(ISD::FROUND, MVT::f64, Custom);
@@ -607,6 +663,8 @@
// 'Expand' implements FCOPYSIGN without calling an external library.
setOperationAction(ISD::FCOPYSIGN, MVT::f16, Expand);
setOperationAction(ISD::FCOPYSIGN, MVT::v2f16, Expand);
+ setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand);
+ setOperationAction(ISD::FCOPYSIGN, MVT::v2bf16, Expand);
setOperationAction(ISD::FCOPYSIGN, MVT::f32, Expand);
setOperationAction(ISD::FCOPYSIGN, MVT::f64, Expand);
@@ -616,9 +674,11 @@
for (const auto &Op :
{ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FABS}) {
setOperationAction(Op, MVT::f16, Promote);
+ setOperationAction(Op, MVT::bf16, Promote);
setOperationAction(Op, MVT::f32, Legal);
setOperationAction(Op, MVT::f64, Legal);
setOperationAction(Op, MVT::v2f16, Expand);
+ setOperationAction(Op, MVT::v2bf16, Expand);
}
// max.f16, max.f16x2 and max.NaN are supported on sm_80+.
auto GetMinMaxAction = [&](LegalizeAction NotSm80Action) {
@@ -636,6 +696,12 @@
setOperationAction(Op, MVT::f32, GetMinMaxAction(Expand));
setFP16OperationAction(Op, MVT::v2f16, GetMinMaxAction(Expand), Expand);
}
+ for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
+ setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Promote), Promote);
+ setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand);
+ setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Expand), Expand);
+ setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand);
+ }
// No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate.
// No FPOW or FREM in PTX.
@@ -1252,7 +1318,7 @@
if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
VT.getScalarType() == MVT::i1)
return TypeSplitVector;
- if (VT == MVT::v2f16)
+ if (Isv2f16Orv2bf16Type(VT))
return TypeLegal;
return TargetLoweringBase::getPreferredVectorAction(VT);
}
@@ -1402,7 +1468,7 @@
sz = promoteScalarArgumentSize(sz);
} else if (isa<PointerType>(Ty)) {
sz = PtrVT.getSizeInBits();
- } else if (Ty->isHalfTy())
+ } else if (Ty->isHalfTy() || Ty->isBFloatTy())
// PTX ABI requires all scalar parameters to be at least 32
// bits in size. fp16 normally uses .b16 as its storage type
// in PTX, so its size must be adjusted here, too.
@@ -2037,7 +2103,7 @@
// generates good SASS in both cases.
SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
SelectionDAG &DAG) const {
- if (!(Op->getValueType(0) == MVT::v2f16 &&
+ if (!(Isv2f16Orv2bf16Type(Op->getOperand(0).getValueType().getSimpleVT()) &&
isa<ConstantFPSDNode>(Op->getOperand(0)) &&
isa<ConstantFPSDNode>(Op->getOperand(1))))
return Op;
@@ -2048,7 +2114,7 @@
cast<ConstantFPSDNode>(Op->getOperand(1))->getValueAPF().bitcastToAPInt();
SDValue Const =
DAG.getConstant(E1.zext(32).shl(16) | E0.zext(32), SDLoc(Op), MVT::i32);
- return DAG.getNode(ISD::BITCAST, SDLoc(Op), MVT::v2f16, Const);
+ return DAG.getNode(ISD::BITCAST, SDLoc(Op), Op->getValueType(0), Const);
}
SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
@@ -2409,7 +2475,7 @@
// v2f16 is legal, so we can't rely on legalizer to handle unaligned
// loads and have to handle it here.
- if (Op.getValueType() == MVT::v2f16) {
+ if (Isv2f16Orv2bf16Type(Op.getValueType().getSimpleVT())) {
LoadSDNode *Load = cast<LoadSDNode>(Op);
EVT MemVT = Load->getMemoryVT();
if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
@@ -2454,7 +2520,7 @@
// v2f16 is legal, so we can't rely on legalizer to handle unaligned
// stores and have to handle it here.
- if (VT == MVT::v2f16 &&
+ if ((Isv2f16Orv2bf16Type(VT.getSimpleVT())) &&
!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
VT, *Store->getMemOperand()))
return expandUnalignedStore(Store, DAG);
@@ -2541,7 +2607,7 @@
// v8f16 is a special case. PTX doesn't have st.v8.f16
// instruction. Instead, we split the vector into v2f16 chunks and
// store them with st.v4.b32.
- assert((EltVT == MVT::f16 || EltVT == MVT::bf16) &&
+ assert((Isf16Orbf16Type(EltVT.getSimpleVT())) &&
"Wrong type for the vector.");
Opcode = NVPTXISD::StoreV4;
StoreF16x2 = true;
@@ -2557,11 +2623,12 @@
// Combine f16,f16 -> v2f16
NumElts /= 2;
for (unsigned i = 0; i < NumElts; ++i) {
- SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val,
+ SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
DAG.getIntPtrConstant(i * 2, DL));
- SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val,
+ SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
DAG.getIntPtrConstant(i * 2 + 1, DL));
- SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2f16, E0, E1);
+ EVT VecVT = EVT::getVectorVT(*DAG.getContext(), EltVT, 2);
+ SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, E0, E1);
Ops.push_back(V2);
}
} else {
@@ -2733,9 +2800,9 @@
EVT LoadVT = EltVT;
if (EltVT == MVT::i1)
LoadVT = MVT::i8;
- else if (EltVT == MVT::v2f16)
+ else if (Isv2f16Orv2bf16Type(EltVT.getSimpleVT()))
// getLoad needs a vector type, but it can't handle
- // vectors which contain v2f16 elements. So we must load
+ // vectors which contain v2f16 or v2bf16 elements. So we must load
// using i32 here and then bitcast back.
LoadVT = MVT::i32;
@@ -5171,7 +5238,7 @@
// v8f16 is a special case. PTX doesn't have ld.v8.f16
// instruction. Instead, we split the vector into v2f16 chunks and
// load them with ld.v4.b32.
- assert((EltVT == MVT::f16 || EltVT == MVT::bf16) &&
+ assert(Isf16Orbf16Type(EltVT.getSimpleVT()) &&
"Unsupported v8 vector type.");
LoadF16x2 = true;
Opcode = NVPTXISD::LoadV4;
Index: llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
+++ llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h
@@ -72,6 +72,7 @@
bool trySurfaceIntrinsic(SDNode *N);
bool tryBFE(SDNode *N);
bool tryConstantFP16(SDNode *N);
+ bool tryConstantBF16(SDNode *N);
bool SelectSETP_F16X2(SDNode *N);
bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
Index: llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -537,6 +537,16 @@
return true;
}
+bool NVPTXDAGToDAGISel::tryConstantBF16(SDNode *N) {
+ if (N->getValueType(0) != MVT::bf16)
+ return false;
+ SDValue Val = CurDAG->getTargetConstantFP(
+ cast<ConstantFPSDNode>(N)->getValueAPF(), SDLoc(N), MVT::bf16);
+ SDNode *LoadConstBF16 =
+ CurDAG->getMachineNode(NVPTX::LOAD_CONST_BF16, SDLoc(N), MVT::bf16, Val);
+ ReplaceNode(N, LoadConstBF16);
+ return true;
+}
// Map ISD:CONDCODE value to appropriate CmpMode expected by
// NVPTXInstPrinter::printCmpMode()
static unsigned getPTXCmpMode(const CondCodeSDNode &CondCode, bool FTZ) {
@@ -1288,6 +1298,10 @@
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
EltVT = MVT::v2f16;
NumElts /= 2;
+ } else if (EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) {
+ assert(NumElts % 2 == 0 && "Vector must have even number of elements");
+ EltVT = MVT::v2bf16;
+ NumElts /= 2;
}
}
Index: llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -267,6 +267,10 @@
MCOp = MCOperand::createExpr(
NVPTXFloatMCExpr::createConstantFPHalf(Val, OutContext));
break;
+ case Type::BFloatTyID:
+ MCOp = MCOperand::createExpr(
+ NVPTXFloatMCExpr::createConstantBFPHalf(Val, OutContext));
+ break;
case Type::FloatTyID:
MCOp = MCOperand::createExpr(
NVPTXFloatMCExpr::createConstantFPSingle(Val, OutContext));
@@ -1353,8 +1357,10 @@
}
break;
}
+ case Type::BFloatTyID:
case Type::HalfTyID:
- // fp16 is stored as .b16 for compatibility with pre-sm_53 PTX assembly.
+ // fp16 and bf16 are stored as .b16 for compatibility with pre-sm_53
+ // PTX assembly.
return "b16";
case Type::FloatTyID:
return "f32";
@@ -1588,7 +1594,7 @@
} else if (PTy) {
assert(PTySizeInBits && "Invalid pointer size");
sz = PTySizeInBits;
- } else if (Ty->isHalfTy())
+ } else if (Ty->isHalfTy() || Ty->isBFloatTy())
// PTX ABI requires all scalar parameters to be at least 32
// bits in size. fp16 normally uses .b16 as its storage type
// in PTX, so its size must be adjusted here, too.
Index: llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
===================================================================
--- llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -61,9 +61,11 @@
OS << "%fd";
break;
case 7:
+ case 9:
OS << "%h";
break;
case 8:
+ case 10:
OS << "%hh";
break;
}
Index: llvm/include/llvm/IR/IntrinsicsNVVM.td
===================================================================
--- llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -597,16 +597,18 @@
[IntrNoMem, IntrSpeculatable, Commutative]>;
}
- foreach variant = ["_bf16", "_nan_bf16", "_xorsign_abs_bf16",
- "_nan_xorsign_abs_bf16"] in {
+ foreach variant = ["_bf16", "_ftz_bf16", "_nan_bf16", "_ftz_nan_bf16",
+ "_xorsign_abs_bf16", "_ftz_xorsign_abs_bf16", "_nan_xorsign_abs_bf16",
+ "_ftz_nan_xorsign_abs_bf16"] in {
def int_nvvm_f # operation # variant :
ClangBuiltin<!strconcat("__nvvm_f", operation, variant)>,
DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_i16_ty, llvm_i16_ty],
[IntrNoMem, IntrSpeculatable, Commutative]>;
}
- foreach variant = ["_bf16x2", "_nan_bf16x2", "_xorsign_abs_bf16x2",
- "_nan_xorsign_abs_bf16x2"] in {
+ foreach variant = ["_bf16x2", "_ftz_bf16x2", "_nan_bf16x2",
+ "_ftz_nan_bf16x2", "_xorsign_abs_bf16x2", "_ftz_xorsign_abs_bf16x2",
+ "_nan_xorsign_abs_bf16x2", "_ftz_nan_xorsign_abs_bf16x2"] in {
def int_nvvm_f # operation # variant :
ClangBuiltin<!strconcat("__nvvm_f", operation, variant)>,
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty],
@@ -874,17 +876,19 @@
[IntrNoMem, IntrSpeculatable]>;
}
- foreach variant = ["_rn_bf16", "_rn_relu_bf16"] in {
+ foreach variant = ["_rn_bf16", "_rn_ftz_bf16", "_rn_sat_bf16",
+ "_rn_ftz_sat_bf16", "_rn_relu_bf16", "_rn_ftz_relu_bf16"] in {
def int_nvvm_fma # variant : ClangBuiltin<!strconcat("__nvvm_fma", variant)>,
- DefaultAttrsIntrinsic<[llvm_i16_ty],
- [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty],
+ DefaultAttrsIntrinsic<[llvm_bfloat_ty],
+ [llvm_bfloat_ty, llvm_bfloat_ty, llvm_bfloat_ty],
[IntrNoMem, IntrSpeculatable]>;
}
- foreach variant = ["_rn_bf16x2", "_rn_relu_bf16x2"] in {
+ foreach variant = ["_rn_bf16x2", "_rn_ftz_bf16x2", "_rn_sat_bf16x2",
+ "_rn_ftz_sat_bf16x2", "_rn_relu_bf16x2", "_rn_ftz_relu_bf16x2"] in {
def int_nvvm_fma # variant : ClangBuiltin<!strconcat("__nvvm_fma", variant)>,
- DefaultAttrsIntrinsic<[llvm_i32_ty],
- [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
+ DefaultAttrsIntrinsic<[llvm_v2bf16_ty],
+ [llvm_v2bf16_ty, llvm_v2bf16_ty, llvm_v2bf16_ty],
[IntrNoMem, IntrSpeculatable]>;
}
@@ -1236,6 +1240,11 @@
def int_nvvm_f2h_rn : ClangBuiltin<"__nvvm_f2h_rn">,
DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_float_ty], [IntrNoMem, IntrSpeculatable]>;
+ def int_nvvm_bf2h_rn_ftz : ClangBuiltin<"__nvvm_bf2h_rn_ftz">,
+ DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_bfloat_ty], [IntrNoMem, IntrSpeculatable]>;
+ def int_nvvm_bf2h_rn : ClangBuiltin<"__nvvm_bf2h_rn">,
+ DefaultAttrsIntrinsic<[llvm_i16_ty], [llvm_bfloat_ty], [IntrNoMem, IntrSpeculatable]>;
+
def int_nvvm_ff2bf16x2_rn : ClangBuiltin<"__nvvm_ff2bf16x2_rn">,
Intrinsic<[llvm_i32_ty], [llvm_float_ty, llvm_float_ty], [IntrNoMem, IntrNoCallback]>;
def int_nvvm_ff2bf16x2_rn_relu : ClangBuiltin<"__nvvm_ff2bf16x2_rn_relu">,
Index: clang/include/clang/Basic/BuiltinsNVPTX.def
===================================================================
--- clang/include/clang/Basic/BuiltinsNVPTX.def
+++ clang/include/clang/Basic/BuiltinsNVPTX.def
@@ -145,12 +145,16 @@
TARGET_BUILTIN(__nvvm_fmin_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_bf16, "UsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_bf16, "UsUsUs", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16, "UsUsUs", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
@@ -187,12 +191,16 @@
TARGET_BUILTIN(__nvvm_fmax_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_bf16, "UsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_bf16, "UsUsUs", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16, "UsUsUs", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits