https://github.com/svkeerthy updated https://github.com/llvm/llvm-project/pull/145118
>From f1976fa2454846d80822761f7a095b29c2062652 Mon Sep 17 00:00:00 2001 From: svkeerthy <venkatakeer...@google.com> Date: Fri, 20 Jun 2025 23:00:40 +0000 Subject: [PATCH] Overloading operator+ for Embeddngs --- llvm/include/llvm/Analysis/IR2Vec.h | 9 ++++-- llvm/lib/Analysis/IR2Vec.cpp | 23 +++++++++++++++ llvm/unittests/Analysis/IR2VecTest.cpp | 39 ++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 3 deletions(-) diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h index 040cb84ff27a1..d63be227b1849 100644 --- a/llvm/include/llvm/Analysis/IR2Vec.h +++ b/llvm/include/llvm/Analysis/IR2Vec.h @@ -107,9 +107,12 @@ struct Embedding { const std::vector<double> &getData() const { return Data; } /// Arithmetic operators - Embedding &operator+=(const Embedding &RHS); - Embedding &operator-=(const Embedding &RHS); - Embedding &operator*=(double Factor); + LLVM_ABI Embedding operator+(const Embedding &RHS) const; + LLVM_ABI Embedding &operator+=(const Embedding &RHS); + LLVM_ABI Embedding operator-(const Embedding &RHS) const; + LLVM_ABI Embedding &operator-=(const Embedding &RHS); + LLVM_ABI Embedding operator*(double Factor) const; + LLVM_ABI Embedding &operator*=(double Factor); /// Adds Src Embedding scaled by Factor with the called Embedding. /// Called_Embedding += Src * Factor diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp index 895b3de58a54e..e499ebdd5ed3c 100644 --- a/llvm/lib/Analysis/IR2Vec.cpp +++ b/llvm/lib/Analysis/IR2Vec.cpp @@ -71,6 +71,14 @@ inline bool fromJSON(const llvm::json::Value &E, Embedding &Out, // Embedding //===----------------------------------------------------------------------===// +Embedding Embedding::operator+(const Embedding &RHS) const { + assert(this->size() == RHS.size() && "Vectors must have the same dimension"); + Embedding Result(*this); + std::transform(this->begin(), this->end(), RHS.begin(), Result.begin(), + std::plus<double>()); + return Result; +} + Embedding &Embedding::operator+=(const Embedding &RHS) { assert(this->size() == RHS.size() && "Vectors must have the same dimension"); std::transform(this->begin(), this->end(), RHS.begin(), this->begin(), @@ -78,6 +86,14 @@ Embedding &Embedding::operator+=(const Embedding &RHS) { return *this; } +Embedding Embedding::operator-(const Embedding &RHS) const { + assert(this->size() == RHS.size() && "Vectors must have the same dimension"); + Embedding Result(*this); + std::transform(this->begin(), this->end(), RHS.begin(), Result.begin(), + std::minus<double>()); + return Result; +} + Embedding &Embedding::operator-=(const Embedding &RHS) { assert(this->size() == RHS.size() && "Vectors must have the same dimension"); std::transform(this->begin(), this->end(), RHS.begin(), this->begin(), @@ -85,6 +101,13 @@ Embedding &Embedding::operator-=(const Embedding &RHS) { return *this; } +Embedding Embedding::operator*(double Factor) const { + Embedding Result(*this); + std::transform(this->begin(), this->end(), Result.begin(), + [Factor](double Elem) { return Elem * Factor; }); + return Result; +} + Embedding &Embedding::operator*=(double Factor) { std::transform(this->begin(), this->end(), this->begin(), [Factor](double Elem) { return Elem * Factor; }); diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp index 3c97c20ae72d5..70d4808dc6d54 100644 --- a/llvm/unittests/Analysis/IR2VecTest.cpp +++ b/llvm/unittests/Analysis/IR2VecTest.cpp @@ -109,6 +109,18 @@ TEST(EmbeddingTest, ConstructorsAndAccessors) { } } +TEST(EmbeddingTest, AddVectorsOutOfPlace) { + Embedding E1 = {1.0, 2.0, 3.0}; + Embedding E2 = {0.5, 1.5, -1.0}; + + Embedding E3 = E1 + E2; + EXPECT_THAT(E3, ElementsAre(1.5, 3.5, 2.0)); + + // Check that E1 and E2 are unchanged + EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0)); + EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0)); +} + TEST(EmbeddingTest, AddVectors) { Embedding E1 = {1.0, 2.0, 3.0}; Embedding E2 = {0.5, 1.5, -1.0}; @@ -120,6 +132,18 @@ TEST(EmbeddingTest, AddVectors) { EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0)); } +TEST(EmbeddingTest, SubtractVectorsOutOfPlace) { + Embedding E1 = {1.0, 2.0, 3.0}; + Embedding E2 = {0.5, 1.5, -1.0}; + + Embedding E3 = E1 - E2; + EXPECT_THAT(E3, ElementsAre(0.5, 0.5, 4.0)); + + // Check that E1 and E2 are unchanged + EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0)); + EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0)); +} + TEST(EmbeddingTest, SubtractVectors) { Embedding E1 = {1.0, 2.0, 3.0}; Embedding E2 = {0.5, 1.5, -1.0}; @@ -137,6 +161,15 @@ TEST(EmbeddingTest, ScaleVector) { EXPECT_THAT(E1, ElementsAre(0.5, 1.0, 1.5)); } +TEST(EmbeddingTest, ScaleVectorOutOfPlace) { + Embedding E1 = {1.0, 2.0, 3.0}; + Embedding E2 = E1 * 0.5f; + EXPECT_THAT(E2, ElementsAre(0.5, 1.0, 1.5)); + + // Check that E1 is unchanged + EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0)); +} + TEST(EmbeddingTest, AddScaledVector) { Embedding E1 = {1.0, 2.0, 3.0}; Embedding E2 = {2.0, 0.5, -1.0}; @@ -180,6 +213,12 @@ TEST(EmbeddingTest, AccessOutOfBounds) { EXPECT_DEATH(E[4] = 4.0, "Index out of bounds"); } +TEST(EmbeddingTest, MismatchedDimensionsAddVectorsOutOfPlace) { + Embedding E1 = {1.0, 2.0}; + Embedding E2 = {1.0}; + EXPECT_DEATH(E1 + E2, "Vectors must have the same dimension"); +} + TEST(EmbeddingTest, MismatchedDimensionsAddVectors) { Embedding E1 = {1.0, 2.0}; Embedding E2 = {1.0}; _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits