This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new d44fd2b57110 [SPARK-50694][SQL] Support renames in subqueries
d44fd2b57110 is described below
commit d44fd2b5711095a3ea39b6d6e0fcc0dbc7118727
Author: Takuya Ueshin <[email protected]>
AuthorDate: Mon Jan 6 12:27:47 2025 -0800
[SPARK-50694][SQL] Support renames in subqueries
### What changes were proposed in this pull request?
Supports renames in subqueries:
- `sub.toDF(...)`
- `sub.alias(...)`
### Why are the changes needed?
When the query is used as a subquery by adding `col.outer()`, `toDF` or
`alias` doesn't work because they need analyzed plans.
### Does this PR introduce _any_ user-facing change?
Yes, those APIs are available in subqueries.
### How was this patch tested?
Added / modified the related tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #49336 from ueshin/issues/SPARK-50694/renames.
Authored-by: Takuya Ueshin <[email protected]>
Signed-off-by: Takuya Ueshin <[email protected]>
---
.../apache/spark/sql/DataFrameSubquerySuite.scala | 37 ++++++++++-
.../sql/DataFrameTableValuedFunctionsSuite.scala | 73 +++++++++++-----------
python/pyspark/sql/dataframe.py | 3 +-
python/pyspark/sql/tests/test_subquery.py | 42 ++++++++++++-
python/pyspark/sql/tests/test_tvf.py | 58 +++++++++--------
.../sql/connect/planner/SparkConnectPlanner.scala | 9 ++-
.../main/scala/org/apache/spark/sql/Dataset.scala | 67 +++++++++++++-------
.../apache/spark/sql/DataFrameSubquerySuite.scala | 40 +++++++++++-
.../sql/DataFrameTableValuedFunctionsSuite.scala | 47 +++++++-------
9 files changed, 255 insertions(+), 121 deletions(-)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
index fc37444f7719..4b36d36983a5 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
@@ -354,6 +354,28 @@ class DataFrameSubquerySuite extends QueryTest with
RemoteSparkSession {
}
}
+ test("lateral join with star expansion") {
+ withView("t1", "t2") {
+ val t1 = table1()
+ val t2 = table2()
+
+ checkAnswer(
+ t1.lateralJoin(spark.range(1).select().select($"*")),
+ sql("SELECT * FROM t1, LATERAL (SELECT *)"))
+ checkAnswer(
+ t1.lateralJoin(t2.select($"*")).toDF("c1", "c2", "c3", "c4"),
+ sql("SELECT * FROM t1, LATERAL (SELECT * FROM t2)").toDF("c1", "c2",
"c3", "c4"))
+ checkAnswer(
+ t1.lateralJoin(t2.select($"t1.*".outer(), $"t2.*"))
+ .toDF("c1", "c2", "c3", "c4", "c5", "c6"),
+ sql("SELECT * FROM t1, LATERAL (SELECT t1.*, t2.* FROM t2)")
+ .toDF("c1", "c2", "c3", "c4", "c5", "c6"))
+ checkAnswer(
+ t1.lateralJoin(t2.alias("t1").select($"t1.*")).toDF("c1", "c2", "c3",
"c4"),
+ sql("SELECT * FROM t1, LATERAL (SELECT t1.* FROM t2 AS
t1)").toDF("c1", "c2", "c3", "c4"))
+ }
+ }
+
test("lateral join with different join types") {
withView("t1") {
val t1 = table1()
@@ -375,6 +397,17 @@ class DataFrameSubquerySuite extends QueryTest with
RemoteSparkSession {
}
}
+ test("lateral join with subquery alias") {
+ withView("t1") {
+ val t1 = table1()
+
+ checkAnswer(
+ t1.lateralJoin(spark.range(1).select($"c1".outer(),
$"c2".outer()).toDF("a", "b").as("s"))
+ .select("a", "b"),
+ sql("SELECT a, b FROM t1, LATERAL (SELECT c1, c2) s(a, b)"))
+ }
+ }
+
test("lateral join with correlated equality / non-equality predicates") {
withView("t1", "t2") {
val t1 = table1()
@@ -441,8 +474,8 @@ class DataFrameSubquerySuite extends QueryTest with
RemoteSparkSession {
val t2 = table2()
checkAnswer(
- t1.lateralJoin(t2.where($"t1.c1".outer() === $"t2.c1").select($"c2"),
"left")
- .join(t1.as("t3"), $"t2.c2" === $"t3.c2", "left")
+ t1.lateralJoin(t2.where($"t1.c1".outer() ===
$"t2.c1").select($"c2").as("s"), "left")
+ .join(t1.as("t3"), $"s.c2" === $"t3.c2", "left")
.toDF("c1", "c2", "c3", "c4", "c5"),
sql("""
|SELECT * FROM t1
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
index aeef2e8f0fcf..12a49ad21676 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
@@ -61,10 +61,11 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest
with RemoteSparkSessi
val t3 = spark.table("t3")
checkAnswer(
- t1.lateralJoin(spark.tvf.explode(array($"c1".outer(), $"c2".outer()))),
+ t1.lateralJoin(
+ spark.tvf.explode(array($"c1".outer(),
$"c2".outer())).toDF("c3").as("t2")),
sql("SELECT * FROM t1, LATERAL EXPLODE(ARRAY(c1, c2)) t2(c3)"))
checkAnswer(
- t3.lateralJoin(spark.tvf.explode($"c2".outer())),
+ t3.lateralJoin(spark.tvf.explode($"c2".outer()).toDF("v").as("t2")),
sql("SELECT * FROM t3, LATERAL EXPLODE(c2) t2(v)"))
checkAnswer(
spark.tvf
@@ -113,10 +114,11 @@ class DataFrameTableValuedFunctionsSuite extends
QueryTest with RemoteSparkSessi
val t3 = spark.table("t3")
checkAnswer(
- t1.lateralJoin(spark.tvf.explode_outer(array($"c1".outer(),
$"c2".outer()))),
+ t1.lateralJoin(
+ spark.tvf.explode_outer(array($"c1".outer(),
$"c2".outer())).toDF("c3").as("t2")),
sql("SELECT * FROM t1, LATERAL EXPLODE_OUTER(ARRAY(c1, c2)) t2(c3)"))
checkAnswer(
- t3.lateralJoin(spark.tvf.explode_outer($"c2".outer())),
+
t3.lateralJoin(spark.tvf.explode_outer($"c2".outer()).toDF("v").as("t2")),
sql("SELECT * FROM t3, LATERAL EXPLODE_OUTER(c2) t2(v)"))
checkAnswer(
spark.tvf
@@ -161,7 +163,10 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest
with RemoteSparkSessi
arrayStruct.lateralJoin(spark.tvf.inline($"arr".outer())),
sql("SELECT * FROM array_struct JOIN LATERAL INLINE(arr)"))
checkAnswer(
- arrayStruct.lateralJoin(spark.tvf.inline($"arr".outer()), $"id" ===
$"col1", "left"),
+ arrayStruct.lateralJoin(
+ spark.tvf.inline($"arr".outer()).toDF("k", "v").as("t"),
+ $"id" === $"k",
+ "left"),
sql("SELECT * FROM array_struct LEFT JOIN LATERAL INLINE(arr) t(k, v)
ON id = k"))
}
}
@@ -202,8 +207,8 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest
with RemoteSparkSessi
sql("SELECT * FROM array_struct JOIN LATERAL INLINE_OUTER(arr)"))
checkAnswer(
arrayStruct.lateralJoin(
- spark.tvf.inline_outer($"arr".outer()),
- $"id" === $"col1",
+ spark.tvf.inline_outer($"arr".outer()).toDF("k", "v").as("t"),
+ $"id" === $"k",
"left"),
sql("SELECT * FROM array_struct LEFT JOIN LATERAL INLINE_OUTER(arr)
t(k, v) ON id = k"))
}
@@ -238,30 +243,27 @@ class DataFrameTableValuedFunctionsSuite extends
QueryTest with RemoteSparkSessi
jsonTable
.as("t1")
.lateralJoin(
- spark.tvf.json_tuple(
- $"t1.jstring".outer(),
- lit("f1"),
- lit("f2"),
- lit("f3"),
- lit("f4"),
- lit("f5")))
- .select($"key", $"c0", $"c1", $"c2", $"c3", $"c4"),
+ spark.tvf
+ .json_tuple(
+ $"t1.jstring".outer(),
+ lit("f1"),
+ lit("f2"),
+ lit("f3"),
+ lit("f4"),
+ lit("f5"))
+ .as("t2"))
+ .select($"t1.key", $"t2.*"),
sql(
"SELECT t1.key, t2.* FROM json_table t1, " +
"LATERAL json_tuple(t1.jstring, 'f1', 'f2', 'f3', 'f4', 'f5') t2"))
checkAnswer(
jsonTable
.as("t1")
- .lateralJoin(
- spark.tvf.json_tuple(
- $"jstring".outer(),
- lit("f1"),
- lit("f2"),
- lit("f3"),
- lit("f4"),
- lit("f5")))
- .where($"c0".isNotNull)
- .select($"key", $"c0", $"c1", $"c2", $"c3", $"c4"),
+ .lateralJoin(spark.tvf
+ .json_tuple($"jstring".outer(), lit("f1"), lit("f2"), lit("f3"),
lit("f4"), lit("f5"))
+ .as("t2"))
+ .where($"t2.c0".isNotNull)
+ .select($"t1.key", $"t2.*"),
sql(
"SELECT t1.key, t2.* FROM json_table t1, " +
"LATERAL json_tuple(t1.jstring, 'f1', 'f2', 'f3', 'f4', 'f5') t2 "
+
@@ -390,17 +392,18 @@ class DataFrameTableValuedFunctionsSuite extends
QueryTest with RemoteSparkSessi
checkAnswer(
t1.lateralJoin(
- spark.tvf.stack(lit(2), lit("Key"), $"c1".outer(), lit("Value"),
$"c2".outer()))
- .select($"col0", $"col1"),
+ spark.tvf.stack(lit(2), lit("Key"), $"c1".outer(), lit("Value"),
$"c2".outer()).as("t"))
+ .select($"t.*"),
sql("SELECT t.* FROM t1, LATERAL stack(2, 'Key', c1, 'Value', c2) t"))
checkAnswer(
- t1.lateralJoin(spark.tvf.stack(lit(1), $"c1".outer(), $"c2".outer()))
- .select($"col0".as("x"), $"col1".as("y")),
+ t1.lateralJoin(
+ spark.tvf.stack(lit(1), $"c1".outer(), $"c2".outer()).toDF("x",
"y").as("t"))
+ .select($"t.*"),
sql("SELECT t.* FROM t1 JOIN LATERAL stack(1, c1, c2) t(x, y)"))
checkAnswer(
t1.join(t3, $"t1.c1" === $"t3.c1")
- .lateralJoin(spark.tvf.stack(lit(1), $"t1.c2".outer(),
$"t3.c2".outer()))
- .select($"col0", $"col1"),
+ .lateralJoin(spark.tvf.stack(lit(1), $"t1.c2".outer(),
$"t3.c2".outer()).as("t"))
+ .select($"t.*"),
sql("SELECT t.* FROM t1 JOIN t3 ON t1.c1 = t3.c1 JOIN LATERAL stack(1,
t1.c2, t3.c2) t"))
}
}
@@ -463,8 +466,8 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest
with RemoteSparkSessi
checkAnswer(
variantTable
.as("t1")
- .lateralJoin(spark.tvf.variant_explode($"v".outer()))
- .select($"id", $"pos", $"key", $"value"),
+ .lateralJoin(spark.tvf.variant_explode($"v".outer()).as("t"))
+ .select($"t1.id", $"t.*"),
sql("SELECT t1.id, t.* FROM variant_table AS t1, LATERAL
variant_explode(v) AS t"))
}
}
@@ -515,8 +518,8 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest
with RemoteSparkSessi
checkAnswer(
variantTable
.as("t1")
- .lateralJoin(spark.tvf.variant_explode_outer($"v".outer()))
- .select($"id", $"pos", $"key", $"value"),
+ .lateralJoin(spark.tvf.variant_explode_outer($"v".outer()).as("t"))
+ .select($"t1.id", $"t.*"),
sql("SELECT t1.id, t.* FROM variant_table AS t1, LATERAL
variant_explode_outer(v) AS t"))
}
}
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 660f577f56f8..e321f2c8d755 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -2713,11 +2713,10 @@ class DataFrame:
>>> customers.alias("c").lateralJoin(
... orders.alias("o")
... .where(sf.col("o.customer_id") ==
sf.col("c.customer_id").outer())
+ ... .select("order_id", "order_date")
... .orderBy(sf.col("order_date").desc())
... .limit(2),
... how="left"
- ... ).select(
- ... "c.customer_id", "name", "order_id", "order_date"
... ).orderBy("customer_id", "order_id").show()
+-----------+-------+--------+----------+
|customer_id| name|order_id|order_date|
diff --git a/python/pyspark/sql/tests/test_subquery.py
b/python/pyspark/sql/tests/test_subquery.py
index 0f431589b461..99a22d7c2966 100644
--- a/python/pyspark/sql/tests/test_subquery.py
+++ b/python/pyspark/sql/tests/test_subquery.py
@@ -518,6 +518,28 @@ class SubqueryTestsMixin:
self.spark.sql("""SELECT * FROM t1, LATERAL (SELECT t1.c1 +
t2.c1 FROM t2)"""),
)
+ def test_lateral_join_with_star_expansion(self):
+ with self.tempView("t1", "t2"):
+ t1 = self.table1()
+ t2 = self.table2()
+
+ assertDataFrameEqual(
+
t1.lateralJoin(self.spark.range(1).select().select(sf.col("*"))),
+ self.spark.sql("""SELECT * FROM t1, LATERAL (SELECT *)"""),
+ )
+ assertDataFrameEqual(
+ t1.lateralJoin(t2.select(sf.col("*"))),
+ self.spark.sql("""SELECT * FROM t1, LATERAL (SELECT * FROM
t2)"""),
+ )
+ assertDataFrameEqual(
+ t1.lateralJoin(t2.select(sf.col("t1.*").outer(),
sf.col("t2.*"))),
+ self.spark.sql("""SELECT * FROM t1, LATERAL (SELECT t1.*, t2.*
FROM t2)"""),
+ )
+ assertDataFrameEqual(
+ t1.lateralJoin(t2.alias("t1").select(sf.col("t1.*"))),
+ self.spark.sql("""SELECT * FROM t1, LATERAL (SELECT t1.* FROM
t2 AS t1)"""),
+ )
+
def test_lateral_join_with_different_join_types(self):
with self.tempView("t1"):
t1 = self.table1()
@@ -572,6 +594,20 @@ class SubqueryTestsMixin:
},
)
+ def test_lateral_join_with_subquery_alias(self):
+ with self.tempView("t1"):
+ t1 = self.table1()
+
+ assertDataFrameEqual(
+ t1.lateralJoin(
+ self.spark.range(1)
+ .select(sf.col("c1").outer(), sf.col("c2").outer())
+ .toDF("a", "b")
+ .alias("s")
+ ).select("a", "b"),
+ self.spark.sql("""SELECT a, b FROM t1, LATERAL (SELECT c1, c2)
s(a, b)"""),
+ )
+
def test_lateral_join_with_correlated_predicates(self):
with self.tempView("t1", "t2"):
t1 = self.table1()
@@ -661,9 +697,11 @@ class SubqueryTestsMixin:
assertDataFrameEqual(
t1.lateralJoin(
- t2.where(sf.col("t1.c1").outer() ==
sf.col("t2.c1")).select(sf.col("c2")),
+ t2.where(sf.col("t1.c1").outer() == sf.col("t2.c1"))
+ .select(sf.col("c2"))
+ .alias("s"),
how="left",
- ).join(t1.alias("t3"), sf.col("t2.c2") == sf.col("t3.c2"),
how="left"),
+ ).join(t1.alias("t3"), sf.col("s.c2") == sf.col("t3.c2"),
how="left"),
self.spark.sql(
"""
SELECT * FROM t1
diff --git a/python/pyspark/sql/tests/test_tvf.py
b/python/pyspark/sql/tests/test_tvf.py
index ea20cbf9b8f3..c7274c0810cf 100644
--- a/python/pyspark/sql/tests/test_tvf.py
+++ b/python/pyspark/sql/tests/test_tvf.py
@@ -65,11 +65,13 @@ class TVFTestsMixin:
assertDataFrameEqual(
t1.lateralJoin(
self.spark.tvf.explode(sf.array(sf.col("c1").outer(),
sf.col("c2").outer()))
- ).toDF("c1", "c2", "c3"),
+ .toDF("c3")
+ .alias("t2")
+ ),
self.spark.sql("""SELECT * FROM t1, LATERAL EXPLODE(ARRAY(c1,
c2)) t2(c3)"""),
)
assertDataFrameEqual(
-
t3.lateralJoin(self.spark.tvf.explode(sf.col("c2").outer())).toDF("c1", "c2",
"v"),
+
t3.lateralJoin(self.spark.tvf.explode(sf.col("c2").outer()).toDF("v").alias("t2")),
self.spark.sql("""SELECT * FROM t3, LATERAL EXPLODE(c2)
t2(v)"""),
)
assertDataFrameEqual(
@@ -127,12 +129,14 @@ class TVFTestsMixin:
self.spark.tvf.explode_outer(
sf.array(sf.col("c1").outer(), sf.col("c2").outer())
)
- ).toDF("c1", "c2", "c3"),
+ .toDF("c3")
+ .alias("t2")
+ ),
self.spark.sql("""SELECT * FROM t1, LATERAL
EXPLODE_OUTER(ARRAY(c1, c2)) t2(c3)"""),
)
assertDataFrameEqual(
-
t3.lateralJoin(self.spark.tvf.explode_outer(sf.col("c2").outer())).toDF(
- "c1", "c2", "v"
+ t3.lateralJoin(
+
self.spark.tvf.explode_outer(sf.col("c2").outer()).toDF("v").alias("t2")
),
self.spark.sql("""SELECT * FROM t3, LATERAL EXPLODE_OUTER(c2)
t2(v)"""),
)
@@ -193,10 +197,10 @@ class TVFTestsMixin:
)
assertDataFrameEqual(
array_struct.lateralJoin(
- self.spark.tvf.inline(sf.col("arr").outer()),
- sf.col("id") == sf.col("col1"),
+ self.spark.tvf.inline(sf.col("arr").outer()).toDF("k",
"v").alias("t"),
+ sf.col("id") == sf.col("k"),
"left",
- ).toDF("id", "arr", "k", "v"),
+ ),
self.spark.sql(
"""
SELECT * FROM array_struct LEFT JOIN LATERAL INLINE(arr)
t(k, v) ON id = k
@@ -252,10 +256,10 @@ class TVFTestsMixin:
)
assertDataFrameEqual(
array_struct.lateralJoin(
- self.spark.tvf.inline_outer(sf.col("arr").outer()),
- sf.col("id") == sf.col("col1"),
+
self.spark.tvf.inline_outer(sf.col("arr").outer()).toDF("k", "v").alias("t"),
+ sf.col("id") == sf.col("k"),
"left",
- ).toDF("id", "arr", "k", "v"),
+ ),
self.spark.sql(
"""
SELECT * FROM array_struct LEFT JOIN LATERAL
INLINE_OUTER(arr) t(k, v) ON id = k
@@ -302,9 +306,9 @@ class TVFTestsMixin:
sf.lit("f3"),
sf.lit("f4"),
sf.lit("f5"),
- )
+ ).alias("t2")
)
- .select("key", "c0", "c1", "c2", "c3", "c4"),
+ .select("t1.key", "t2.*"),
self.spark.sql(
"""
SELECT t1.key, t2.* FROM json_table t1,
@@ -322,10 +326,10 @@ class TVFTestsMixin:
sf.lit("f3"),
sf.lit("f4"),
sf.lit("f5"),
- )
+ ).alias("t2")
)
- .where(sf.col("c0").isNotNull())
- .select("key", "c0", "c1", "c2", "c3", "c4"),
+ .where(sf.col("t2.c0").isNotNull())
+ .select("t1.key", "t2.*"),
self.spark.sql(
"""
SELECT t1.key, t2.* FROM json_table t1,
@@ -485,8 +489,8 @@ class TVFTestsMixin:
sf.col("c1").outer(),
sf.lit("Value"),
sf.col("c2").outer(),
- )
- ).select("col0", "col1"),
+ ).alias("t")
+ ).select("t.*"),
self.spark.sql(
"""SELECT t.* FROM t1, LATERAL stack(2, 'Key', c1,
'Value', c2) t"""
),
@@ -494,17 +498,19 @@ class TVFTestsMixin:
assertDataFrameEqual(
t1.lateralJoin(
self.spark.tvf.stack(sf.lit(1), sf.col("c1").outer(),
sf.col("c2").outer())
- ).select("col0", "col1"),
- self.spark.sql("""SELECT t.* FROM t1 JOIN LATERAL stack(1, c1,
c2) t"""),
+ .toDF("x", "y")
+ .alias("t")
+ ).select("t.*"),
+ self.spark.sql("""SELECT t.* FROM t1 JOIN LATERAL stack(1, c1,
c2) t(x, y)"""),
)
assertDataFrameEqual(
t1.join(t3, sf.col("t1.c1") == sf.col("t3.c1"))
.lateralJoin(
self.spark.tvf.stack(
sf.lit(1), sf.col("t1.c2").outer(),
sf.col("t3.c2").outer()
- )
+ ).alias("t")
)
- .select("col0", "col1"),
+ .select("t.*"),
self.spark.sql(
"""
SELECT t.* FROM t1 JOIN t3 ON t1.c1 = t3.c1
@@ -570,8 +576,8 @@ class TVFTestsMixin:
assertDataFrameEqual(
variant_table.alias("t1")
-
.lateralJoin(self.spark.tvf.variant_explode(sf.col("v").outer()))
- .select("id", "pos", "key", "value"),
+
.lateralJoin(self.spark.tvf.variant_explode(sf.col("v").outer()).alias("t"))
+ .select("t1.id", "t.*"),
self.spark.sql(
"""
SELECT t1.id, t.* FROM variant_table AS t1,
@@ -629,8 +635,8 @@ class TVFTestsMixin:
assertDataFrameEqual(
variant_table.alias("t1")
-
.lateralJoin(self.spark.tvf.variant_explode_outer(sf.col("v").outer()))
- .select("id", "pos", "key", "value"),
+
.lateralJoin(self.spark.tvf.variant_explode_outer(sf.col("v").outer()).alias("t"))
+ .select("t1.id", "t.*"),
self.spark.sql(
"""
SELECT t1.id, t.* FROM variant_table AS t1,
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 490ae473a6e4..c0b4384af8b6 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -45,7 +45,7 @@ import org.apache.spark.internal.LogKeys.{DATAFRAME_ID,
SESSION_ID}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile,
TaskResourceProfile, TaskResourceRequest}
import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter,
Observation, RelationalGroupedDataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier,
FunctionIdentifier, QueryPlanningTracker}
-import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry,
GlobalTempView, LazyExpression, LocalTempView, MultiAlias,
NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias,
UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer,
UnresolvedExtractValue, UnresolvedFunction, UnresolvedPlanId, UnresolvedRegex,
UnresolvedRelation, UnresolvedStar, UnresolvedTableValuedFunction,
UnresolvedTranspose}
+import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry,
GlobalTempView, LazyExpression, LocalTempView, MultiAlias,
NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias,
UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer,
UnresolvedExtractValue, UnresolvedFunction, UnresolvedPlanId, UnresolvedRegex,
UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases,
UnresolvedTableValuedFunction, UnresolvedTranspose}
import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder,
ExpressionEncoder, RowEncoder}
import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
import org.apache.spark.sql.catalyst.expressions._
@@ -566,10 +566,9 @@ class SparkConnectPlanner(
}
private def transformToDF(rel: proto.ToDF): LogicalPlan = {
- Dataset
- .ofRows(session, transformRelation(rel.getInput))
- .toDF(rel.getColumnNamesList.asScala.toSeq: _*)
- .logicalPlan
+ UnresolvedSubqueryColumnAliases(
+ rel.getColumnNamesList.asScala.toSeq,
+ transformRelation(rel.getInput))
}
private def transformMapPartitions(rel: proto.MapPartitions): LogicalPlan = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index b9ae0e5b9131..287628f2cbef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -92,6 +92,23 @@ private[sql] object Dataset {
dataset
}
+ def apply[T](
+ sparkSession: SparkSession,
+ logicalPlan: LogicalPlan,
+ encoderGenerator: () => Encoder[T]): Dataset[T] = {
+ val dataset = new Dataset(sparkSession, logicalPlan, encoderGenerator)
+ // Eagerly bind the encoder so we verify that the encoder matches the
underlying
+ // schema. The user will get an error if this is not the case.
+ // optimization: it is guaranteed that [[InternalRow]] can be converted to
[[Row]] so
+ // do not do this check in that case. this check can be expensive since it
requires running
+ // the whole [[Analyzer]] to resolve the deserializer
+ if (!dataset.queryExecution.isLazyAnalysis
+ && dataset.encoder.clsTag.runtimeClass != classOf[Row]) {
+ dataset.resolvedEnc
+ }
+ dataset
+ }
+
def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame =
sparkSession.withActive {
val qe = sparkSession.sessionState.executePlan(logicalPlan)
@@ -241,8 +258,13 @@ class Dataset[T] private[sql](
this(queryExecution, () => encoder)
}
+ def this(
+ sparkSession: SparkSession, logicalPlan: LogicalPlan, encoderGenerator:
() => Encoder[T]) = {
+ this(sparkSession.sessionState.executePlan(logicalPlan), encoderGenerator)
+ }
+
def this(sparkSession: SparkSession, logicalPlan: LogicalPlan, encoder:
Encoder[T]) = {
- this(sparkSession.sessionState.executePlan(logicalPlan), encoder)
+ this(sparkSession, logicalPlan, () => encoder)
}
def this(sqlContext: SQLContext, logicalPlan: LogicalPlan, encoder:
Encoder[T]) = {
@@ -508,16 +530,8 @@ class Dataset[T] private[sql](
/** @inheritdoc */
@scala.annotation.varargs
- def toDF(colNames: String*): DataFrame = {
- require(schema.size == colNames.size,
- "The number of columns doesn't match.\n" +
- s"Old column names (${schema.size}): " +
schema.fields.map(_.name).mkString(", ") + "\n" +
- s"New column names (${colNames.size}): " + colNames.mkString(", "))
-
- val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute,
newName) =>
- Column(oldAttribute).as(newName)
- }
- select(newCols : _*)
+ def toDF(colNames: String*): DataFrame = withPlan {
+ UnresolvedSubqueryColumnAliases(colNames, logicalPlan)
}
/** @inheritdoc */
@@ -854,7 +868,7 @@ class Dataset[T] private[sql](
}
/** @inheritdoc */
- def as(alias: String): Dataset[T] = withTypedPlan {
+ def as(alias: String): Dataset[T] = withSameTypedPlan {
SubqueryAlias(alias, logicalPlan)
}
@@ -909,7 +923,7 @@ class Dataset[T] private[sql](
}
/** @inheritdoc */
- def filter(condition: Column): Dataset[T] = withTypedPlan {
+ def filter(condition: Column): Dataset[T] = withSameTypedPlan {
Filter(condition.expr, logicalPlan)
}
@@ -1038,7 +1052,7 @@ class Dataset[T] private[sql](
/** @inheritdoc */
@scala.annotation.varargs
- def observe(name: String, expr: Column, exprs: Column*): Dataset[T] =
withTypedPlan {
+ def observe(name: String, expr: Column, exprs: Column*): Dataset[T] =
withSameTypedPlan {
CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id)
}
@@ -1050,12 +1064,12 @@ class Dataset[T] private[sql](
}
/** @inheritdoc */
- def limit(n: Int): Dataset[T] = withTypedPlan {
+ def limit(n: Int): Dataset[T] = withSameTypedPlan {
Limit(Literal(n), logicalPlan)
}
/** @inheritdoc */
- def offset(n: Int): Dataset[T] = withTypedPlan {
+ def offset(n: Int): Dataset[T] = withSameTypedPlan {
Offset(Literal(n), logicalPlan)
}
@@ -1142,7 +1156,7 @@ class Dataset[T] private[sql](
/** @inheritdoc */
def sample(withReplacement: Boolean, fraction: Double, seed: Long):
Dataset[T] = {
- withTypedPlan {
+ withSameTypedPlan {
Sample(0.0, fraction, withReplacement, seed, logicalPlan)
}
}
@@ -1340,7 +1354,7 @@ class Dataset[T] private[sql](
def dropDuplicates(): Dataset[T] = dropDuplicates(this.columns)
/** @inheritdoc */
- def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan {
+ def dropDuplicates(colNames: Seq[String]): Dataset[T] = withSameTypedPlan {
val groupCols = groupColsFromDropDuplicates(colNames)
Deduplicate(groupCols, logicalPlan)
}
@@ -1351,7 +1365,7 @@ class Dataset[T] private[sql](
}
/** @inheritdoc */
- def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] =
withTypedPlan {
+ def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] =
withSameTypedPlan {
val groupCols = groupColsFromDropDuplicates(colNames)
// UnsupportedOperationChecker will fail the query if this is called with
batch Dataset.
DeduplicateWithinWatermark(groupCols, logicalPlan)
@@ -1511,7 +1525,7 @@ class Dataset[T] private[sql](
}
/** @inheritdoc */
- def repartition(numPartitions: Int): Dataset[T] = withTypedPlan {
+ def repartition(numPartitions: Int): Dataset[T] = withSameTypedPlan {
Repartition(numPartitions, shuffle = true, logicalPlan)
}
@@ -1526,7 +1540,7 @@ class Dataset[T] private[sql](
s"""Invalid partitionExprs specified: $sortOrders
|For range partitioning use repartitionByRange(...) instead.
""".stripMargin)
- withTypedPlan {
+ withSameTypedPlan {
RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan,
numPartitions)
}
}
@@ -1539,13 +1553,13 @@ class Dataset[T] private[sql](
case expr: SortOrder => expr
case expr: Expression => SortOrder(expr, Ascending)
})
- withTypedPlan {
+ withSameTypedPlan {
RepartitionByExpression(sortOrder, logicalPlan, numPartitions)
}
}
/** @inheritdoc */
- def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan {
+ def coalesce(numPartitions: Int): Dataset[T] = withSameTypedPlan {
Repartition(numPartitions, shuffle = false, logicalPlan)
}
@@ -2240,7 +2254,7 @@ class Dataset[T] private[sql](
SortOrder(expr, Ascending)
}
}
- withTypedPlan {
+ withSameTypedPlan {
Sort(sortOrder, global = global, logicalPlan)
}
}
@@ -2255,6 +2269,11 @@ class Dataset[T] private[sql](
Dataset(sparkSession, logicalPlan)
}
+ /** A convenient function to wrap a logical plan and produce a Dataset. */
+ @inline private def withSameTypedPlan(logicalPlan: LogicalPlan): Dataset[T]
= {
+ Dataset(sparkSession, logicalPlan, encoderGenerator)
+ }
+
/** A convenient function to wrap a set based logical plan and produce a
Dataset. */
@inline private def withSetOperator[U : Encoder](logicalPlan: LogicalPlan):
Dataset[U] = {
if (isUnTyped) {
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
index f94cf89276ec..fdfb909d9ba7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
@@ -418,6 +418,30 @@ class DataFrameSubquerySuite extends QueryTest with
SharedSparkSession {
}
}
+ test("lateral join with star expansion") {
+ withView("t1", "t2") {
+ val t1 = table1()
+ val t2 = table2()
+
+ checkAnswer(
+ t1.lateralJoin(spark.range(1).select().select($"*")),
+ sql("SELECT * FROM t1, LATERAL (SELECT *)")
+ )
+ checkAnswer(
+ t1.lateralJoin(t2.select($"*")),
+ sql("SELECT * FROM t1, LATERAL (SELECT * FROM t2)")
+ )
+ checkAnswer(
+ t1.lateralJoin(t2.select($"t1.*".outer(), $"t2.*")),
+ sql("SELECT * FROM t1, LATERAL (SELECT t1.*, t2.* FROM t2)")
+ )
+ checkAnswer(
+ t1.lateralJoin(t2.alias("t1").select($"t1.*")),
+ sql("SELECT * FROM t1, LATERAL (SELECT t1.* FROM t2 AS t1)")
+ )
+ }
+ }
+
test("lateral join with different join types") {
withView("t1") {
val t1 = table1()
@@ -444,6 +468,18 @@ class DataFrameSubquerySuite extends QueryTest with
SharedSparkSession {
}
}
+ test("lateral join with subquery alias") {
+ withView("t1") {
+ val t1 = table1()
+
+ checkAnswer(
+ t1.lateralJoin(spark.range(1).select($"c1".outer(),
$"c2".outer()).toDF("a", "b").as("s"))
+ .select("a", "b"),
+ sql("SELECT a, b FROM t1, LATERAL (SELECT c1, c2) s(a, b)")
+ )
+ }
+ }
+
test("lateral join with correlated equality / non-equality predicates") {
withView("t1", "t2") {
val t1 = table1()
@@ -516,8 +552,8 @@ class DataFrameSubquerySuite extends QueryTest with
SharedSparkSession {
checkAnswer(
t1.lateralJoin(
- t2.where($"t1.c1".outer() === $"t2.c1").select($"c2"), "left"
- ).join(t1.as("t3"), $"t2.c2" === $"t3.c2", "left"),
+ t2.where($"t1.c1".outer() === $"t2.c1").select($"c2").as("s"), "left"
+ ).join(t1.as("t3"), $"s.c2" === $"t3.c2", "left"),
sql(
"""
|SELECT * FROM t1
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
index 4f2cd275ffdf..637e0cf964fe 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
@@ -60,11 +60,11 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest
with SharedSparkSessi
val t3 = spark.table("t3")
checkAnswer(
- t1.lateralJoin(spark.tvf.explode(array($"c1".outer(), $"c2".outer()))),
+ t1.lateralJoin(spark.tvf.explode(array($"c1".outer(),
$"c2".outer())).toDF("c3").as("t2")),
sql("SELECT * FROM t1, LATERAL EXPLODE(ARRAY(c1, c2)) t2(c3)")
)
checkAnswer(
- t3.lateralJoin(spark.tvf.explode($"c2".outer())),
+ t3.lateralJoin(spark.tvf.explode($"c2".outer()).toDF("v").as("t2")),
sql("SELECT * FROM t3, LATERAL EXPLODE(c2) t2(v)")
)
checkAnswer(
@@ -112,11 +112,12 @@ class DataFrameTableValuedFunctionsSuite extends
QueryTest with SharedSparkSessi
val t3 = spark.table("t3")
checkAnswer(
- t1.lateralJoin(spark.tvf.explode_outer(array($"c1".outer(),
$"c2".outer()))),
+ t1.lateralJoin(
+ spark.tvf.explode_outer(array($"c1".outer(),
$"c2".outer())).toDF("c3").as("t2")),
sql("SELECT * FROM t1, LATERAL EXPLODE_OUTER(ARRAY(c1, c2)) t2(c3)")
)
checkAnswer(
- t3.lateralJoin(spark.tvf.explode_outer($"c2".outer())),
+
t3.lateralJoin(spark.tvf.explode_outer($"c2".outer()).toDF("v").as("t2")),
sql("SELECT * FROM t3, LATERAL EXPLODE_OUTER(c2) t2(v)")
)
checkAnswer(
@@ -164,8 +165,8 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest
with SharedSparkSessi
)
checkAnswer(
arrayStruct.lateralJoin(
- spark.tvf.inline($"arr".outer()),
- $"id" === $"col1",
+ spark.tvf.inline($"arr".outer()).toDF("k", "v").as("t"),
+ $"id" === $"k",
"left"
),
sql("SELECT * FROM array_struct LEFT JOIN LATERAL INLINE(arr) t(k, v)
ON id = k")
@@ -210,8 +211,8 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest
with SharedSparkSessi
)
checkAnswer(
arrayStruct.lateralJoin(
- spark.tvf.inline_outer($"arr".outer()),
- $"id" === $"col1",
+ spark.tvf.inline_outer($"arr".outer()).toDF("k", "v").as("t"),
+ $"id" === $"k",
"left"
),
sql("SELECT * FROM array_struct LEFT JOIN LATERAL INLINE_OUTER(arr)
t(k, v) ON id = k")
@@ -249,8 +250,8 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest
with SharedSparkSessi
jsonTable.as("t1").lateralJoin(
spark.tvf.json_tuple(
$"t1.jstring".outer(),
- lit("f1"), lit("f2"), lit("f3"), lit("f4"), lit("f5"))
- ).select($"key", $"c0", $"c1", $"c2", $"c3", $"c4"),
+ lit("f1"), lit("f2"), lit("f3"), lit("f4"), lit("f5")).as("t2")
+ ).select($"t1.key", $"t2.*"),
sql("SELECT t1.key, t2.* FROM json_table t1, " +
"LATERAL json_tuple(t1.jstring, 'f1', 'f2', 'f3', 'f4', 'f5') t2")
)
@@ -258,9 +259,9 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest
with SharedSparkSessi
jsonTable.as("t1").lateralJoin(
spark.tvf.json_tuple(
$"jstring".outer(),
- lit("f1"), lit("f2"), lit("f3"), lit("f4"), lit("f5"))
- ).where($"c0".isNotNull)
- .select($"key", $"c0", $"c1", $"c2", $"c3", $"c4"),
+ lit("f1"), lit("f2"), lit("f3"), lit("f4"), lit("f5")).as("t2")
+ ).where($"t2.c0".isNotNull)
+ .select($"t1.key", $"t2.*"),
sql("SELECT t1.key, t2.* FROM json_table t1, " +
"LATERAL json_tuple(t1.jstring, 'f1', 'f2', 'f3', 'f4', 'f5') t2 " +
"WHERE t2.c0 IS NOT NULL")
@@ -388,21 +389,21 @@ class DataFrameTableValuedFunctionsSuite extends
QueryTest with SharedSparkSessi
checkAnswer(
t1.lateralJoin(
- spark.tvf.stack(lit(2), lit("Key"), $"c1".outer(), lit("Value"),
$"c2".outer())
- ).select($"col0", $"col1"),
+ spark.tvf.stack(lit(2), lit("Key"), $"c1".outer(), lit("Value"),
$"c2".outer()).as("t")
+ ).select($"t.*"),
sql("SELECT t.* FROM t1, LATERAL stack(2, 'Key', c1, 'Value', c2) t")
)
checkAnswer(
t1.lateralJoin(
- spark.tvf.stack(lit(1), $"c1".outer(), $"c2".outer())
- ).select($"col0".as("x"), $"col1".as("y")),
+ spark.tvf.stack(lit(1), $"c1".outer(), $"c2".outer()).toDF("x",
"y").as("t")
+ ).select($"t.*"),
sql("SELECT t.* FROM t1 JOIN LATERAL stack(1, c1, c2) t(x, y)")
)
checkAnswer(
t1.join(t3, $"t1.c1" === $"t3.c1")
.lateralJoin(
- spark.tvf.stack(lit(1), $"t1.c2".outer(), $"t3.c2".outer())
- ).select($"col0", $"col1"),
+ spark.tvf.stack(lit(1), $"t1.c2".outer(), $"t3.c2".outer()).as("t")
+ ).select($"t.*"),
sql("SELECT t.* FROM t1 JOIN t3 ON t1.c1 = t3.c1 JOIN LATERAL stack(1,
t1.c2, t3.c2) t")
)
}
@@ -466,8 +467,8 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest
with SharedSparkSessi
checkAnswer(
variantTable.as("t1").lateralJoin(
- spark.tvf.variant_explode($"v".outer())
- ).select($"id", $"pos", $"key", $"value"),
+ spark.tvf.variant_explode($"v".outer()).as("t")
+ ).select($"t1.id", $"t.*"),
sql("SELECT t1.id, t.* FROM variant_table AS t1, LATERAL
variant_explode(v) AS t")
)
}
@@ -519,8 +520,8 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest
with SharedSparkSessi
checkAnswer(
variantTable.as("t1").lateralJoin(
- spark.tvf.variant_explode_outer($"v".outer())
- ).select($"id", $"pos", $"key", $"value"),
+ spark.tvf.variant_explode_outer($"v".outer()).as("t")
+ ).select($"t1.id", $"t.*"),
sql("SELECT t1.id, t.* FROM variant_table AS t1, LATERAL
variant_explode_outer(v) AS t")
)
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]