This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 37ead02eb13a [SPARK-54827][SQL] Add helper function
`TreeNode.containsTag`
37ead02eb13a is described below
commit 37ead02eb13afde8837db1013df150ab01039b0a
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Dec 24 20:20:48 2025 +0800
[SPARK-54827][SQL] Add helper function `TreeNode.containsTag`
### What changes were proposed in this pull request?
Add helper function `TreeNode.containsTag`
### Why are the changes needed?
In many places, we don't care the tag value, we only need to check whether
a tag exists.
This new function can help simplify the code a bit, e.g.
`getTagValue(Cast.BY_TABLE_INSERTION).isDefined` ->
`containsTag(Cast.BY_TABLE_INSERTION)`
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ci
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #53587 from zhengruifeng/containsTag.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../org/apache/spark/sql/catalyst/analysis/AliasResolution.scala | 2 +-
.../sql/catalyst/analysis/ApplyDefaultCollationToStringType.scala | 2 +-
.../org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 2 +-
.../apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala | 2 +-
.../spark/sql/catalyst/analysis/ColumnResolutionHelper.scala | 7 +++----
.../spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala | 2 +-
.../spark/sql/catalyst/analysis/TypeCoercionValidation.scala | 3 +--
.../catalyst/analysis/resolver/DefaultCollationTypeCoercion.scala | 2 +-
.../apache/spark/sql/catalyst/analysis/resolver/JoinResolver.scala | 2 +-
.../spark/sql/catalyst/analysis/resolver/ResolverGuard.scala | 2 +-
.../scala/org/apache/spark/sql/catalyst/expressions/Cast.scala | 2 +-
.../apache/spark/sql/catalyst/expressions/namedExpressions.scala | 2 +-
.../spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala | 2 +-
.../org/apache/spark/sql/catalyst/optimizer/expressions.scala | 2 +-
.../scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala | 2 +-
.../main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala | 4 ++++
.../main/scala/org/apache/spark/sql/catalyst/util/package.scala | 2 +-
.../src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala | 2 +-
.../spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala | 2 +-
.../spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala | 4 ++--
.../src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala | 2 +-
21 files changed, 27 insertions(+), 25 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AliasResolution.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AliasResolution.scala
index 1811eb4e403b..c16170dd84dc 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AliasResolution.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AliasResolution.scala
@@ -70,7 +70,7 @@ object AliasResolution {
private def extractOnly(e: Expression): Boolean = e match {
case _: ExtractValue => e.children.forall(extractOnly)
case _: Literal => true
- case attr: Attribute if
attr.getTagValue(ResolverTag.SINGLE_PASS_IS_LCA).isEmpty => true
+ case attr: Attribute if !attr.containsTag(ResolverTag.SINGLE_PASS_IS_LCA)
=> true
case _ => false
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyDefaultCollationToStringType.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyDefaultCollationToStringType.scala
index c51d8ebbe92c..11e813276283 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyDefaultCollationToStringType.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ApplyDefaultCollationToStringType.scala
@@ -268,7 +268,7 @@ object ApplyDefaultCollationToStringType extends
Rule[LogicalPlan] {
newType => columnDef.copy(dataType =
replaceDefaultStringType(columnDef.dataType, newType))
case cast: Cast if hasDefaultStringType(cast.dataType) &&
- cast.getTagValue(Cast.USER_SPECIFIED_CAST).isDefined =>
+ cast.containsTag(Cast.USER_SPECIFIED_CAST) =>
newType => cast.copy(dataType = replaceDefaultStringType(cast.dataType,
newType))
case Literal(value, dt) if hasDefaultStringType(dt) =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 3b8a363e704a..c34ae507e758 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -898,7 +898,7 @@ trait CheckAnalysis extends LookupCatalog with
QueryErrorsBase with PlanToString
"invalidExprSqls" -> invalidExprSqls.mkString(", ")))
case j @ LateralJoin(_, right, _, _)
- if j.getTagValue(LateralJoin.BY_TABLE_ARGUMENT).isEmpty =>
+ if !j.containsTag(LateralJoin.BY_TABLE_ARGUMENT) =>
right.plan.foreach {
case Generate(pyudtf: PythonUDTF, _, _, _, _, _)
if pyudtf.evalType == PythonEvalType.SQL_ARROW_UDTF =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala
index 800a036df596..c77ab7cc7a2e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CollationTypeCoercion.scala
@@ -36,7 +36,7 @@ object CollationTypeCoercion extends SQLConfHelper {
private val COLLATION_CONTEXT_TAG = new
TreeNodeTag[DataType]("collationContext")
private def hasCollationContextTag(expr: Expression): Boolean = {
- expr.getTagValue(COLLATION_CONTEXT_TAG).isDefined
+ expr.containsTag(COLLATION_CONTEXT_TAG)
}
def apply(expression: Expression): Expression = expression match {
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
index 870e03364225..e9f4a4d92c7b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
@@ -140,8 +140,7 @@ trait ColumnResolutionHelper extends Logging with
DataTypeErrorsBase {
}
matched(ordinal)
- case u @ UnresolvedAttribute(nameParts)
- if u.getTagValue(LogicalPlan.PLAN_ID_TAG).isEmpty =>
+ case u @ UnresolvedAttribute(nameParts) if
!u.containsTag(LogicalPlan.PLAN_ID_TAG) =>
// UnresolvedAttribute with PLAN_ID_TAG should be resolved in
resolveDataFrameColumn
val result = withPosition(u) {
resolveColumnByName(nameParts)
@@ -451,7 +450,7 @@ trait ColumnResolutionHelper extends Logging with
DataTypeErrorsBase {
u: UnresolvedAttribute,
q: LogicalPlan,
includeLastResort: Boolean = false): Option[Expression] = {
- assert(u.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty,
+ assert(u.containsTag(LogicalPlan.PLAN_ID_TAG),
s"UnresolvedAttribute $u should have a Plan Id tag")
resolveDataFrameColumn(u, q.children).map { r =>
@@ -524,7 +523,7 @@ trait ColumnResolutionHelper extends Logging with
DataTypeErrorsBase {
val planId = planIdOpt.get
logDebug(s"Extract plan_id $planId from $u")
- val isMetadataAccess = u.getTagValue(LogicalPlan.IS_METADATA_COL).nonEmpty
+ val isMetadataAccess = u.containsTag(LogicalPlan.IS_METADATA_COL)
val (resolved, matched) = resolveDataFrameColumnByPlanId(
u, planId, isMetadataAccess, q, 0)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
index a0f67fa3f445..00a0665bc4eb 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveDataFrameDropColumns.scala
@@ -36,7 +36,7 @@ class ResolveDataFrameDropColumns(val catalogManager:
CatalogManager)
// df.drop(col("non-existing-column"))
val dropped = d.dropList.flatMap {
case u: UnresolvedAttribute =>
- if (u.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty) {
+ if (u.containsTag(LogicalPlan.PLAN_ID_TAG)) {
// Plan Id comes from Spark Connect,
// Here we ignore the `UnresolvedAttribute` if its Plan Id can be
found
// but column not found.
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionValidation.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionValidation.scala
index 24097c55895e..67d0a4fb17aa 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionValidation.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionValidation.scala
@@ -92,8 +92,7 @@ object TypeCoercionValidation extends QueryErrorsBase {
var issueFixedIfAnsiOff = true
getAllExpressions(nonAnsiPlan).foreach(_.foreachUp {
case e: Expression
- if e.getTagValue(DATA_TYPE_MISMATCH_ERROR).isDefined &&
- e.checkInputDataTypes().isFailure =>
+ if e.containsTag(DATA_TYPE_MISMATCH_ERROR) &&
e.checkInputDataTypes().isFailure =>
e.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckFailure(_) | _:
TypeCheckResult.DataTypeMismatch =>
issueFixedIfAnsiOff = false
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/DefaultCollationTypeCoercion.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/DefaultCollationTypeCoercion.scala
index 38e00000a17e..bda40f71ecd1 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/DefaultCollationTypeCoercion.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/DefaultCollationTypeCoercion.scala
@@ -92,7 +92,7 @@ object DefaultCollationTypeCoercion {
* we should change all its occurrences to [[StringType]] with default
collation.
*/
private def shouldApplyCollationToCast(cast: Cast): Boolean = {
- cast.getTagValue(Cast.USER_SPECIFIED_CAST).isDefined &&
+ cast.containsTag(Cast.USER_SPECIFIED_CAST) &&
hasDefaultStringType(cast.dataType)
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/JoinResolver.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/JoinResolver.scala
index c718a8ba3782..2d9bd845fab9 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/JoinResolver.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/JoinResolver.scala
@@ -208,7 +208,7 @@ class JoinResolver(resolver: Resolver, expressionResolver:
ExpressionResolver)
scopes.current.hiddenOutput.filter(_.qualifiedAccessOnly)
val newProjectList =
- if (unresolvedJoin.getTagValue(ResolverTag.TOP_LEVEL_OPERATOR).isEmpty) {
+ if (!unresolvedJoin.containsTag(ResolverTag.TOP_LEVEL_OPERATOR)) {
newOutputList ++ qualifiedAccessOnlyColumnsFromHiddenOutput
} else {
newOutputList
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala
index 5b28d5369e38..8c26003b733b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/resolver/ResolverGuard.scala
@@ -328,7 +328,7 @@ class ResolverGuard(catalogManager: CatalogManager) extends
SQLConfHelper {
private def checkUnresolvedAttribute(unresolvedAttribute:
UnresolvedAttribute) =
!ResolverGuard.UNSUPPORTED_ATTRIBUTE_NAMES.contains(unresolvedAttribute.nameParts.head)
&&
- !unresolvedAttribute.getTagValue(LogicalPlan.PLAN_ID_TAG).isDefined
+ !unresolvedAttribute.containsTag(LogicalPlan.PLAN_ID_TAG)
private def checkUnresolvedPredicate(unresolvedPredicate: Predicate) =
unresolvedPredicate match {
case inSubquery: InSubquery =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index 727b155e8579..849f3b8a0d1b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -575,7 +575,7 @@ case class Cast(
private def typeCheckFailureInCast: DataTypeMismatch = evalMode match {
case EvalMode.ANSI =>
- if (getTagValue(Cast.BY_TABLE_INSERTION).isDefined) {
+ if (containsTag(Cast.BY_TABLE_INSERTION)) {
Cast.typeCheckFailureMessage(child.dataType, dataType,
Some(SQLConf.STORE_ASSIGNMENT_POLICY.key ->
SQLConf.StoreAssignmentPolicy.LEGACY.toString))
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 732fc9a02a1d..ed06fb2ae05d 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -417,7 +417,7 @@ case class PrettyAttribute(
override def toString: String = name
override def sql: String = {
- if (getTagValue(ResolverTag.SINGLE_PASS_IS_LCA).nonEmpty) {
+ if (containsTag(ResolverTag.SINGLE_PASS_IS_LCA)) {
// For a query like:
//
// {{{ select 1 as a, a + 1 }}}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
index 86316494f6ff..aae092bcb263 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PropagateEmptyRelation.scala
@@ -148,7 +148,7 @@ abstract class PropagateEmptyRelationBase extends
Rule[LogicalPlan] with CastSup
case _: LocalLimit if !p.isStreaming => empty(p)
case _: Offset => empty(p)
case _: RepartitionOperation =>
- if (p.getTagValue(ROOT_REPARTITION).isEmpty) {
+ if (!p.containsTag(ROOT_REPARTITION)) {
empty(p)
} else {
p.unsetTagValue(ROOT_REPARTITION)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 71eb3e5ea2bd..661e43f8548b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -88,7 +88,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
case Size(c: CreateMap, _) if c.children.forall(hasNoSideEffect) =>
Literal(c.children.length / 2)
- case e if e.getTagValue(FAILED_TO_EVALUATE).isDefined => e
+ case e if e.containsTag(FAILED_TO_EVALUATE) => e
// Fold expressions that are foldable.
case e if e.foldable => tryFold(e, isConditionalBranch)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala
index aeac6f2e6914..a9fe34242d9e 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/NormalizePlan.scala
@@ -157,7 +157,7 @@ object NormalizePlan extends PredicateHelper {
.reduce(And)
Join(left, right, newJoinType, Some(newCondition), hint)
case project: Project
- if
project.getTagValue(ResolverTag.PROJECT_FOR_EXPRESSION_ID_DEDUPLICATION).isDefined
=>
+ if
project.containsTag(ResolverTag.PROJECT_FOR_EXPRESSION_ID_DEDUPLICATION) =>
project.child
case aggregate @ Aggregate(_, _, innerProject: Project, _) =>
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 5acd441e98f4..e82e6a30b9bb 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -196,6 +196,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
tags(tag) = value
}
+ def containsTag[T](tag: TreeNodeTag[T]): Boolean = {
+ getTagValue[T](tag).isDefined
+ }
+
def getTagValue[T](tag: TreeNodeTag[T]): Option[T] = {
if (isTagsEmpty) {
None
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
index 562a02e6a111..9f8e9c706760 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala
@@ -116,7 +116,7 @@ package object util extends Logging {
),
dataType = r.dataType
)
- case c: Cast if c.getTagValue(Cast.USER_SPECIFIED_CAST).isEmpty =>
+ case c: Cast if !c.containsTag(Cast.USER_SPECIFIED_CAST) =>
PrettyAttribute(usePrettyExpression(c.child,
shouldTrimTempResolvedColumn).sql, c.dataType)
case p: PythonFuncExpression => PrettyPythonUDF(p.name, p.dataType,
p.children)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index e7844a88bf14..7f94cc77f345 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -122,7 +122,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with
Logging with Serializ
private def setLogicalLink(logicalPlan: LogicalPlan, inherited: Boolean =
false): Unit = {
// Stop at a descendant which is the root of a sub-tree transformed from
another logical node.
- if (inherited && getTagValue(SparkPlan.LOGICAL_PLAN_TAG).isDefined) {
+ if (inherited && containsTag(SparkPlan.LOGICAL_PLAN_TAG)) {
return
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
index 7b3e0cd549b8..e2a013b9e814 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEPropagateEmptyRelation.scala
@@ -43,7 +43,7 @@ object AQEPropagateEmptyRelation extends
PropagateEmptyRelationBase {
override protected def empty(plan: LogicalPlan): LogicalPlan =
EmptyRelation(plan)
private def isRootRepartition(plan: LogicalPlan): Boolean = plan match {
- case l: LogicalQueryStage if l.getTagValue(ROOT_REPARTITION).isDefined =>
true
+ case l: LogicalQueryStage if l.containsTag(ROOT_REPARTITION) => true
case _ => false
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
index 0e50c03b6cc9..4840016bf745 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala
@@ -718,7 +718,7 @@ case class AdaptiveSparkPlanExec(
private def setLogicalLinkForNewQueryStage(stage: QueryStageExec, plan:
SparkPlan): Unit = {
val link = plan.getTagValue(TEMP_LOGICAL_PLAN_TAG).orElse(
plan.logicalLink.orElse(plan.collectFirst {
- case p if p.getTagValue(TEMP_LOGICAL_PLAN_TAG).isDefined =>
+ case p if p.containsTag(TEMP_LOGICAL_PLAN_TAG) =>
p.getTagValue(TEMP_LOGICAL_PLAN_TAG).get
case p if p.logicalLink.isDefined => p.logicalLink.get
}))
@@ -835,7 +835,7 @@ case class AdaptiveSparkPlanExec(
*/
private def cleanUpTempTags(plan: SparkPlan): Unit = {
plan.foreach {
- case plan: SparkPlan if
plan.getTagValue(TEMP_LOGICAL_PLAN_TAG).isDefined =>
+ case plan: SparkPlan if plan.containsTag(TEMP_LOGICAL_PLAN_TAG) =>
plan.unsetTagValue(TEMP_LOGICAL_PLAN_TAG)
case _ =>
}
diff --git
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 392813467cd4..07d9df59b86d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -264,7 +264,7 @@ case class RelationConversions(
tableDesc, _, query, overwrite, ifPartitionNotExists, _, _, _, _, _, _)
if query.resolved && DDLUtils.isHiveTable(tableDesc) &&
tableDesc.partitionColumnNames.isEmpty && isConvertible(tableDesc)
&&
- conf.getConf(HiveUtils.CONVERT_METASTORE_CTAS) &&
i.getTagValue(BY_CTAS).isDefined =>
+ conf.getConf(HiveUtils.CONVERT_METASTORE_CTAS) &&
i.containsTag(BY_CTAS) =>
// validation is required to be done here before relation conversion.
DDLUtils.checkTableColumns(tableDesc.copy(schema = query.schema))
val hiveTable = DDLUtils.readHiveTable(tableDesc)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]