This is an automated email from the ASF dual-hosted git repository.

ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 22731393069a [SPARK-50601][SQL] Support withColumns / 
withColumnsRenamed in subqueries
22731393069a is described below

commit 22731393069a3f180a9e719e57a694347c0ce87b
Author: Takuya Ueshin <[email protected]>
AuthorDate: Mon Jan 13 18:22:16 2025 -0800

    [SPARK-50601][SQL] Support withColumns / withColumnsRenamed in subqueries
    
    ### What changes were proposed in this pull request?
    
    Supports `withColumns` / `withColumnsRenamed` in subqueries.
    
    ### Why are the changes needed?
    
    When the query is used as a subquery by adding `col.outer()`, `withColumns` 
or `withColumnsRenamed` doesn't work because they need analyzed plans.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, those APIs are available in subqueries.
    
    ### How was this patch tested?
    
    Added the related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #49386 from ueshin/issues/SPARK-50601/with_columns.
    
    Lead-authored-by: Takuya Ueshin <[email protected]>
    Co-authored-by: Takuya UESHIN <[email protected]>
    Signed-off-by: Takuya Ueshin <[email protected]>
---
 .../apache/spark/sql/DataFrameSubquerySuite.scala  |  57 +++++++--
 .../sql/tests/connect/test_parity_subquery.py      |   4 -
 python/pyspark/sql/tests/test_subquery.py          |  39 +++++-
 .../spark/sql/catalyst/analysis/unresolved.scala   | 132 ++++++++++++++++++---
 .../sql/connect/planner/SparkConnectPlanner.scala  |  33 +++---
 .../connect/planner/SparkConnectPlannerSuite.scala |  33 +++---
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  56 +++------
 .../apache/spark/sql/DataFrameSubquerySuite.scala  |  48 +++++++-
 8 files changed, 295 insertions(+), 107 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
index 4b36d36983a5..1d2165b668f6 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.{SparkException, SparkRuntimeException}
+import org.apache.spark.SparkRuntimeException
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.test.{QueryTest, RemoteSparkSession}
 
@@ -665,15 +665,52 @@ class DataFrameSubquerySuite extends QueryTest with 
RemoteSparkSession {
     withView("t1") {
       val t1 = table1()
 
-      // TODO(SPARK-50601): Fix the SparkConnectPlanner to support this case
-      checkError(
-        intercept[SparkException] {
-          t1.withColumn("scalar", spark.range(1).select($"c1".outer() + 
$"c2".outer()).scalar())
-            .collect()
-        },
-        "INTERNAL_ERROR",
-        parameters = Map("message" -> "Found the unresolved operator: .*"),
-        matchPVals = true)
+      checkAnswer(
+        t1.withColumn(
+          "scalar",
+          spark
+            .range(1)
+            .select($"c1".outer() + $"c2".outer())
+            .scalar()),
+        t1.select($"*", ($"c1" + $"c2").as("scalar")))
+
+      checkAnswer(
+        t1.withColumn(
+          "scalar",
+          spark
+            .range(1)
+            .withColumn("c1", $"c1".outer())
+            .select($"c1" + $"c2".outer())
+            .scalar()),
+        t1.select($"*", ($"c1" + $"c2").as("scalar")))
+
+      checkAnswer(
+        t1.withColumn(
+          "scalar",
+          spark
+            .range(1)
+            .select($"c1".outer().as("c1"))
+            .withColumn("c2", $"c2".outer())
+            .select($"c1" + $"c2")
+            .scalar()),
+        t1.select($"*", ($"c1" + $"c2").as("scalar")))
+    }
+  }
+
+  test("subquery in withColumnsRenamed") {
+    withView("t1") {
+      val t1 = table1()
+
+      checkAnswer(
+        t1.withColumn(
+          "scalar",
+          spark
+            .range(1)
+            .select($"c1".outer().as("c1"), $"c2".outer().as("c2"))
+            .withColumnsRenamed(Map("c1" -> "x", "c2" -> "y"))
+            .select($"x" + $"y")
+            .scalar()),
+        t1.select($"*", ($"c1".as("x") + $"c2".as("y")).as("scalar")))
     }
   }
 
diff --git a/python/pyspark/sql/tests/connect/test_parity_subquery.py 
b/python/pyspark/sql/tests/connect/test_parity_subquery.py
index dae60a354d20..f3225fcb7f2d 100644
--- a/python/pyspark/sql/tests/connect/test_parity_subquery.py
+++ b/python/pyspark/sql/tests/connect/test_parity_subquery.py
@@ -45,10 +45,6 @@ class SubqueryParityTests(SubqueryTestsMixin, 
ReusedConnectTestCase):
     def test_subquery_in_unpivot(self):
         self.check_subquery_in_unpivot(None, None)
 
-    @unittest.skip("SPARK-50601: Fix the SparkConnectPlanner to support this 
case")
-    def test_subquery_in_with_columns(self):
-        super().test_subquery_in_with_columns()
-
 
 if __name__ == "__main__":
     from pyspark.sql.tests.connect.test_parity_subquery import *  # noqa: F401
diff --git a/python/pyspark/sql/tests/test_subquery.py 
b/python/pyspark/sql/tests/test_subquery.py
index 99a22d7c2966..7c63ddb69458 100644
--- a/python/pyspark/sql/tests/test_subquery.py
+++ b/python/pyspark/sql/tests/test_subquery.py
@@ -939,7 +939,44 @@ class SubqueryTestsMixin:
                     .select(sf.col("c1").outer() + sf.col("c2").outer())
                     .scalar(),
                 ),
-                t1.withColumn("scalar", sf.col("c1") + sf.col("c2")),
+                t1.select("*", (sf.col("c1") + sf.col("c2")).alias("scalar")),
+            )
+            assertDataFrameEqual(
+                t1.withColumn(
+                    "scalar",
+                    self.spark.range(1)
+                    .withColumn("c1", sf.col("c1").outer())
+                    .select(sf.col("c1") + sf.col("c2").outer())
+                    .scalar(),
+                ),
+                t1.select("*", (sf.col("c1") + sf.col("c2")).alias("scalar")),
+            )
+            assertDataFrameEqual(
+                t1.withColumn(
+                    "scalar",
+                    self.spark.range(1)
+                    .select(sf.col("c1").outer().alias("c1"))
+                    .withColumn("c2", sf.col("c2").outer())
+                    .select(sf.col("c1") + sf.col("c2"))
+                    .scalar(),
+                ),
+                t1.select("*", (sf.col("c1") + sf.col("c2")).alias("scalar")),
+            )
+
+    def test_subquery_in_with_columns_renamed(self):
+        with self.tempView("t1"):
+            t1 = self.table1()
+
+            assertDataFrameEqual(
+                t1.withColumn(
+                    "scalar",
+                    self.spark.range(1)
+                    .select(sf.col("c1").outer().alias("c1"), 
sf.col("c2").outer().alias("c2"))
+                    .withColumnsRenamed({"c1": "x", "c2": "y"})
+                    .select(sf.col("x") + sf.col("y"))
+                    .scalar(),
+                ),
+                t1.select("*", (sf.col("c1").alias("x") + 
sf.col("c2").alias("y")).alias("scalar")),
             )
 
     def test_subquery_in_drop(self):
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index b47af90c651a..fabe551d054c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils.toSQLId
 import org.apache.spark.sql.connector.catalog.TableWritePrivilege
 import org.apache.spark.sql.errors.{QueryCompilationErrors, 
QueryExecutionErrors}
 import org.apache.spark.sql.types.{DataType, Metadata, StructType}
-import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils}
 import org.apache.spark.util.ArrayImplicits._
 
 /**
@@ -429,7 +429,7 @@ object UnresolvedFunction {
  * Represents all of the input attributes to a given relational operator, for 
example in
  * "SELECT * FROM ...". A [[Star]] gets automatically expanded during analysis.
  */
-abstract class Star extends LeafExpression with NamedExpression {
+trait Star extends NamedExpression {
 
   override def name: String = throw new UnresolvedException("name")
   override def exprId: ExprId = throw new UnresolvedException("exprId")
@@ -451,15 +451,20 @@ abstract class Star extends LeafExpression with 
NamedExpression {
  * This is also used to expand structs. For example:
  * "SELECT record.* from (SELECT struct(a,b,c) as record ...)
  *
- * @param target an optional name that should be the target of the expansion.  
If omitted all
- *               targets' columns are produced. This can either be a table 
name or struct name. This
- *               is a list of identifiers that is the path of the expansion.
- *
- * This class provides the shared behavior between the classes for SELECT * 
([[UnresolvedStar]])
- * and SELECT * EXCEPT ([[UnresolvedStarExceptOrReplace]]). [[UnresolvedStar]] 
is just a case class
- * of this, while [[UnresolvedStarExceptOrReplace]] adds some additional logic 
to the expand method.
+ * This trait provides the shared behavior among the classes for SELECT * 
([[UnresolvedStar]])
+ * and SELECT * EXCEPT ([[UnresolvedStarExceptOrReplace]]), etc. 
[[UnresolvedStar]] is just a case
+ * class of this, while [[UnresolvedStarExceptOrReplace]] or other classes add 
some additional logic
+ * to the expand method.
  */
-abstract class UnresolvedStarBase(target: Option[Seq[String]]) extends Star 
with Unevaluable {
+trait UnresolvedStarBase extends Star with Unevaluable {
+
+  /**
+   * An optional name that should be the target of the expansion. If omitted 
all
+   * targets' columns are produced. This can either be a table name or struct 
name. This
+   * is a list of identifiers that is the path of the expansion.
+   */
+  def target: Option[Seq[String]]
+
   /**
    * Returns true if the nameParts is a subset of the last elements of 
qualifier of the attribute.
    *
@@ -583,7 +588,7 @@ case class UnresolvedStarExceptOrReplace(
     target: Option[Seq[String]],
     excepts: Seq[Seq[String]],
     replacements: Option[Seq[NamedExpression]])
-  extends UnresolvedStarBase(target) {
+  extends LeafExpression with UnresolvedStarBase {
 
   /**
    * We expand the * EXCEPT by the following three steps:
@@ -712,6 +717,103 @@ case class UnresolvedStarExceptOrReplace(
   }
 }
 
+/**
+ * Represents some of the input attributes to a given relational operator, for 
example in
+ * `df.withColumn`.
+ *
+ * @param colNames a list of column names that should be replaced or produced.
+ *
+ * @param exprs the corresponding expressions for `colNames`.
+ *
+ * @param explicitMetadata an optional list of explicit metadata to associate 
with the columns.
+ */
+case class UnresolvedStarWithColumns(
+     colNames: Seq[String],
+     exprs: Seq[Expression],
+     explicitMetadata: Option[Seq[Metadata]] = None)
+  extends UnresolvedStarBase {
+
+  override def target: Option[Seq[String]] = None
+  override def children: Seq[Expression] = exprs
+
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): UnresolvedStarWithColumns =
+    copy(exprs = newChildren)
+
+  override def expand(input: LogicalPlan, resolver: Resolver): 
Seq[NamedExpression] = {
+    assert(colNames.size == exprs.size,
+      s"The size of column names: ${colNames.size} isn't equal to " +
+        s"the size of expressions: ${exprs.size}")
+    explicitMetadata.foreach { m =>
+      assert(colNames.size == m.size,
+        s"The size of column names: ${colNames.size} isn't equal to " +
+          s"the size of metadata elements: ${m.size}")
+    }
+
+    SchemaUtils.checkColumnNameDuplication(colNames, resolver)
+
+    val expandedCols = super.expand(input, resolver)
+
+    val columnSeq = explicitMetadata match {
+      case Some(ms) => colNames.zip(exprs).zip(ms.map(Some(_)))
+      case _ => colNames.zip(exprs).map((_, None))
+    }
+
+    val replacedAndExistingColumns = expandedCols.map { field =>
+      columnSeq.find { case ((colName, _), _) =>
+        resolver(field.name, colName)
+      } match {
+        case Some(((colName, expr), m)) => Alias(expr, 
colName)(explicitMetadata = m)
+        case _ => field
+      }
+    }
+
+    val newColumns = columnSeq.filter { case ((colName, _), _) =>
+      !expandedCols.exists(f => resolver(f.name, colName))
+    }.map {
+      case ((colName, expr), m) => Alias(expr, colName)(explicitMetadata = m)
+    }
+
+    replacedAndExistingColumns ++ newColumns
+  }
+}
+
+/**
+ * Represents some of the input attributes to a given relational operator, for 
example in
+ * `df.withColumnRenamed`.
+ *
+ * @param existingNames a list of column names that should be replaced.
+ *                      If the column does not exist, it is ignored.
+ *
+ * @param newNames a list of new column names that should be used to replace 
the existing columns.
+ */
+case class UnresolvedStarWithColumnsRenames(
+    existingNames: Seq[String],
+    newNames: Seq[String])
+  extends LeafExpression with UnresolvedStarBase {
+
+  override def target: Option[Seq[String]] = None
+
+  override def expand(input: LogicalPlan, resolver: Resolver): 
Seq[NamedExpression] = {
+    assert(existingNames.size == newNames.size,
+      s"The size of existing column names: ${existingNames.size} isn't equal 
to " +
+        s"the size of new column names: ${newNames.size}")
+
+    val expandedCols = super.expand(input, resolver)
+
+    existingNames.zip(newNames).foldLeft(expandedCols) {
+      case (attrs, (existingName, newName)) =>
+        attrs.map(attr =>
+          if (resolver(attr.name, existingName)) {
+            Alias(attr, newName)()
+          } else {
+            attr
+          }
+        )
+    }
+  }
+}
+
 /**
  * Represents all of the input attributes to a given relational operator, for 
example in
  * "SELECT * FROM ...".
@@ -723,7 +825,8 @@ case class UnresolvedStarExceptOrReplace(
  *              targets' columns are produced. This can either be a table name 
or struct name. This
  *              is a list of identifiers that is the path of the expansion.
  */
-case class UnresolvedStar(target: Option[Seq[String]]) extends 
UnresolvedStarBase(target)
+case class UnresolvedStar(target: Option[Seq[String]])
+  extends LeafExpression with UnresolvedStarBase
 
 /**
  * Represents all of the input attributes to a given relational operator, for 
example in
@@ -733,7 +836,7 @@ case class UnresolvedStar(target: Option[Seq[String]]) 
extends UnresolvedStarBas
  *              tables' columns are produced.
  */
 case class UnresolvedRegex(regexPattern: String, table: Option[String], 
caseSensitive: Boolean)
-  extends Star with Unevaluable {
+  extends LeafExpression with Star with Unevaluable {
   override def expand(input: LogicalPlan, resolver: Resolver): 
Seq[NamedExpression] = {
     val pattern = if (caseSensitive) regexPattern else s"(?i)$regexPattern"
     table match {
@@ -791,7 +894,8 @@ case class MultiAlias(child: Expression, names: Seq[String])
  *
  * @param expressions Expressions to expand.
  */
-case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star with 
Unevaluable {
+case class ResolvedStar(expressions: Seq[NamedExpression])
+  extends LeafExpression with Star with Unevaluable {
   override def newInstance(): NamedExpression = throw new 
UnresolvedException("newInstance")
   override def expand(input: LogicalPlan, resolver: Resolver): 
Seq[NamedExpression] = expressions
   override def toString: String = expressions.mkString("ResolvedStar(", ", ", 
")")
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 6ab69aea12e5..acbbeb49b267 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -45,7 +45,7 @@ import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, 
SESSION_ID}
 import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, 
TaskResourceProfile, TaskResourceRequest}
 import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, 
Observation, RelationalGroupedDataset, Row, SparkSession}
 import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, 
FunctionIdentifier, QueryPlanningTracker}
-import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, 
GlobalTempView, LazyExpression, LocalTempView, MultiAlias, 
NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, 
UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, 
UnresolvedExtractValue, UnresolvedFunction, UnresolvedPlanId, UnresolvedRegex, 
UnresolvedRelation, UnresolvedStar, UnresolvedSubqueryColumnAliases, 
UnresolvedTableValuedFunction, UnresolvedTranspose}
+import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, 
GlobalTempView, LazyExpression, LocalTempView, MultiAlias, 
NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, 
UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, 
UnresolvedExtractValue, UnresolvedFunction, UnresolvedPlanId, UnresolvedRegex, 
UnresolvedRelation, UnresolvedStar, UnresolvedStarWithColumns, 
UnresolvedStarWithColumnsRenames, UnresolvedSubqueryColumnAliases, 
UnresolvedTableValuedFunc [...]
 import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder, 
ExpressionEncoder, RowEncoder}
 import 
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.UnboundRowEncoder
 import org.apache.spark.sql.catalyst.expressions._
@@ -1065,25 +1065,21 @@ class SparkConnectPlanner(
   }
 
   private def transformWithColumnsRenamed(rel: proto.WithColumnsRenamed): 
LogicalPlan = {
-    if (rel.getRenamesCount > 0) {
-      val (colNames, newColNames) = rel.getRenamesList.asScala.toSeq.map { 
rename =>
+    val (colNames, newColNames) = if (rel.getRenamesCount > 0) {
+      rel.getRenamesList.asScala.toSeq.map { rename =>
         (rename.getColName, rename.getNewColName)
       }.unzip
-      Dataset
-        .ofRows(session, transformRelation(rel.getInput))
-        .withColumnsRenamed(colNames, newColNames)
-        .logicalPlan
     } else {
       // for backward compatibility
-      Dataset
-        .ofRows(session, transformRelation(rel.getInput))
-        .withColumnsRenamed(rel.getRenameColumnsMapMap)
-        .logicalPlan
+      rel.getRenameColumnsMapMap.asScala.toSeq.unzip
     }
+    Project(
+      Seq(UnresolvedStarWithColumnsRenames(existingNames = colNames, newNames 
= newColNames)),
+      transformRelation(rel.getInput))
   }
 
   private def transformWithColumns(rel: proto.WithColumns): LogicalPlan = {
-    val (colNames, cols, metadata) =
+    val (colNames, exprs, metadata) =
       rel.getAliasesList.asScala.toSeq.map { alias =>
         if (alias.getNameCount != 1) {
           throw InvalidPlanInput(s"""WithColumns require column name only 
contains one name part,
@@ -1096,13 +1092,16 @@ class SparkConnectPlanner(
           Metadata.empty
         }
 
-        (alias.getName(0), Column(transformExpression(alias.getExpr)), 
metadata)
+        (alias.getName(0), transformExpression(alias.getExpr), metadata)
       }.unzip3
 
-    Dataset
-      .ofRows(session, transformRelation(rel.getInput))
-      .withColumns(colNames, cols, metadata)
-      .logicalPlan
+    Project(
+      Seq(
+        UnresolvedStarWithColumns(
+          colNames = colNames,
+          exprs = exprs,
+          explicitMetadata = Some(metadata))),
+      transformRelation(rel.getInput))
   }
 
   private def transformWithWatermark(rel: proto.WithWatermark): LogicalPlan = {
diff --git 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
index aaeb5d9fe509..054a32179935 100644
--- 
a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
+++ 
b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectPlannerSuite.scala
@@ -504,26 +504,27 @@ class SparkConnectPlannerSuite extends SparkFunSuite with 
SparkConnectPlanTest {
   }
 
   test("Test duplicated names in WithColumns") {
-    intercept[AnalysisException] {
-      transform(
-        proto.Relation
-          .newBuilder()
-          .setWithColumns(
-            proto.WithColumns
-              .newBuilder()
-              .setInput(readRel)
-              .addAliases(proto.Expression.Alias
+    val logical = transform(
+      proto.Relation
+        .newBuilder()
+        .setWithColumns(
+          proto.WithColumns
+            .newBuilder()
+            .setInput(readRel)
+            .addAliases(
+              proto.Expression.Alias
                 .newBuilder()
                 .addName("test")
                 .setExpr(proto.Expression.newBuilder
                   
.setLiteral(proto.Expression.Literal.newBuilder.setInteger(32))))
-              .addAliases(proto.Expression.Alias
-                .newBuilder()
-                .addName("test")
-                .setExpr(proto.Expression.newBuilder
-                  
.setLiteral(proto.Expression.Literal.newBuilder.setInteger(32)))))
-          .build())
-    }
+            .addAliases(proto.Expression.Alias
+              .newBuilder()
+              .addName("test")
+              .setExpr(proto.Expression.newBuilder
+                
.setLiteral(proto.Expression.Literal.newBuilder.setInteger(32)))))
+        .build())
+
+    intercept[AnalysisException](Dataset.ofRows(spark, logical))
   }
 
   test("Test multi nameparts for column names in WithColumns") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index e4e782a50e3d..e41521cba533 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1275,29 +1275,14 @@ class Dataset[T] private[sql](
     require(colNames.size == cols.size,
       s"The size of column names: ${colNames.size} isn't equal to " +
         s"the size of columns: ${cols.size}")
-    SchemaUtils.checkColumnNameDuplication(
-      colNames,
-      sparkSession.sessionState.conf.caseSensitiveAnalysis)
-
-    val resolver = sparkSession.sessionState.analyzer.resolver
-    val output = queryExecution.analyzed.output
-
-    val columnSeq = colNames.zip(cols)
-
-    val replacedAndExistingColumns = output.map { field =>
-      columnSeq.find { case (colName, _) =>
-        resolver(field.name, colName)
-      } match {
-        case Some((colName: String, col: Column)) => col.as(colName)
-        case _ => Column(field)
-      }
+    withPlan {
+      Project(
+        Seq(
+          UnresolvedStarWithColumns(
+            colNames = colNames,
+            exprs = cols.map(_.expr))),
+        logicalPlan)
     }
-
-    val newColumns = columnSeq.filter { case (colName, col) =>
-      !output.exists(f => resolver(f.name, colName))
-    }.map { case (colName, col) => col.as(colName) }
-
-    select(replacedAndExistingColumns ++ newColumns : _*)
   }
 
   /** @inheritdoc */
@@ -1324,26 +1309,13 @@ class Dataset[T] private[sql](
     require(colNames.size == newColNames.size,
       s"The size of existing column names: ${colNames.size} isn't equal to " +
         s"the size of new column names: ${newColNames.size}")
-
-    val resolver = sparkSession.sessionState.analyzer.resolver
-    val output: Seq[NamedExpression] = queryExecution.analyzed.output
-    var shouldRename = false
-
-    val projectList = colNames.zip(newColNames).foldLeft(output) {
-      case (attrs, (existingName, newName)) =>
-        attrs.map(attr =>
-          if (resolver(attr.name, existingName)) {
-            shouldRename = true
-            Alias(attr, newName)()
-          } else {
-            attr
-          }
-        )
-    }
-    if (shouldRename) {
-      withPlan(Project(projectList, logicalPlan))
-    } else {
-      toDF()
+    withPlan {
+      Project(
+        Seq(
+          UnresolvedStarWithColumnsRenames(
+            existingNames = colNames,
+            newNames = newColNames)),
+        logicalPlan)
     }
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
index fdfb909d9ba7..621d468454d4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSubquerySuite.scala
@@ -777,9 +777,51 @@ class DataFrameSubquerySuite extends QueryTest with 
SharedSparkSession {
       val t1 = table1()
 
       checkAnswer(
-        t1.withColumn("scalar", spark.range(1).select($"c1".outer() + 
$"c2".outer()).scalar()),
-        t1.withColumn("scalar", $"c1" + $"c2")
-      )
+        t1.withColumn(
+          "scalar",
+          spark
+            .range(1)
+            .select($"c1".outer() + $"c2".outer())
+            .scalar()),
+        t1.select($"*", ($"c1" + $"c2").as("scalar")))
+
+      checkAnswer(
+        t1.withColumn(
+          "scalar",
+          spark
+            .range(1)
+            .withColumn("c1", $"c1".outer())
+            .select($"c1" + $"c2".outer())
+            .scalar()),
+        t1.select($"*", ($"c1" + $"c2").as("scalar")))
+
+      checkAnswer(
+        t1.withColumn(
+          "scalar",
+          spark
+            .range(1)
+            .select($"c1".outer().as("c1"))
+            .withColumn("c2", $"c2".outer())
+            .select($"c1" + $"c2")
+            .scalar()),
+        t1.select($"*", ($"c1" + $"c2").as("scalar")))
+    }
+  }
+
+  test("subquery in withColumnsRenamed") {
+    withView("t1") {
+      val t1 = table1()
+
+      checkAnswer(
+        t1.withColumn(
+          "scalar",
+          spark
+            .range(1)
+            .select($"c1".outer().as("c1"), $"c2".outer().as("c2"))
+            .withColumnsRenamed(Map("c1" -> "x", "c2" -> "y"))
+            .select($"x" + $"y")
+            .scalar()),
+        t1.select($"*", ($"c1".as("x") + $"c2".as("y")).as("scalar")))
     }
   }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to