This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new d9f0c44e7f24 [SPARK-45770][SQL][PYTHON][CONNECT][3.5] Introduce plan
DataFrameDropColumns for Dataframe.drop
d9f0c44e7f24 is described below
commit d9f0c44e7f24cba95f7bf1737bb52ff73a7b9094
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Tue Nov 14 12:09:37 2023 +0800
[SPARK-45770][SQL][PYTHON][CONNECT][3.5] Introduce plan
DataFrameDropColumns for Dataframe.drop
### What changes were proposed in this pull request?
backport https://github.com/apache/spark/pull/43683 to 3.5
### Why are the changes needed?
to fix a connect bug
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #43776 from zhengruifeng/sql_drop_plan_35.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/tests/test_dataframe.py | 37 ++++++++++++++++
.../spark/sql/catalyst/analysis/Analyzer.scala | 1 +
.../analysis/ResolveDataFrameDropColumns.scala | 49 ++++++++++++++++++++++
.../plans/logical/basicLogicalOperators.scala | 14 +++++++
.../spark/sql/catalyst/trees/TreePatterns.scala | 1 +
.../main/scala/org/apache/spark/sql/Dataset.scala | 15 +------
6 files changed, 104 insertions(+), 13 deletions(-)
diff --git a/python/pyspark/sql/tests/test_dataframe.py
b/python/pyspark/sql/tests/test_dataframe.py
index 33049233dee9..5907c8c09fb4 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -106,6 +106,43 @@ class DataFrameTestsMixin:
self.assertEqual(df.drop(col("name"), col("age")).columns, ["active"])
self.assertEqual(df.drop(col("name"), col("age"),
col("random")).columns, ["active"])
+ def test_drop_join(self):
+ left_df = self.spark.createDataFrame(
+ [(1, "a"), (2, "b"), (3, "c")],
+ ["join_key", "value1"],
+ )
+ right_df = self.spark.createDataFrame(
+ [(1, "aa"), (2, "bb"), (4, "dd")],
+ ["join_key", "value2"],
+ )
+ joined_df = left_df.join(
+ right_df,
+ on=left_df["join_key"] == right_df["join_key"],
+ how="left",
+ )
+
+ dropped_1 = joined_df.drop(left_df["join_key"])
+ self.assertEqual(dropped_1.columns, ["value1", "join_key", "value2"])
+ self.assertEqual(
+ dropped_1.sort("value1").collect(),
+ [
+ Row(value1="a", join_key=1, value2="aa"),
+ Row(value1="b", join_key=2, value2="bb"),
+ Row(value1="c", join_key=None, value2=None),
+ ],
+ )
+
+ dropped_2 = joined_df.drop(right_df["join_key"])
+ self.assertEqual(dropped_2.columns, ["join_key", "value1", "value2"])
+ self.assertEqual(
+ dropped_2.sort("value1").collect(),
+ [
+ Row(join_key=1, value1="a", value2="aa"),
+ Row(join_key=2, value1="b", value2="bb"),
+ Row(join_key=3, value1="c", value2=None),
+ ],
+ )
+
def test_with_columns_renamed(self):
df = self.spark.createDataFrame([("Alice", 50), ("Alice", 60)],
["name", "age"])
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 8e3c9b30c61b..80cb5d8c6087 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -307,6 +307,7 @@ class Analyzer(override val catalogManager: CatalogManager)
extends RuleExecutor
ResolveWindowFrame ::
ResolveNaturalAndUsingJoin ::
ResolveOutputRelation ::
+ new ResolveDataFrameDropColumns(catalogManager) ::
ExtractWindowExpressions ::
GlobalAggregates ::
ResolveAggregateFunctions ::
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
new file mode 100644
index 000000000000..2642b4a1c5da
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns,
LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.DF_DROP_COLUMNS
+import org.apache.spark.sql.connector.catalog.CatalogManager
+
+/**
+ * A rule that rewrites DataFrameDropColumns to Project.
+ * Note that DataFrameDropColumns allows and ignores non-existing columns.
+ */
+class ResolveDataFrameDropColumns(val catalogManager: CatalogManager)
+ extends Rule[LogicalPlan] with ColumnResolutionHelper {
+
+ override def apply(plan: LogicalPlan): LogicalPlan =
plan.resolveOperatorsWithPruning(
+ _.containsPattern(DF_DROP_COLUMNS)) {
+ case d: DataFrameDropColumns if d.childrenResolved =>
+ // expressions in dropList can be unresolved, e.g.
+ // df.drop(col("non-existing-column"))
+ val dropped = d.dropList.map {
+ case u: UnresolvedAttribute =>
+ resolveExpressionByPlanChildren(u, d.child)
+ case e => e
+ }
+ val remaining = d.child.output.filterNot(attr =>
dropped.exists(_.semanticEquals(attr)))
+ if (remaining.size == d.child.output.size) {
+ d.child
+ } else {
+ Project(remaining, d.child)
+ }
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index 96b67fc52e0d..0e460706fc5b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -235,6 +235,20 @@ object Project {
}
}
+case class DataFrameDropColumns(dropList: Seq[Expression], child: LogicalPlan)
extends UnaryNode {
+ override def output: Seq[Attribute] = Nil
+
+ override def maxRows: Option[Long] = child.maxRows
+ override def maxRowsPerPartition: Option[Long] = child.maxRowsPerPartition
+
+ final override val nodePatterns: Seq[TreePattern] = Seq(DF_DROP_COLUMNS)
+
+ override lazy val resolved: Boolean = false
+
+ override protected def withNewChildInternal(newChild: LogicalPlan):
DataFrameDropColumns =
+ copy(child = newChild)
+}
+
/**
* Applies a [[Generator]] to a stream of input rows, combining the
* output of each into a new stream of rows. This operation is similar to a
`flatMap` in functional
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index b806ebbed52d..bf7b2db1719f 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -105,6 +105,7 @@ object TreePattern extends Enumeration {
val AS_OF_JOIN: Value = Value
val COMMAND: Value = Value
val CTE: Value = Value
+ val DF_DROP_COLUMNS: Value = Value
val DISTINCT_LIKE: Value = Value
val EVAL_PYTHON_UDF: Value = Value
val EVAL_PYTHON_UDTF: Value = Value
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index e047b927b905..f53c6ddaa388 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -3013,19 +3013,8 @@ class Dataset[T] private[sql](
* @since 3.4.0
*/
@scala.annotation.varargs
- def drop(col: Column, cols: Column*): DataFrame = {
- val allColumns = col +: cols
- val expressions = (for (col <- allColumns) yield col match {
- case Column(u: UnresolvedAttribute) =>
- queryExecution.analyzed.resolveQuoted(
- u.name, sparkSession.sessionState.analyzer.resolver).getOrElse(u)
- case Column(expr: Expression) => expr
- })
- val attrs = this.logicalPlan.output
- val colsAfterDrop = attrs.filter { attr =>
- expressions.forall(expression => !attr.semanticEquals(expression))
- }.map(attr => Column(attr))
- select(colsAfterDrop : _*)
+ def drop(col: Column, cols: Column*): DataFrame = withPlan {
+ DataFrameDropColumns((col +: cols).map(_.expr), logicalPlan)
}
/**
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]