https://github.com/Lukasdoe created 
https://github.com/llvm/llvm-project/pull/146230

This commit implements the WebAssembly branch hinting proposal, as detailed at 
https://webassembly.github.io/branch-hinting/metadata/code/binary.html. This 
proposal introduces a mechanism to convey branch likelihood information to the 
WebAssembly engine, allowing for more effective performance optimizations.

The proposal specifies a new custom section named `metadata.code.branch_hint`. 
This section can contain a sequence of hints, where each hint is a single byte 
that applies to a corresponding `br_if` or `if` instruction. The hint values 
are:
 - `0x00` (`unlikely`): The branch is unlikely to be taken.
 - `0x01` (`likely`): The branch is likely to be taken.

This implementation includes the following changes:
 - Addition of the "branch-hinting" feature (flag)
 - Collection of edge probabilities in CFGStackify pass
 - Outputting of `metadata.code.branch_hint` section in WebAssemblyAsmPrinter
 - Addition of the `WebAssembly::Specifier::S_DEBUG_REF` symbol ref specifier
 - Custom relaxation of leb128 fragments for storage of uleb128 encoded 
function indices and instruction offsets
 - Custom handling of code metadata sections in lld, required since the 
proposal requires code metadata sections to start with a combined count of 
function hints, followed by an ordered list of function hints.

This change is purely an optimization and does not alter the semantics of 
WebAssembly programs.

From 96a4d759aaffac07b85c02982ba174a026e32d40 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Lukas=20D=C3=B6llerer?= <cont...@lukas-doellerer.de>
Date: Sat, 28 Jun 2025 19:54:37 +0200
Subject: [PATCH] [LLVM][WebAssembly] Implement branch hinting proposal

This commit implements the WebAssembly branch hinting proposal, as detailed at 
https://webassembly.github.io/branch-hinting/metadata/code/binary.html. This 
proposal introduces a mechanism to convey branch likelihood information to the 
WebAssembly engine, allowing for more effective performance optimizations.

The proposal specifies a new custom section named `metadata.code.branch_hint`. 
This section can contain a sequence of hints, where each hint is a single byte 
that applies to a corresponding `br_if` or `if` instruction. The hint values 
are:
 - `0x00` (`unlikely`): The branch is unlikely to be taken.
 - `0x01` (`likely`): The branch is likely to be taken.

This implementation includes the following changes:
 - Addition of the "branch-hinting" feature (flag)
 - Collection of edge probabilities in CFGStackify pass
 - Outputting of `metadata.code.branch_hint` section in WebAssemblyAsmPrinter
 - Addition of the `WebAssembly::Specifier::S_DEBUG_REF` symbol ref specifier
 - Custom relaxation of leb128 fragments for storage of uleb128 encoded 
function indices and instruction offsets
 - Custom handling of code metadata sections in lld, required since the 
proposal requires code metadata sections to start with a combined count of 
function hints, followed by an ordered list of function hints.

This change is purely an optimization and does not alter the semantics of 
WebAssembly programs.
---
 clang/include/clang/Driver/Options.td         |  2 +
 clang/lib/Basic/Targets/WebAssembly.cpp       | 12 +++
 clang/lib/Basic/Targets/WebAssembly.h         |  1 +
 lld/test/wasm/Inputs/branch-hints.ll          | 29 +++++++
 lld/test/wasm/code-metadata-branch-hints.ll   | 37 +++++++++
 lld/wasm/OutputSections.cpp                   | 56 +++++++++++++
 lld/wasm/OutputSections.h                     |  9 +++
 lld/wasm/Writer.cpp                           | 13 +--
 .../MCTargetDesc/WebAssemblyAsmBackend.cpp    | 26 ++++++
 .../MCTargetDesc/WebAssemblyMCExpr.h          |  1 +
 .../WebAssemblyWasmObjectWriter.cpp           |  5 +-
 llvm/lib/Target/WebAssembly/WebAssembly.td    |  6 +-
 .../WebAssembly/WebAssemblyAsmPrinter.cpp     | 69 ++++++++++++++++
 .../WebAssembly/WebAssemblyAsmPrinter.h       |  9 +++
 .../WebAssembly/WebAssemblyCFGStackify.cpp    | 20 ++++-
 .../WebAssembly/WebAssemblyInstrInfo.td       |  4 +
 .../WebAssemblyMachineFunctionInfo.h          |  2 +
 .../Target/WebAssembly/WebAssemblySubtarget.h |  2 +
 ...branch-hints-custom-high-low-thresholds.ll | 79 +++++++++++++++++++
 llvm/test/MC/WebAssembly/branch-hints.ll      | 66 ++++++++++++++++
 20 files changed, 440 insertions(+), 8 deletions(-)
 create mode 100644 lld/test/wasm/Inputs/branch-hints.ll
 create mode 100644 lld/test/wasm/code-metadata-branch-hints.ll
 create mode 100644 
llvm/test/MC/WebAssembly/branch-hints-custom-high-low-thresholds.ll
 create mode 100644 llvm/test/MC/WebAssembly/branch-hints.ll

diff --git a/clang/include/clang/Driver/Options.td 
b/clang/include/clang/Driver/Options.td
index bd8df8f6a749a..8cb3a875a5c82 100644
--- a/clang/include/clang/Driver/Options.td
+++ b/clang/include/clang/Driver/Options.td
@@ -5224,6 +5224,8 @@ def mtail_call : Flag<["-"], "mtail-call">, 
Group<m_wasm_Features_Group>;
 def mno_tail_call : Flag<["-"], "mno-tail-call">, Group<m_wasm_Features_Group>;
 def mwide_arithmetic : Flag<["-"], "mwide-arithmetic">, 
Group<m_wasm_Features_Group>;
 def mno_wide_arithmetic : Flag<["-"], "mno-wide-arithmetic">, 
Group<m_wasm_Features_Group>;
+def mbranch_hinting : Flag<["-"], "mbranch-hinting">, 
Group<m_wasm_Features_Group>;
+def mno_branch_hinting : Flag<["-"], "mno-branch-hinting">, 
Group<m_wasm_Features_Group>;
 def mexec_model_EQ : Joined<["-"], "mexec-model=">, 
Group<m_wasm_Features_Driver_Group>,
                      Values<"command,reactor">,
                      HelpText<"Execution model (WebAssembly only)">,
diff --git a/clang/lib/Basic/Targets/WebAssembly.cpp 
b/clang/lib/Basic/Targets/WebAssembly.cpp
index f19c57f1a3a50..14c9d501bc1fa 100644
--- a/clang/lib/Basic/Targets/WebAssembly.cpp
+++ b/clang/lib/Basic/Targets/WebAssembly.cpp
@@ -69,6 +69,7 @@ bool WebAssemblyTargetInfo::hasFeature(StringRef Feature) 
const {
       .Case("simd128", SIMDLevel >= SIMD128)
       .Case("tail-call", HasTailCall)
       .Case("wide-arithmetic", HasWideArithmetic)
+      .Case("branch-hinting", HasBranchHinting)
       .Default(false);
 }
 
@@ -116,6 +117,8 @@ void WebAssemblyTargetInfo::getTargetDefines(const 
LangOptions &Opts,
     Builder.defineMacro("__wasm_tail_call__");
   if (HasWideArithmetic)
     Builder.defineMacro("__wasm_wide_arithmetic__");
+  if (HasBranchHinting)
+    Builder.defineMacro("__wasm_branch_hinting__");
 
   Builder.defineMacro("__GCC_HAVE_SYNC_COMPARE_AND_SWAP_1");
   Builder.defineMacro("__GCC_HAVE_SYNC_COMPARE_AND_SWAP_2");
@@ -194,6 +197,7 @@ bool WebAssemblyTargetInfo::initFeatureMap(
     Features["multimemory"] = true;
     Features["tail-call"] = true;
     Features["wide-arithmetic"] = true;
+    Features["branch-hinting"] = true;
     setSIMDLevel(Features, RelaxedSIMD, true);
   };
   if (CPU == "generic") {
@@ -347,6 +351,14 @@ bool WebAssemblyTargetInfo::handleTargetFeatures(
       HasWideArithmetic = false;
       continue;
     }
+    if (Feature == "+branch-hinting") {
+      HasBranchHinting = true;
+      continue;
+    }
+    if (Feature == "-branch-hinting") {
+      HasBranchHinting = false;
+      continue;
+    }
 
     Diags.Report(diag::err_opt_not_valid_with_opt)
         << Feature << "-target-feature";
diff --git a/clang/lib/Basic/Targets/WebAssembly.h 
b/clang/lib/Basic/Targets/WebAssembly.h
index 04f0cb5df4601..666da09e61636 100644
--- a/clang/lib/Basic/Targets/WebAssembly.h
+++ b/clang/lib/Basic/Targets/WebAssembly.h
@@ -71,6 +71,7 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyTargetInfo : public 
TargetInfo {
   bool HasSignExt = false;
   bool HasTailCall = false;
   bool HasWideArithmetic = false;
+  bool HasBranchHinting = false;
 
   std::string ABI;
 
diff --git a/lld/test/wasm/Inputs/branch-hints.ll 
b/lld/test/wasm/Inputs/branch-hints.ll
new file mode 100644
index 0000000000000..1a92259707171
--- /dev/null
+++ b/lld/test/wasm/Inputs/branch-hints.ll
@@ -0,0 +1,29 @@
+target triple = "wasm32-unknown-unknown"
+
+define i32 @test0(i32 %a) {
+entry:
+  %cmp0 = icmp eq i32 %a, 0
+  ; This metadata hints that the true branch is overwhelmingly likely.
+  br i1 %cmp0, label %if.then, label %ret1, !prof !0
+if.then:
+  %cmp1 = icmp eq i32 %a, 1
+  br i1 %cmp1, label %ret1, label %ret2, !prof !1
+ret1:
+  ret i32 2
+ret2:
+  ret i32 1
+}
+
+define i32 @test1(i32 %a) {
+entry:
+  %cmp = icmp eq i32 %a, 0
+  br i1 %cmp, label %if.then, label %if.else, !prof !1
+if.then:
+  ret i32 1
+if.else:
+  ret i32 2
+}
+
+; the resulting branch hint is actually reversed, since llvm-br is turned into 
br_unless, inverting branch probs
+!0 = !{!"branch_weights", i32 2000, i32 1}
+!1 = !{!"branch_weights", i32 1, i32 2000}
\ No newline at end of file
diff --git a/lld/test/wasm/code-metadata-branch-hints.ll 
b/lld/test/wasm/code-metadata-branch-hints.ll
new file mode 100644
index 0000000000000..076aa40881152
--- /dev/null
+++ b/lld/test/wasm/code-metadata-branch-hints.ll
@@ -0,0 +1,37 @@
+; RUN: llc -filetype=obj %s -o %t1.o -mattr=+branch-hinting
+; RUN: llc -filetype=obj %S/Inputs/branch-hints.ll -o %t2.o 
-mattr=+branch-hinting
+; RUN: wasm-ld --export-all -o %t.wasm %t1.o %t2.o
+; RUN: obj2yaml %t.wasm | FileCheck %s
+
+define i32 @_start(i32 %a) {
+entry:
+  %cmp = icmp eq i32 %a, 0
+  br i1 %cmp, label %if.then, label %if.else, !prof !0
+if.then:
+  ret i32 1
+if.else:
+  ret i32 2
+}
+
+define i32 @test_func1(i32 %a) {
+entry:
+  %cmp = icmp eq i32 %a, 0
+  br i1 %cmp, label %if.then, label %if.else, !prof !1
+if.then:
+  ret i32 1
+if.else:
+  ret i32 2
+}
+
+!0 = !{!"branch_weights", i32 2000, i32 1}
+!1 = !{!"branch_weights", i32 1, i32 2000}
+
+; CHECK:        - Type:            CUSTOM
+; CHECK-NEXT:     Name:            metadata.code.branch_hint
+; CHECK-NEXT:     Payload:         
'84808080008180808000010501008280808000010501018380808000020701000E0101848080800001050101'
+
+; CHECK:        - Type:            CUSTOM
+; CHECK:          Name:            target_features
+; CHECK-NEXT:     Features:
+; CHECK:            - Prefix:          USED
+; CHECK-NEXT:         Name:            branch-hinting
\ No newline at end of file
diff --git a/lld/wasm/OutputSections.cpp b/lld/wasm/OutputSections.cpp
index cd80254a18d5c..cd65dd4a3855b 100644
--- a/lld/wasm/OutputSections.cpp
+++ b/lld/wasm/OutputSections.cpp
@@ -271,6 +271,62 @@ void CustomSection::writeTo(uint8_t *buf) {
     section->writeTo(buf);
 }
 
+void CodeMetaDataSection::writeTo(uint8_t *buf) {
+  log("writing " + toString(*this) + " offset=" + Twine(offset) +
+      " size=" + Twine(getSize()) + " chunks=" + Twine(inputSections.size()));
+
+  assert(offset);
+  buf += offset;
+
+  // Write section header
+  memcpy(buf, header.data(), header.size());
+  buf += header.size();
+  memcpy(buf, nameData.data(), nameData.size());
+  buf += nameData.size();
+
+  uint32_t TotalNumHints = 0;
+  for (const InputChunk *section :
+       make_range(inputSections.rbegin(), inputSections.rend())) {
+    section->writeTo(buf);
+    unsigned EncodingSize;
+    uint32_t NumHints =
+        decodeULEB128(buf + section->outSecOff, &EncodingSize, nullptr);
+    if (EncodingSize != 5) {
+      fatal("Unexpected encoding size for function hint vec size in " + name +
+            ": must be exactly 5 bytes.");
+    }
+    TotalNumHints += NumHints;
+  }
+  encodeULEB128(TotalNumHints, buf, 5);
+}
+
+void CodeMetaDataSection::finalizeContents() {
+  finalizeInputSections();
+
+  raw_string_ostream os(nameData);
+  encodeULEB128(name.size(), os);
+  os << name;
+
+  bool firstSection = true;
+  for (InputChunk *section : inputSections) {
+    assert(!section->discarded);
+    payloadSize = alignTo(payloadSize, section->alignment);
+    if (firstSection) {
+      section->outSecOff = payloadSize;
+      payloadSize += section->getSize();
+      firstSection = false;
+    } else {
+      // adjust output offset so that each section write overwrites exactly the
+      // subsequent section's function hint vector size (which deduplicates)
+      section->outSecOff = payloadSize - 5;
+      // payload size should not include the hint vector size, which is deduped
+      payloadSize += section->getSize() - 5;
+    }
+  }
+
+  createHeader(payloadSize + nameData.size());
+}
+
 uint32_t CustomSection::getNumRelocations() const {
   uint32_t count = 0;
   for (const InputChunk *inputSect : inputSections)
diff --git a/lld/wasm/OutputSections.h b/lld/wasm/OutputSections.h
index 4b0329dd16cf2..6580c71ab6f5a 100644
--- a/lld/wasm/OutputSections.h
+++ b/lld/wasm/OutputSections.h
@@ -132,6 +132,15 @@ class CustomSection : public OutputSection {
   std::string nameData;
 };
 
+class CodeMetaDataSection : public CustomSection {
+public:
+  CodeMetaDataSection(std::string name, ArrayRef<InputChunk *> inputSections)
+      : CustomSection(name, inputSections) {}
+
+  void writeTo(uint8_t *buf) override;
+  void finalizeContents() override;
+};
+
 } // namespace wasm
 } // namespace lld
 
diff --git a/lld/wasm/Writer.cpp b/lld/wasm/Writer.cpp
index 3cc3e0d498979..e9638bce8e86b 100644
--- a/lld/wasm/Writer.cpp
+++ b/lld/wasm/Writer.cpp
@@ -170,14 +170,17 @@ void Writer::createCustomSections() {
   for (auto &pair : customSectionMapping) {
     StringRef name = pair.first;
     LLVM_DEBUG(dbgs() << "createCustomSection: " << name << "\n");
-
-    OutputSection *sec = make<CustomSection>(std::string(name), pair.second);
+    OutputSection *Sec;
+    if (name == "metadata.code.branch_hint")
+      Sec = make<CodeMetaDataSection>(std::string(name), pair.second);
+    else
+      Sec = make<CustomSection>(std::string(name), pair.second);
     if (ctx.arg.relocatable || ctx.arg.emitRelocs) {
-      auto *sym = make<OutputSectionSymbol>(sec);
+      auto *sym = make<OutputSectionSymbol>(Sec);
       out.linkingSec->addToSymtab(sym);
-      sec->sectionSym = sym;
+      Sec->sectionSym = sym;
     }
-    addSection(sec);
+    addSection(Sec);
   }
 }
 
diff --git a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyAsmBackend.cpp 
b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyAsmBackend.cpp
index 91a1db80deb3c..309954199325a 100644
--- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyAsmBackend.cpp
+++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyAsmBackend.cpp
@@ -13,6 +13,7 @@
 
 #include "MCTargetDesc/WebAssemblyFixupKinds.h"
 #include "MCTargetDesc/WebAssemblyMCTargetDesc.h"
+#include "WebAssemblyMCExpr.h"
 #include "llvm/MC/MCAsmBackend.h"
 #include "llvm/MC/MCAssembler.h"
 #include "llvm/MC/MCExpr.h"
@@ -21,6 +22,7 @@
 #include "llvm/MC/MCSubtargetInfo.h"
 #include "llvm/MC/MCSymbol.h"
 #include "llvm/MC/MCWasmObjectWriter.h"
+#include "llvm/Support/Casting.h"
 #include "llvm/Support/raw_ostream.h"
 
 using namespace llvm;
@@ -46,6 +48,9 @@ class WebAssemblyAsmBackend final : public MCAsmBackend {
   std::unique_ptr<MCObjectTargetWriter>
   createObjectTargetWriter() const override;
 
+  std::pair<bool, bool> relaxLEB128(const MCAssembler &Asm, MCLEBFragment &LF,
+                                    int64_t &Value) const override;
+
   bool writeNopData(raw_ostream &OS, uint64_t Count,
                     const MCSubtargetInfo *STI) const override;
 };
@@ -72,6 +77,27 @@ WebAssemblyAsmBackend::getFixupKindInfo(MCFixupKind Kind) 
const {
   return Infos[Kind - FirstTargetFixupKind];
 }
 
+std::pair<bool, bool>
+WebAssemblyAsmBackend::relaxLEB128(const MCAssembler &assembler,
+                                   MCLEBFragment &LF, int64_t &Value) const {
+  const MCExpr &Expr = LF.getValue();
+  if (Expr.getKind() == MCExpr::ExprKind::SymbolRef) {
+    const MCSymbolRefExpr &SymExpr = llvm::cast<MCSymbolRefExpr>(Expr);
+    if (static_cast<WebAssembly::Specifier>(SymExpr.getSpecifier()) ==
+        WebAssembly::S_DEBUG_REF) {
+      Value = assembler.getSymbolOffset(SymExpr.getSymbol());
+      return std::make_pair(true, false);
+    }
+  }
+  // currently, this is only used for leb128 encoded function indices
+  // that require relocations
+  LF.getFixups().push_back(
+      MCFixup::create(0, &Expr, WebAssembly::fixup_uleb128_i32, 
Expr.getLoc()));
+  // ensure that the stored placeholder is large enough to hold any 32-bit val
+  Value = UINT32_MAX;
+  return std::make_pair(true, false);
+}
+
 bool WebAssemblyAsmBackend::writeNopData(raw_ostream &OS, uint64_t Count,
                                          const MCSubtargetInfo *STI) const {
   for (uint64_t I = 0; I < Count; ++I)
diff --git a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCExpr.h 
b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCExpr.h
index f74af06fb84fa..8276fad49baae 100644
--- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCExpr.h
+++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyMCExpr.h
@@ -23,6 +23,7 @@ enum Specifier {
   S_TBREL,     // Table index relative to __table_base
   S_TLSREL,    // Memory address relative to __tls_base
   S_TYPEINDEX, // Reference to a symbol's type (signature)
+  S_DEBUG_REF, // Marker placed for generation of metadata.code.* section
 };
 } // namespace llvm::WebAssembly
 
diff --git 
a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyWasmObjectWriter.cpp 
b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyWasmObjectWriter.cpp
index 33cf12e59870c..4a7bb2f4acc1a 100644
--- a/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyWasmObjectWriter.cpp
+++ b/llvm/lib/Target/WebAssembly/MCTargetDesc/WebAssemblyWasmObjectWriter.cpp
@@ -93,10 +93,13 @@ unsigned WebAssemblyWasmObjectWriter::getRelocType(
   case WebAssembly::S_None:
     break;
   case WebAssembly::S_FUNCINDEX:
+    if (static_cast<unsigned>(Fixup.getKind()) ==
+        WebAssembly::fixup_uleb128_i32)
+      return wasm::R_WASM_FUNCTION_INDEX_LEB;
     return wasm::R_WASM_FUNCTION_INDEX_I32;
   }
 
-  switch (unsigned(Fixup.getKind())) {
+  switch (static_cast<unsigned>(Fixup.getKind())) {
   case WebAssembly::fixup_sleb128_i32:
     if (SymA.isFunction())
       return wasm::R_WASM_TABLE_INDEX_SLEB;
diff --git a/llvm/lib/Target/WebAssembly/WebAssembly.td 
b/llvm/lib/Target/WebAssembly/WebAssembly.td
index 13603f8181198..ec3889e2037e4 100644
--- a/llvm/lib/Target/WebAssembly/WebAssembly.td
+++ b/llvm/lib/Target/WebAssembly/WebAssembly.td
@@ -90,6 +90,10 @@ def FeatureWideArithmetic :
       SubtargetFeature<"wide-arithmetic", "HasWideArithmetic", "true",
                        "Enable wide-arithmetic instructions">;
 
+def FeatureBranchHinting :
+      SubtargetFeature<"branch-hinting", "HasBranchHinting", "true",
+                       "Enable branch hints for branch instructions">;
+
 
//===----------------------------------------------------------------------===//
 // Architectures.
 
//===----------------------------------------------------------------------===//
@@ -142,7 +146,7 @@ def : ProcessorModel<"bleeding-edge", NoSchedModel,
                        FeatureMultivalue, FeatureMutableGlobals,
                        FeatureNontrappingFPToInt, FeatureRelaxedSIMD,
                        FeatureReferenceTypes, FeatureSIMD128, FeatureSignExt,
-                       FeatureTailCall]>;
+                       FeatureTailCall, FeatureBranchHinting]>;
 
 
//===----------------------------------------------------------------------===//
 // Target Declaration
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp 
b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp
index c61ed3c7d5d81..6eaab5939163d 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.cpp
@@ -54,6 +54,15 @@ using namespace llvm;
 #define DEBUG_TYPE "asm-printer"
 
 extern cl::opt<bool> WasmKeepRegisters;
+// values are divided by 1<<31 to calculate the probability
+static cl::opt<uint32_t> WasmHighBranchProb(
+    "wasm-branch-prob-high", cl::Hidden,
+    cl::desc("lowest branch probability to not be annotated as likely taken"),
+    cl::init(0x40000000));
+static cl::opt<uint32_t> WasmLowBranchProb(
+    "wasm-branch-prob-low", cl::Hidden,
+    cl::desc("highest branch probability to be annotated as unlikely taken"),
+    cl::init(0x40000000));
 
 
//===----------------------------------------------------------------------===//
 // Helpers.
@@ -441,6 +450,38 @@ void WebAssemblyAsmPrinter::emitEndOfAsmFile(Module &M) {
   EmitProducerInfo(M);
   EmitTargetFeatures(M);
   EmitFunctionAttributes(M);
+
+  // Subtarget may be null if no functions have been defined in file
+  if (Subtarget && Subtarget->hasBranchHinting())
+    EmitBranchHintSection();
+}
+
+void WebAssemblyAsmPrinter::EmitBranchHintSection() const {
+  MCSectionWasm *BranchHintsSection = OutContext.getWasmSection(
+      "metadata.code.branch_hint", SectionKind::getMetadata());
+  OutStreamer->pushSection();
+  OutStreamer->switchSection(BranchHintsSection);
+  // should we emit empty branch hints section?
+  OutStreamer->emitULEB128IntValue(branchHints.size(), 5);
+  for (const auto &BHR : branchHints) {
+    if (!BHR)
+      continue;
+    // emit relocatable function index for the function symbol
+    OutStreamer->emitULEB128Value(MCSymbolRefExpr::create(
+        BHR->func_sym, WebAssembly::S_FUNCINDEX, OutContext));
+    // emit the number of hints for this function (is constant -> does not need
+    // handling by target streamer for reloc)
+    OutStreamer->emitULEB128IntValue(BHR->hints.size());
+    for (const auto &[instrSym, hint] : BHR->hints) {
+      assert(hint == 0 || hint == 1);
+      // offset from function start
+      OutStreamer->emitULEB128Value(MCSymbolRefExpr::create(
+          instrSym, WebAssembly::S_DEBUG_REF, OutContext));
+      OutStreamer->emitULEB128IntValue(1); // hint size
+      OutStreamer->emitULEB128IntValue(hint);
+    }
+  }
+  OutStreamer->popSection();
 }
 
 void WebAssemblyAsmPrinter::EmitProducerInfo(Module &M) {
@@ -696,6 +737,34 @@ void WebAssemblyAsmPrinter::emitInstruction(const 
MachineInstr *MI) {
     WebAssemblyMCInstLower MCInstLowering(OutContext, *this);
     MCInst TmpInst;
     MCInstLowering.lower(MI, TmpInst);
+    if (Subtarget->hasBranchHinting() &&
+        MI->getOpcode() == WebAssembly::BR_IF && MFI &&
+        MFI->BranchProbabilities.contains(MI)) {
+      MCSymbol *BrIfSym = OutContext.createTempSymbol();
+      OutStreamer->emitLabel(BrIfSym);
+
+      constexpr uint8_t HintLikely = 0x01;
+      constexpr uint8_t HintUnlikely = 0x00;
+      const BranchProbability &Prob = MFI->BranchProbabilities[MI];
+      uint8_t HintValue;
+      if (Prob > BranchProbability::getRaw(WasmHighBranchProb.getValue()))
+        HintValue = HintLikely;
+      else if (Prob <= BranchProbability::getRaw(WasmLowBranchProb.getValue()))
+        HintValue = HintUnlikely;
+      else
+        goto emit; // Don't emit branch hint between thresholds
+
+      // we know that we only emit branch hints for internal functions,
+      // therefore we can directly cast and don't need getMCSymbolForFunction
+      MCSymbol *FuncSym = cast<MCSymbolWasm>(getSymbol(&MF->getFunction()));
+      uint32_t LocalFuncIdx = MF->getFunctionNumber();
+      if (branchHints.size() <= LocalFuncIdx) {
+        branchHints.resize(LocalFuncIdx + 1);
+        branchHints[LocalFuncIdx] = BranchHintRecord{FuncSym, {}};
+      }
+      branchHints[LocalFuncIdx]->hints.emplace_back(BrIfSym, HintValue);
+    }
+  emit:
     EmitToStreamer(*OutStreamer, TmpInst);
     break;
   }
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.h 
b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.h
index 46063bbe0fba1..b6595492cf26c 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyAsmPrinter.h
@@ -18,6 +18,11 @@
 namespace llvm {
 class WebAssemblyTargetStreamer;
 
+struct BranchHintRecord {
+  MCSymbol *func_sym;
+  SmallVector<std::pair<MCSymbol *, uint8_t>, 0> hints;
+};
+
 class LLVM_LIBRARY_VISIBILITY WebAssemblyAsmPrinter final : public AsmPrinter {
 public:
   static char ID;
@@ -28,6 +33,9 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyAsmPrinter final : 
public AsmPrinter {
   WebAssemblyFunctionInfo *MFI;
   bool signaturesEmitted = false;
 
+  // vec idx == local func_idx
+  std::vector<std::optional<BranchHintRecord>> branchHints;
+
 public:
   explicit WebAssemblyAsmPrinter(TargetMachine &TM,
                                  std::unique_ptr<MCStreamer> Streamer)
@@ -59,6 +67,7 @@ class LLVM_LIBRARY_VISIBILITY WebAssemblyAsmPrinter final : 
public AsmPrinter {
   void EmitProducerInfo(Module &M);
   void EmitTargetFeatures(Module &M);
   void EmitFunctionAttributes(Module &M);
+  void EmitBranchHintSection() const;
   void emitSymbolType(const MCSymbolWasm *Sym);
   void emitGlobalVariable(const GlobalVariable *GV) override;
   void emitJumpTableInfo() override;
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyCFGStackify.cpp 
b/llvm/lib/Target/WebAssembly/WebAssemblyCFGStackify.cpp
index 640be5fe8e8c9..84611030e448d 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyCFGStackify.cpp
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyCFGStackify.cpp
@@ -31,6 +31,7 @@
 #include "WebAssemblyUtilities.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/BinaryFormat/Wasm.h"
+#include "llvm/CodeGen/MachineBranchProbabilityInfo.h"
 #include "llvm/CodeGen/MachineDominators.h"
 #include "llvm/CodeGen/MachineInstrBuilder.h"
 #include "llvm/CodeGen/MachineLoopInfo.h"
@@ -48,6 +49,7 @@ STATISTIC(NumCatchUnwindMismatches, "Number of catch unwind 
mismatches found");
 namespace {
 class WebAssemblyCFGStackify final : public MachineFunctionPass {
   MachineDominatorTree *MDT;
+  MachineBranchProbabilityInfo *MBPI;
 
   StringRef getPassName() const override { return "WebAssembly CFG Stackify"; }
 
@@ -55,6 +57,7 @@ class WebAssemblyCFGStackify final : public 
MachineFunctionPass {
     AU.addRequired<MachineDominatorTreeWrapperPass>();
     AU.addRequired<MachineLoopInfoWrapperPass>();
     AU.addRequired<WebAssemblyExceptionInfo>();
+    AU.addRequired<MachineBranchProbabilityInfoWrapperPass>();
     MachineFunctionPass::getAnalysisUsage(AU);
   }
 
@@ -2562,8 +2565,22 @@ void 
WebAssemblyCFGStackify::rewriteDepthImmediates(MachineFunction &MF) {
           MO = MachineOperand::CreateImm(getDelegateDepth(Stack, MO.getMBB()));
         else if (MI.getOpcode() == WebAssembly::RETHROW)
           MO = MachineOperand::CreateImm(getRethrowDepth(Stack, MO.getMBB()));
-        else
+        else {
+          // this is the last place where we can easily calculate the branch
+          // probabilities. we do not emit scf-ifs, therefore only br_ifs have
+          // to be handled here.
+          if (MF.getSubtarget<WebAssemblySubtarget>().hasBranchHinting() &&
+              MI.getOpcode() == WebAssembly::BR_IF &&
+              MI.getParent()->hasSuccessorProbabilities()) {
+            const auto Prob =
+                MBPI->getEdgeProbability(MI.getParent(), MO.getMBB());
+            WebAssemblyFunctionInfo *MFI =
+                MF.getInfo<WebAssemblyFunctionInfo>();
+            assert(!MFI->BranchProbabilities.contains(&MI));
+            MFI->BranchProbabilities[&MI] = Prob;
+          }
           MO = MachineOperand::CreateImm(getBranchDepth(Stack, MO.getMBB()));
+        }
       }
       MI.addOperand(MF, MO);
     }
@@ -2639,6 +2656,7 @@ bool 
WebAssemblyCFGStackify::runOnMachineFunction(MachineFunction &MF) {
                     << MF.getName() << '\n');
   const MCAsmInfo *MCAI = MF.getTarget().getMCAsmInfo();
   MDT = &getAnalysis<MachineDominatorTreeWrapperPass>().getDomTree();
+  MBPI = &getAnalysis<MachineBranchProbabilityInfoWrapperPass>().getMBPI();
 
   releaseMemory();
 
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.td 
b/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.td
index b5e723e2a48d3..f1d4a62535060 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.td
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyInstrInfo.td
@@ -96,6 +96,10 @@ def HasWideArithmetic :
     Predicate<"Subtarget->hasWideArithmetic()">,
     AssemblerPredicate<(all_of FeatureWideArithmetic), "wide-arithmetic">;
 
+def HasBranchHinting :
+    Predicate<"Subtarget->hasBranchHinting()">,
+    AssemblerPredicate<(all_of FeatureBranchHinting), "branch-hinting">;
+
 
//===----------------------------------------------------------------------===//
 // WebAssembly-specific DAG Node Types.
 
//===----------------------------------------------------------------------===//
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.h 
b/llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.h
index 40ae4aef1d7f2..343168b570bef 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblyMachineFunctionInfo.h
@@ -153,6 +153,8 @@ class WebAssemblyFunctionInfo final : public 
MachineFunctionInfo {
 
   bool isCFGStackified() const { return CFGStackified; }
   void setCFGStackified(bool Value = true) { CFGStackified = Value; }
+
+  DenseMap<const MachineInstr *, BranchProbability> BranchProbabilities;
 };
 
 void computeLegalValueVTs(const WebAssemblyTargetLowering &TLI,
diff --git a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h 
b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h
index 591ce25611e3e..96a24a1d40ef7 100644
--- a/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h
+++ b/llvm/lib/Target/WebAssembly/WebAssemblySubtarget.h
@@ -54,6 +54,7 @@ class WebAssemblySubtarget final : public 
WebAssemblyGenSubtargetInfo {
   bool HasSignExt = false;
   bool HasTailCall = false;
   bool HasWideArithmetic = false;
+  bool HasBranchHinting = false;
 
   /// What processor and OS we're targeting.
   Triple TargetTriple;
@@ -112,6 +113,7 @@ class WebAssemblySubtarget final : public 
WebAssemblyGenSubtargetInfo {
   bool hasSIMD128() const { return SIMDLevel >= SIMD128; }
   bool hasTailCall() const { return HasTailCall; }
   bool hasWideArithmetic() const { return HasWideArithmetic; }
+  bool hasBranchHinting() const { return HasBranchHinting; }
 
   /// Parses features string setting specified subtarget options. Definition of
   /// function is auto generated by tblgen.
diff --git 
a/llvm/test/MC/WebAssembly/branch-hints-custom-high-low-thresholds.ll 
b/llvm/test/MC/WebAssembly/branch-hints-custom-high-low-thresholds.ll
new file mode 100644
index 0000000000000..8fdce544aa06e
--- /dev/null
+++ b/llvm/test/MC/WebAssembly/branch-hints-custom-high-low-thresholds.ll
@@ -0,0 +1,79 @@
+; RUN: llc -mcpu=mvp -filetype=obj %s -mattr=+branch-hinting 
-wasm-branch-prob-high=0x60000000 -wasm-branch-prob-low=0x0 -o - | obj2yaml | 
FileCheck %s
+
+; This test checks that branch weight metadata (!prof) is correctly translated 
to webassembly branch hints
+; We set the prob-thresholds so that "likely" branches are only emitted if 
prob > 75% and "unlikely" branches
+; if prob <= 0%.
+
+; CHECK:        - Type:            CUSTOM
+; CHECK-NEXT:     Relocations:
+; CHECK-NEXT:       - Type:            R_WASM_FUNCTION_INDEX_LEB
+; CHECK-NEXT:         Index:           0
+; CHECK-NEXT:         Offset:          0x5
+; CHECK-NEXT:       - Type:            R_WASM_FUNCTION_INDEX_LEB
+; CHECK-NEXT:         Index:           2
+; CHECK-NEXT:         Offset:          0xE
+; CHECK-NEXT:     Name:            metadata.code.branch_hint
+; CHECK-NEXT:     Payload:         
'8380808000808080800001050101828080800001050100'
+
+; CHECK:        - Type:            CUSTOM
+; CHECK-NEXT:     Name:            linking
+; CHECK-NEXT:     Version:         2
+; CHECK-NEXT:     SymbolTable:
+; CHECK-NEXT:       - Index:           0
+; CHECK-NEXT:         Kind:            FUNCTION
+; CHECK-NEXT:         Name:            test0
+; CHECK-NEXT:         Flags:           [  ]
+; CHECK-NEXT:         Function:        0
+; CHECK-NEXT:       - Index:           1
+; CHECK-NEXT:         Kind:            FUNCTION
+; CHECK-NEXT:         Name:            test1
+; CHECK-NEXT:         Flags:           [  ]
+; CHECK-NEXT:         Function:        1
+; CHECK-NEXT:       - Index:           2
+; CHECK-NEXT:         Kind:            FUNCTION
+; CHECK-NEXT:         Name:            test2
+; CHECK-NEXT:         Flags:           [  ]
+; CHECK-NEXT:         Function:        2
+
+; CHECK:        - Type:            CUSTOM
+; CHECK-NEXT:     Name:            target_features
+; CHECK-NEXT:     Features:
+; CHECK-NEXT:       - Prefix:          USED
+; CHECK-NEXT:         Name:            branch-hinting
+
+target triple = "wasm32-unknown-unknown"
+
+define i32 @test0(i32 %a) {
+entry:
+  %cmp0 = icmp eq i32 %a, 0
+  br i1 %cmp0, label %if_then, label %if_else, !prof !0
+if_then:
+  ret i32 1
+if_else:
+  ret i32 0
+}
+
+define i32 @test1(i32 %a) {
+entry:
+  %cmp0 = icmp eq i32 %a, 0
+  br i1 %cmp0, label %if_then, label %if_else, !prof !1
+if_then:
+  ret i32 1
+if_else:
+  ret i32 0
+}
+
+define i32 @test2(i32 %a) {
+entry:
+  %cmp0 = icmp eq i32 %a, 0
+  br i1 %cmp0, label %if_then, label %if_else, !prof !2
+if_then:
+  ret i32 1
+if_else:
+  ret i32 0
+}
+
+; the resulting branch hint is actually reversed, since llvm-br is turned into 
br_unless, inverting branch probs
+!0 = !{!"branch_weights", !"expected", i32 100, i32 310} ; prob 75.61%
+!1 = !{!"branch_weights", i32 1, i32 1}                  ; prob == 50% (no 
hint)
+!2 = !{!"branch_weights", i32 1, i32 0}                  ; prob == 0% 
(unlikely hint)
\ No newline at end of file
diff --git a/llvm/test/MC/WebAssembly/branch-hints.ll 
b/llvm/test/MC/WebAssembly/branch-hints.ll
new file mode 100644
index 0000000000000..8b4c96a6f4eff
--- /dev/null
+++ b/llvm/test/MC/WebAssembly/branch-hints.ll
@@ -0,0 +1,66 @@
+; RUN: llc -mcpu=mvp -filetype=obj %s -mattr=+branch-hinting -o - | obj2yaml | 
FileCheck %s
+
+; This test checks that branch weight metadata (!prof) is correctly lowered to
+; the WebAssembly branch hint custom section.
+
+; CHECK:        - Type:            CUSTOM
+; CHECK-NEXT:     Relocations:
+; CHECK-NEXT:       - Type:            R_WASM_FUNCTION_INDEX_LEB
+; CHECK-NEXT:         Index:           0
+; CHECK-NEXT:         Offset:          0x5
+; CHECK-NEXT:       - Type:            R_WASM_FUNCTION_INDEX_LEB
+; CHECK-NEXT:         Index:           1
+; CHECK-NEXT:         Offset:          0x11
+; CHECK-NEXT:     Name:            metadata.code.branch_hint
+; CHECK-NEXT:     Payload:         
'82808080008080808000020701000E0101818080800001050101'
+
+; CHECK:        - Type:            CUSTOM
+; CHECK-NEXT:     Name:            linking
+; CHECK-NEXT:     Version:         2
+; CHECK-NEXT:     SymbolTable:
+; CHECK-NEXT:       - Index:           0
+; CHECK-NEXT:         Kind:            FUNCTION
+; CHECK-NEXT:         Name:            test_unlikely_likely_branch
+; CHECK-NEXT:         Flags:           [  ]
+; CHECK-NEXT:         Function:        0
+; CHECK-NEXT:       - Index:           1
+; CHECK-NEXT:         Kind:            FUNCTION
+; CHECK-NEXT:         Name:            test_likely_branch
+; CHECK-NEXT:         Flags:           [  ]
+; CHECK-NEXT:         Function:        1
+
+; CHECK:        - Type:            CUSTOM
+; CHECK-NEXT:     Name:            target_features
+; CHECK-NEXT:     Features:
+; CHECK-NEXT:       - Prefix:          USED
+; CHECK-NEXT:         Name:            branch-hinting
+
+target triple = "wasm32-unknown-unknown"
+
+define i32 @test_unlikely_likely_branch(i32 %a) {
+entry:
+  %cmp0 = icmp eq i32 %a, 0
+  ; This metadata hints that the true branch is overwhelmingly likely.
+  br i1 %cmp0, label %if.then, label %ret1, !prof !0
+if.then:
+  %cmp1 = icmp eq i32 %a, 1
+  br i1 %cmp1, label %ret1, label %ret2, !prof !1
+ret1:
+  ret i32 2
+ret2:
+  ret i32 1
+}
+
+define i32 @test_likely_branch(i32 %a) {
+entry:
+  %cmp = icmp eq i32 %a, 0
+  br i1 %cmp, label %if.then, label %if.else, !prof !1
+if.then:
+  ret i32 1
+if.else:
+  ret i32 2
+}
+
+; the resulting branch hint is actually reversed, since llvm-br is turned into 
br_unless, inverting branch probs
+!0 = !{!"branch_weights", i32 2000, i32 1}
+!1 = !{!"branch_weights", i32 1, i32 2000}
\ No newline at end of file

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to