kushanam updated this revision to Diff 523106.
kushanam added a comment.
adding cland directives and removing bf16 registers
Depends on D144911 <https://reviews.llvm.org/D144911>
Differential Revision: https://reviews.llvm.org/D144911
Repository:
rG LLVM Github Monorepo
CHANGES SINCE LAST ACTION
https://reviews.llvm.org/D149976/new/
https://reviews.llvm.org/D149976
Files:
clang/include/clang/Basic/BuiltinsNVPTX.def
llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
Index: llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
+++ llvm/lib/Target/NVPTX/NVPTXRegisterInfo.cpp
@@ -29,14 +29,13 @@
std::string getNVPTXRegClassName(TargetRegisterClass const *RC) {
if (RC == &NVPTX::Float32RegsRegClass)
return ".f32";
- if (RC == &NVPTX::Float16RegsRegClass || RC == &NVPTX::BFloat16RegsRegClass)
+ if (RC == &NVPTX::Float16RegsRegClass)
// Ideally fp16 registers should be .f16, but this syntax is only
// supported on sm_53+. On the other hand, .b16 registers are
// accepted for all supported fp16 instructions on all GPU
// variants, so we can use them instead.
return ".b16";
- if (RC == &NVPTX::Float16x2RegsRegClass ||
- RC == &NVPTX::BFloat16x2RegsRegClass)
+ if (RC == &NVPTX::Float16x2RegsRegClass)
return ".b32";
if (RC == &NVPTX::Float64RegsRegClass)
return ".f64";
@@ -74,10 +73,9 @@
std::string getNVPTXRegClassStr(TargetRegisterClass const *RC) {
if (RC == &NVPTX::Float32RegsRegClass)
return "%f";
- if (RC == &NVPTX::Float16RegsRegClass || RC == &NVPTX::BFloat16RegsRegClass)
+ if (RC == &NVPTX::Float16RegsRegClass)
return "%h";
- if (RC == &NVPTX::Float16x2RegsRegClass ||
- RC == &NVPTX::BFloat16x2RegsRegClass)
+ if (RC == &NVPTX::Float16x2RegsRegClass)
return "%hh";
if (RC == &NVPTX::Float64RegsRegClass)
return "%fd";
Index: llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -998,9 +998,6 @@
FMA_TUPLE<"_rn_ftz_relu_f16x2", int_nvvm_fma_rn_ftz_relu_f16x2,
Float16x2Regs, [hasPTX70, hasSM80]>,
- // FMA_TUPLE<"_rn_relu_bf16", int_nvvm_fma_rn_relu_bf16, BFloat16Regs,
- // [hasPTX70, hasSM80]>,
-
FMA_TUPLE<"_rn_bf16x2", int_nvvm_fma_rn_bf16x2, BFloat16x2Regs,
[hasPTX70, hasSM80]>,
FMA_TUPLE<"_rn_relu_bf16x2", int_nvvm_fma_rn_relu_bf16x2, BFloat16x2Regs,
@@ -1254,24 +1251,6 @@
def : Pat<(int_nvvm_ff2bf16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
(CVT_bf16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>;
-// def : Pat<(int_nvvm_ff2f16x2_rn Float32Regs:$a, Float32Regs:$b),
-// (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN)>;
-// def : Pat<(int_nvvm_ff2f16x2_rn_relu Float32Regs:$a, Float32Regs:$b),
-// (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRN_RELU)>;
-// def : Pat<(int_nvvm_ff2f16x2_rz Float32Regs:$a, Float32Regs:$b),
-// (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ)>;
-// def : Pat<(int_nvvm_ff2f16x2_rz_relu Float32Regs:$a, Float32Regs:$b),
-// (CVT_f16x2_f32 Float32Regs:$a, Float32Regs:$b, CvtRZ_RELU)>;
-
-// def : Pat<(int_nvvm_f2bf16_rn Float32Regs:$a),
-// (CVT_bf16_f32 Float32Regs:$a, CvtRN)>;
-// def : Pat<(int_nvvm_f2bf16_rn_relu Float32Regs:$a),
-// (CVT_bf16_f32 Float32Regs:$a, CvtRN_RELU)>;
-// def : Pat<(int_nvvm_f2bf16_rz Float32Regs:$a),
-// (CVT_bf16_f32 Float32Regs:$a, CvtRZ)>;
-// def : Pat<(int_nvvm_f2bf16_rz_relu Float32Regs:$a),
-// (CVT_bf16_f32 Float32Regs:$a, CvtRZ_RELU)>;
-
def CVT_tf32_f32 :
NVPTXInst<(outs Int32Regs:$dest), (ins Float32Regs:$a),
"cvt.rna.tf32.f32 \t$dest, $a;",
@@ -1387,11 +1366,6 @@
def : Pat<(int_nvvm_f2h_rn Float32Regs:$a),
(BITCONVERT_16_F2I (CVT_f16_f32 Float32Regs:$a, CvtRN))>;
-// def : Pat<(int_nvvm_bf2h_rn_ftz Float32Regs:$a),
-// (BITCONVERT_16_BF2I (CVT_bf16_f32 Float32Regs:$a, CvtRN_FTZ))>;
-// def : Pat<(int_nvvm_f2h_rn BFloat16Regs:$a),
-// (BITCONVERT_16_BF2I (CVT_bf16_f32 BFloat16Regs:$a, CvtRN))>;
-
//
// Bitcast
//
Index: llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
+++ llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
@@ -656,15 +656,6 @@
def CVT_INREG_s64_s32 : NVPTXInst<(outs Int64Regs:$dst), (ins Int64Regs:$src),
"cvt.s64.s32 \t$dst, $src;", []>;
-multiclass CVT_FROM_FLOAT_SM80<string FromName, RegisterClass RC> {
- def _f32 :
- NVPTXInst<(outs RC:$dst),
- (ins Float32Regs:$src, CvtMode:$mode),
- !strconcat("cvt${mode:base}${mode:relu}.",
- FromName, ".f32 \t$dst, $src;"), []>,
- Requires<[hasPTX70, hasSM80]>;
- }
-
multiclass CVT_FROM_FLOAT_V2_SM80<string FromName, RegisterClass RC> {
def _f32 :
NVPTXInst<(outs RC:$dst),
@@ -753,12 +744,6 @@
"selp.b32 \t$dst, $a, $b, $p;",
[(set Float16x2Regs:$dst,
(select Int1Regs:$p, (v2f16 Float16x2Regs:$a), (v2f16 Float16x2Regs:$b)))]>;
-def SELP_bf16x2rr :
- NVPTXInst<(outs BFloat16x2Regs:$dst),
- (ins BFloat16x2Regs:$a, BFloat16x2Regs:$b, Int1Regs:$p),
- "selp.b32 \t$dst, $a, $b, $p;",
- [(set BFloat16x2Regs:$dst,
- (select Int1Regs:$p, (v2bf16 BFloat16x2Regs:$a), (v2bf16 BFloat16x2Regs:$b)))]>;
//-----------------------------------
// Test Instructions
@@ -2091,7 +2076,7 @@
(SETP_f16rr (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, Mode)>,
Requires<[useFP16Math]>;
- //bf16 -> pred
+ // bf16 -> pred
def : Pat<(i1 (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b))),
(SETP_bf16rr BFloat16Regs:$a, BFloat16Regs:$b, ModeFTZ)>,
Requires<[useBFP16Math,doF32FTZ]>;
@@ -2156,7 +2141,7 @@
(SET_f16ir (LOAD_CONST_F16 fpimm:$a), Float16Regs:$b, Mode)>,
Requires<[useFP16Math]>;
- // bf16 -> i32
+ // bf16 -> i32
def : Pat<(i32 (OpNode (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b))),
(SET_bf16rr BFloat16Regs:$a, BFloat16Regs:$b, ModeFTZ)>,
Requires<[useBFP16Math, doF32FTZ]>;
@@ -2707,9 +2692,7 @@
defm LD_i32 : LD<Int32Regs>;
defm LD_i64 : LD<Int64Regs>;
defm LD_f16 : LD<Float16Regs>;
- defm LD_bf16 : LD<BFloat16Regs>;
defm LD_f16x2 : LD<Float16x2Regs>;
- defm LD_bf16x2 : LD<BFloat16x2Regs>;
defm LD_f32 : LD<Float32Regs>;
defm LD_f64 : LD<Float64Regs>;
}
@@ -3366,29 +3349,29 @@
[(set BFloat16Regs:$dst,
(extractelt (v2bf16 BFloat16x2Regs:$src), 1))]>;
- // Coalesce two bf16 registers into bf16x2
- def BuildBF16x2 : NVPTXInst<(outs BFloat16x2Regs:$dst),
- (ins BFloat16Regs:$a, BFloat16Regs:$b),
- "mov.b32 \t$dst, {{$a, $b}};",
- [(set (v2bf16 BFloat16x2Regs:$dst),
- (build_vector (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>;
-
- // Directly initializing underlying the b32 register is one less SASS
- // instruction than than vector-packing move.
- def BuildBF16x2i : NVPTXInst<(outs BFloat16x2Regs:$dst), (ins i32imm:$src),
- "mov.b32 \t$dst, $src;",
- []>;
-
- // Split f16x2 into two f16 registers.
- def SplitBF16x2 : NVPTXInst<(outs BFloat16Regs:$lo, BFloat16Regs:$hi),
- (ins BFloat16x2Regs:$src),
- "mov.b32 \t{{$lo, $hi}}, $src;",
- []>;
- // Split an i32 into two f16
- def SplitI32toBF16x2 : NVPTXInst<(outs BFloat16Regs:$lo, BFloat16Regs:$hi),
- (ins Int32Regs:$src),
- "mov.b32 \t{{$lo, $hi}}, $src;",
- []>;
+ // // Coalesce two bf16 registers into bf16x2
+ // def BuildBF16x2 : NVPTXInst<(outs BFloat16x2Regs:$dst),
+ // (ins BFloat16Regs:$a, BFloat16Regs:$b),
+ // "mov.b32 \t$dst, {{$a, $b}};",
+ // [(set (v2bf16 BFloat16x2Regs:$dst),
+ // (build_vector (bf16 BFloat16Regs:$a), (bf16 BFloat16Regs:$b)))]>;
+
+ // // Directly initializing underlying the b32 register is one less SASS
+ // // instruction than than vector-packing move.
+ // def BuildBF16x2i : NVPTXInst<(outs BFloat16x2Regs:$dst), (ins i32imm:$src),
+ // "mov.b32 \t$dst, $src;",
+ // []>;
+
+ // // Split f16x2 into two f16 registers.
+ // def SplitBF16x2 : NVPTXInst<(outs BFloat16Regs:$lo, BFloat16Regs:$hi),
+ // (ins BFloat16x2Regs:$src),
+ // "mov.b32 \t{{$lo, $hi}}, $src;",
+ // []>;
+ // // Split an i32 into two f16
+ // def SplitI32toBF16x2 : NVPTXInst<(outs BFloat16Regs:$lo, BFloat16Regs:$hi),
+ // (ins Int32Regs:$src),
+ // "mov.b32 \t{{$lo, $hi}}, $src;",
+ // []>;
}
// Count leading zeros
Index: llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
+++ llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp
@@ -56,11 +56,6 @@
: NVPTX::BITCONVERT_16_I2F);
} else if (DestRC == &NVPTX::Float16x2RegsRegClass) {
Op = NVPTX::IMOV32rr;
- } else if (DestRC == &NVPTX::BFloat16RegsRegClass) {
- Op = (SrcRC == &NVPTX::BFloat16RegsRegClass ? NVPTX::BFMOV16rr
- : NVPTX::BITCONVERT_16_I2BF);
- } else if (DestRC == &NVPTX::BFloat16x2RegsRegClass) {
- Op = NVPTX::IMOV32rr;
} else if (DestRC == &NVPTX::Float32RegsRegClass) {
Op = (SrcRC == &NVPTX::Float32RegsRegClass ? NVPTX::FMOV32rr
: NVPTX::BITCONVERT_32_I2F);
Index: llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -143,6 +143,26 @@
}
}
+static bool Isv2f16Orv2bf16Type(MVT VT) {
+ switch (VT.SimpleTy) {
+ default:
+ return false;
+ case MVT::v2f16:
+ case MVT::v2bf16:
+ return true;
+ }
+}
+
+static bool Isf16Orbf16Type(MVT VT) {
+ switch (VT.SimpleTy) {
+ default:
+ return false;
+ case MVT::f16:
+ case MVT::bf16:
+ return true;
+ }
+}
+
/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
/// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors
/// into their primitive components.
@@ -193,7 +213,7 @@
// Vectors with an even number of f16 elements will be passed to
// us as an array of v2f16/v2bf16 elements. We must match this so we
// stay in sync with Ins/Outs.
- if ((EltVT == MVT::f16 || EltVT == MVT::bf16) && NumElts % 2 == 0) {
+ if ((Isf16Orbf16Type(EltVT.getSimpleVT())) && NumElts % 2 == 0) {
EltVT = EltVT == MVT::f16 ? MVT::v2f16 : MVT::v2bf16;
NumElts /= 2;
}
@@ -411,8 +431,6 @@
addRegisterClass(MVT::f64, &NVPTX::Float64RegsRegClass);
addRegisterClass(MVT::f16, &NVPTX::Float16RegsRegClass);
addRegisterClass(MVT::v2f16, &NVPTX::Float16x2RegsRegClass);
- addRegisterClass(MVT::bf16, &NVPTX::BFloat16RegsRegClass);
- addRegisterClass(MVT::v2bf16, &NVPTX::BFloat16x2RegsRegClass);
// Conversion to/from FP16/FP16x2 is always legal.
setOperationAction(ISD::SINT_TO_FP, MVT::f16, Legal);
@@ -586,7 +604,7 @@
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL,
ISD::SREM, ISD::UREM});
- // setcc for f16x2 and bf16x2 needs special handling to prevent
+ // setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
if (STI.allowFP16Math() || STI.allowBF16Math())
setTargetDAGCombine(ISD::SETCC);
@@ -616,8 +634,8 @@
IsFP16FP16x2NegAvailable ? Legal : Expand);
const bool IsBFP16FP16x2NegAvailable = STI.getSmVersion() >= 80 &&
- STI.getPTXVersion() >= 70 &&
- STI.allowBF16Math();
+ STI.getPTXVersion() >= 70 &&
+ STI.allowBF16Math();
for (const auto &VT : {MVT::bf16, MVT::v2bf16})
setOperationAction(ISD::FNEG, VT,
IsBFP16FP16x2NegAvailable ? Legal : Expand);
@@ -631,6 +649,7 @@
setOperationAction(Op, MVT::f32, Legal);
setOperationAction(Op, MVT::f64, Legal);
setOperationAction(Op, MVT::v2f16, Expand);
+ setOperationAction(Op, MVT::v2bf16, Expand);
}
setOperationAction(ISD::FROUND, MVT::f16, Promote);
@@ -680,12 +699,10 @@
for (const auto &Op : {ISD::FMINNUM, ISD::FMAXNUM}) {
setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Promote), Promote);
setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand);
- }
- for (const auto &Op : {ISD::FMINIMUM, ISD::FMAXIMUM}) {
setBF16OperationAction(Op, MVT::bf16, GetMinMaxAction(Expand), Expand);
- setOperationAction(Op, MVT::f32, GetMinMaxAction(Expand));
setBF16OperationAction(Op, MVT::v2bf16, GetMinMaxAction(Expand), Expand);
}
+
// No FEXP2, FLOG2. The PTX ex2 and log2 functions are always approximate.
// No FPOW or FREM in PTX.
@@ -1301,7 +1318,7 @@
if (!VT.isScalableVector() && VT.getVectorNumElements() != 1 &&
VT.getScalarType() == MVT::i1)
return TypeSplitVector;
- if (VT == MVT::v2f16 || VT == MVT::v2bf16)
+ if (Isv2f16Orv2bf16Type(VT))
return TypeLegal;
return TargetLoweringBase::getPreferredVectorAction(VT);
}
@@ -2086,8 +2103,7 @@
// generates good SASS in both cases.
SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
SelectionDAG &DAG) const {
- if (!((Op->getValueType(0) == MVT::v2f16 ||
- Op->getValueType(0) == MVT::v2bf16) &&
+ if (!(Isv2f16Orv2bf16Type(Op->getOperand(0).getValueType().getSimpleVT()) &&
isa<ConstantFPSDNode>(Op->getOperand(0)) &&
isa<ConstantFPSDNode>(Op->getOperand(1))))
return Op;
@@ -2098,9 +2114,7 @@
cast<ConstantFPSDNode>(Op->getOperand(1))->getValueAPF().bitcastToAPInt();
SDValue Const =
DAG.getConstant(E1.zext(32).shl(16) | E0.zext(32), SDLoc(Op), MVT::i32);
- return Op->getValueType(0) == MVT::v2bf16
- ? DAG.getNode(ISD::BITCAST, SDLoc(Op), MVT::v2bf16, Const)
- : DAG.getNode(ISD::BITCAST, SDLoc(Op), MVT::v2f16, Const);
+ return DAG.getNode(ISD::BITCAST, SDLoc(Op), Op->getValueType(0), Const);
}
SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
@@ -2461,7 +2475,7 @@
// v2f16 is legal, so we can't rely on legalizer to handle unaligned
// loads and have to handle it here.
- if (Op.getValueType() == MVT::v2f16 || Op.getValueType() == MVT::v2bf16) {
+ if (Isv2f16Orv2bf16Type(Op.getValueType().getSimpleVT())) {
LoadSDNode *Load = cast<LoadSDNode>(Op);
EVT MemVT = Load->getMemoryVT();
if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
@@ -2506,7 +2520,7 @@
// v2f16 is legal, so we can't rely on legalizer to handle unaligned
// stores and have to handle it here.
- if ((VT == MVT::v2f16 || VT == MVT::v2bf16) &&
+ if ((Isv2f16Orv2bf16Type(VT.getSimpleVT())) &&
!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
VT, *Store->getMemOperand()))
return expandUnalignedStore(Store, DAG);
@@ -2593,7 +2607,7 @@
// v8f16 is a special case. PTX doesn't have st.v8.f16
// instruction. Instead, we split the vector into v2f16 chunks and
// store them with st.v4.b32.
- assert((EltVT == MVT::f16 || EltVT == MVT::bf16) &&
+ assert((Isf16Orbf16Type(EltVT.getSimpleVT())) &&
"Wrong type for the vector.");
Opcode = NVPTXISD::StoreV4;
StoreF16x2 = true;
@@ -2608,24 +2622,14 @@
if (StoreF16x2) {
// Combine f16,f16 -> v2f16
NumElts /= 2;
- if (EltVT == MVT::f16) {
- for (unsigned i = 0; i < NumElts; ++i) {
- SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val,
- DAG.getIntPtrConstant(i * 2, DL));
- SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f16, Val,
- DAG.getIntPtrConstant(i * 2 + 1, DL));
- SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2f16, E0, E1);
- Ops.push_back(V2);
- }
- } else {
- for (unsigned i = 0; i < NumElts; ++i) {
- SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::bf16, Val,
- DAG.getIntPtrConstant(i * 2, DL));
- SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::bf16, Val,
- DAG.getIntPtrConstant(i * 2 + 1, DL));
- SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, MVT::v2bf16, E0, E1);
- Ops.push_back(V2);
- }
+ for (unsigned i = 0; i < NumElts; ++i) {
+ SDValue E0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
+ DAG.getIntPtrConstant(i * 2, DL));
+ SDValue E1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
+ DAG.getIntPtrConstant(i * 2 + 1, DL));
+ EVT VecVT = EVT::getVectorVT(*DAG.getContext(), EltVT, 2);
+ SDValue V2 = DAG.getNode(ISD::BUILD_VECTOR, DL, VecVT, E0, E1);
+ Ops.push_back(V2);
}
} else {
// Then the split values
@@ -2796,7 +2800,7 @@
EVT LoadVT = EltVT;
if (EltVT == MVT::i1)
LoadVT = MVT::i8;
- else if (EltVT == MVT::v2f16 || EltVT == MVT::v2bf16)
+ else if (Isv2f16Orv2bf16Type(EltVT.getSimpleVT()))
// getLoad needs a vector type, but it can't handle
// vectors which contain v2f16 or v2bf16 elements. So we must load
// using i32 here and then bitcast back.
@@ -5234,7 +5238,7 @@
// v8f16 is a special case. PTX doesn't have ld.v8.f16
// instruction. Instead, we split the vector into v2f16 chunks and
// load them with ld.v4.b32.
- assert((EltVT == MVT::f16 || EltVT == MVT::bf16) &&
+ assert(Isf16Orbf16Type(EltVT.getSimpleVT()) &&
"Unsupported v8 vector type.");
LoadF16x2 = true;
Opcode = NVPTXISD::LoadV4;
Index: llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
===================================================================
--- llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
+++ llvm/lib/Target/NVPTX/NVPTXAsmPrinter.cpp
@@ -312,10 +312,6 @@
Ret = (7 << 28);
} else if (RC == &NVPTX::Float16x2RegsRegClass) {
Ret = (8 << 28);
- } else if (RC == &NVPTX::BFloat16RegsRegClass) {
- Ret = (9 << 28);
- } else if (RC == &NVPTX::BFloat16x2RegsRegClass) {
- Ret = (10 << 28);
} else {
report_fatal_error("Bad register class");
}
Index: clang/include/clang/Basic/BuiltinsNVPTX.def
===================================================================
--- clang/include/clang/Basic/BuiltinsNVPTX.def
+++ clang/include/clang/Basic/BuiltinsNVPTX.def
@@ -145,12 +145,16 @@
TARGET_BUILTIN(__nvvm_fmin_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_bf16, "UsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_bf16, "UsUsUs", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16, "UsUsUs", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
@@ -187,12 +191,16 @@
TARGET_BUILTIN(__nvvm_fmax_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_bf16, "UsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_bf16, "UsUsUs", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16, "UsUsUs", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
+TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
AND(SM_86, PTX72))
TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits