https://github.com/Lukasdoe updated https://github.com/llvm/llvm-project/pull/146230
From d49de354afa4031364fd08d5089672117d7d8f84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lukas=20D=C3=B6llerer?= <cont...@lukas-doellerer.de> Date: Sat, 28 Jun 2025 20:11:56 +0200 Subject: [PATCH] [LLVM][WebAssembly] Implement branch hinting proposal This commit implements the WebAssembly branch hinting proposal, as detailed at https://webassembly.github.io/branch-hinting/metadata/code/binary.html. This proposal introduces a mechanism to convey branch likelihood information to the WebAssembly engine, allowing for more effective performance optimizations. The proposal specifies a new custom section named `metadata.code.branch_hint`. This section can contain a sequence of hints, where each hint is a single byte that applies to a corresponding `br_if` or `if` instruction. The hint values are: - `0x00` (`unlikely`): The branch is unlikely to be taken. - `0x01` (`likely`): The branch is likely to be taken. This implementation includes the following changes: - Addition of the "branch-hinting" feature (flag) - Collection of edge probabilities in CFGStackify pass - Outputting of `metadata.code.branch_hint` section in WebAssemblyAsmPrinter - Addition of the `WebAssembly::Specifier::S_DEBUG_REF` symbol ref specifier - Custom relaxation of leb128 fragments for storage of uleb128 encoded function indices and instruction offsets - Custom handling of code metadata sections in lld, required since the proposal requires code metadata sections to start with a combined count of function hints, followed by an ordered list of function hints. This change is purely an optimization and does not alter the semantics of WebAssembly programs. --- clang/include/clang/Driver/Options.td | 2 + clang/lib/Basic/Targets/WebAssembly.cpp | 12 +++ clang/lib/Basic/Targets/WebAssembly.h | 1 + lld/test/wasm/Inputs/branch-hints.ll | 29 +++++++ lld/test/wasm/code-metadata-branch-hints.ll | 37 +++++++++ lld/wasm/OutputSections.cpp | 56 +++++++++++++ lld/wasm/OutputSections.h | 9 +++ lld/wasm/Writer.cpp | 13 +-- .../MCTargetDesc/WebAssemblyAsmBackend.cpp | 25 ++++++ .../MCTargetDesc/WebAssemblyMCAsmInfo.h | 1 + .../WebAssemblyWasmObjectWriter.cpp | 5 +- llvm/lib/Target/WebAssembly/WebAssembly.td | 6 +- .../WebAssembly/WebAssemblyAsmPrinter.cpp | 69 ++++++++++++++++ .../WebAssembly/WebAssemblyAsmPrinter.h | 9 +++ .../WebAssembly/WebAssemblyCFGStackify.cpp | 20 ++++- .../WebAssembly/WebAssemblyInstrInfo.td | 4 + .../WebAssemblyMachineFunctionInfo.h | 2 + .../Target/WebAssembly/WebAssemblySubtarget.h | 2 + ...branch-hints-custom-high-low-thresholds.ll | 79 +++++++++++++++++++ llvm/test/MC/WebAssembly/branch-hints.ll | 66 ++++++++++++++++ 20 files changed, 439 insertions(+), 8 deletions(-) create mode 100644 lld/test/wasm/Inputs/branch-hints.ll create mode 100644 lld/test/wasm/code-metadata-branch-hints.ll create mode 100644 llvm/test/MC/WebAssembly/branch-hints-custom-high-low-thresholds.ll create mode 100644 llvm/test/MC/WebAssembly/branch-hints.ll diff --git a/clang/include/clang/Driver/Options.td b/clang/include/clang/Driver/Options.td index 9911d752966e3..31274aca2a25b 100644 --- a/clang/include/clang/Driver/Options.td +++ b/clang/include/clang/Driver/Options.td @@ -5276,6 +5276,8 @@ def mtail_call : Flag<["-"], "mtail-call">, Group<m_wasm_Features_Group>; def mno_tail_call : Flag<["-"], "mno-tail-call">, Group<m_wasm_Features_Group>; def mwide_arithmetic : Flag<["-"], "mwide-arithmetic">, Group<m_wasm_Features_Group>; def mno_wide_arithmetic : Flag<["-"], "mno-wide-arithmetic">, Group<m_wasm_Features_Group>; +def mbranch_hinting : Flag<["-"], "mbranch-hinting">, Group<m_wasm_Features_Group>; +def mno_branch_hinting : Flag<["-"], "mno-branch-hinting">, Group<m_wasm_Features_Group>; def mexec_model_EQ : Joined<["-"], "mexec-model=">, Group<m_wasm_Features_Driver_Group>, Values<"command,reactor">, HelpText<"Execution model (WebAssembly only)">, diff --git a/clang/lib/Basic/Targets/WebAssembly.cpp b/clang/lib/Basic/Targets/WebAssembly.cpp index f19c57f1a3a50..14c9d501bc1fa 100644 --- a/clang/lib/Basic/Targets/WebAssembly.cpp +++ b/clang/lib/Basic/Targets/WebAssembly.cpp @@ -69,6 +69,7 @@ bool WebAssemblyTargetInfo::hasFeature(StringRef Feature) const { .Case("simd128", SIMDLevel >= SIMD128) .Case("tail-call", HasTailCall) .Case("wide-arithmetic", HasWideArithmetic) + .Case("branch-hinting", HasBranchHinting) .Default(false); } @@ -116,6 +117,8 @@ void WebAssemblyTargetInfo::getTargetDefines(const LangOptions &Opts, Builder.defineMacro("__wasm_tail_call__"); if (HasWideArithmetic) Builder.defineMacro("__wasm_wide_arithmetic__"); + if (HasBranchHinting) + Builder.defineMacro("__wasm_branch_hinting__"); Builder.defineMacro("__GCC_HAVE_SYNC_COMPARE_AND_SWAP_1"); Builder.defineMacro("__GCC_HAVE_SYNC_COMPARE_AND_SWAP_2"); @@ -194,6 +197,7 @@ bool WebAssemblyTargetInfo::initFeatureMap( Features["multimemory"] = true; Features["tail-call"] = true; Features["wide-arithmetic"] = true; + Features["branch-hinting"] = true; setSIMDLevel(Features, RelaxedSIMD, true); }; if (CPU == "generic") { @@ -347,6 +351,14 @@ bool WebAssemblyTargetInfo::handleTargetFeatures( HasWideArithmetic = false; continue; } + if (Feature == "+branch-hinting") { + HasBranchHinting = true; + continue; + } + if (Feature == "-branch-hinting") { + HasBranchHinting = false; + continue; + } Diags.Report(diag::err_opt_not_valid_with_opt) << Feature << "-target-feature"; diff --git a/clang/lib/Basic/Targets/WebAssembly.h b/clang/lib/Basic/Targets/WebAssembly.h index d5aee5c0bd0eb..8b7bc6c43bd17 100644 --- a/clang/lib/Basic/Targets/WebAssembly.h +++ b/clang/lib/Basic/Targets/WebAssembly.h @@ -72,6 +72,7 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyTargetInfo : public TargetInfo { bool HasSignExt = false; bool HasTailCall = false; bool HasWideArithmetic = false; + bool HasBranchHinting = false; std::string ABI; diff --git a/lld/test/wasm/Inputs/branch-hints.ll b/lld/test/wasm/Inputs/branch-hints.ll new file mode 100644 index 0000000000000..1a92259707171 --- /dev/null +++ b/lld/test/wasm/Inputs/branch-hints.ll @@ -0,0 +1,29 @@ +target triple = "wasm32-unknown-unknown" + +define i32 @test0(i32 %a) { +entry: + %cmp0 = icmp eq i32 %a, 0 + ; This metadata hints that the true branch is overwhelmingly likely. + br i1 %cmp0, label %if.then, label %ret1, !prof !0 +if.then: + %cmp1 = icmp eq i32 %a, 1 + br i1 %cmp1, label %ret1, label %ret2, !prof !1 +ret1: + ret i32 2 +ret2: + ret i32 1 +} + +define i32 @test1(i32 %a) { +entry: + %cmp = icmp eq i32 %a, 0 + br i1 %cmp, label %if.then, label %if.else, !prof !1 +if.then: + ret i32 1 +if.else: + ret i32 2 +} + +; the resulting branch hint is actually reversed, since llvm-br is turned into br_unless, inverting branch probs +!0 = !{!"branch_weights", i32 2000, i32 1} +!1 = !{!"branch_weights", i32 1, i32 2000} \ No newline at end of file diff --git a/lld/test/wasm/code-metadata-branch-hints.ll b/lld/test/wasm/code-metadata-branch-hints.ll new file mode 100644 index 0000000000000..076aa40881152 --- /dev/null +++ b/lld/test/wasm/code-metadata-branch-hints.ll @@ -0,0 +1,37 @@ +; RUN: llc -filetype=obj %s -o %t1.o -mattr=+branch-hinting +; RUN: llc -filetype=obj %S/Inputs/branch-hints.ll -o %t2.o -mattr=+branch-hinting +; RUN: wasm-ld --export-all -o %t.wasm %t1.o %t2.o +; RUN: obj2yaml %t.wasm | FileCheck %s + +define i32 @_start(i32 %a) { +entry: + %cmp = icmp eq i32 %a, 0 + br i1 %cmp, label %if.then, label %if.else, !prof !0 +if.then: + ret i32 1 +if.else: + ret i32 2 +} + +define i32 @test_func1(i32 %a) { +entry: + %cmp = icmp eq i32 %a, 0 + br i1 %cmp, label %if.then, label %if.else, !prof !1 +if.then: + ret i32 1 +if.else: + ret i32 2 +} + +!0 = !{!"branch_weights", i32 2000, i32 1} +!1 = !{!"branch_weights", i32 1, i32 2000} + +; CHECK: - Type: CUSTOM +; CHECK-NEXT: Name: metadata.code.branch_hint +; CHECK-NEXT: Payload: '84808080008180808000010501008280808000010501018380808000020701000E0101848080800001050101' + +; CHECK: - Type: CUSTOM +; CHECK: Name: target_features +; CHECK-NEXT: Features: +; CHECK: - Prefix: USED +; CHECK-NEXT: Name: branch-hinting \ No newline at end of file diff --git a/lld/wasm/OutputSections.cpp b/lld/wasm/OutputSections.cpp index 8ccd38f7895cb..c4fed3742dc6f 100644 --- a/lld/wasm/OutputSections.cpp +++ b/lld/wasm/OutputSections.cpp @@ -270,6 +270,62 @@ void CustomSection::writeTo(uint8_t *buf) { section->writeTo(buf); } +void CodeMetaDataSection::writeTo(uint8_t *buf) { + log("writing " + toString(*this) + " offset=" + Twine(offset) + + " size=" + Twine(getSize()) + " chunks=" + Twine(inputSections.size())); + + assert(offset); + buf += offset; + + // Write section header + memcpy(buf, header.data(), header.size()); + buf += header.size(); + memcpy(buf, nameData.data(), nameData.size()); + buf += nameData.size(); + + uint32_t TotalNumHints = 0; + for (const InputChunk *section : + make_range(inputSections.rbegin(), inputSections.rend())) { + section->writeTo(buf); + unsigned EncodingSize; + uint32_t NumHints = + decodeULEB128(buf + section->outSecOff, &EncodingSize, nullptr); + if (EncodingSize != 5) { + fatal("Unexpected encoding size for function hint vec size in " + name + + ": must be exactly 5 bytes."); + } + TotalNumHints += NumHints; + } + encodeULEB128(TotalNumHints, buf, 5); +} + +void CodeMetaDataSection::finalizeContents() { + finalizeInputSections(); + + raw_string_ostream os(nameData); + encodeULEB128(name.size(), os); + os << name; + + bool firstSection = true; + for (InputChunk *section : inputSections) { + assert(!section->discarded); + payloadSize = alignTo(payloadSize, section->alignment); + if (firstSection) { + section->outSecOff = payloadSize; + payloadSize += section->getSize(); + firstSection = false; + } else { + // adjust output offset so that each section write overwrites exactly the + // subsequent section's function hint vector size (which deduplicates) + section->outSecOff = payloadSize - 5; + // payload size should not include the hint vector size, which is deduped + payloadSize += section->getSize() - 5; + } + } + + createHeader(payloadSize + nameData.size()); +} + uint32_t CustomSection::getNumRelocations() const { uint32_t count = 0; for (const InputChunk *inputSect : inputSections) diff --git a/lld/wasm/OutputSections.h b/lld/wasm/OutputSections.h index 4b0329dd16cf2..6580c71ab6f5a 100644 --- a/lld/wasm/OutputSections.h +++ b/lld/wasm/OutputSections.h @@ -132,6 +132,15 @@ class CustomSection : public OutputSection { std::string nameData; }; +class CodeMetaDataSection : public CustomSection { +public: + CodeMetaDataSection(std::string name, ArrayRef<InputChunk *> inputSections) + : CustomSection(name, inputSections) {} + + void writeTo(uint8_t *buf) override; + void finalizeContents() override; +}; + } // namespace wasm } // namespace lld diff --git a/lld/wasm/Writer.cpp b/lld/wasm/Writer.cpp index b704677d36c93..e9cfd5eac10db 100644 --- a/lld/wasm/Writer.cpp +++ b/lld/wasm/Writer.cpp @@ -165,14 +165,17 @@ void Writer::createCustomSections() { for (auto &pair : customSectionMapping) { StringRef name = pair.first; LLVM_DEBUG(dbgs() << "createCustomSection: " << name << "\n"); - - OutputSection *sec = make<CustomSection>(std::string(name), pair.second); + OutputSection *Sec; + if (name == "metadata.code.branch_hint") + Sec = make<CodeMetaDataSection>(std::string(name), pair.second); + else + Sec = make<CustomSection>(std::string(name), pair.second); if (ctx.arg.relocatable || ctx.arg.emitRelocs) { - auto *sym = make<OutputSectionSymbol>(sec); + auto *sym = make<OutputSectionSymbol>(Sec); out.linkingSec->addToSymtab(sym); - sec->sectionSym = sym; + Sec->sectionSym = sym; } - addSection(sec); + addSection(Sec); } } diff --git a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyAsmBackend.cpp b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyAsmBackend.cpp index 7bc672c069476..7525c6400cc00 100644 --- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyAsmBackend.cpp +++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyAsmBackend.cpp @@ -21,6 +21,7 @@ #include "llvm/MC/MCSubtargetInfo.h" #include "llvm/MC/MCSymbol.h" #include "llvm/MC/MCWasmObjectWriter.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/raw_ostream.h" using namespace llvm; @@ -44,6 +45,9 @@ class WebAssemblyAsmBackend final : public MCAsmBackend { std::unique_ptr<MCObjectTargetWriter> createObjectTargetWriter() const override; + std::pair<bool, bool> relaxLEB128(const MCAssembler &Asm, MCLEBFragment &LF, + int64_t &Value) const override; + bool writeNopData(raw_ostream &OS, uint64_t Count, const MCSubtargetInfo *STI) const override; }; @@ -70,6 +74,27 @@ WebAssemblyAsmBackend::getFixupKindInfo(MCFixupKind Kind) const { return Infos[Kind - FirstTargetFixupKind]; } +std::pair<bool, bool> +WebAssemblyAsmBackend::relaxLEB128(const MCAssembler &assembler, + MCLEBFragment &LF, int64_t &Value) const { + const MCExpr &Expr = LF.getValue(); + if (Expr.getKind() == MCExpr::ExprKind::SymbolRef) { + const MCSymbolRefExpr &SymExpr = llvm::cast<MCSymbolRefExpr>(Expr); + if (static_cast<WebAssembly::Specifier>(SymExpr.getSpecifier()) == + WebAssembly::S_DEBUG_REF) { + Value = assembler.getSymbolOffset(SymExpr.getSymbol()); + return std::make_pair(true, false); + } + } + // currently, this is only used for leb128 encoded function indices + // that require relocations + LF.getFixups().push_back( + MCFixup::create(0, &Expr, WebAssembly::fixup_uleb128_i32, Expr.getLoc())); + // ensure that the stored placeholder is large enough to hold any 32-bit val + Value = UINT32_MAX; + return std::make_pair(true, false); +} + bool WebAssemblyAsmBackend::writeNopData(raw_ostream &OS, uint64_t Count, const MCSubtargetInfo *STI) const { for (uint64_t I = 0; I < Count; ++I) diff --git a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCAsmInfo.h b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCAsmInfo.h index 0b3778b36d6ac..bfb5c59b1b80c 100644 --- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCAsmInfo.h +++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCAsmInfo.h @@ -37,6 +37,7 @@ enum Specifier { S_TBREL, // Table index relative to __table_base S_TLSREL, // Memory address relative to __tls_base S_TYPEINDEX, // Reference to a symbol's type (signature) + S_DEBUG_REF, // Marker placed for generation of metadata.code.* section }; } } // end namespace llvm diff --git a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyWasmObjectWriter.cpp b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyWasmObjectWriter.cpp index 2cf4bec077385..f2e0c11022423 100644 --- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyWasmObjectWriter.cpp +++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyWasmObjectWriter.cpp @@ -93,10 +93,13 @@ unsigned WebAssemblyWasmObjectWriter::getRelocType( case WebAssembly::S_None: break; case WebAssembly::S_FUNCINDEX: + if (static_cast<unsigned>(Fixup.getKind()) == + WebAssembly::fixup_uleb128_i32) + return wasm::R_WASM_FUNCTION_INDEX_LEB; return wasm::R_WASM_FUNCTION_INDEX_I32; } - switch (unsigned(Fixup.getKind())) { + switch (static_cast<unsigned>(Fixup.getKind())) { case WebAssembly::fixup_sleb128_i32: if (SymA.isFunction()) return wasm::R_WASM_TABLE_INDEX_SLEB; diff --git a/llvm/lib/Target/WebAssembly/WebAssembly.td b/llvm/lib/Target/WebAssembly/WebAssembly.td index 13603f8181198..ec3889e2037e4 100644 --- a/llvm/lib/Target/WebAssembly/WebAssembly.td +++ b/llvm/lib/Target/WebAssembly/WebAssembly.td @@ -90,6 +90,10 @@ def FeatureWideArithmetic : SubtargetFeature<"wide-arithmetic", "HasWideArithmetic", "true", "Enable wide-arithmetic instructions">; +def FeatureBranchHinting : + SubtargetFeature<"branch-hinting", "HasBranchHinting", "true", + "Enable branch hints for branch instructions">; + //===----------------------------------------------------------------------===// // Architectures. //===----------------------------------------------------------------------===// @@ -142,7 +146,7 @@ def : ProcessorModel<"bleeding-edge", NoSchedModel, FeatureMultivalue, FeatureMutableGlobals, FeatureNontrappingFPToInt, FeatureRelaxedSIMD, FeatureReferenceTypes, FeatureSIMD128, FeatureSignExt, - FeatureTailCall]>; + FeatureTailCall, FeatureBranchHinting]>; //===----------------------------------------------------------------------===// // Target Declaration diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp index 44a19e4baaf62..55defc3a312e6 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp @@ -55,6 +55,15 @@ using namespace llvm; #define DEBUG_TYPE "asm-printer" extern cl::opt<bool> WasmKeepRegisters; +// values are divided by 1<<31 to calculate the probability +static cl::opt<uint32_t> WasmHighBranchProb( + "wasm-branch-prob-high", cl::Hidden, + cl::desc("lowest branch probability to not be annotated as likely taken"), + cl::init(0x40000000)); +static cl::opt<uint32_t> WasmLowBranchProb( + "wasm-branch-prob-low", cl::Hidden, + cl::desc("highest branch probability to be annotated as unlikely taken"), + cl::init(0x40000000)); //===----------------------------------------------------------------------===// // Helpers. @@ -442,6 +451,38 @@ void WebAssemblyAsmPrinter::emitEndOfAsmFile(Module &M) { EmitProducerInfo(M); EmitTargetFeatures(M); EmitFunctionAttributes(M); + + // Subtarget may be null if no functions have been defined in file + if (Subtarget && Subtarget->hasBranchHinting()) + EmitBranchHintSection(); +} + +void WebAssemblyAsmPrinter::EmitBranchHintSection() const { + MCSectionWasm *BranchHintsSection = OutContext.getWasmSection( + "metadata.code.branch_hint", SectionKind::getMetadata()); + OutStreamer->pushSection(); + OutStreamer->switchSection(BranchHintsSection); + // should we emit empty branch hints section? + OutStreamer->emitULEB128IntValue(branchHints.size(), 5); + for (const auto &BHR : branchHints) { + if (!BHR) + continue; + // emit relocatable function index for the function symbol + OutStreamer->emitULEB128Value(MCSymbolRefExpr::create( + BHR->func_sym, WebAssembly::S_FUNCINDEX, OutContext)); + // emit the number of hints for this function (is constant -> does not need + // handling by target streamer for reloc) + OutStreamer->emitULEB128IntValue(BHR->hints.size()); + for (const auto &[instrSym, hint] : BHR->hints) { + assert(hint == 0 || hint == 1); + // offset from function start + OutStreamer->emitULEB128Value(MCSymbolRefExpr::create( + instrSym, WebAssembly::S_DEBUG_REF, OutContext)); + OutStreamer->emitULEB128IntValue(1); // hint size + OutStreamer->emitULEB128IntValue(hint); + } + } + OutStreamer->popSection(); } void WebAssemblyAsmPrinter::EmitProducerInfo(Module &M) { @@ -697,6 +738,34 @@ void WebAssemblyAsmPrinter::emitInstruction(const MachineInstr *MI) { WebAssemblyMCInstLower MCInstLowering(OutContext, *this); MCInst TmpInst; MCInstLowering.lower(MI, TmpInst); + if (Subtarget->hasBranchHinting() && + MI->getOpcode() == WebAssembly::BR_IF && MFI && + MFI->BranchProbabilities.contains(MI)) { + MCSymbol *BrIfSym = OutContext.createTempSymbol(); + OutStreamer->emitLabel(BrIfSym); + + constexpr uint8_t HintLikely = 0x01; + constexpr uint8_t HintUnlikely = 0x00; + const BranchProbability &Prob = MFI->BranchProbabilities[MI]; + uint8_t HintValue; + if (Prob > BranchProbability::getRaw(WasmHighBranchProb.getValue())) + HintValue = HintLikely; + else if (Prob <= BranchProbability::getRaw(WasmLowBranchProb.getValue())) + HintValue = HintUnlikely; + else + goto emit; // Don't emit branch hint between thresholds + + // we know that we only emit branch hints for internal functions, + // therefore we can directly cast and don't need getMCSymbolForFunction + MCSymbol *FuncSym = cast<MCSymbolWasm>(getSymbol(&MF->getFunction())); + uint32_t LocalFuncIdx = MF->getFunctionNumber(); + if (branchHints.size() <= LocalFuncIdx) { + branchHints.resize(LocalFuncIdx + 1); + branchHints[LocalFuncIdx] = BranchHintRecord{FuncSym, {}}; + } + branchHints[LocalFuncIdx]->hints.emplace_back(BrIfSym, HintValue); + } + emit: EmitToStreamer(*OutStreamer, TmpInst); break; } diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.h b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.h index 46063bbe0fba1..b6595492cf26c 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.h @@ -18,6 +18,11 @@ namespace llvm { class WebAssemblyTargetStreamer; +struct BranchHintRecord { + MCSymbol *func_sym; + SmallVector<std::pair<MCSymbol *, uint8_t>, 0> hints; +}; + class LLVM_LIBRARY_VISIBILITY WebAssemblyAsmPrinter final : public AsmPrinter { public: static char ID; @@ -28,6 +33,9 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyAsmPrinter final : public AsmPrinter { WebAssemblyFunctionInfo *MFI; bool signaturesEmitted = false; + // vec idx == local func_idx + std::vector<std::optional<BranchHintRecord>> branchHints; + public: explicit WebAssemblyAsmPrinter(TargetMachine &TM, std::unique_ptr<MCStreamer> Streamer) @@ -59,6 +67,7 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyAsmPrinter final : public AsmPrinter { void EmitProducerInfo(Module &M); void EmitTargetFeatures(Module &M); void EmitFunctionAttributes(Module &M); + void EmitBranchHintSection() const; void emitSymbolType(const MCSymbolWasm *Sym); void emitGlobalVariable(const GlobalVariable *GV) override; void emitJumpTableInfo() override; diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyCFGStackify.cpp b/llvm/lib/Target/WebAssembly/WebAssemblyCFGStackify.cpp index 640be5fe8e8c9..84611030e448d 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyCFGStackify.cpp +++ b/llvm/lib/Target/WebAssembly/WebAssemblyCFGStackify.cpp @@ -31,6 +31,7 @@ #include "WebAssemblyUtilities.h" #include "llvm/ADT/Statistic.h" #include "llvm/BinaryFormat/Wasm.h" +#include "llvm/CodeGen/MachineBranchProbabilityInfo.h" #include "llvm/CodeGen/MachineDominators.h" #include "llvm/CodeGen/MachineInstrBuilder.h" #include "llvm/CodeGen/MachineLoopInfo.h" @@ -48,6 +49,7 @@ STATISTIC(NumCatchUnwindMismatches, "Number of catch unwind mismatches found"); namespace { class WebAssemblyCFGStackify final : public MachineFunctionPass { MachineDominatorTree *MDT; + MachineBranchProbabilityInfo *MBPI; StringRef getPassName() const override { return "WebAssembly CFG Stackify"; } @@ -55,6 +57,7 @@ class WebAssemblyCFGStackify final : public MachineFunctionPass { AU.addRequired<MachineDominatorTreeWrapperPass>(); AU.addRequired<MachineLoopInfoWrapperPass>(); AU.addRequired<WebAssemblyExceptionInfo>(); + AU.addRequired<MachineBranchProbabilityInfoWrapperPass>(); MachineFunctionPass::getAnalysisUsage(AU); } @@ -2562,8 +2565,22 @@ void WebAssemblyCFGStackify::rewriteDepthImmediates(MachineFunction &MF) { MO = MachineOperand::CreateImm(getDelegateDepth(Stack, MO.getMBB())); else if (MI.getOpcode() == WebAssembly::RETHROW) MO = MachineOperand::CreateImm(getRethrowDepth(Stack, MO.getMBB())); - else + else { + // this is the last place where we can easily calculate the branch + // probabilities. we do not emit scf-ifs, therefore only br_ifs have + // to be handled here. + if (MF.getSubtarget<WebAssemblySubtarget>().hasBranchHinting() && + MI.getOpcode() == WebAssembly::BR_IF && + MI.getParent()->hasSuccessorProbabilities()) { + const auto Prob = + MBPI->getEdgeProbability(MI.getParent(), MO.getMBB()); + WebAssemblyFunctionInfo *MFI = + MF.getInfo<WebAssemblyFunctionInfo>(); + assert(!MFI->BranchProbabilities.contains(&MI)); + MFI->BranchProbabilities[&MI] = Prob; + } MO = MachineOperand::CreateImm(getBranchDepth(Stack, MO.getMBB())); + } } MI.addOperand(MF, MO); } @@ -2639,6 +2656,7 @@ bool WebAssemblyCFGStackify::runOnMachineFunction(MachineFunction &MF) { << MF.getName() << '\n'); const MCAsmInfo *MCAI = MF.getTarget().getMCAsmInfo(); MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree(); + MBPI = &getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI(); releaseMemory(); diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.td b/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.td index b5e723e2a48d3..f1d4a62535060 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.td +++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.td @@ -96,6 +96,10 @@ def HasWideArithmetic : Predicate<"Subtarget->hasWideArithmetic()">, AssemblerPredicate<(all_of FeatureWideArithmetic), "wide-arithmetic">; +def HasBranchHinting : + Predicate<"Subtarget->hasBranchHinting()">, + AssemblerPredicate<(all_of FeatureBranchHinting), "branch-hinting">; + //===----------------------------------------------------------------------===// // WebAssembly-specific DAG Node Types. //===----------------------------------------------------------------------===// diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.h b/llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.h index 40ae4aef1d7f2..343168b570bef 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.h @@ -153,6 +153,8 @@ class WebAssemblyFunctionInfo final : public MachineFunctionInfo { bool isCFGStackified() const { return CFGStackified; } void setCFGStackified(bool Value = true) { CFGStackified = Value; } + + DenseMap<const MachineInstr *, BranchProbability> BranchProbabilities; }; void computeLegalValueVTs(const WebAssemblyTargetLowering &TLI, diff --git a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h index 591ce25611e3e..96a24a1d40ef7 100644 --- a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h +++ b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h @@ -54,6 +54,7 @@ class WebAssemblySubtarget final : public WebAssemblyGenSubtargetInfo { bool HasSignExt = false; bool HasTailCall = false; bool HasWideArithmetic = false; + bool HasBranchHinting = false; /// What processor and OS we're targeting. Triple TargetTriple; @@ -112,6 +113,7 @@ class WebAssemblySubtarget final : public WebAssemblyGenSubtargetInfo { bool hasSIMD128() const { return SIMDLevel >= SIMD128; } bool hasTailCall() const { return HasTailCall; } bool hasWideArithmetic() const { return HasWideArithmetic; } + bool hasBranchHinting() const { return HasBranchHinting; } /// Parses features string setting specified subtarget options. Definition of /// function is auto generated by tblgen. diff --git a/llvm/test/MC/WebAssembly/branch-hints-custom-high-low-thresholds.ll b/llvm/test/MC/WebAssembly/branch-hints-custom-high-low-thresholds.ll new file mode 100644 index 0000000000000..8fdce544aa06e --- /dev/null +++ b/llvm/test/MC/WebAssembly/branch-hints-custom-high-low-thresholds.ll @@ -0,0 +1,79 @@ +; RUN: llc -mcpu=mvp -filetype=obj %s -mattr=+branch-hinting -wasm-branch-prob-high=0x60000000 -wasm-branch-prob-low=0x0 -o - | obj2yaml | FileCheck %s + +; This test checks that branch weight metadata (!prof) is correctly translated to webassembly branch hints +; We set the prob-thresholds so that "likely" branches are only emitted if prob > 75% and "unlikely" branches +; if prob <= 0%. + +; CHECK: - Type: CUSTOM +; CHECK-NEXT: Relocations: +; CHECK-NEXT: - Type: R_WASM_FUNCTION_INDEX_LEB +; CHECK-NEXT: Index: 0 +; CHECK-NEXT: Offset: 0x5 +; CHECK-NEXT: - Type: R_WASM_FUNCTION_INDEX_LEB +; CHECK-NEXT: Index: 2 +; CHECK-NEXT: Offset: 0xE +; CHECK-NEXT: Name: metadata.code.branch_hint +; CHECK-NEXT: Payload: '8380808000808080800001050101828080800001050100' + +; CHECK: - Type: CUSTOM +; CHECK-NEXT: Name: linking +; CHECK-NEXT: Version: 2 +; CHECK-NEXT: SymbolTable: +; CHECK-NEXT: - Index: 0 +; CHECK-NEXT: Kind: FUNCTION +; CHECK-NEXT: Name: test0 +; CHECK-NEXT: Flags: [ ] +; CHECK-NEXT: Function: 0 +; CHECK-NEXT: - Index: 1 +; CHECK-NEXT: Kind: FUNCTION +; CHECK-NEXT: Name: test1 +; CHECK-NEXT: Flags: [ ] +; CHECK-NEXT: Function: 1 +; CHECK-NEXT: - Index: 2 +; CHECK-NEXT: Kind: FUNCTION +; CHECK-NEXT: Name: test2 +; CHECK-NEXT: Flags: [ ] +; CHECK-NEXT: Function: 2 + +; CHECK: - Type: CUSTOM +; CHECK-NEXT: Name: target_features +; CHECK-NEXT: Features: +; CHECK-NEXT: - Prefix: USED +; CHECK-NEXT: Name: branch-hinting + +target triple = "wasm32-unknown-unknown" + +define i32 @test0(i32 %a) { +entry: + %cmp0 = icmp eq i32 %a, 0 + br i1 %cmp0, label %if_then, label %if_else, !prof !0 +if_then: + ret i32 1 +if_else: + ret i32 0 +} + +define i32 @test1(i32 %a) { +entry: + %cmp0 = icmp eq i32 %a, 0 + br i1 %cmp0, label %if_then, label %if_else, !prof !1 +if_then: + ret i32 1 +if_else: + ret i32 0 +} + +define i32 @test2(i32 %a) { +entry: + %cmp0 = icmp eq i32 %a, 0 + br i1 %cmp0, label %if_then, label %if_else, !prof !2 +if_then: + ret i32 1 +if_else: + ret i32 0 +} + +; the resulting branch hint is actually reversed, since llvm-br is turned into br_unless, inverting branch probs +!0 = !{!"branch_weights", !"expected", i32 100, i32 310} ; prob 75.61% +!1 = !{!"branch_weights", i32 1, i32 1} ; prob == 50% (no hint) +!2 = !{!"branch_weights", i32 1, i32 0} ; prob == 0% (unlikely hint) \ No newline at end of file diff --git a/llvm/test/MC/WebAssembly/branch-hints.ll b/llvm/test/MC/WebAssembly/branch-hints.ll new file mode 100644 index 0000000000000..8b4c96a6f4eff --- /dev/null +++ b/llvm/test/MC/WebAssembly/branch-hints.ll @@ -0,0 +1,66 @@ +; RUN: llc -mcpu=mvp -filetype=obj %s -mattr=+branch-hinting -o - | obj2yaml | FileCheck %s + +; This test checks that branch weight metadata (!prof) is correctly lowered to +; the WebAssembly branch hint custom section. + +; CHECK: - Type: CUSTOM +; CHECK-NEXT: Relocations: +; CHECK-NEXT: - Type: R_WASM_FUNCTION_INDEX_LEB +; CHECK-NEXT: Index: 0 +; CHECK-NEXT: Offset: 0x5 +; CHECK-NEXT: - Type: R_WASM_FUNCTION_INDEX_LEB +; CHECK-NEXT: Index: 1 +; CHECK-NEXT: Offset: 0x11 +; CHECK-NEXT: Name: metadata.code.branch_hint +; CHECK-NEXT: Payload: '82808080008080808000020701000E0101818080800001050101' + +; CHECK: - Type: CUSTOM +; CHECK-NEXT: Name: linking +; CHECK-NEXT: Version: 2 +; CHECK-NEXT: SymbolTable: +; CHECK-NEXT: - Index: 0 +; CHECK-NEXT: Kind: FUNCTION +; CHECK-NEXT: Name: test_unlikely_likely_branch +; CHECK-NEXT: Flags: [ ] +; CHECK-NEXT: Function: 0 +; CHECK-NEXT: - Index: 1 +; CHECK-NEXT: Kind: FUNCTION +; CHECK-NEXT: Name: test_likely_branch +; CHECK-NEXT: Flags: [ ] +; CHECK-NEXT: Function: 1 + +; CHECK: - Type: CUSTOM +; CHECK-NEXT: Name: target_features +; CHECK-NEXT: Features: +; CHECK-NEXT: - Prefix: USED +; CHECK-NEXT: Name: branch-hinting + +target triple = "wasm32-unknown-unknown" + +define i32 @test_unlikely_likely_branch(i32 %a) { +entry: + %cmp0 = icmp eq i32 %a, 0 + ; This metadata hints that the true branch is overwhelmingly likely. + br i1 %cmp0, label %if.then, label %ret1, !prof !0 +if.then: + %cmp1 = icmp eq i32 %a, 1 + br i1 %cmp1, label %ret1, label %ret2, !prof !1 +ret1: + ret i32 2 +ret2: + ret i32 1 +} + +define i32 @test_likely_branch(i32 %a) { +entry: + %cmp = icmp eq i32 %a, 0 + br i1 %cmp, label %if.then, label %if.else, !prof !1 +if.then: + ret i32 1 +if.else: + ret i32 2 +} + +; the resulting branch hint is actually reversed, since llvm-br is turned into br_unless, inverting branch probs +!0 = !{!"branch_weights", i32 2000, i32 1} +!1 = !{!"branch_weights", i32 1, i32 2000} \ No newline at end of file _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits