This is an automated email from the ASF dual-hosted git repository.
yangjie01 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d3a8b303c5c0 [SPARK-46787][CONNECT] `bloomFilter` function should
throw `AnalysisException` for invalid input
d3a8b303c5c0 is described below
commit d3a8b303c5c056ec0863d20b33de6f1a5865dfae
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Jan 25 11:11:18 2024 +0800
[SPARK-46787][CONNECT] `bloomFilter` function should throw
`AnalysisException` for invalid input
### What changes were proposed in this pull request?
`bloomFilter` function should throw `AnalysisException` for invalid input
### Why are the changes needed?
1. `BloomFilterAggregate` itself validates the input, and throws meaningful
errors. we should not handle those invalid input and throw `InvalidPlanInput`
in Planner.
2. to be consistent with vanilla Scala API and other functions
### Does this PR introduce _any_ user-facing change?
yes, `InvalidPlanInput` -> `AnalysisException`
### How was this patch tested?
updated CI
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #44821 from zhengruifeng/connect_bloom_filter_agg_error.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: yangjie01 <[email protected]>
---
.../apache/spark/sql/DataFrameStatFunctions.scala | 28 ++++------------------
.../spark/sql/ClientDataFrameStatSuite.scala | 20 ++++++++--------
.../sql/connect/planner/SparkConnectPlanner.scala | 25 +------------------
3 files changed, 16 insertions(+), 57 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 4daa9fa88e66..4eef26da706f 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -22,7 +22,6 @@ import java.io.ByteArrayInputStream
import scala.jdk.CollectionConverters._
-import org.apache.spark.SparkException
import org.apache.spark.connect.proto.{Relation, StatSampleBy}
import org.apache.spark.sql.DataFrameStatFunctions.approxQuantileResultEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder,
BinaryEncoder, PrimitiveDoubleEncoder}
@@ -599,7 +598,7 @@ final class DataFrameStatFunctions private[sql]
(sparkSession: SparkSession, roo
* @since 3.5.0
*/
def bloomFilter(colName: String, expectedNumItems: Long, fpp: Double):
BloomFilter = {
- buildBloomFilter(Column(colName), expectedNumItems, -1L, fpp)
+ bloomFilter(Column(colName), expectedNumItems, fpp)
}
/**
@@ -614,7 +613,8 @@ final class DataFrameStatFunctions private[sql]
(sparkSession: SparkSession, roo
* @since 3.5.0
*/
def bloomFilter(col: Column, expectedNumItems: Long, fpp: Double):
BloomFilter = {
- buildBloomFilter(col, expectedNumItems, -1L, fpp)
+ val numBits = BloomFilter.optimalNumOfBits(expectedNumItems, fpp)
+ bloomFilter(col, expectedNumItems, numBits)
}
/**
@@ -629,7 +629,7 @@ final class DataFrameStatFunctions private[sql]
(sparkSession: SparkSession, roo
* @since 3.5.0
*/
def bloomFilter(colName: String, expectedNumItems: Long, numBits: Long):
BloomFilter = {
- buildBloomFilter(Column(colName), expectedNumItems, numBits, Double.NaN)
+ bloomFilter(Column(colName), expectedNumItems, numBits)
}
/**
@@ -644,25 +644,7 @@ final class DataFrameStatFunctions private[sql]
(sparkSession: SparkSession, roo
* @since 3.5.0
*/
def bloomFilter(col: Column, expectedNumItems: Long, numBits: Long):
BloomFilter = {
- buildBloomFilter(col, expectedNumItems, numBits, Double.NaN)
- }
-
- private def buildBloomFilter(
- col: Column,
- expectedNumItems: Long,
- numBits: Long,
- fpp: Double): BloomFilter = {
- def numBitsValue: Long = if (!fpp.isNaN) {
- BloomFilter.optimalNumOfBits(expectedNumItems, fpp)
- } else {
- numBits
- }
-
- if (fpp <= 0d || fpp >= 1d) {
- throw new SparkException("False positive probability must be within
range (0.0, 1.0)")
- }
- val agg = Column.fn("bloom_filter_agg", col, lit(expectedNumItems),
lit(numBitsValue))
-
+ val agg = Column.fn("bloom_filter_agg", col, lit(expectedNumItems),
lit(numBits))
val ds = sparkSession.newDataset(BinaryEncoder) { builder =>
builder.getProjectBuilder
.setInput(root)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala
index d0a89f672f75..299ff7ff4fe3 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientDataFrameStatSuite.scala
@@ -21,7 +21,7 @@ import java.util.Random
import org.scalatest.matchers.must.Matchers._
-import org.apache.spark.{SparkException, SparkIllegalArgumentException}
+import org.apache.spark.SparkIllegalArgumentException
import org.apache.spark.sql.test.RemoteSparkSession
class ClientDataFrameStatSuite extends RemoteSparkSession {
@@ -248,19 +248,19 @@ class ClientDataFrameStatSuite extends RemoteSparkSession
{
test("Bloom filter test invalid inputs") {
val df = spark.range(1000).toDF("id")
- val message1 = intercept[SparkException] {
+ val error1 = intercept[AnalysisException] {
df.stat.bloomFilter("id", -1000, 100)
- }.getMessage
- assert(message1.contains("Expected insertions must be positive"))
+ }
+ assert(error1.getErrorClass === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE")
- val message2 = intercept[SparkException] {
+ val error2 = intercept[AnalysisException] {
df.stat.bloomFilter("id", 1000, -100)
- }.getMessage
- assert(message2.contains("Number of bits must be positive"))
+ }
+ assert(error2.getErrorClass === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE")
- val message3 = intercept[SparkException] {
+ val error3 = intercept[AnalysisException] {
df.stat.bloomFilter("id", 1000, -1.0)
- }.getMessage
- assert(message3.contains("False positive probability must be within range
(0.0, 1.0)"))
+ }
+ assert(error3.getErrorClass === "DATATYPE_MISMATCH.VALUE_OUT_OF_RANGE")
}
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index e7c47856d9ae..3e59b2644755 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -1805,31 +1805,8 @@ class SparkConnectPlanner(
case "bloom_filter_agg" if fun.getArgumentsCount == 3 =>
// [col, expectedNumItems: Long, numBits: Long]
val children = fun.getArgumentsList.asScala.map(transformExpression)
-
- // Check expectedNumItems is LongType and value greater than 0L
- val expectedNumItemsExpr = children(1)
- val expectedNumItems = expectedNumItemsExpr match {
- case Literal(l: Long, LongType) => l
- case _ =>
- throw InvalidPlanInput("Expected insertions must be long literal.")
- }
- if (expectedNumItems <= 0L) {
- throw InvalidPlanInput("Expected insertions must be positive.")
- }
-
- val numBitsExpr = children(2)
- // Check numBits is LongType and value greater than 0L
- numBitsExpr match {
- case Literal(numBits: Long, LongType) =>
- if (numBits <= 0L) {
- throw InvalidPlanInput("Number of bits must be positive.")
- }
- case _ =>
- throw InvalidPlanInput("Number of bits must be long literal.")
- }
-
Some(
- new BloomFilterAggregate(children.head, expectedNumItemsExpr,
numBitsExpr)
+ new BloomFilterAggregate(children(0), children(1), children(2))
.toAggregateExpression())
case "window" if Seq(2, 3, 4).contains(fun.getArgumentsCount) =>
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]