https://github.com/svkeerthy updated 
https://github.com/llvm/llvm-project/pull/145118

>From f1976fa2454846d80822761f7a095b29c2062652 Mon Sep 17 00:00:00 2001
From: svkeerthy <venkatakeer...@google.com>
Date: Fri, 20 Jun 2025 23:00:40 +0000
Subject: [PATCH] Overloading operator+ for Embeddngs

---
 llvm/include/llvm/Analysis/IR2Vec.h    |  9 ++++--
 llvm/lib/Analysis/IR2Vec.cpp           | 23 +++++++++++++++
 llvm/unittests/Analysis/IR2VecTest.cpp | 39 ++++++++++++++++++++++++++
 3 files changed, 68 insertions(+), 3 deletions(-)

diff --git a/llvm/include/llvm/Analysis/IR2Vec.h 
b/llvm/include/llvm/Analysis/IR2Vec.h
index 040cb84ff27a1..d63be227b1849 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -107,9 +107,12 @@ struct Embedding {
   const std::vector<double> &getData() const { return Data; }
 
   /// Arithmetic operators
-  Embedding &operator+=(const Embedding &RHS);
-  Embedding &operator-=(const Embedding &RHS);
-  Embedding &operator*=(double Factor);
+  LLVM_ABI Embedding operator+(const Embedding &RHS) const;
+  LLVM_ABI Embedding &operator+=(const Embedding &RHS);
+  LLVM_ABI Embedding operator-(const Embedding &RHS) const;
+  LLVM_ABI Embedding &operator-=(const Embedding &RHS);
+  LLVM_ABI Embedding operator*(double Factor) const;
+  LLVM_ABI Embedding &operator*=(double Factor);
 
   /// Adds Src Embedding scaled by Factor with the called Embedding.
   /// Called_Embedding += Src * Factor
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 895b3de58a54e..e499ebdd5ed3c 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -71,6 +71,14 @@ inline bool fromJSON(const llvm::json::Value &E, Embedding 
&Out,
 // Embedding
 
//===----------------------------------------------------------------------===//
 
+Embedding Embedding::operator+(const Embedding &RHS) const {
+  assert(this->size() == RHS.size() && "Vectors must have the same dimension");
+  Embedding Result(*this);
+  std::transform(this->begin(), this->end(), RHS.begin(), Result.begin(),
+                 std::plus<double>());
+  return Result;
+}
+
 Embedding &Embedding::operator+=(const Embedding &RHS) {
   assert(this->size() == RHS.size() && "Vectors must have the same dimension");
   std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
@@ -78,6 +86,14 @@ Embedding &Embedding::operator+=(const Embedding &RHS) {
   return *this;
 }
 
+Embedding Embedding::operator-(const Embedding &RHS) const {
+  assert(this->size() == RHS.size() && "Vectors must have the same dimension");
+  Embedding Result(*this);
+  std::transform(this->begin(), this->end(), RHS.begin(), Result.begin(),
+                 std::minus<double>());
+  return Result;
+}
+
 Embedding &Embedding::operator-=(const Embedding &RHS) {
   assert(this->size() == RHS.size() && "Vectors must have the same dimension");
   std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
@@ -85,6 +101,13 @@ Embedding &Embedding::operator-=(const Embedding &RHS) {
   return *this;
 }
 
+Embedding Embedding::operator*(double Factor) const {
+  Embedding Result(*this);
+  std::transform(this->begin(), this->end(), Result.begin(),
+                 [Factor](double Elem) { return Elem * Factor; });
+  return Result;
+}
+
 Embedding &Embedding::operator*=(double Factor) {
   std::transform(this->begin(), this->end(), this->begin(),
                  [Factor](double Elem) { return Elem * Factor; });
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp 
b/llvm/unittests/Analysis/IR2VecTest.cpp
index 3c97c20ae72d5..70d4808dc6d54 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -109,6 +109,18 @@ TEST(EmbeddingTest, ConstructorsAndAccessors) {
   }
 }
 
+TEST(EmbeddingTest, AddVectorsOutOfPlace) {
+  Embedding E1 = {1.0, 2.0, 3.0};
+  Embedding E2 = {0.5, 1.5, -1.0};
+
+  Embedding E3 = E1 + E2;
+  EXPECT_THAT(E3, ElementsAre(1.5, 3.5, 2.0));
+
+  // Check that E1 and E2 are unchanged
+  EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
+  EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
+}
+
 TEST(EmbeddingTest, AddVectors) {
   Embedding E1 = {1.0, 2.0, 3.0};
   Embedding E2 = {0.5, 1.5, -1.0};
@@ -120,6 +132,18 @@ TEST(EmbeddingTest, AddVectors) {
   EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
 }
 
+TEST(EmbeddingTest, SubtractVectorsOutOfPlace) {
+  Embedding E1 = {1.0, 2.0, 3.0};
+  Embedding E2 = {0.5, 1.5, -1.0};
+
+  Embedding E3 = E1 - E2;
+  EXPECT_THAT(E3, ElementsAre(0.5, 0.5, 4.0));
+
+  // Check that E1 and E2 are unchanged
+  EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
+  EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
+}
+
 TEST(EmbeddingTest, SubtractVectors) {
   Embedding E1 = {1.0, 2.0, 3.0};
   Embedding E2 = {0.5, 1.5, -1.0};
@@ -137,6 +161,15 @@ TEST(EmbeddingTest, ScaleVector) {
   EXPECT_THAT(E1, ElementsAre(0.5, 1.0, 1.5));
 }
 
+TEST(EmbeddingTest, ScaleVectorOutOfPlace) {
+  Embedding E1 = {1.0, 2.0, 3.0};
+  Embedding E2 = E1 * 0.5f;
+  EXPECT_THAT(E2, ElementsAre(0.5, 1.0, 1.5));
+
+  // Check that E1 is unchanged
+  EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
+}
+
 TEST(EmbeddingTest, AddScaledVector) {
   Embedding E1 = {1.0, 2.0, 3.0};
   Embedding E2 = {2.0, 0.5, -1.0};
@@ -180,6 +213,12 @@ TEST(EmbeddingTest, AccessOutOfBounds) {
   EXPECT_DEATH(E[4] = 4.0, "Index out of bounds");
 }
 
+TEST(EmbeddingTest, MismatchedDimensionsAddVectorsOutOfPlace) {
+  Embedding E1 = {1.0, 2.0};
+  Embedding E2 = {1.0};
+  EXPECT_DEATH(E1 + E2, "Vectors must have the same dimension");
+}
+
 TEST(EmbeddingTest, MismatchedDimensionsAddVectors) {
   Embedding E1 = {1.0, 2.0};
   Embedding E2 = {1.0};

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to