This is an automated email from the ASF dual-hosted git repository.
sarutak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-rust.git
The following commit(s) were added to refs/heads/master by this push:
new f749758 [SPARK-53554] Fix `sum_distinct`
f749758 is described below
commit f749758e33e5f029e06542e775a0e441f1ec2f1d
Author: Kousuke Saruta <[email protected]>
AuthorDate: Wed Oct 1 21:13:14 2025 +0900
[SPARK-53554] Fix `sum_distinct`
### What changes were proposed in this pull request?
This PR aims to fix `sum_distinct`.
In the current spark-connect-rust, `sum_distinct` doesn't work correctly.
If we run the following code,
```
let df = spark
.sql("SELECT * FROM VALUES (1), (2), (3), (1), (2), (3) AS data(value)")
.await?;
df.select([sum_distinct(col("value")).alias("sum")])
.show(None, None, None)
.await?;
```
We will get the following error.
```
Error: AnalysisException("[UNRESOLVED_ROUTINE] Cannot resolve routine
`sum_distinct` on search path [`system`.`builtin`, `system`.`session`,
`spark_catalog`.`default`]. SQLSTATE: 42883")
```
### Why are the changes needed?
Bug fix.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Added new tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #8 from sarutak/fix-sum-distinct.
Authored-by: Kousuke Saruta <[email protected]>
Signed-off-by: Kousuke Saruta <[email protected]>
---
crates/connect/src/functions/mod.rs | 54 ++++++++++++++++++++++++++++++++++++-
1 file changed, 53 insertions(+), 1 deletion(-)
diff --git a/crates/connect/src/functions/mod.rs
b/crates/connect/src/functions/mod.rs
index 4884939..89892ba 100644
--- a/crates/connect/src/functions/mod.rs
+++ b/crates/connect/src/functions/mod.rs
@@ -1108,7 +1108,21 @@ gen_func!(stddev, [col: Column], "Alias for
stddev_samp.");
gen_func!(stddev_pop, [col: Column], "Returns population standard deviation of
the expression in a group.");
gen_func!(stddev_samp, [col: Column], "Returns the unbiased sample standard
deviation of the expression in a group.");
gen_func!(sum, [col: Column], "Returns the sum of all values in the
expression.");
-gen_func!(sum_distinct, [col: Column], "Returns the sum of distinct values in
the expression.");
+
+/// "Returns the sum of distinct values in the expression."
+pub fn sum_distinct(col: impl Into<Column>) -> Column {
+ Column::from(spark::Expression {
+ expr_type: Some(spark::expression::ExprType::UnresolvedFunction(
+ spark::expression::UnresolvedFunction {
+ function_name: "sum".to_string(),
+ arguments: VecExpression::from_iter(vec![col]).into(),
+ is_distinct: true,
+ is_user_defined_function: false,
+ },
+ )),
+ })
+}
+
gen_func!(var_pop, [col: Column], "Returns the population variance of the
values in a group.");
gen_func!(var_samp, [col: Column], "Returns the unbiased sample variance of
the values in a group.");
gen_func!(variance, [col: Column], "Alias for var_samp");
@@ -2546,4 +2560,42 @@ mod tests {
assert_eq!(expected, res);
Ok(())
}
+
+ // Test aggregate functions
+ #[tokio::test]
+ async fn test_func_sum_distinct() -> Result<(), SparkError> {
+ let spark = setup().await;
+ let select_func = |df: DataFrame| {
+ df.select([sum_distinct(col("value")).alias("sum")])
+ .collect()
+ };
+ let record_batch_func =
+ |col: ArrayRef|
RecordBatch::try_from_iter_with_nullable(vec![("sum", col, true)]);
+
+ let df = spark
+ .sql("SELECT * FROM VALUES (1), (2), (3), (1), (2), (3) AS
data(value)")
+ .await?;
+ let res = select_func(df).await?;
+ let expected_col: ArrayRef = Arc::new(Int64Array::from(vec![6]));
+ let expected = record_batch_func(expected_col)?;
+ assert_eq!(expected, res);
+
+ let df = spark
+ .sql("SELECT * FROM VALUES (1), (2), (3), (null), (1), (2), (3) AS
data(value)")
+ .await?;
+ let res = select_func(df).await?;
+ let expected_col: ArrayRef = Arc::new(Int64Array::from(vec![6]));
+ let expected = record_batch_func(expected_col)?;
+ assert_eq!(expected, res);
+
+ let df = spark
+ .sql("SELECT * FROM VALUES (null), (null), (null) AS data(value)")
+ .await?;
+ let res = select_func(df).await?;
+ let expected_col: ArrayRef = Arc::new(Float64Array::from(vec![None]));
+ let expected = record_batch_func(expected_col)?;
+ assert_eq!(expected, res);
+
+ Ok(())
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]