This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d44fd2b57110 [SPARK-50694][SQL] Support renames in subqueries
d44fd2b57110 is described below

commit d44fd2b5711095a3ea39b6d6e0fcc0dbc7118727
Author: Takuya Ueshin <[email protected]>
AuthorDate: Mon Jan 6 12:27:47 2025 -0800

    [SPARK-50694][SQL] Support renames in subqueries
    
    ### What changes were proposed in this pull request?
    
    Supports renames in subqueries:
    
    - `sub.toDF(...)`
    - `sub.alias(...)`
    
    ### Why are the changes needed?
    
    When the query is used as a subquery by adding `col.outer()`, `toDF` or 
`alias` doesn't work because they need analyzed plans.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, those APIs are available in subqueries.
    
    ### How was this patch tested?
    
    Added / modified the related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #49336 from ueshin/issues/SPARK-50694/renames.
    
    Authored-by: Takuya Ueshin <[email protected]>
    Signed-off-by: Takuya Ueshin <[email protected]>
---
 .../apache/spark/sql/DataFrameSubquerySuite.scala  | 37 ++++++++++-
 .../sql/DataFrameTableValuedFunctionsSuite.scala   | 73 +++++++++++-----------
 python/pyspark/sql/dataframe.py                    |  3 +-
 python/pyspark/sql/tests/test_subquery.py          | 42 ++++++++++++-
 python/pyspark/sql/tests/test_tvf.py               | 58 +++++++++--------
 .../sql/connect/planner/SparkConnectPlanner.scala  |  9 ++-
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 67 +++++++++++++-------
 .../apache/spark/sql/DataFrameSubquerySuite.scala  | 40 +++++++++++-
 .../sql/DataFrameTableValuedFunctionsSuite.scala   | 47 +++++++-------
 9 files changed, 255 insertions(+), 121 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
index fc37444f7719..4b36d36983a5 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
@@ -354,6 +354,28 @@ class DataFrameSubquerySuite extends QueryTest with 
RemoteSparkSession {
     }
   }
 
+  test("lateral join with star expansion") {
+    withView("t1", "t2") {
+      val t1 = table1()
+      val t2 = table2()
+
+      checkAnswer(
+        t1.lateralJoin(spark.range(1).select().select($"*")),
+        sql("SELECT * FROM t1, LATERAL (SELECT *)"))
+      checkAnswer(
+        t1.lateralJoin(t2.select($"*")).toDF("c1", "c2", "c3", "c4"),
+        sql("SELECT * FROM t1, LATERAL (SELECT * FROM t2)").toDF("c1", "c2", 
"c3", "c4"))
+      checkAnswer(
+        t1.lateralJoin(t2.select($"t1.*".outer(), $"t2.*"))
+          .toDF("c1", "c2", "c3", "c4", "c5", "c6"),
+        sql("SELECT * FROM t1, LATERAL (SELECT t1.*, t2.* FROM t2)")
+          .toDF("c1", "c2", "c3", "c4", "c5", "c6"))
+      checkAnswer(
+        t1.lateralJoin(t2.alias("t1").select($"t1.*")).toDF("c1", "c2", "c3", 
"c4"),
+        sql("SELECT * FROM t1, LATERAL (SELECT t1.* FROM t2 AS 
t1)").toDF("c1", "c2", "c3", "c4"))
+    }
+  }
+
   test("lateral join with different join types") {
     withView("t1") {
       val t1 = table1()
@@ -375,6 +397,17 @@ class DataFrameSubquerySuite extends QueryTest with 
RemoteSparkSession {
     }
   }
 
+  test("lateral join with subquery alias") {
+    withView("t1") {
+      val t1 = table1()
+
+      checkAnswer(
+        t1.lateralJoin(spark.range(1).select($"c1".outer(), 
$"c2".outer()).toDF("a", "b").as("s"))
+          .select("a", "b"),
+        sql("SELECT a, b FROM t1, LATERAL (SELECT c1, c2) s(a, b)"))
+    }
+  }
+
   test("lateral join with correlated equality / non-equality predicates") {
     withView("t1", "t2") {
       val t1 = table1()
@@ -441,8 +474,8 @@ class DataFrameSubquerySuite extends QueryTest with 
RemoteSparkSession {
       val t2 = table2()
 
       checkAnswer(
-        t1.lateralJoin(t2.where($"t1.c1".outer() === $"t2.c1").select($"c2"), 
"left")
-          .join(t1.as("t3"), $"t2.c2" === $"t3.c2", "left")
+        t1.lateralJoin(t2.where($"t1.c1".outer() === 
$"t2.c1").select($"c2").as("s"), "left")
+          .join(t1.as("t3"), $"s.c2" === $"t3.c2", "left")
           .toDF("c1", "c2", "c3", "c4", "c5"),
         sql("""
             |SELECT * FROM t1
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
index aeef2e8f0fcf..12a49ad21676 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
@@ -61,10 +61,11 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with RemoteSparkSessi
       val t3 = spark.table("t3")
 
       checkAnswer(
-        t1.lateralJoin(spark.tvf.explode(array($"c1".outer(), $"c2".outer()))),
+        t1.lateralJoin(
+          spark.tvf.explode(array($"c1".outer(), 
$"c2".outer())).toDF("c3").as("t2")),
         sql("SELECT * FROM t1, LATERAL EXPLODE(ARRAY(c1, c2)) t2(c3)"))
       checkAnswer(
-        t3.lateralJoin(spark.tvf.explode($"c2".outer())),
+        t3.lateralJoin(spark.tvf.explode($"c2".outer()).toDF("v").as("t2")),
         sql("SELECT * FROM t3, LATERAL EXPLODE(c2) t2(v)"))
       checkAnswer(
         spark.tvf
@@ -113,10 +114,11 @@ class DataFrameTableValuedFunctionsSuite extends 
QueryTest with RemoteSparkSessi
       val t3 = spark.table("t3")
 
       checkAnswer(
-        t1.lateralJoin(spark.tvf.explode_outer(array($"c1".outer(), 
$"c2".outer()))),
+        t1.lateralJoin(
+          spark.tvf.explode_outer(array($"c1".outer(), 
$"c2".outer())).toDF("c3").as("t2")),
         sql("SELECT * FROM t1, LATERAL EXPLODE_OUTER(ARRAY(c1, c2)) t2(c3)"))
       checkAnswer(
-        t3.lateralJoin(spark.tvf.explode_outer($"c2".outer())),
+        
t3.lateralJoin(spark.tvf.explode_outer($"c2".outer()).toDF("v").as("t2")),
         sql("SELECT * FROM t3, LATERAL EXPLODE_OUTER(c2) t2(v)"))
       checkAnswer(
         spark.tvf
@@ -161,7 +163,10 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with RemoteSparkSessi
         arrayStruct.lateralJoin(spark.tvf.inline($"arr".outer())),
         sql("SELECT * FROM array_struct JOIN LATERAL INLINE(arr)"))
       checkAnswer(
-        arrayStruct.lateralJoin(spark.tvf.inline($"arr".outer()), $"id" === 
$"col1", "left"),
+        arrayStruct.lateralJoin(
+          spark.tvf.inline($"arr".outer()).toDF("k", "v").as("t"),
+          $"id" === $"k",
+          "left"),
         sql("SELECT * FROM array_struct LEFT JOIN LATERAL INLINE(arr) t(k, v) 
ON id = k"))
     }
   }
@@ -202,8 +207,8 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with RemoteSparkSessi
         sql("SELECT * FROM array_struct JOIN LATERAL INLINE_OUTER(arr)"))
       checkAnswer(
         arrayStruct.lateralJoin(
-          spark.tvf.inline_outer($"arr".outer()),
-          $"id" === $"col1",
+          spark.tvf.inline_outer($"arr".outer()).toDF("k", "v").as("t"),
+          $"id" === $"k",
           "left"),
         sql("SELECT * FROM array_struct LEFT JOIN LATERAL INLINE_OUTER(arr) 
t(k, v) ON id = k"))
     }
@@ -238,30 +243,27 @@ class DataFrameTableValuedFunctionsSuite extends 
QueryTest with RemoteSparkSessi
         jsonTable
           .as("t1")
           .lateralJoin(
-            spark.tvf.json_tuple(
-              $"t1.jstring".outer(),
-              lit("f1"),
-              lit("f2"),
-              lit("f3"),
-              lit("f4"),
-              lit("f5")))
-          .select($"key", $"c0", $"c1", $"c2", $"c3", $"c4"),
+            spark.tvf
+              .json_tuple(
+                $"t1.jstring".outer(),
+                lit("f1"),
+                lit("f2"),
+                lit("f3"),
+                lit("f4"),
+                lit("f5"))
+              .as("t2"))
+          .select($"t1.key", $"t2.*"),
         sql(
           "SELECT t1.key, t2.* FROM json_table t1, " +
             "LATERAL json_tuple(t1.jstring, 'f1', 'f2', 'f3', 'f4', 'f5') t2"))
       checkAnswer(
         jsonTable
           .as("t1")
-          .lateralJoin(
-            spark.tvf.json_tuple(
-              $"jstring".outer(),
-              lit("f1"),
-              lit("f2"),
-              lit("f3"),
-              lit("f4"),
-              lit("f5")))
-          .where($"c0".isNotNull)
-          .select($"key", $"c0", $"c1", $"c2", $"c3", $"c4"),
+          .lateralJoin(spark.tvf
+            .json_tuple($"jstring".outer(), lit("f1"), lit("f2"), lit("f3"), 
lit("f4"), lit("f5"))
+            .as("t2"))
+          .where($"t2.c0".isNotNull)
+          .select($"t1.key", $"t2.*"),
         sql(
           "SELECT t1.key, t2.* FROM json_table t1, " +
             "LATERAL json_tuple(t1.jstring, 'f1', 'f2', 'f3', 'f4', 'f5') t2 " 
+
@@ -390,17 +392,18 @@ class DataFrameTableValuedFunctionsSuite extends 
QueryTest with RemoteSparkSessi
 
       checkAnswer(
         t1.lateralJoin(
-          spark.tvf.stack(lit(2), lit("Key"), $"c1".outer(), lit("Value"), 
$"c2".outer()))
-          .select($"col0", $"col1"),
+          spark.tvf.stack(lit(2), lit("Key"), $"c1".outer(), lit("Value"), 
$"c2".outer()).as("t"))
+          .select($"t.*"),
         sql("SELECT t.* FROM t1, LATERAL stack(2, 'Key', c1, 'Value', c2) t"))
       checkAnswer(
-        t1.lateralJoin(spark.tvf.stack(lit(1), $"c1".outer(), $"c2".outer()))
-          .select($"col0".as("x"), $"col1".as("y")),
+        t1.lateralJoin(
+          spark.tvf.stack(lit(1), $"c1".outer(), $"c2".outer()).toDF("x", 
"y").as("t"))
+          .select($"t.*"),
         sql("SELECT t.* FROM t1 JOIN LATERAL stack(1, c1, c2) t(x, y)"))
       checkAnswer(
         t1.join(t3, $"t1.c1" === $"t3.c1")
-          .lateralJoin(spark.tvf.stack(lit(1), $"t1.c2".outer(), 
$"t3.c2".outer()))
-          .select($"col0", $"col1"),
+          .lateralJoin(spark.tvf.stack(lit(1), $"t1.c2".outer(), 
$"t3.c2".outer()).as("t"))
+          .select($"t.*"),
         sql("SELECT t.* FROM t1 JOIN t3 ON t1.c1 = t3.c1 JOIN LATERAL stack(1, 
t1.c2, t3.c2) t"))
     }
   }
@@ -463,8 +466,8 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with RemoteSparkSessi
       checkAnswer(
         variantTable
           .as("t1")
-          .lateralJoin(spark.tvf.variant_explode($"v".outer()))
-          .select($"id", $"pos", $"key", $"value"),
+          .lateralJoin(spark.tvf.variant_explode($"v".outer()).as("t"))
+          .select($"t1.id", $"t.*"),
         sql("SELECT t1.id, t.* FROM variant_table AS t1, LATERAL 
variant_explode(v) AS t"))
     }
   }
@@ -515,8 +518,8 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with RemoteSparkSessi
       checkAnswer(
         variantTable
           .as("t1")
-          .lateralJoin(spark.tvf.variant_explode_outer($"v".outer()))
-          .select($"id", $"pos", $"key", $"value"),
+          .lateralJoin(spark.tvf.variant_explode_outer($"v".outer()).as("t"))
+          .select($"t1.id", $"t.*"),
         sql("SELECT t1.id, t.* FROM variant_table AS t1, LATERAL 
variant_explode_outer(v) AS t"))
     }
   }
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 660f577f56f8..e321f2c8d755 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -2713,11 +2713,10 @@ class DataFrame:
         >>> customers.alias("c").lateralJoin(
         ...     orders.alias("o")
         ...     .where(sf.col("o.customer_id") == 
sf.col("c.customer_id").outer())
+        ...     .select("order_id", "order_date")
         ...     .orderBy(sf.col("order_date").desc())
         ...     .limit(2),
         ...     how="left"
-        ... ).select(
-        ...     "c.customer_id", "name", "order_id", "order_date"
         ... ).orderBy("customer_id", "order_id").show()
         +-----------+-------+--------+----------+
         |customer_id|   name|order_id|order_date|
diff --git a/python/pyspark/sql/tests/test_subquery.py 
b/python/pyspark/sql/tests/test_subquery.py
index 0f431589b461..99a22d7c2966 100644
--- a/python/pyspark/sql/tests/test_subquery.py
+++ b/python/pyspark/sql/tests/test_subquery.py
@@ -518,6 +518,28 @@ class SubqueryTestsMixin:
                 self.spark.sql("""SELECT * FROM t1, LATERAL (SELECT t1.c1 + 
t2.c1 FROM t2)"""),
             )
 
+    def test_lateral_join_with_star_expansion(self):
+        with self.tempView("t1", "t2"):
+            t1 = self.table1()
+            t2 = self.table2()
+
+            assertDataFrameEqual(
+                
t1.lateralJoin(self.spark.range(1).select().select(sf.col("*"))),
+                self.spark.sql("""SELECT * FROM t1, LATERAL (SELECT *)"""),
+            )
+            assertDataFrameEqual(
+                t1.lateralJoin(t2.select(sf.col("*"))),
+                self.spark.sql("""SELECT * FROM t1, LATERAL (SELECT * FROM 
t2)"""),
+            )
+            assertDataFrameEqual(
+                t1.lateralJoin(t2.select(sf.col("t1.*").outer(), 
sf.col("t2.*"))),
+                self.spark.sql("""SELECT * FROM t1, LATERAL (SELECT t1.*, t2.* 
FROM t2)"""),
+            )
+            assertDataFrameEqual(
+                t1.lateralJoin(t2.alias("t1").select(sf.col("t1.*"))),
+                self.spark.sql("""SELECT * FROM t1, LATERAL (SELECT t1.* FROM 
t2 AS t1)"""),
+            )
+
     def test_lateral_join_with_different_join_types(self):
         with self.tempView("t1"):
             t1 = self.table1()
@@ -572,6 +594,20 @@ class SubqueryTestsMixin:
                 },
             )
 
+    def test_lateral_join_with_subquery_alias(self):
+        with self.tempView("t1"):
+            t1 = self.table1()
+
+            assertDataFrameEqual(
+                t1.lateralJoin(
+                    self.spark.range(1)
+                    .select(sf.col("c1").outer(), sf.col("c2").outer())
+                    .toDF("a", "b")
+                    .alias("s")
+                ).select("a", "b"),
+                self.spark.sql("""SELECT a, b FROM t1, LATERAL (SELECT c1, c2) 
s(a, b)"""),
+            )
+
     def test_lateral_join_with_correlated_predicates(self):
         with self.tempView("t1", "t2"):
             t1 = self.table1()
@@ -661,9 +697,11 @@ class SubqueryTestsMixin:
 
             assertDataFrameEqual(
                 t1.lateralJoin(
-                    t2.where(sf.col("t1.c1").outer() == 
sf.col("t2.c1")).select(sf.col("c2")),
+                    t2.where(sf.col("t1.c1").outer() == sf.col("t2.c1"))
+                    .select(sf.col("c2"))
+                    .alias("s"),
                     how="left",
-                ).join(t1.alias("t3"), sf.col("t2.c2") == sf.col("t3.c2"), 
how="left"),
+                ).join(t1.alias("t3"), sf.col("s.c2") == sf.col("t3.c2"), 
how="left"),
                 self.spark.sql(
                     """
                     SELECT * FROM t1
diff --git a/python/pyspark/sql/tests/test_tvf.py 
b/python/pyspark/sql/tests/test_tvf.py
index ea20cbf9b8f3..c7274c0810cf 100644
--- a/python/pyspark/sql/tests/test_tvf.py
+++ b/python/pyspark/sql/tests/test_tvf.py
@@ -65,11 +65,13 @@ class TVFTestsMixin:
             assertDataFrameEqual(
                 t1.lateralJoin(
                     self.spark.tvf.explode(sf.array(sf.col("c1").outer(), 
sf.col("c2").outer()))
-                ).toDF("c1", "c2", "c3"),
+                    .toDF("c3")
+                    .alias("t2")
+                ),
                 self.spark.sql("""SELECT * FROM t1, LATERAL EXPLODE(ARRAY(c1, 
c2)) t2(c3)"""),
             )
             assertDataFrameEqual(
-                
t3.lateralJoin(self.spark.tvf.explode(sf.col("c2").outer())).toDF("c1", "c2", 
"v"),
+                
t3.lateralJoin(self.spark.tvf.explode(sf.col("c2").outer()).toDF("v").alias("t2")),
                 self.spark.sql("""SELECT * FROM t3, LATERAL EXPLODE(c2) 
t2(v)"""),
             )
             assertDataFrameEqual(
@@ -127,12 +129,14 @@ class TVFTestsMixin:
                     self.spark.tvf.explode_outer(
                         sf.array(sf.col("c1").outer(), sf.col("c2").outer())
                     )
-                ).toDF("c1", "c2", "c3"),
+                    .toDF("c3")
+                    .alias("t2")
+                ),
                 self.spark.sql("""SELECT * FROM t1, LATERAL 
EXPLODE_OUTER(ARRAY(c1, c2)) t2(c3)"""),
             )
             assertDataFrameEqual(
-                
t3.lateralJoin(self.spark.tvf.explode_outer(sf.col("c2").outer())).toDF(
-                    "c1", "c2", "v"
+                t3.lateralJoin(
+                    
self.spark.tvf.explode_outer(sf.col("c2").outer()).toDF("v").alias("t2")
                 ),
                 self.spark.sql("""SELECT * FROM t3, LATERAL EXPLODE_OUTER(c2) 
t2(v)"""),
             )
@@ -193,10 +197,10 @@ class TVFTestsMixin:
             )
             assertDataFrameEqual(
                 array_struct.lateralJoin(
-                    self.spark.tvf.inline(sf.col("arr").outer()),
-                    sf.col("id") == sf.col("col1"),
+                    self.spark.tvf.inline(sf.col("arr").outer()).toDF("k", 
"v").alias("t"),
+                    sf.col("id") == sf.col("k"),
                     "left",
-                ).toDF("id", "arr", "k", "v"),
+                ),
                 self.spark.sql(
                     """
                     SELECT * FROM array_struct LEFT JOIN LATERAL INLINE(arr) 
t(k, v) ON id = k
@@ -252,10 +256,10 @@ class TVFTestsMixin:
             )
             assertDataFrameEqual(
                 array_struct.lateralJoin(
-                    self.spark.tvf.inline_outer(sf.col("arr").outer()),
-                    sf.col("id") == sf.col("col1"),
+                    
self.spark.tvf.inline_outer(sf.col("arr").outer()).toDF("k", "v").alias("t"),
+                    sf.col("id") == sf.col("k"),
                     "left",
-                ).toDF("id", "arr", "k", "v"),
+                ),
                 self.spark.sql(
                     """
                     SELECT * FROM array_struct LEFT JOIN LATERAL 
INLINE_OUTER(arr) t(k, v) ON id = k
@@ -302,9 +306,9 @@ class TVFTestsMixin:
                         sf.lit("f3"),
                         sf.lit("f4"),
                         sf.lit("f5"),
-                    )
+                    ).alias("t2")
                 )
-                .select("key", "c0", "c1", "c2", "c3", "c4"),
+                .select("t1.key", "t2.*"),
                 self.spark.sql(
                     """
                     SELECT t1.key, t2.* FROM json_table t1,
@@ -322,10 +326,10 @@ class TVFTestsMixin:
                         sf.lit("f3"),
                         sf.lit("f4"),
                         sf.lit("f5"),
-                    )
+                    ).alias("t2")
                 )
-                .where(sf.col("c0").isNotNull())
-                .select("key", "c0", "c1", "c2", "c3", "c4"),
+                .where(sf.col("t2.c0").isNotNull())
+                .select("t1.key", "t2.*"),
                 self.spark.sql(
                     """
                     SELECT t1.key, t2.* FROM json_table t1,
@@ -485,8 +489,8 @@ class TVFTestsMixin:
                         sf.col("c1").outer(),
                         sf.lit("Value"),
                         sf.col("c2").outer(),
-                    )
-                ).select("col0", "col1"),
+                    ).alias("t")
+                ).select("t.*"),
                 self.spark.sql(
                     """SELECT t.* FROM t1, LATERAL stack(2, 'Key', c1, 
'Value', c2) t"""
                 ),
@@ -494,17 +498,19 @@ class TVFTestsMixin:
             assertDataFrameEqual(
                 t1.lateralJoin(
                     self.spark.tvf.stack(sf.lit(1), sf.col("c1").outer(), 
sf.col("c2").outer())
-                ).select("col0", "col1"),
-                self.spark.sql("""SELECT t.* FROM t1 JOIN LATERAL stack(1, c1, 
c2) t"""),
+                    .toDF("x", "y")
+                    .alias("t")
+                ).select("t.*"),
+                self.spark.sql("""SELECT t.* FROM t1 JOIN LATERAL stack(1, c1, 
c2) t(x, y)"""),
             )
             assertDataFrameEqual(
                 t1.join(t3, sf.col("t1.c1") == sf.col("t3.c1"))
                 .lateralJoin(
                     self.spark.tvf.stack(
                         sf.lit(1), sf.col("t1.c2").outer(), 
sf.col("t3.c2").outer()
-                    )
+                    ).alias("t")
                 )
-                .select("col0", "col1"),
+                .select("t.*"),
                 self.spark.sql(
                     """
                     SELECT t.* FROM t1 JOIN t3 ON t1.c1 = t3.c1
@@ -570,8 +576,8 @@ class TVFTestsMixin:
 
             assertDataFrameEqual(
                 variant_table.alias("t1")
-                
.lateralJoin(self.spark.tvf.variant_explode(sf.col("v").outer()))
-                .select("id", "pos", "key", "value"),
+                
.lateralJoin(self.spark.tvf.variant_explode(sf.col("v").outer()).alias("t"))
+                .select("t1.id", "t.*"),
                 self.spark.sql(
                     """
                     SELECT t1.id, t.* FROM variant_table AS t1,
@@ -629,8 +635,8 @@ class TVFTestsMixin:
 
             assertDataFrameEqual(
                 variant_table.alias("t1")
-                
.lateralJoin(self.spark.tvf.variant_explode_outer(sf.col("v").outer()))
-                .select("id", "pos", "key", "value"),
+                
.lateralJoin(self.spark.tvf.variant_explode_outer(sf.col("v").outer()).alias("t"))
+                .select("t1.id", "t.*"),
                 self.spark.sql(
                     """
                     SELECT t1.id, t.* FROM variant_table AS t1,
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 490ae473a6e4..c0b4384af8b6 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -45,7 +45,7 @@ import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, 
SESSION_ID}
 import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, 
TaskResourceProfile, TaskResourceRequest}
 import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, 
Observation, RelationalGroupedDataset, Row, SparkSession}
 import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, 
FunctionIdentifier, QueryPlanningTracker}
-import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, 
GlobalTempView, LazyExpression, LocalTempView, MultiAlias, 
NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, 
UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, 
UnresolvedExtractValue, UnresolvedFunction, UnresolvedPlanId, UnresolvedRegex, 
UnresolvedRelation, UnresolvedStar, UnresolvedTableValuedFunction, 
UnresolvedTranspose}
+import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, 
GlobalTempView, LazyExpression, LocalTempView, MultiAlias, 
NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, 
UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, 
UnresolvedExtractValue, UnresolvedFunction, UnresolvedPlanId, UnresolvedRegex, 
UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, 
UnresolvedTableValuedFunction, UnresolvedTranspose}
 import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder, 
ExpressionEncoder, RowEncoder}
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
 import org.apache.spark.sql.catalyst.expressions._
@@ -566,10 +566,9 @@ class SparkConnectPlanner(
   }
 
   private def transformToDF(rel: proto.ToDF): LogicalPlan = {
-    Dataset
-      .ofRows(session, transformRelation(rel.getInput))
-      .toDF(rel.getColumnNamesList.asScala.toSeq: _*)
-      .logicalPlan
+    UnresolvedSubqueryColumnAliases(
+      rel.getColumnNamesList.asScala.toSeq,
+      transformRelation(rel.getInput))
   }
 
   private def transformMapPartitions(rel: proto.MapPartitions): LogicalPlan = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index b9ae0e5b9131..287628f2cbef 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -92,6 +92,23 @@ private[sql] object Dataset {
     dataset
   }
 
+  def apply[T](
+      sparkSession: SparkSession,
+      logicalPlan: LogicalPlan,
+      encoderGenerator: () => Encoder[T]): Dataset[T] = {
+    val dataset = new Dataset(sparkSession, logicalPlan, encoderGenerator)
+    // Eagerly bind the encoder so we verify that the encoder matches the 
underlying
+    // schema. The user will get an error if this is not the case.
+    // optimization: it is guaranteed that [[InternalRow]] can be converted to 
[[Row]] so
+    // do not do this check in that case. this check can be expensive since it 
requires running
+    // the whole [[Analyzer]] to resolve the deserializer
+    if (!dataset.queryExecution.isLazyAnalysis
+        && dataset.encoder.clsTag.runtimeClass != classOf[Row]) {
+      dataset.resolvedEnc
+    }
+    dataset
+  }
+
   def ofRows(sparkSession: SparkSession, logicalPlan: LogicalPlan): DataFrame =
     sparkSession.withActive {
       val qe = sparkSession.sessionState.executePlan(logicalPlan)
@@ -241,8 +258,13 @@ class Dataset[T] private[sql](
     this(queryExecution, () => encoder)
   }
 
+  def this(
+      sparkSession: SparkSession, logicalPlan: LogicalPlan, encoderGenerator: 
() => Encoder[T]) = {
+    this(sparkSession.sessionState.executePlan(logicalPlan), encoderGenerator)
+  }
+
   def this(sparkSession: SparkSession, logicalPlan: LogicalPlan, encoder: 
Encoder[T]) = {
-    this(sparkSession.sessionState.executePlan(logicalPlan), encoder)
+    this(sparkSession, logicalPlan, () => encoder)
   }
 
   def this(sqlContext: SQLContext, logicalPlan: LogicalPlan, encoder: 
Encoder[T]) = {
@@ -508,16 +530,8 @@ class Dataset[T] private[sql](
 
   /** @inheritdoc */
   @scala.annotation.varargs
-  def toDF(colNames: String*): DataFrame = {
-    require(schema.size == colNames.size,
-      "The number of columns doesn't match.\n" +
-        s"Old column names (${schema.size}): " + 
schema.fields.map(_.name).mkString(", ") + "\n" +
-        s"New column names (${colNames.size}): " + colNames.mkString(", "))
-
-    val newCols = logicalPlan.output.zip(colNames).map { case (oldAttribute, 
newName) =>
-      Column(oldAttribute).as(newName)
-    }
-    select(newCols : _*)
+  def toDF(colNames: String*): DataFrame = withPlan {
+    UnresolvedSubqueryColumnAliases(colNames, logicalPlan)
   }
 
   /** @inheritdoc */
@@ -854,7 +868,7 @@ class Dataset[T] private[sql](
   }
 
   /** @inheritdoc */
-  def as(alias: String): Dataset[T] = withTypedPlan {
+  def as(alias: String): Dataset[T] = withSameTypedPlan {
     SubqueryAlias(alias, logicalPlan)
   }
 
@@ -909,7 +923,7 @@ class Dataset[T] private[sql](
   }
 
   /** @inheritdoc */
-  def filter(condition: Column): Dataset[T] = withTypedPlan {
+  def filter(condition: Column): Dataset[T] = withSameTypedPlan {
     Filter(condition.expr, logicalPlan)
   }
 
@@ -1038,7 +1052,7 @@ class Dataset[T] private[sql](
 
   /** @inheritdoc */
   @scala.annotation.varargs
-  def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = 
withTypedPlan {
+  def observe(name: String, expr: Column, exprs: Column*): Dataset[T] = 
withSameTypedPlan {
     CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id)
   }
 
@@ -1050,12 +1064,12 @@ class Dataset[T] private[sql](
   }
 
   /** @inheritdoc */
-  def limit(n: Int): Dataset[T] = withTypedPlan {
+  def limit(n: Int): Dataset[T] = withSameTypedPlan {
     Limit(Literal(n), logicalPlan)
   }
 
   /** @inheritdoc */
-  def offset(n: Int): Dataset[T] = withTypedPlan {
+  def offset(n: Int): Dataset[T] = withSameTypedPlan {
     Offset(Literal(n), logicalPlan)
   }
 
@@ -1142,7 +1156,7 @@ class Dataset[T] private[sql](
 
   /** @inheritdoc */
   def sample(withReplacement: Boolean, fraction: Double, seed: Long): 
Dataset[T] = {
-    withTypedPlan {
+    withSameTypedPlan {
       Sample(0.0, fraction, withReplacement, seed, logicalPlan)
     }
   }
@@ -1340,7 +1354,7 @@ class Dataset[T] private[sql](
   def dropDuplicates(): Dataset[T] = dropDuplicates(this.columns)
 
   /** @inheritdoc */
-  def dropDuplicates(colNames: Seq[String]): Dataset[T] = withTypedPlan {
+  def dropDuplicates(colNames: Seq[String]): Dataset[T] = withSameTypedPlan {
     val groupCols = groupColsFromDropDuplicates(colNames)
     Deduplicate(groupCols, logicalPlan)
   }
@@ -1351,7 +1365,7 @@ class Dataset[T] private[sql](
   }
 
   /** @inheritdoc */
-  def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] = 
withTypedPlan {
+  def dropDuplicatesWithinWatermark(colNames: Seq[String]): Dataset[T] = 
withSameTypedPlan {
     val groupCols = groupColsFromDropDuplicates(colNames)
     // UnsupportedOperationChecker will fail the query if this is called with 
batch Dataset.
     DeduplicateWithinWatermark(groupCols, logicalPlan)
@@ -1511,7 +1525,7 @@ class Dataset[T] private[sql](
   }
 
   /** @inheritdoc */
-  def repartition(numPartitions: Int): Dataset[T] = withTypedPlan {
+  def repartition(numPartitions: Int): Dataset[T] = withSameTypedPlan {
     Repartition(numPartitions, shuffle = true, logicalPlan)
   }
 
@@ -1526,7 +1540,7 @@ class Dataset[T] private[sql](
       s"""Invalid partitionExprs specified: $sortOrders
          |For range partitioning use repartitionByRange(...) instead.
        """.stripMargin)
-    withTypedPlan {
+    withSameTypedPlan {
       RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, 
numPartitions)
     }
   }
@@ -1539,13 +1553,13 @@ class Dataset[T] private[sql](
       case expr: SortOrder => expr
       case expr: Expression => SortOrder(expr, Ascending)
     })
-    withTypedPlan {
+    withSameTypedPlan {
       RepartitionByExpression(sortOrder, logicalPlan, numPartitions)
     }
   }
 
   /** @inheritdoc */
-  def coalesce(numPartitions: Int): Dataset[T] = withTypedPlan {
+  def coalesce(numPartitions: Int): Dataset[T] = withSameTypedPlan {
     Repartition(numPartitions, shuffle = false, logicalPlan)
   }
 
@@ -2240,7 +2254,7 @@ class Dataset[T] private[sql](
           SortOrder(expr, Ascending)
       }
     }
-    withTypedPlan {
+    withSameTypedPlan {
       Sort(sortOrder, global = global, logicalPlan)
     }
   }
@@ -2255,6 +2269,11 @@ class Dataset[T] private[sql](
     Dataset(sparkSession, logicalPlan)
   }
 
+  /** A convenient function to wrap a logical plan and produce a Dataset. */
+  @inline private def withSameTypedPlan(logicalPlan: LogicalPlan): Dataset[T] 
= {
+    Dataset(sparkSession, logicalPlan, encoderGenerator)
+  }
+
   /** A convenient function to wrap a set based logical plan and produce a 
Dataset. */
   @inline private def withSetOperator[U : Encoder](logicalPlan: LogicalPlan): 
Dataset[U] = {
     if (isUnTyped) {
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
index f94cf89276ec..fdfb909d9ba7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
@@ -418,6 +418,30 @@ class DataFrameSubquerySuite extends QueryTest with 
SharedSparkSession {
     }
   }
 
+  test("lateral join with star expansion") {
+    withView("t1", "t2") {
+      val t1 = table1()
+      val t2 = table2()
+
+      checkAnswer(
+        t1.lateralJoin(spark.range(1).select().select($"*")),
+        sql("SELECT * FROM t1, LATERAL (SELECT *)")
+      )
+      checkAnswer(
+        t1.lateralJoin(t2.select($"*")),
+        sql("SELECT * FROM t1, LATERAL (SELECT * FROM t2)")
+      )
+      checkAnswer(
+        t1.lateralJoin(t2.select($"t1.*".outer(), $"t2.*")),
+        sql("SELECT * FROM t1, LATERAL (SELECT t1.*, t2.* FROM t2)")
+      )
+      checkAnswer(
+        t1.lateralJoin(t2.alias("t1").select($"t1.*")),
+        sql("SELECT * FROM t1, LATERAL (SELECT t1.* FROM t2 AS t1)")
+      )
+    }
+  }
+
   test("lateral join with different join types") {
     withView("t1") {
       val t1 = table1()
@@ -444,6 +468,18 @@ class DataFrameSubquerySuite extends QueryTest with 
SharedSparkSession {
     }
   }
 
+  test("lateral join with subquery alias") {
+    withView("t1") {
+      val t1 = table1()
+
+      checkAnswer(
+        t1.lateralJoin(spark.range(1).select($"c1".outer(), 
$"c2".outer()).toDF("a", "b").as("s"))
+          .select("a", "b"),
+        sql("SELECT a, b FROM t1, LATERAL (SELECT c1, c2) s(a, b)")
+      )
+    }
+  }
+
   test("lateral join with correlated equality / non-equality predicates") {
     withView("t1", "t2") {
       val t1 = table1()
@@ -516,8 +552,8 @@ class DataFrameSubquerySuite extends QueryTest with 
SharedSparkSession {
 
       checkAnswer(
         t1.lateralJoin(
-          t2.where($"t1.c1".outer() === $"t2.c1").select($"c2"), "left"
-        ).join(t1.as("t3"), $"t2.c2" === $"t3.c2", "left"),
+          t2.where($"t1.c1".outer() === $"t2.c1").select($"c2").as("s"), "left"
+        ).join(t1.as("t3"), $"s.c2" === $"t3.c2", "left"),
         sql(
           """
             |SELECT * FROM t1
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
index 4f2cd275ffdf..637e0cf964fe 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTableValuedFunctionsSuite.scala
@@ -60,11 +60,11 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with SharedSparkSessi
       val t3 = spark.table("t3")
 
       checkAnswer(
-        t1.lateralJoin(spark.tvf.explode(array($"c1".outer(), $"c2".outer()))),
+        t1.lateralJoin(spark.tvf.explode(array($"c1".outer(), 
$"c2".outer())).toDF("c3").as("t2")),
         sql("SELECT * FROM t1, LATERAL EXPLODE(ARRAY(c1, c2)) t2(c3)")
       )
       checkAnswer(
-        t3.lateralJoin(spark.tvf.explode($"c2".outer())),
+        t3.lateralJoin(spark.tvf.explode($"c2".outer()).toDF("v").as("t2")),
         sql("SELECT * FROM t3, LATERAL EXPLODE(c2) t2(v)")
       )
       checkAnswer(
@@ -112,11 +112,12 @@ class DataFrameTableValuedFunctionsSuite extends 
QueryTest with SharedSparkSessi
       val t3 = spark.table("t3")
 
       checkAnswer(
-        t1.lateralJoin(spark.tvf.explode_outer(array($"c1".outer(), 
$"c2".outer()))),
+        t1.lateralJoin(
+          spark.tvf.explode_outer(array($"c1".outer(), 
$"c2".outer())).toDF("c3").as("t2")),
         sql("SELECT * FROM t1, LATERAL EXPLODE_OUTER(ARRAY(c1, c2)) t2(c3)")
       )
       checkAnswer(
-        t3.lateralJoin(spark.tvf.explode_outer($"c2".outer())),
+        
t3.lateralJoin(spark.tvf.explode_outer($"c2".outer()).toDF("v").as("t2")),
         sql("SELECT * FROM t3, LATERAL EXPLODE_OUTER(c2) t2(v)")
       )
       checkAnswer(
@@ -164,8 +165,8 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with SharedSparkSessi
       )
       checkAnswer(
         arrayStruct.lateralJoin(
-          spark.tvf.inline($"arr".outer()),
-          $"id" === $"col1",
+          spark.tvf.inline($"arr".outer()).toDF("k", "v").as("t"),
+          $"id" === $"k",
           "left"
         ),
         sql("SELECT * FROM array_struct LEFT JOIN LATERAL INLINE(arr) t(k, v) 
ON id = k")
@@ -210,8 +211,8 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with SharedSparkSessi
       )
       checkAnswer(
         arrayStruct.lateralJoin(
-          spark.tvf.inline_outer($"arr".outer()),
-          $"id" === $"col1",
+          spark.tvf.inline_outer($"arr".outer()).toDF("k", "v").as("t"),
+          $"id" === $"k",
           "left"
         ),
         sql("SELECT * FROM array_struct LEFT JOIN LATERAL INLINE_OUTER(arr) 
t(k, v) ON id = k")
@@ -249,8 +250,8 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with SharedSparkSessi
         jsonTable.as("t1").lateralJoin(
           spark.tvf.json_tuple(
             $"t1.jstring".outer(),
-            lit("f1"), lit("f2"), lit("f3"), lit("f4"), lit("f5"))
-        ).select($"key", $"c0", $"c1", $"c2", $"c3", $"c4"),
+            lit("f1"), lit("f2"), lit("f3"), lit("f4"), lit("f5")).as("t2")
+        ).select($"t1.key", $"t2.*"),
         sql("SELECT t1.key, t2.* FROM json_table t1, " +
           "LATERAL json_tuple(t1.jstring, 'f1', 'f2', 'f3', 'f4', 'f5') t2")
       )
@@ -258,9 +259,9 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with SharedSparkSessi
         jsonTable.as("t1").lateralJoin(
           spark.tvf.json_tuple(
             $"jstring".outer(),
-            lit("f1"), lit("f2"), lit("f3"), lit("f4"), lit("f5"))
-        ).where($"c0".isNotNull)
-          .select($"key", $"c0", $"c1", $"c2", $"c3", $"c4"),
+            lit("f1"), lit("f2"), lit("f3"), lit("f4"), lit("f5")).as("t2")
+        ).where($"t2.c0".isNotNull)
+          .select($"t1.key", $"t2.*"),
         sql("SELECT t1.key, t2.* FROM json_table t1, " +
           "LATERAL json_tuple(t1.jstring, 'f1', 'f2', 'f3', 'f4', 'f5') t2 " +
           "WHERE t2.c0 IS NOT NULL")
@@ -388,21 +389,21 @@ class DataFrameTableValuedFunctionsSuite extends 
QueryTest with SharedSparkSessi
 
       checkAnswer(
         t1.lateralJoin(
-          spark.tvf.stack(lit(2), lit("Key"), $"c1".outer(), lit("Value"), 
$"c2".outer())
-        ).select($"col0", $"col1"),
+          spark.tvf.stack(lit(2), lit("Key"), $"c1".outer(), lit("Value"), 
$"c2".outer()).as("t")
+        ).select($"t.*"),
         sql("SELECT t.* FROM t1, LATERAL stack(2, 'Key', c1, 'Value', c2) t")
       )
       checkAnswer(
         t1.lateralJoin(
-          spark.tvf.stack(lit(1), $"c1".outer(), $"c2".outer())
-        ).select($"col0".as("x"), $"col1".as("y")),
+          spark.tvf.stack(lit(1), $"c1".outer(), $"c2".outer()).toDF("x", 
"y").as("t")
+        ).select($"t.*"),
         sql("SELECT t.* FROM t1 JOIN LATERAL stack(1, c1, c2) t(x, y)")
       )
       checkAnswer(
         t1.join(t3, $"t1.c1" === $"t3.c1")
           .lateralJoin(
-            spark.tvf.stack(lit(1), $"t1.c2".outer(), $"t3.c2".outer())
-          ).select($"col0", $"col1"),
+            spark.tvf.stack(lit(1), $"t1.c2".outer(), $"t3.c2".outer()).as("t")
+          ).select($"t.*"),
         sql("SELECT t.* FROM t1 JOIN t3 ON t1.c1 = t3.c1 JOIN LATERAL stack(1, 
t1.c2, t3.c2) t")
       )
     }
@@ -466,8 +467,8 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with SharedSparkSessi
 
       checkAnswer(
         variantTable.as("t1").lateralJoin(
-          spark.tvf.variant_explode($"v".outer())
-        ).select($"id", $"pos", $"key", $"value"),
+          spark.tvf.variant_explode($"v".outer()).as("t")
+        ).select($"t1.id", $"t.*"),
         sql("SELECT t1.id, t.* FROM variant_table AS t1, LATERAL 
variant_explode(v) AS t")
       )
     }
@@ -519,8 +520,8 @@ class DataFrameTableValuedFunctionsSuite extends QueryTest 
with SharedSparkSessi
 
       checkAnswer(
         variantTable.as("t1").lateralJoin(
-          spark.tvf.variant_explode_outer($"v".outer())
-        ).select($"id", $"pos", $"key", $"value"),
+          spark.tvf.variant_explode_outer($"v".outer()).as("t")
+        ).select($"t1.id", $"t.*"),
         sql("SELECT t1.id, t.* FROM variant_table AS t1, LATERAL 
variant_explode_outer(v) AS t")
       )
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to