https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/145118
>From 10019cae162bb53e147797b655da75aac33b0a20 Mon Sep 17 00:00:00 2001 From: svkeerthy <venkatakeer...@google.com> Date: Fri, 20 Jun 2025 23:00:40 +0000 Subject: [PATCH] Overloading operator+ for Embeddngs --- llvm/include/llvm/Analysis/IR2Vec.h | 9 ++++-- llvm/lib/Analysis/IR2Vec.cpp | 19 ++++++++++++- llvm/unittests/Analysis/IR2VecTest.cpp | 39 ++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h index 040cb84ff27a1..ef8f630d7feb1 100644 --- a/llvm/include/llvm/Analysis/IR2Vec.h +++ b/llvm/include/llvm/Analysis/IR2Vec.h @@ -107,9 +107,12 @@ struct Embedding { const std::vector<double> &getData() const { return Data; } /// Arithmetic operators - Embedding &operator+=(const Embedding &RHS); - Embedding &operator-=(const Embedding &RHS); - Embedding &operator*=(double Factor); + LLVM_ABI Embedding &operator+=(const Embedding &RHS); + LLVM_ABI Embedding operator+(const Embedding &RHS) const; + LLVM_ABI Embedding &operator-=(const Embedding &RHS); + LLVM_ABI Embedding operator-(const Embedding &RHS) const; + LLVM_ABI Embedding &operator*=(double Factor); + LLVM_ABI Embedding operator*(double Factor) const; /// Adds Src Embedding scaled by Factor with the called Embedding. /// Called_Embedding += Src * Factor diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 895b3de58a54e..bf456102bb618 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -70,7 +70,6 @@ inline bool fromJSON(const llvm::json::Value &E, Embedding &Out, // ==----------------------------------------------------------------------===// // Embedding //===----------------------------------------------------------------------===// - Embedding &Embedding::operator+=(const Embedding &RHS) { assert(this->size() == RHS.size() && "Vectors must have the same dimension"); std::transform(this->begin(), this->end(), RHS.begin(), this->begin(), @@ -78,6 +77,12 @@ Embedding &Embedding::operator+=(const Embedding &RHS) { return *this; } +Embedding Embedding::operator+(const Embedding &RHS) const { + Embedding Result(*this); + Result += RHS; + return Result; +} + Embedding &Embedding::operator-=(const Embedding &RHS) { assert(this->size() == RHS.size() && "Vectors must have the same dimension"); std::transform(this->begin(), this->end(), RHS.begin(), this->begin(), @@ -85,12 +90,24 @@ Embedding &Embedding::operator-=(const Embedding &RHS) { return *this; } +Embedding Embedding::operator-(const Embedding &RHS) const { + Embedding Result(*this); + Result -= RHS; + return Result; +} + Embedding &Embedding::operator*=(double Factor) { std::transform(this->begin(), this->end(), this->begin(), [Factor](double Elem) { return Elem * Factor; }); return *this; } +Embedding Embedding::operator*(double Factor) const { + Embedding Result(*this); + Result *= Factor; + return Result; +} + Embedding &Embedding::scaleAndAdd(const Embedding &Src, float Factor) { assert(this->size() == Src.size() && "Vectors must have the same dimension"); for (size_t Itr = 0; Itr < this->size(); ++Itr) diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp index 3c97c20ae72d5..70d4808dc6d54 100644 --- a/llvm/unittests/Analysis/IR2VecTest.cpp +++ b/llvm/unittests/Analysis/IR2VecTest.cpp @@ -109,6 +109,18 @@ TEST(EmbeddingTest, ConstructorsAndAccessors) { } } +TEST(EmbeddingTest, AddVectorsOutOfPlace) { + Embedding E1 = {1.0, 2.0, 3.0}; + Embedding E2 = {0.5, 1.5, -1.0}; + + Embedding E3 = E1 + E2; + EXPECT_THAT(E3, ElementsAre(1.5, 3.5, 2.0)); + + // Check that E1 and E2 are unchanged + EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0)); + EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0)); +} + TEST(EmbeddingTest, AddVectors) { Embedding E1 = {1.0, 2.0, 3.0}; Embedding E2 = {0.5, 1.5, -1.0}; @@ -120,6 +132,18 @@ TEST(EmbeddingTest, AddVectors) { EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0)); } +TEST(EmbeddingTest, SubtractVectorsOutOfPlace) { + Embedding E1 = {1.0, 2.0, 3.0}; + Embedding E2 = {0.5, 1.5, -1.0}; + + Embedding E3 = E1 - E2; + EXPECT_THAT(E3, ElementsAre(0.5, 0.5, 4.0)); + + // Check that E1 and E2 are unchanged + EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0)); + EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0)); +} + TEST(EmbeddingTest, SubtractVectors) { Embedding E1 = {1.0, 2.0, 3.0}; Embedding E2 = {0.5, 1.5, -1.0}; @@ -137,6 +161,15 @@ TEST(EmbeddingTest, ScaleVector) { EXPECT_THAT(E1, ElementsAre(0.5, 1.0, 1.5)); } +TEST(EmbeddingTest, ScaleVectorOutOfPlace) { + Embedding E1 = {1.0, 2.0, 3.0}; + Embedding E2 = E1 * 0.5f; + EXPECT_THAT(E2, ElementsAre(0.5, 1.0, 1.5)); + + // Check that E1 is unchanged + EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0)); +} + TEST(EmbeddingTest, AddScaledVector) { Embedding E1 = {1.0, 2.0, 3.0}; Embedding E2 = {2.0, 0.5, -1.0}; @@ -180,6 +213,12 @@ TEST(EmbeddingTest, AccessOutOfBounds) { EXPECT_DEATH(E[4] = 4.0, "Index out of bounds"); } +TEST(EmbeddingTest, MismatchedDimensionsAddVectorsOutOfPlace) { + Embedding E1 = {1.0, 2.0}; + Embedding E2 = {1.0}; + EXPECT_DEATH(E1 + E2, "Vectors must have the same dimension"); +} + TEST(EmbeddingTest, MismatchedDimensionsAddVectors) { Embedding E1 = {1.0, 2.0}; Embedding E2 = {1.0}; _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits