This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 17813f8b86fd [SPARK-50920][ML][PYTHON][CONNECT] Support NaiveBayes on 
Connect
17813f8b86fd is described below

commit 17813f8b86fd9c77066105c308990d6e76150771
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sun Jan 26 12:06:00 2025 +0800

    [SPARK-50920][ML][PYTHON][CONNECT] Support NaiveBayes on Connect
    
    ### What changes were proposed in this pull request?
    Support NaiveBayes on Connect
    
    ### Why are the changes needed?
    feature parity
    
    ### Does this PR introduce _any_ user-facing change?
    yes, new algorithm supported on connect
    
    ### How was this patch tested?
    added tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #49672 from zhengruifeng/ml_connect_nb.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
    (cherry picked from commit e0437e0021efceff9a76118de76807f0ddc26b43)
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../org/apache/spark/ml/linalg/Matrices.scala      |  2 +
 .../services/org.apache.spark.ml.Estimator         |  1 +
 .../services/org.apache.spark.ml.Transformer       |  1 +
 .../spark/ml/classification/NaiveBayes.scala       |  3 +
 python/pyspark/ml/tests/test_classification.py     | 64 +++++++++++++++++++++-
 .../org/apache/spark/sql/connect/ml/MLUtils.scala  |  1 +
 6 files changed, 71 insertions(+), 1 deletion(-)

diff --git 
a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala 
b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala
index a5ac0f24f385..ad8869f8a81f 100644
--- a/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/linalg/Matrices.scala
@@ -1048,6 +1048,8 @@ object SparseMatrix {
 @Since("2.0.0")
 object Matrices {
 
+  private[ml] val empty = new DenseMatrix(0, 0, Array.emptyDoubleArray)
+
   private[ml] def fromVectors(vectors: Seq[Vector]): Matrix = {
     val numRows = vectors.length
     val numCols = vectors.head.size
diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator 
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
index 1183f50ae7f3..5d811598095b 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Estimator
@@ -19,6 +19,7 @@
 # So register the supported estimator here if you're trying to add a new one.
 
 # classification
+org.apache.spark.ml.classification.NaiveBayes
 org.apache.spark.ml.classification.LinearSVC
 org.apache.spark.ml.classification.LogisticRegression
 org.apache.spark.ml.classification.DecisionTreeClassifier
diff --git 
a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer 
b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
index ce880bb2ef31..d2a8d6036d4e 100644
--- a/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
+++ b/mllib/src/main/resources/META-INF/services/org.apache.spark.ml.Transformer
@@ -35,6 +35,7 @@ org.apache.spark.ml.feature.HashingTF
 
 ########### Model for loading
 # classification
+org.apache.spark.ml.classification.NaiveBayesModel
 org.apache.spark.ml.classification.LinearSVCModel
 org.apache.spark.ml.classification.LogisticRegressionModel
 org.apache.spark.ml.classification.DecisionTreeClassificationModel
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 4a511581d31a..de2023899ee5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -401,6 +401,9 @@ class NaiveBayesModel private[ml] (
 
   import NaiveBayes._
 
+  private[ml] def this() = this(Identifiable.randomUID("nb"),
+    Vectors.empty, Matrices.empty, Matrices.empty)
+
   /**
    * mllib NaiveBayes is a wrapper of ml implementation currently.
    * Input labels of mllib could be {-1, +1} and mllib NaiveBayesModel exposes 
labels,
diff --git a/python/pyspark/ml/tests/test_classification.py 
b/python/pyspark/ml/tests/test_classification.py
index bcf376007198..8ee2dcac5c12 100644
--- a/python/pyspark/ml/tests/test_classification.py
+++ b/python/pyspark/ml/tests/test_classification.py
@@ -22,8 +22,10 @@ from shutil import rmtree
 import numpy as np
 
 from pyspark.ml.linalg import Vectors, Matrices
-from pyspark.sql import SparkSession, DataFrame
+from pyspark.sql import SparkSession, DataFrame, Row
 from pyspark.ml.classification import (
+    NaiveBayes,
+    NaiveBayesModel,
     LinearSVC,
     LinearSVCModel,
     LinearSVCSummary,
@@ -46,6 +48,66 @@ from pyspark.ml.classification import (
 
 
 class ClassificationTestsMixin:
+    def test_naive_bayes(self):
+        spark = self.spark
+        df = spark.createDataFrame(
+            [
+                Row(label=0.0, weight=0.1, features=Vectors.dense([0.0, 0.0])),
+                Row(label=0.0, weight=0.5, features=Vectors.dense([0.0, 1.0])),
+                Row(label=1.0, weight=1.0, features=Vectors.dense([1.0, 0.0])),
+            ]
+        )
+
+        nb = NaiveBayes(smoothing=1.0, modelType="multinomial", 
weightCol="weight")
+        self.assertEqual(nb.getSmoothing(), 1.0)
+        self.assertEqual(nb.getModelType(), "multinomial")
+        self.assertEqual(nb.getWeightCol(), "weight")
+
+        model = nb.fit(df)
+        self.assertEqual(model.numClasses, 2)
+        self.assertEqual(model.numFeatures, 2)
+        self.assertTrue(
+            np.allclose(model.pi.toArray(), [-0.81093022, -0.58778666], 
atol=1e-4), model.pi
+        )
+        self.assertTrue(
+            np.allclose(
+                model.theta.toArray(),
+                [[-0.91629073, -0.51082562], [-0.40546511, -1.09861229]],
+                atol=1e-4,
+            ),
+            model.theta,
+        )
+        self.assertTrue(np.allclose(model.sigma.toArray(), [], atol=1e-4), 
model.sigma)
+
+        vec = Vectors.dense(0.0, 5.0)
+        self.assertEqual(model.predict(vec), 0.0)
+        pred = model.predictRaw(vec)
+        self.assertTrue(np.allclose(pred.toArray(), [-3.36505834, 
-6.08084811], atol=1e-4), pred)
+        pred = model.predictProbability(vec)
+        self.assertTrue(np.allclose(pred.toArray(), [0.93795196, 0.06204804], 
atol=1e-4), pred)
+
+        output = model.transform(df)
+        expected_cols = [
+            "label",
+            "weight",
+            "features",
+            "rawPrediction",
+            "probability",
+            "prediction",
+        ]
+        self.assertEqual(output.columns, expected_cols)
+        self.assertEqual(output.count(), 3)
+
+        # Model save & load
+        with tempfile.TemporaryDirectory(prefix="naive_bayes") as d:
+            nb.write().overwrite().save(d)
+            nb2 = NaiveBayes.load(d)
+            self.assertEqual(str(nb), str(nb2))
+
+            model.write().overwrite().save(d)
+            model2 = NaiveBayesModel.load(d)
+            self.assertEqual(str(model), str(model2))
+
     def test_binomial_logistic_regression_with_bound(self):
         df = self.spark.createDataFrame(
             [
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
index cd6e13f33d2b..d6e13d301c7e 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLUtils.scala
@@ -523,6 +523,7 @@ private[ml] object MLUtils {
     (classOf[GBTRegressionModel], Set("featureImportances", 
"evaluateEachIteration")),
 
     // Classification Models
+    (classOf[NaiveBayesModel], Set("pi", "theta", "sigma")),
     (classOf[LinearSVCModel], Set("intercept", "coefficients", "evaluate")),
     (
       classOf[LogisticRegressionModel],


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to