This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3c57180038f [SPARK-41772][CONNECT][PYTHON] Fix incorrect column name
in `withField`'s doctest
3c57180038f is described below
commit 3c57180038f8ddfcc184a74a5e387a3637e01cc2
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sun Jan 22 17:53:43 2023 +0800
[SPARK-41772][CONNECT][PYTHON] Fix incorrect column name in `withField`'s
doctest
### What changes were proposed in this pull request?
Fix incorrect column name in `withField`'s doctest
```
pyspark.sql.connect.column.Column.withField
Failed example:
df.withColumn('a', df['a'].withField('b', lit(3))).select('a.b').show()
Expected:
+---+
| b|
+---+
| 3|
+---+
Got:
+---+
|a.b|
+---+
| 3|
+---+
<BLANKLINE>
```
### Why are the changes needed?
for parity
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
added UT and enabled doctest
Closes #39699 from zhengruifeng/connect_fix_41772.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../sql/connect/planner/SparkConnectPlanner.scala | 19 ++++++++------
python/pyspark/sql/connect/column.py | 3 ---
.../sql/tests/connect/test_connect_column.py | 29 +++++++++++++++-------
3 files changed, 31 insertions(+), 20 deletions(-)
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index f65fc2c8d0f..f95f065c5b3 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -691,9 +691,12 @@ class SparkConnectPlanner(val session: SparkSession) {
} else {
logical.OneRowRelation()
}
- val projection =
-
rel.getExpressionsList.asScala.map(transformExpression).map(UnresolvedAlias(_))
- logical.Project(projectList = projection.toSeq, child = baseRel)
+
+ val projection = rel.getExpressionsList.asScala.toSeq
+ .map(transformExpression)
+ .map(toNamedExpression)
+
+ logical.Project(projectList = projection, child = baseRel)
}
private def transformUnresolvedExpression(exp: proto.Expression):
UnresolvedAttribute = {
@@ -745,6 +748,11 @@ class SparkConnectPlanner(val session: SparkSession) {
}
}
+ private def toNamedExpression(expr: Expression): NamedExpression = expr
match {
+ case named: NamedExpression => named
+ case expr => UnresolvedAlias(expr)
+ }
+
private def transformExpressionPlugin(extension: ProtoAny): Expression = {
SparkConnectPluginRegistry.expressionRegistry
// Lazily traverse the collection.
@@ -1245,11 +1253,6 @@ class SparkConnectPlanner(val session: SparkSession) {
}
val input = transformRelation(rel.getInput)
- def toNamedExpression(expr: Expression): NamedExpression = expr match {
- case named: NamedExpression => named
- case expr => UnresolvedAlias(expr)
- }
-
val groupingExprs =
rel.getGroupingExpressionsList.asScala.toSeq.map(transformExpression)
val aggExprs =
rel.getAggregateExpressionsList.asScala.toSeq.map(transformExpression)
val aliasedAgg = (groupingExprs ++ aggExprs).map(toNamedExpression)
diff --git a/python/pyspark/sql/connect/column.py
b/python/pyspark/sql/connect/column.py
index d2c334ae67f..44200e21495 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -439,9 +439,6 @@ def _test() -> None:
.getOrCreate()
)
- # TODO(SPARK-41772): Enable
pyspark.sql.connect.column.Column.withField doctest
- del pyspark.sql.connect.column.Column.withField.__doc__
-
(failure_count, test_count) = doctest.testmod(
pyspark.sql.connect.column,
globs=globs,
diff --git a/python/pyspark/sql/tests/connect/test_connect_column.py
b/python/pyspark/sql/tests/connect/test_connect_column.py
index ffee64706d5..1e0609c480c 100644
--- a/python/pyspark/sql/tests/connect/test_connect_column.py
+++ b/python/pyspark/sql/tests/connect/test_connect_column.py
@@ -33,6 +33,7 @@ from pyspark.sql.connect.types import (
)
from pyspark.sql.types import (
+ Row,
StructField,
StructType,
ArrayType,
@@ -58,7 +59,8 @@ from pyspark.sql.connect.client import SparkConnectException
if should_test_connect:
import pandas as pd
- from pyspark.sql.connect.functions import lit
+ from pyspark.sql import functions as SF
+ from pyspark.sql.connect import functions as CF
class SparkConnectColumnTests(SparkConnectSQLTestCase):
@@ -83,7 +85,7 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
def test_column_operator(self):
# SPARK-41351: Column needs to support !=
df = self.connect.range(10)
- self.assertEqual(9, len(df.filter(df.id != lit(1)).collect()))
+ self.assertEqual(9, len(df.filter(df.id != CF.lit(1)).collect()))
def test_columns(self):
# SPARK-41036: test `columns` API for python client.
@@ -133,8 +135,6 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
def test_column_with_null(self):
# SPARK-41751: test isNull, isNotNull, eqNullSafe
- from pyspark.sql import functions as SF
- from pyspark.sql.connect import functions as CF
query = """
SELECT * FROM VALUES
@@ -313,9 +313,6 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
def test_none(self):
# SPARK-41783: test none
- from pyspark.sql import functions as SF
- from pyspark.sql.connect import functions as CF
-
query = """
SELECT * FROM VALUES
(1, 1, NULL), (2, NULL, 1), (NULL, 3, 4)
@@ -348,8 +345,10 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
def test_simple_binary_expressions(self):
"""Test complex expression"""
- df = self.connect.read.table(self.tbl_name)
- pdf = df.select(df.id).where(df.id % lit(30) ==
lit(0)).sort(df.id.asc()).toPandas()
+ cdf = self.connect.read.table(self.tbl_name)
+ pdf = (
+ cdf.select(cdf.id).where(cdf.id % CF.lit(30) ==
CF.lit(0)).sort(cdf.id.asc()).toPandas()
+ )
self.assertEqual(len(pdf.index), 4)
res = pd.DataFrame(data={"id": [0, 30, 60, 90]})
@@ -964,6 +963,18 @@ class SparkConnectColumnTests(SparkConnectSQLTestCase):
).toPandas(),
)
+ def test_with_field_column_name(self):
+ data = [Row(a=Row(b=1, c=2))]
+
+ cdf = self.connect.createDataFrame(data)
+ cdf1 = cdf.withColumn("a", cdf["a"].withField("b",
CF.lit(3))).select("a.b")
+
+ sdf = self.spark.createDataFrame(data)
+ sdf1 = sdf.withColumn("a", sdf["a"].withField("b",
SF.lit(3))).select("a.b")
+
+ self.assertEqual(cdf1.schema, sdf1.schema)
+ self.assertEqual(cdf1.collect(), sdf1.collect())
+
if __name__ == "__main__":
import os
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]