This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7c3c7c5a4bd [SPARK-41086][SQL] Use DataFrame ID to semantically
validate CollectMetrics
7c3c7c5a4bd is described below
commit 7c3c7c5a4bd94c9e05b5e680a5242c2485875633
Author: Rui Wang <[email protected]>
AuthorDate: Fri Sep 22 11:07:25 2023 +0800
[SPARK-41086][SQL] Use DataFrame ID to semantically validate CollectMetrics
### What changes were proposed in this pull request?
In existing code, plan matching is used to validate if two CollectMetrics
have the same name but different semantic. However, plan matching approach is
fragile. A better way to tackle this is to just utilize the unique DataFrame
Id. This is because observe API is only supported by DataFrame API. SQL does
not have such syntax.
So two CollectMetric are semantic the same if and only if they have same
name and same DataFrame id.
### Why are the changes needed?
This is to use a more stable approach to replace a fragile approach.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
UT
### Was this patch authored or co-authored using generative AI tooling?
NO
Closes #43010 from amaliujia/another_approch_for_collect_metrics.
Authored-by: Rui Wang <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../sql/connect/planner/SparkConnectPlanner.scala | 6 +--
python/pyspark/sql/connect/plan.py | 1 +
.../spark/sql/catalyst/analysis/Analyzer.scala | 4 +-
.../sql/catalyst/analysis/CheckAnalysis.scala | 36 ++------------
.../plans/logical/basicLogicalOperators.scala | 3 +-
.../sql/catalyst/analysis/AnalysisSuite.scala | 55 +++++++++-------------
.../main/scala/org/apache/spark/sql/Dataset.scala | 2 +-
.../spark/sql/execution/SparkStrategies.scala | 2 +-
8 files changed, 35 insertions(+), 74 deletions(-)
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 924169715f7..dda7a713fa0 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -164,7 +164,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder)
extends Logging {
case proto.Relation.RelTypeCase.CACHED_REMOTE_RELATION =>
transformCachedRemoteRelation(rel.getCachedRemoteRelation)
case proto.Relation.RelTypeCase.COLLECT_METRICS =>
- transformCollectMetrics(rel.getCollectMetrics)
+ transformCollectMetrics(rel.getCollectMetrics, rel.getCommon.getPlanId)
case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse)
case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
throw new IndexOutOfBoundsException("Expected Relation to be set, but
is empty.")
@@ -1048,12 +1048,12 @@ class SparkConnectPlanner(val sessionHolder:
SessionHolder) extends Logging {
numPartitionsOpt)
}
- private def transformCollectMetrics(rel: proto.CollectMetrics): LogicalPlan
= {
+ private def transformCollectMetrics(rel: proto.CollectMetrics, planId:
Long): LogicalPlan = {
val metrics = rel.getMetricsList.asScala.toSeq.map { expr =>
Column(transformExpression(expr))
}
- CollectMetrics(rel.getName, metrics.map(_.named),
transformRelation(rel.getInput))
+ CollectMetrics(rel.getName, metrics.map(_.named),
transformRelation(rel.getInput), planId)
}
private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index d069081e1af..219545cf646 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1192,6 +1192,7 @@ class CollectMetrics(LogicalPlan):
assert self._child is not None
plan = proto.Relation()
+ plan.common.plan_id = self._child._plan_id
plan.collect_metrics.input.CopyFrom(self._child.plan(session))
plan.collect_metrics.name = self._name
plan.collect_metrics.metrics.extend([self.col_to_expr(x, session) for
x in self._exprs])
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index cff29de858e..aac85e19721 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -3893,9 +3893,9 @@ object CleanupAliases extends Rule[LogicalPlan] with
AliasHelper {
Window(cleanedWindowExprs, partitionSpec.map(trimAliases),
orderSpec.map(trimAliases(_).asInstanceOf[SortOrder]), child)
- case CollectMetrics(name, metrics, child) =>
+ case CollectMetrics(name, metrics, child, dataframeId) =>
val cleanedMetrics = metrics.map(trimNonTopLevelAliases)
- CollectMetrics(name, cleanedMetrics, child)
+ CollectMetrics(name, cleanedMetrics, child, dataframeId)
case Unpivot(ids, values, aliases, variableColumnName, valueColumnNames,
child) =>
val cleanedIds = ids.map(_.map(trimNonTopLevelAliases))
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 3c9a816df26..83b682bc917 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -497,7 +497,7 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
groupingExprs.foreach(checkValidGroupingExprs)
aggregateExprs.foreach(checkValidAggregateExpression)
- case CollectMetrics(name, metrics, _) =>
+ case CollectMetrics(name, metrics, _, _) =>
if (name == null || name.isEmpty) {
operator.failAnalysis(
errorClass = "INVALID_OBSERVED_METRICS.MISSING_NAME",
@@ -1097,17 +1097,15 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
* are allowed (e.g. self-joins).
*/
private def checkCollectedMetrics(plan: LogicalPlan): Unit = {
- val metricsMap = mutable.Map.empty[String, LogicalPlan]
+ val metricsMap = mutable.Map.empty[String, CollectMetrics]
def check(plan: LogicalPlan): Unit = plan.foreach { node =>
node match {
- case metrics @ CollectMetrics(name, _, _) =>
- val simplifiedMetrics =
simplifyPlanForCollectedMetrics(metrics.canonicalized)
+ case metrics @ CollectMetrics(name, _, _, dataframeId) =>
metricsMap.get(name) match {
case Some(other) =>
- val simplifiedOther =
simplifyPlanForCollectedMetrics(other.canonicalized)
// Exact duplicates are allowed. They can be the result
// of a CTE that is used multiple times or a self join.
- if (simplifiedMetrics != simplifiedOther) {
+ if (dataframeId != other.dataframeId) {
failAnalysis(
errorClass = "DUPLICATED_METRICS_NAME",
messageParameters = Map("metricName" -> name))
@@ -1126,32 +1124,6 @@ trait CheckAnalysis extends PredicateHelper with
LookupCatalog with QueryErrorsB
check(plan)
}
- /**
- * This method is only used for checking collected metrics. This method
tries to
- * remove extra project which only re-assign expr ids from the plan so that
we can identify exact
- * duplicates metric definition.
- */
- def simplifyPlanForCollectedMetrics(plan: LogicalPlan): LogicalPlan = {
- plan.resolveOperators {
- case p: Project if p.projectList.size == p.child.output.size =>
- val assignExprIdOnly = p.projectList.zipWithIndex.forall {
- case (Alias(attr: AttributeReference, _), index) =>
- // The input plan of this method is already canonicalized. The
attribute id becomes the
- // ordinal of this attribute in the child outputs. So an
alias-only Project means the
- // the id of the aliased attribute is the same as its index in the
project list.
- attr.exprId.id == index
- case (left: AttributeReference, index) =>
- left.exprId.id == index
- case _ => false
- }
- if (assignExprIdOnly) {
- p.child
- } else {
- p
- }
- }
- }
-
/**
* Validates to make sure the outer references appearing inside the subquery
* are allowed.
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index efb7dbb44ef..8f976a49a2b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -1969,7 +1969,8 @@ trait SupportsSubquery extends LogicalPlan
case class CollectMetrics(
name: String,
metrics: Seq[NamedExpression],
- child: LogicalPlan)
+ child: LogicalPlan,
+ dataframeId: Long)
extends UnaryNode {
override lazy val resolved: Boolean = {
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index ed3137430df..ffc12a2b981 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -779,34 +779,35 @@ class AnalysisSuite extends AnalysisTest with Matchers {
val literal = Literal(1).as("lit")
// Ok
- assert(CollectMetrics("event", literal :: sum :: random_sum :: Nil,
testRelation).resolved)
+ assert(CollectMetrics("event", literal :: sum :: random_sum :: Nil,
testRelation, 0).resolved)
// Bad name
- assert(!CollectMetrics("", sum :: Nil, testRelation).resolved)
+ assert(!CollectMetrics("", sum :: Nil, testRelation, 0).resolved)
assertAnalysisErrorClass(
- CollectMetrics("", sum :: Nil, testRelation),
+ CollectMetrics("", sum :: Nil, testRelation, 0),
expectedErrorClass = "INVALID_OBSERVED_METRICS.MISSING_NAME",
expectedMessageParameters = Map(
- "operator" -> "'CollectMetrics , [sum(a#x) AS sum#xL]\n+-
LocalRelation <empty>, [a#x]\n")
+ "operator" ->
+ "'CollectMetrics , [sum(a#x) AS sum#xL], 0\n+- LocalRelation
<empty>, [a#x]\n")
)
// No columns
- assert(!CollectMetrics("evt", Nil, testRelation).resolved)
+ assert(!CollectMetrics("evt", Nil, testRelation, 0).resolved)
def checkAnalysisError(exprs: Seq[NamedExpression], errors: String*): Unit
= {
- assertAnalysisError(CollectMetrics("event", exprs, testRelation), errors)
+ assertAnalysisError(CollectMetrics("event", exprs, testRelation, 0),
errors)
}
// Unwrapped attribute
assertAnalysisErrorClass(
- CollectMetrics("event", a :: Nil, testRelation),
+ CollectMetrics("event", a :: Nil, testRelation, 0),
expectedErrorClass =
"INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_ATTRIBUTE",
expectedMessageParameters = Map("expr" -> "\"a\"")
)
// Unwrapped non-deterministic expression
assertAnalysisErrorClass(
- CollectMetrics("event", Rand(10).as("rnd") :: Nil, testRelation),
+ CollectMetrics("event", Rand(10).as("rnd") :: Nil, testRelation, 0),
expectedErrorClass =
"INVALID_OBSERVED_METRICS.NON_AGGREGATE_FUNC_ARG_IS_NON_DETERMINISTIC",
expectedMessageParameters = Map("expr" -> "\"rand(10) AS rnd\"")
)
@@ -816,7 +817,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
CollectMetrics(
"event",
Sum(a).toAggregateExpression(isDistinct = true).as("sum") :: Nil,
- testRelation),
+ testRelation, 0),
expectedErrorClass =
"INVALID_OBSERVED_METRICS.AGGREGATE_EXPRESSION_WITH_DISTINCT_UNSUPPORTED",
expectedMessageParameters = Map("expr" -> "\"sum(DISTINCT a) AS sum\"")
@@ -827,7 +828,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
CollectMetrics(
"event",
Sum(Sum(a).toAggregateExpression()).toAggregateExpression().as("sum")
:: Nil,
- testRelation),
+ testRelation, 0),
expectedErrorClass =
"INVALID_OBSERVED_METRICS.NESTED_AGGREGATES_UNSUPPORTED",
expectedMessageParameters = Map("expr" -> "\"sum(sum(a)) AS sum\"")
)
@@ -838,7 +839,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
WindowSpecDefinition(Nil, a.asc :: Nil,
SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow)))
assertAnalysisErrorClass(
- CollectMetrics("event", windowExpr.as("rn") :: Nil, testRelation),
+ CollectMetrics("event", windowExpr.as("rn") :: Nil, testRelation, 0),
expectedErrorClass =
"INVALID_OBSERVED_METRICS.WINDOW_EXPRESSIONS_UNSUPPORTED",
expectedMessageParameters = Map(
"expr" ->
@@ -856,14 +857,14 @@ class AnalysisSuite extends AnalysisTest with Matchers {
// Same result - duplicate names are allowed
assertAnalysisSuccess(Union(
- CollectMetrics("evt1", count :: Nil, testRelation) ::
- CollectMetrics("evt1", count :: Nil, testRelation) :: Nil))
+ CollectMetrics("evt1", count :: Nil, testRelation, 0) ::
+ CollectMetrics("evt1", count :: Nil, testRelation, 0) :: Nil))
// Same children, structurally different metrics - fail
assertAnalysisErrorClass(
Union(
- CollectMetrics("evt1", count :: Nil, testRelation) ::
- CollectMetrics("evt1", sum :: Nil, testRelation) :: Nil),
+ CollectMetrics("evt1", count :: Nil, testRelation, 0) ::
+ CollectMetrics("evt1", sum :: Nil, testRelation, 1) :: Nil),
expectedErrorClass = "DUPLICATED_METRICS_NAME",
expectedMessageParameters = Map("metricName" -> "evt1")
)
@@ -873,17 +874,17 @@ class AnalysisSuite extends AnalysisTest with Matchers {
val tblB = LocalRelation(b)
assertAnalysisErrorClass(
Union(
- CollectMetrics("evt1", count :: Nil, testRelation) ::
- CollectMetrics("evt1", count :: Nil, tblB) :: Nil),
+ CollectMetrics("evt1", count :: Nil, testRelation, 0) ::
+ CollectMetrics("evt1", count :: Nil, tblB, 1) :: Nil),
expectedErrorClass = "DUPLICATED_METRICS_NAME",
expectedMessageParameters = Map("metricName" -> "evt1")
)
// Subquery different tree - fail
- val subquery = Aggregate(Nil, sum :: Nil, CollectMetrics("evt1", count ::
Nil, testRelation))
+ val subquery = Aggregate(Nil, sum :: Nil, CollectMetrics("evt1", count ::
Nil, testRelation, 0))
val query = Project(
b :: ScalarSubquery(subquery, Nil).as("sum") :: Nil,
- CollectMetrics("evt1", count :: Nil, tblB))
+ CollectMetrics("evt1", count :: Nil, tblB, 1))
assertAnalysisErrorClass(
query,
expectedErrorClass = "DUPLICATED_METRICS_NAME",
@@ -895,7 +896,7 @@ class AnalysisSuite extends AnalysisTest with Matchers {
case a: AggregateExpression => a.copy(filter = Some(true))
}.asInstanceOf[NamedExpression]
assertAnalysisErrorClass(
- CollectMetrics("evt1", sumWithFilter :: Nil, testRelation),
+ CollectMetrics("evt1", sumWithFilter :: Nil, testRelation, 0),
expectedErrorClass =
"INVALID_OBSERVED_METRICS.AGGREGATE_EXPRESSION_WITH_FILTER_UNSUPPORTED",
expectedMessageParameters = Map("expr" -> "\"sum(a) FILTER (WHERE true)
AS sum\"")
@@ -1675,18 +1676,4 @@ class AnalysisSuite extends AnalysisTest with Matchers {
checkAnalysis(ident2.select($"a"), testRelation.select($"a").analyze)
}
}
-
- test("simplifyPlanForCollectedMetrics should handle non alias-only project
case") {
- val inner = Project(
- Seq(
- Alias(testRelation2.output(0), "a")(),
- testRelation2.output(1),
- Alias(testRelation2.output(2), "c")(),
- testRelation2.output(3),
- testRelation2.output(4)
- ),
- testRelation2)
- val actualPlan =
getAnalyzer.simplifyPlanForCollectedMetrics(inner.canonicalized)
- assert(actualPlan == testRelation2.canonicalized)
- }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 528904bb29a..f07496e6430 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2218,7 +2218,7 @@ class Dataset[T] private[sql](
*/
@varargs
def observe(name: String, expr: Column, exprs: Column*): Dataset[T] =
withTypedPlan {
- CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan)
+ CollectMetrics(name, (expr +: exprs).map(_.named), logicalPlan, id)
}
/**
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 903565a6d59..d851eacd5ab 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -935,7 +935,7 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
throw QueryExecutionErrors.ddlUnsupportedTemporarilyError("UPDATE
TABLE")
case _: MergeIntoTable =>
throw QueryExecutionErrors.ddlUnsupportedTemporarilyError("MERGE INTO
TABLE")
- case logical.CollectMetrics(name, metrics, child) =>
+ case logical.CollectMetrics(name, metrics, child, _) =>
execution.CollectMetricsExec(name, metrics, planLater(child)) :: Nil
case WriteFiles(child, fileFormat, partitionColumns, bucket, options,
staticPartitions) =>
WriteFilesExec(planLater(child), fileFormat, partitionColumns, bucket,
options,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]