================ @@ -39,186 +39,201 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" -#include <algorithm> #define NVVM_REFLECT_FUNCTION "__nvvm_reflect" #define NVVM_REFLECT_OCL_FUNCTION "__nvvm_reflect_ocl" +// Argument of reflect call to retrive arch number +#define CUDA_ARCH_NAME "__CUDA_ARCH" +// Argument of reflect call to retrive ftz mode +#define CUDA_FTZ_NAME "__CUDA_FTZ" +// Name of module metadata where ftz mode is stored +#define CUDA_FTZ_MODULE_NAME "nvvm-reflect-ftz" using namespace llvm; -#define DEBUG_TYPE "nvptx-reflect" +#define DEBUG_TYPE "nvvm-reflect" + +namespace llvm { +void initializeNVVMReflectLegacyPassPass(PassRegistry &); +} // namespace llvm namespace { -class NVVMReflect : public FunctionPass { +class NVVMReflect { + // Map from reflect function call arguments to the value to replace the call + // with. Should include __CUDA_FTZ and __CUDA_ARCH values. + StringMap<unsigned> ReflectMap; + bool handleReflectFunction(Module &M, StringRef ReflectName); + void populateReflectMap(Module &M); + void foldReflectCall(CallInst *Call, Constant *NewValue); + public: - static char ID; - unsigned int SmVersion; - NVVMReflect() : NVVMReflect(0) {} - explicit NVVMReflect(unsigned int Sm) : FunctionPass(ID), SmVersion(Sm) {} + // __CUDA_FTZ is assigned in `runOnModule` by checking nvvm-reflect-ftz module + // metadata. + explicit NVVMReflect(unsigned SmVersion) + : ReflectMap({{CUDA_ARCH_NAME, SmVersion * 10}}) {} + bool runOnModule(Module &M); +}; - bool runOnFunction(Function &) override; +class NVVMReflectLegacyPass : public ModulePass { + NVVMReflect Impl; + +public: + static char ID; + NVVMReflectLegacyPass(unsigned SmVersion) : ModulePass(ID), Impl(SmVersion) {} + bool runOnModule(Module &M) override; }; } // namespace -FunctionPass *llvm::createNVVMReflectPass(unsigned int SmVersion) { - return new NVVMReflect(SmVersion); +ModulePass *llvm::createNVVMReflectPass(unsigned SmVersion) { + return new NVVMReflectLegacyPass(SmVersion); } static cl::opt<bool> NVVMReflectEnabled("nvvm-reflect-enable", cl::init(true), cl::Hidden, cl::desc("NVVM reflection, enabled by default")); -char NVVMReflect::ID = 0; -INITIALIZE_PASS(NVVMReflect, "nvvm-reflect", +char NVVMReflectLegacyPass::ID = 0; +INITIALIZE_PASS(NVVMReflectLegacyPass, "nvvm-reflect", "Replace occurrences of __nvvm_reflect() calls with 0/1", false, false) -static bool runNVVMReflect(Function &F, unsigned SmVersion) { - if (!NVVMReflectEnabled) - return false; - - if (F.getName() == NVVM_REFLECT_FUNCTION || - F.getName() == NVVM_REFLECT_OCL_FUNCTION) { - assert(F.isDeclaration() && "_reflect function should not have a body"); - assert(F.getReturnType()->isIntegerTy() && - "_reflect's return type should be integer"); - return false; +// Allow users to specify additional key/value pairs to reflect. These key/value +// pairs are the last to be added to the VarMap, and therefore will take +// precedence over initial values (i.e. __CUDA_FTZ from module medadata and +// __CUDA_ARCH from SmVersion). +static cl::list<std::string> ReflectList( + "nvvm-reflect-add", cl::value_desc("name=<int>"), cl::Hidden, + cl::desc("A key=value pair. Replace __nvvm_reflect(name) with value."), + cl::ValueRequired); + +// Set the VarMap with, first, the value of __CUDA_FTZ from module metadata, and +// then the key/value pairs from the command line. +void NVVMReflect::populateReflectMap(Module &M) { + if (auto *Flag = mdconst::extract_or_null<ConstantInt>( + M.getModuleFlag(CUDA_FTZ_MODULE_NAME))) + ReflectMap[CUDA_FTZ_NAME] = Flag->getSExtValue(); + + for (auto &Option : ReflectList) { + LLVM_DEBUG(dbgs() << "ReflectOption : " << Option << "\n"); + StringRef OptionRef(Option); + auto [Name, Val] = OptionRef.split('='); + if (Name.empty()) + report_fatal_error(Twine("Empty name in nvvm-reflect-add option '") + + Option + "'"); + if (Val.empty()) + report_fatal_error(Twine("Missing value in nvvm-reflect-add option '") + + Option + "'"); + unsigned ValInt; + if (!to_integer(Val.trim(), ValInt, 10)) + report_fatal_error( + Twine("integer value expected in nvvm-reflect-add option '") + + Option + "'"); + ReflectMap[Name] = ValInt; } +} - SmallVector<Instruction *, 4> ToRemove; - SmallVector<Instruction *, 4> ToSimplify; - - // Go through the calls in this function. Each call to __nvvm_reflect or - // llvm.nvvm.reflect should be a CallInst with a ConstantArray argument. - // First validate that. If the c-string corresponding to the ConstantArray can - // be found successfully, see if it can be found in VarMap. If so, replace the - // uses of CallInst with the value found in VarMap. If not, replace the use - // with value 0. - - // The IR for __nvvm_reflect calls differs between CUDA versions. - // - // CUDA 6.5 and earlier uses this sequence: - // %ptr = tail call i8* @llvm.nvvm.ptr.constant.to.gen.p0i8.p4i8 - // (i8 addrspace(4)* getelementptr inbounds - // ([8 x i8], [8 x i8] addrspace(4)* @str, i32 0, i32 0)) - // %reflect = tail call i32 @__nvvm_reflect(i8* %ptr) - // - // The value returned by Sym->getOperand(0) is a Constant with a - // ConstantDataSequential operand which can be converted to string and used - // for lookup. - // - // CUDA 7.0 does it slightly differently: - // %reflect = call i32 @__nvvm_reflect(i8* addrspacecast - // (i8 addrspace(1)* getelementptr inbounds - // ([8 x i8], [8 x i8] addrspace(1)* @str, i32 0, i32 0) to i8*)) - // - // In this case, we get a Constant with a GlobalVariable operand and we need - // to dig deeper to find its initializer with the string we'll use for lookup. - for (Instruction &I : instructions(F)) { - CallInst *Call = dyn_cast<CallInst>(&I); +/// Process a reflect function by finding all its calls and replacing them with +/// appropriate constant values. For __CUDA_FTZ, uses the module flag value. +/// For __CUDA_ARCH, uses SmVersion * 10. For all other strings, uses 0. +bool NVVMReflect::handleReflectFunction(Module &M, StringRef ReflectName) { + Function *F = M.getFunction(ReflectName); + if (!F) + return false; + assert(F->isDeclaration() && "_reflect function should not have a body"); + assert(F->getReturnType()->isIntegerTy() && + "_reflect's return type should be integer"); + + const bool Changed = F->getNumUses() > 0; + for (User *U : make_early_inc_range(F->users())) { + // Reflect function calls look like: + // @arch = private unnamed_addr addrspace(1) constant [12 x i8] + // c"__CUDA_ARCH\00" call i32 @__nvvm_reflect(ptr addrspacecast (ptr + // addrspace(1) @arch to ptr)) We need to extract the string argument from + // the call (i.e. "__CUDA_ARCH") + auto *Call = dyn_cast<CallInst>(U); if (!Call) - continue; - Function *Callee = Call->getCalledFunction(); - if (!Callee || (Callee->getName() != NVVM_REFLECT_FUNCTION && - Callee->getName() != NVVM_REFLECT_OCL_FUNCTION && - Callee->getIntrinsicID() != Intrinsic::nvvm_reflect)) - continue; - - // FIXME: Improve error handling here and elsewhere in this pass. - assert(Call->getNumOperands() == 2 && - "Wrong number of operands to __nvvm_reflect function"); - - // In cuda 6.5 and earlier, we will have an extra constant-to-generic - // conversion of the string. - const Value *Str = Call->getArgOperand(0); - if (const CallInst *ConvCall = dyn_cast<CallInst>(Str)) { - // FIXME: Add assertions about ConvCall. - Str = ConvCall->getArgOperand(0); - } - // Pre opaque pointers we have a constant expression wrapping the constant - // string. - Str = Str->stripPointerCasts(); - assert(isa<Constant>(Str) && - "Format of __nvvm_reflect function not recognized"); - - const Value *Operand = cast<Constant>(Str)->getOperand(0); - if (const GlobalVariable *GV = dyn_cast<GlobalVariable>(Operand)) { - // For CUDA-7.0 style __nvvm_reflect calls, we need to find the operand's - // initializer. - assert(GV->hasInitializer() && - "Format of _reflect function not recognized"); - const Constant *Initializer = GV->getInitializer(); - Operand = Initializer; - } - - assert(isa<ConstantDataSequential>(Operand) && - "Format of _reflect function not recognized"); - assert(cast<ConstantDataSequential>(Operand)->isCString() && - "Format of _reflect function not recognized"); - - StringRef ReflectArg = cast<ConstantDataSequential>(Operand)->getAsString(); - ReflectArg = ReflectArg.substr(0, ReflectArg.size() - 1); - LLVM_DEBUG(dbgs() << "Arg of _reflect : " << ReflectArg << "\n"); - - int ReflectVal = 0; // The default value is 0 - if (ReflectArg == "__CUDA_FTZ") { - // Try to pull __CUDA_FTZ from the nvvm-reflect-ftz module flag. Our - // choice here must be kept in sync with AutoUpgrade, which uses the same - // technique to detect whether ftz is enabled. - if (auto *Flag = mdconst::extract_or_null<ConstantInt>( - F.getParent()->getModuleFlag("nvvm-reflect-ftz"))) - ReflectVal = Flag->getSExtValue(); - } else if (ReflectArg == "__CUDA_ARCH") { - ReflectVal = SmVersion * 10; - } - - // If the immediate user is a simple comparison we want to simplify it. - for (User *U : Call->users()) - if (Instruction *I = dyn_cast<Instruction>(U)) - ToSimplify.push_back(I); - - Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal)); - ToRemove.push_back(Call); + report_fatal_error( + "__nvvm_reflect can only be used in a call instruction"); + if (Call->getNumOperands() != 2) + report_fatal_error("__nvvm_reflect requires exactly one argument"); + + auto *GlobalStr = + dyn_cast<Constant>(Call->getArgOperand(0)->stripPointerCasts()); + if (!GlobalStr) + report_fatal_error("__nvvm_reflect argument must be a constant string"); + + auto *ConstantStr = + dyn_cast<ConstantDataSequential>(GlobalStr->getOperand(0)); + if (!ConstantStr) + report_fatal_error("__nvvm_reflect argument must be a string constant"); + if (!ConstantStr->isCString()) + report_fatal_error( + "__nvvm_reflect argument must be a null-terminated string"); + + StringRef ReflectArg = ConstantStr->getAsString().drop_back(); + if (ReflectArg.empty()) + report_fatal_error("__nvvm_reflect argument cannot be empty"); + // Now that we have extracted the string argument, we can look it up in the + // ReflectMap + unsigned ReflectVal = 0; // The default value is 0 + if (ReflectMap.contains(ReflectArg)) + ReflectVal = ReflectMap[ReflectArg]; + + LLVM_DEBUG(dbgs() << "Replacing call of reflect function " << F->getName() + << "(" << ReflectArg << ") with value " << ReflectVal + << "\n"); + auto *NewValue = ConstantInt::get(Call->getType(), ReflectVal); + foldReflectCall(Call, NewValue); + Call->eraseFromParent(); } - // The code guarded by __nvvm_reflect may be invalid for the target machine. - // Traverse the use-def chain, continually simplifying constant expressions - // until we find a terminator that we can then remove. - while (!ToSimplify.empty()) { - Instruction *I = ToSimplify.pop_back_val(); - if (Constant *C = ConstantFoldInstruction(I, F.getDataLayout())) { - for (User *U : I->users()) - if (Instruction *I = dyn_cast<Instruction>(U)) - ToSimplify.push_back(I); - - I->replaceAllUsesWith(C); - if (isInstructionTriviallyDead(I)) { - ToRemove.push_back(I); - } + // Remove the __nvvm_reflect function from the module + F->eraseFromParent(); + return Changed; +} + +void NVVMReflect::foldReflectCall(CallInst *Call, Constant *NewValue) { + SmallVector<Instruction *, 8> Worklist; + // Replace an instruction with a constant and add all users of the instruction + // to the worklist + auto ReplaceInstructionWithConst = [&](Instruction *I, Constant *C) { + for (auto *U : I->users()) + if (auto *UI = dyn_cast<Instruction>(U)) + Worklist.push_back(UI); + I->replaceAllUsesWith(C); + }; + + ReplaceInstructionWithConst(Call, NewValue); + + auto &DL = Call->getModule()->getDataLayout(); + while (!Worklist.empty()) { + auto *I = Worklist.pop_back_val(); + if (auto *C = ConstantFoldInstruction(I, DL)) { + ReplaceInstructionWithConst(I, C); + if (isInstructionTriviallyDead(I)) + I->eraseFromParent(); } else if (I->isTerminator()) { ConstantFoldTerminator(I->getParent()); } } - - // Removing via isInstructionTriviallyDead may add duplicates to the ToRemove - // array. Filter out the duplicates before starting to erase from parent. - std::sort(ToRemove.begin(), ToRemove.end()); - auto NewLastIter = llvm::unique(ToRemove); - ToRemove.erase(NewLastIter, ToRemove.end()); - - for (Instruction *I : ToRemove) - I->eraseFromParent(); - - return ToRemove.size() > 0; } -bool NVVMReflect::runOnFunction(Function &F) { - return runNVVMReflect(F, SmVersion); +bool NVVMReflect::runOnModule(Module &M) { + if (!NVVMReflectEnabled) + return false; + populateReflectMap(M); + bool Changed = true; + Changed |= handleReflectFunction(M, NVVM_REFLECT_FUNCTION); + Changed |= handleReflectFunction(M, NVVM_REFLECT_OCL_FUNCTION); + Changed |= + handleReflectFunction(M, Intrinsic::getName(Intrinsic::nvvm_reflect)); + return Changed; } -NVVMReflectPass::NVVMReflectPass() : NVVMReflectPass(0) {} - -PreservedAnalyses NVVMReflectPass::run(Function &F, - FunctionAnalysisManager &AM) { - return runNVVMReflect(F, SmVersion) ? PreservedAnalyses::none() - : PreservedAnalyses::all(); +bool NVVMReflectLegacyPass::runOnModule(Module &M) { + return Impl.runOnModule(M); } + +PreservedAnalyses NVVMReflectPass::run(Module &M, ModuleAnalysisManager &AM) { + return NVVMReflect(SmVersion).runOnModule(M) ? PreservedAnalyses::none() + : PreservedAnalyses::all(); +} ---------------- Artem-B wrote:
Nit: missing newline. https://github.com/llvm/llvm-project/pull/134416 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits