This is an automated email from the ASF dual-hosted git repository.

yangjie01 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d3a8b303c5c0 [SPARK-46787][CONNECT] `bloomFilter` function should 
throw `AnalysisException` for invalid input
d3a8b303c5c0 is described below

commit d3a8b303c5c056ec0863d20b33de6f1a5865dfae
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Jan 25 11:11:18 2024 +0800

    [SPARK-46787][CONNECT] `bloomFilter` function should throw 
`AnalysisException` for invalid input
    
    ### What changes were proposed in this pull request?
    `bloomFilter` function should throw `AnalysisException` for invalid input
    
    ### Why are the changes needed?
    
    1. `BloomFilterAggregate` itself validates the input, and throws meaningful 
errors. we should not handle those invalid input and throw `InvalidPlanInput` 
in Planner.
    2. to be consistent with vanilla Scala API and other functions
    
    ### Does this PR introduce _any_ user-facing change?
    yes, `InvalidPlanInput` -> `AnalysisException`
    
    ### How was this patch tested?
    updated CI
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #44821 from zhengruifeng/connect_bloom_filter_agg_error.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: yangjie01 <[email protected]>
---
 .../apache/spark/sql/DataFrameStatFunctions.scala  | 28 ++++------------------
 .../spark/sql/ClientDataFrameStatSuite.scala       | 20 ++++++++--------
 .../sql/connect/planner/SparkConnectPlanner.scala  | 25 +------------------
 3 files changed, 16 insertions(+), 57 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 4daa9fa88e66..4eef26da706f 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -22,7 +22,6 @@ import java.io.ByteArrayInputStream
 
 import scala.jdk.CollectionConverters._
 
-import org.apache.spark.SparkException
 import org.apache.spark.connect.proto.{Relation, StatSampleBy}
 import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder
 import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, 
BinaryEncoder, PrimitiveDoubleEncoder}
@@ -599,7 +598,7 @@ final class DataFrameStatFunctions private[sql] 
(sparkSession: SparkSession, roo
    * @since 3.5.0
    */
   def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double): 
BloomFilter = {
-    buildBloomFilter(Column(colName), expectedNumItems, -1L, fpp)
+    bloomFilter(Column(colName), expectedNumItems, fpp)
   }
 
   /**
@@ -614,7 +613,8 @@ final class DataFrameStatFunctions private[sql] 
(sparkSession: SparkSession, roo
    * @since 3.5.0
    */
   def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double): 
BloomFilter = {
-    buildBloomFilter(col, expectedNumItems, -1L, fpp)
+    val numBits = BloomFilter.optimalNumOfBits(expectedNumItems, fpp)
+    bloomFilter(col, expectedNumItems, numBits)
   }
 
   /**
@@ -629,7 +629,7 @@ final class DataFrameStatFunctions private[sql] 
(sparkSession: SparkSession, roo
    * @since 3.5.0
    */
   def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long): 
BloomFilter = {
-    buildBloomFilter(Column(colName), expectedNumItems, numBits, Double.NaN)
+    bloomFilter(Column(colName), expectedNumItems, numBits)
   }
 
   /**
@@ -644,25 +644,7 @@ final class DataFrameStatFunctions private[sql] 
(sparkSession: SparkSession, roo
    * @since 3.5.0
    */
   def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long): 
BloomFilter = {
-    buildBloomFilter(col, expectedNumItems, numBits, Double.NaN)
-  }
-
-  private def buildBloomFilter(
-      col: Column,
-      expectedNumItems: Long,
-      numBits: Long,
-      fpp: Double): BloomFilter = {
-    def numBitsValue: Long = if (!fpp.isNaN) {
-      BloomFilter.optimalNumOfBits(expectedNumItems, fpp)
-    } else {
-      numBits
-    }
-
-    if (fpp <= 0d || fpp >= 1d) {
-      throw new SparkException("False positive probability must be within 
range (0.0, 1.0)")
-    }
-    val agg = Column.fn("bloom_filter_agg", col, lit(expectedNumItems), 
lit(numBitsValue))
-
+    val agg = Column.fn("bloom_filter_agg", col, lit(expectedNumItems), 
lit(numBits))
     val ds = sparkSession.newDataset(BinaryEncoder) { builder =>
       builder.getProjectBuilder
         .setInput(root)
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala
index d0a89f672f75..299ff7ff4fe3 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala
@@ -21,7 +21,7 @@ import java.util.Random
 
 import org.scalatest.matchers.must.Matchers._
 
-import org.apache.spark.{SparkException, SparkIllegalArgumentException}
+import org.apache.spark.SparkIllegalArgumentException
 import org.apache.spark.sql.test.RemoteSparkSession
 
 class ClientDataFrameStatSuite extends RemoteSparkSession {
@@ -248,19 +248,19 @@ class ClientDataFrameStatSuite extends RemoteSparkSession 
{
 
   test("Bloom filter test invalid inputs") {
     val df = spark.range(1000).toDF("id")
-    val message1 = intercept[SparkException] {
+    val error1 = intercept[AnalysisException] {
       df.stat.bloomFilter("id", -1000, 100)
-    }.getMessage
-    assert(message1.contains("Expected insertions must be positive"))
+    }
+    assert(error1.getErrorClass === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE")
 
-    val message2 = intercept[SparkException] {
+    val error2 = intercept[AnalysisException] {
       df.stat.bloomFilter("id", 1000, -100)
-    }.getMessage
-    assert(message2.contains("Number of bits must be positive"))
+    }
+    assert(error2.getErrorClass === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE")
 
-    val message3 = intercept[SparkException] {
+    val error3 = intercept[AnalysisException] {
       df.stat.bloomFilter("id", 1000, -1.0)
-    }.getMessage
-    assert(message3.contains("False positive probability must be within range 
(0.0, 1.0)"))
+    }
+    assert(error3.getErrorClass === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE")
   }
 }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index e7c47856d9ae..3e59b2644755 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1805,31 +1805,8 @@ class SparkConnectPlanner(
       case "bloom_filter_agg" if fun.getArgumentsCount == 3 =>
         // [col, expectedNumItems: Long, numBits: Long]
         val children = fun.getArgumentsList.asScala.map(transformExpression)
-
-        // Check expectedNumItems is LongType and value greater than 0L
-        val expectedNumItemsExpr = children(1)
-        val expectedNumItems = expectedNumItemsExpr match {
-          case Literal(l: Long, LongType) => l
-          case _ =>
-            throw InvalidPlanInput("Expected insertions must be long literal.")
-        }
-        if (expectedNumItems <= 0L) {
-          throw InvalidPlanInput("Expected insertions must be positive.")
-        }
-
-        val numBitsExpr = children(2)
-        // Check numBits is LongType and value greater than 0L
-        numBitsExpr match {
-          case Literal(numBits: Long, LongType) =>
-            if (numBits <= 0L) {
-              throw InvalidPlanInput("Number of bits must be positive.")
-            }
-          case _ =>
-            throw InvalidPlanInput("Number of bits must be long literal.")
-        }
-
         Some(
-          new BloomFilterAggregate(children.head, expectedNumItemsExpr, 
numBitsExpr)
+          new BloomFilterAggregate(children(0), children(1), children(2))
             .toAggregateExpression())
 
       case "window" if Seq(2, 3, 4).contains(fun.getArgumentsCount) =>


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to