This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a3b918eb94c1 [SPARK-49540][PS] Unify the usage of
`distributed_sequence_id`
a3b918eb94c1 is described below
commit a3b918eb94c1ad49bf8bdfddf31d40a346e0fafb
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Sep 9 12:26:23 2024 +0800
[SPARK-49540][PS] Unify the usage of `distributed_sequence_id`
### What changes were proposed in this pull request?
in PySpark Classic, it was used via a dataframe method
`withSequenceColumn`, while in PySpark Connect, it was used as an internal
function
This PR unifies the usage of `distributed_sequence_id`
### Why are the changes needed?
code refactoring
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
updated tests
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #48028 from zhengruifeng/func_withSequenceColumn.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/pandas/internal.py | 18 +++++-------------
python/pyspark/pandas/spark/functions.py | 12 ++++++++++++
.../src/main/scala/org/apache/spark/sql/Dataset.scala | 8 --------
.../apache/spark/sql/api/python/PythonSQLUtils.scala | 7 +++++++
.../org/apache/spark/sql/DataFrameSelfJoinSuite.scala | 5 +++--
.../scala/org/apache/spark/sql/DataFrameSuite.scala | 4 +++-
6 files changed, 30 insertions(+), 24 deletions(-)
diff --git a/python/pyspark/pandas/internal.py
b/python/pyspark/pandas/internal.py
index 92d4a3357319..4be345201ba6 100644
--- a/python/pyspark/pandas/internal.py
+++ b/python/pyspark/pandas/internal.py
@@ -43,6 +43,7 @@ from pyspark.sql.types import ( # noqa: F401
)
from pyspark.sql.utils import is_timestamp_ntz_preferred, is_remote
from pyspark import pandas as ps
+from pyspark.pandas.spark import functions as SF
from pyspark.pandas._typing import Label
from pyspark.pandas.spark.utils import as_nullable_spark_type,
force_decimal_precision_scale
from pyspark.pandas.data_type_ops.base import DataTypeOps
@@ -938,19 +939,10 @@ class InternalFrame:
+--------+---+
"""
if len(sdf.columns) > 0:
- if is_remote():
- from pyspark.sql.connect.column import Column as ConnectColumn
- from pyspark.sql.connect.expressions import
DistributedSequenceID
-
- return sdf.select(
- ConnectColumn(DistributedSequenceID()).alias(column_name),
- "*",
- )
- else:
- return PySparkDataFrame(
- sdf._jdf.toDF().withSequenceColumn(column_name),
- sdf.sparkSession,
- )
+ return sdf.select(
+ SF.distributed_sequence_id().alias(column_name),
+ "*",
+ )
else:
cnt = sdf.count()
if cnt > 0:
diff --git a/python/pyspark/pandas/spark/functions.py
b/python/pyspark/pandas/spark/functions.py
index 6aaa63956c14..4bcf07f6f650 100644
--- a/python/pyspark/pandas/spark/functions.py
+++ b/python/pyspark/pandas/spark/functions.py
@@ -174,6 +174,18 @@ def null_index(col: Column) -> Column:
return Column(sc._jvm.PythonSQLUtils.nullIndex(col._jc))
+def distributed_sequence_id() -> Column:
+ if is_remote():
+ from pyspark.sql.connect.functions.builtin import _invoke_function
+
+ return _invoke_function("distributed_sequence_id")
+ else:
+ from pyspark import SparkContext
+
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.PythonSQLUtils.distributed_sequence_id())
+
+
def collect_top_k(col: Column, num: int, reverse: bool) -> Column:
if is_remote():
from pyspark.sql.connect.functions.builtin import
_invoke_function_over_columns
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 870571b533d0..0fab60a94842 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2010,14 +2010,6 @@ class Dataset[T] private[sql](
// For Python API
////////////////////////////////////////////////////////////////////////////
- /**
- * It adds a new long column with the name `name` that increases one by one.
- * This is for 'distributed-sequence' default index in pandas API on Spark.
- */
- private[sql] def withSequenceColumn(name: String) = {
- select(column(DistributedSequenceID()).alias(name), col("*"))
- }
-
/**
* Converts a JavaRDD to a PythonRDD.
*/
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index 7dbc586f6473..93082740cca6 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -176,6 +176,13 @@ private[sql] object PythonSQLUtils extends Logging {
def pandasCovar(col1: Column, col2: Column, ddof: Int): Column =
Column.internalFn("pandas_covar", col1, col2, lit(ddof))
+ /**
+ * A long column that increases one by one.
+ * This is for 'distributed-sequence' default index in pandas API on Spark.
+ */
+ def distributed_sequence_id(): Column =
+ Column.internalFn("distributed_sequence_id")
+
def unresolvedNamedLambdaVariable(name: String): Column =
Column(internal.UnresolvedNamedLambdaVariable.apply(name))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
index 310b5a62c908..d888b09d76ea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSelfJoinSuite.scala
@@ -18,10 +18,11 @@
package org.apache.spark.sql
import org.apache.spark.api.python.PythonEvalType
+import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending,
AttributeReference, PythonUDF, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Expand, Generate,
ScriptInputOutputSchema, ScriptTransformation, Window => WindowPlan}
import org.apache.spark.sql.expressions.Window
-import org.apache.spark.sql.functions.{count, explode, sum, year}
+import org.apache.spark.sql.functions.{col, count, explode, sum, year}
import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
@@ -404,7 +405,7 @@ class DataFrameSelfJoinSuite extends QueryTest with
SharedSparkSession {
assertAmbiguousSelfJoin(df12.join(df11, df11("x") === df12("y")))
// Test for AttachDistributedSequence
- val df13 = df1.withSequenceColumn("seq")
+ val df13 = df1.select(distributed_sequence_id().alias("seq"), col("*"))
val df14 = df13.filter($"value" === "A2")
assertAmbiguousSelfJoin(df13.join(df14, df13("key1") === df14("key2")))
assertAmbiguousSelfJoin(df14.join(df13, df13("key1") === df14("key2")))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index b1c41033fd76..9bfbdda33c36 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -29,6 +29,7 @@ import org.scalatest.matchers.should.Matchers._
import org.apache.spark.SparkException
import org.apache.spark.api.python.PythonEvalType
+import org.apache.spark.sql.api.python.PythonSQLUtils.distributed_sequence_id
import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
@@ -2316,7 +2317,8 @@ class DataFrameSuite extends QueryTest
}
test("SPARK-36338: DataFrame.withSequenceColumn should append unique
sequence IDs") {
- val ids =
spark.range(10).repartition(5).withSequenceColumn("default_index")
+ val ids = spark.range(10).repartition(5).select(
+ distributed_sequence_id().alias("default_index"), col("id"))
assert(ids.collect().map(_.getLong(0)).toSet === Range(0, 10).toSet)
assert(ids.take(5).map(_.getLong(0)).toSet === Range(0, 5).toSet)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]