This is an automated email from the ASF dual-hosted git repository.

sarutak pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark-connect-rust.git


The following commit(s) were added to refs/heads/master by this push:
     new f749758  [SPARK-53554] Fix `sum_distinct`
f749758 is described below

commit f749758e33e5f029e06542e775a0e441f1ec2f1d
Author: Kousuke Saruta <[email protected]>
AuthorDate: Wed Oct 1 21:13:14 2025 +0900

    [SPARK-53554] Fix `sum_distinct`
    
    ### What changes were proposed in this pull request?
    This PR aims to fix `sum_distinct`.
    In the current spark-connect-rust, `sum_distinct` doesn't work correctly.
    If we run the following code,
    
    ```
    let df = spark
        .sql("SELECT * FROM VALUES (1), (2), (3), (1), (2), (3) AS data(value)")
        .await?;
    df.select([sum_distinct(col("value")).alias("sum")])
        .show(None, None, None)
        .await?;
    ```
    
    We will get the following error.
    ```
    Error: AnalysisException("[UNRESOLVED_ROUTINE] Cannot resolve routine 
`sum_distinct` on search path [`system`.`builtin`, `system`.`session`, 
`spark_catalog`.`default`]. SQLSTATE: 42883")
    ```
    
    ### Why are the changes needed?
    Bug fix.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    
    Added new tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #8 from sarutak/fix-sum-distinct.
    
    Authored-by: Kousuke Saruta <[email protected]>
    Signed-off-by: Kousuke Saruta <[email protected]>
---
 crates/connect/src/functions/mod.rs | 54 ++++++++++++++++++++++++++++++++++++-
 1 file changed, 53 insertions(+), 1 deletion(-)

diff --git a/crates/connect/src/functions/mod.rs 
b/crates/connect/src/functions/mod.rs
index 4884939..89892ba 100644
--- a/crates/connect/src/functions/mod.rs
+++ b/crates/connect/src/functions/mod.rs
@@ -1108,7 +1108,21 @@ gen_func!(stddev, [col: Column], "Alias for 
stddev_samp.");
 gen_func!(stddev_pop, [col: Column], "Returns population standard deviation of 
the expression in a group.");
 gen_func!(stddev_samp, [col: Column], "Returns the unbiased sample standard 
deviation of the expression in a group.");
 gen_func!(sum, [col: Column], "Returns the sum of all values in the 
expression.");
-gen_func!(sum_distinct, [col: Column], "Returns the sum of distinct values in 
the expression.");
+
+/// "Returns the sum of distinct values in the expression."
+pub fn sum_distinct(col: impl Into<Column>) -> Column {
+    Column::from(spark::Expression {
+        expr_type: Some(spark::expression::ExprType::UnresolvedFunction(
+            spark::expression::UnresolvedFunction {
+                function_name: "sum".to_string(),
+                arguments: VecExpression::from_iter(vec![col]).into(),
+                is_distinct: true,
+                is_user_defined_function: false,
+            },
+        )),
+    })
+}
+
 gen_func!(var_pop, [col: Column], "Returns the population variance of the 
values in a group.");
 gen_func!(var_samp, [col: Column], "Returns the unbiased sample variance of 
the values in a group.");
 gen_func!(variance, [col: Column], "Alias for var_samp");
@@ -2546,4 +2560,42 @@ mod tests {
         assert_eq!(expected, res);
         Ok(())
     }
+
+    // Test aggregate functions
+    #[tokio::test]
+    async fn test_func_sum_distinct() -> Result<(), SparkError> {
+        let spark = setup().await;
+        let select_func = |df: DataFrame| {
+            df.select([sum_distinct(col("value")).alias("sum")])
+                .collect()
+        };
+        let record_batch_func =
+            |col: ArrayRef| 
RecordBatch::try_from_iter_with_nullable(vec![("sum", col, true)]);
+
+        let df = spark
+            .sql("SELECT * FROM VALUES (1), (2), (3), (1), (2), (3) AS 
data(value)")
+            .await?;
+        let res = select_func(df).await?;
+        let expected_col: ArrayRef = Arc::new(Int64Array::from(vec![6]));
+        let expected = record_batch_func(expected_col)?;
+        assert_eq!(expected, res);
+
+        let df = spark
+            .sql("SELECT * FROM VALUES (1), (2), (3), (null), (1), (2), (3) AS 
data(value)")
+            .await?;
+        let res = select_func(df).await?;
+        let expected_col: ArrayRef = Arc::new(Int64Array::from(vec![6]));
+        let expected = record_batch_func(expected_col)?;
+        assert_eq!(expected, res);
+
+        let df = spark
+            .sql("SELECT * FROM VALUES (null), (null), (null) AS data(value)")
+            .await?;
+        let res = select_func(df).await?;
+        let expected_col: ArrayRef = Arc::new(Float64Array::from(vec![None]));
+        let expected = record_batch_func(expected_col)?;
+        assert_eq!(expected, res);
+
+        Ok(())
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to