This is an automated email from the ASF dual-hosted git repository.
HeartSaVioR pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0fb04a4ac9aa [SPARK-57003][SQL][SS] Widen stateful operator output and
state schema nullability
0fb04a4ac9aa is described below
commit 0fb04a4ac9aa41cc5b9d0fa4c42877d2f1f450eb
Author: Jungtaek Lim <[email protected]>
AuthorDate: Wed May 27 22:30:16 2026 +0900
[SPARK-57003][SQL][SS] Widen stateful operator output and state schema
nullability
### What changes were proposed in this pull request?
Introduce a three-component fix for stateful-operator nullability drift,
gated by `spark.sql.streaming.statefulOperator.alwaysNullableOutput.enabled`
(pinned per-query via the offset log):
- (a) `WidenStatefulOpNullability.widenStateSchema`: every stateful
physical exec widens its state key/value schema to fully nullable at
construction. This covers `StateStoreSaveExec`, `BaseStreamingDeduplicateExec`,
`StreamingSymmetricHashJoinExec`, `FlatMapGroupsWithStateExec`,
`TransformWithStateExec` (including user-defined state variable col family
schemas), `TransformWithStateInPySparkExec`, and `StreamingGlobalLimitExec`.
- (b) `WidenStatefulOpNullability.widenOutputForStatefulOp`: every stateful
logical and physical operator widens its declared `output` to fully nullable.
- (c) `WidenStatefulOperatorAttributeNullability`: an optimizer rule that
widens `AttributeReference`s inside stateful ops' internal expressions and
propagates upward through ancestor expressions. The rule uses
`resolveOperatorsUp` (bottom-up) and scopes the widening precisely: at a
stateful operator, all children's output is included (for internal expression
references like grouping keys); at non-stateful ancestors, only children whose
subtrees contain a stateful operator are include [...]
With the above fix, we aim to ensure the state schema to be "fully"
nullable (top level column, nested column, and collection types) regardless of
the input schema, and the output schema of the stateful operator to be also
"fully" nullable as well. The change of output schema for stateful operator is
necessary, because even if the input schema is non-nullable, state can produce
the null value, hence the output can be nullable.
### Why are the changes needed?
This has been a long standing issue of streaming engine vs Query Optimizer.
By the nature of streaming query, the query is meant to be long-running, in
many cases spans to multiple Spark versions. Also, the logical plan is not
always the same across batches (e.g. there are multiple stream sources and one
of the source does not have a new data at batch N). This puts the streaming
query to be affected by analyzer and optimizer.
The state schema of stateful operator is mostly determined by the input
schema of the stateful operator, and nullability isn't an exception. If the
input schema has a nullable column, state schema would have a nullable column.
Vice versa with non-nullable column.
For Query Optimizer, one of the optimizations is to flip the nullability,
say, nullable to non-nullable if appropriate. This can be done directly or
indirectly, and the most problematic case is when the optimization is applied
"selectively".
The one of easy example is the elimination of Union: for the streaming
query with multiple streams using Union, batch N could have one stream be
non-empty while another stream to be empty. For that
case,`PropagateEmptyRelation` can drop empty `Union` branches, causing a
per-column nullability flip that propagates into a stateful operator's state
schema across microbatches or restarts. This causes either
`STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE` on restart or a codegen NPE when
state-res [...]
### Does this PR introduce _any_ user-facing change?
No user-visible behavior change for new queries (all stateful operator
outputs become nullable, which is semantically correct). Existing queries keep
their original behavior via the offset log gate.
### How was this patch tested?
New `StreamingStatefulOperatorNullabilityDriftSuite` covering:
- New-query path: Union-branch-drop restart scenarios for aggregate,
dropDuplicates, dropDuplicatesWithinWatermark, stream-stream join,
flatMapGroupsWithState, and transformWithState.
- Codegen NPE regression with struct grouping keys.
- Existing-query path: widening forced off still triggers schema mismatch.
- State schema assertion validates all state stores and column families
(both v2 file format and v3 directory format including `_stateSchema`).
- Rule-level: scope check (non-stateful subtrees skipped).
- Helper-level: `deepWidenAttribute` recursion into nested types.
### Was this patch authored or co-authored using generative AI tooling?
Yes. Generated-by: Claude 4.7 Opus
Closes #56061 from HeartSaVioR/widen-stateful-op-nullability.
Authored-by: Jungtaek Lim <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
---
...WidenStatefulOperatorAttributeNullability.scala | 167 +++++++
.../plans/logical/basicLogicalOperators.scala | 32 +-
.../spark/sql/catalyst/plans/logical/object.scala | 12 +-
.../plans/logical/pythonLogicalOperators.scala | 10 +-
.../org/apache/spark/sql/internal/SQLConf.scala | 18 +
.../streaming/ClientStreamingQuerySuite.scala | 2 +-
.../sql/execution/adaptive/AQEOptimizer.scala | 5 +-
.../FlatMapGroupsInPandasWithStateExec.scala | 4 +-
.../TransformWithStateInPySparkExec.scala | 45 +-
.../streaming/checkpointing/OffsetSeq.scala | 6 +-
.../FlatMapGroupsWithStateExec.scala | 24 +-
.../join/StreamingSymmetricHashJoinExec.scala | 34 +-
.../operators/stateful/statefulOperators.scala | 87 ++--
.../operators/stateful/streamingLimits.scala | 4 +-
.../TransformWithStateExec.scala | 55 ++-
.../streaming/runtime/IncrementalExecution.scala | 4 +-
.../spark/sql/streaming/StreamingJoinSuite.scala | 10 +-
.../spark/sql/streaming/StreamingJoinV4Suite.scala | 12 +-
...mingStatefulOperatorNullabilityDriftSuite.scala | 534 +++++++++++++++++++++
.../sql/streaming/TransformWithStateSuite.scala | 4 +-
20 files changed, 976 insertions(+), 93 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/WidenStatefulOperatorAttributeNullability.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/WidenStatefulOperatorAttributeNullability.scala
new file mode 100644
index 000000000000..b2ce8780a2ed
--- /dev/null
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/WidenStatefulOperatorAttributeNullability.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute,
AttributeReference, ExprId}
+import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{DataType, StructType}
+
+/**
+ * Shared helpers for the stateful-operator nullability fix. The fix has three
+ * independent components, all gated by
+ * [[SQLConf.STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT]] (pinned per-query via
the
+ * offset log so existing queries keep their pre-fix behavior on restart):
+ *
+ * - (a) `widenStateSchema`: explicit `asNullable` at every state-schema
construction
+ * site in each stateful physical exec.
+ * - (b) `widenOutputForStatefulOp`: a per-op `output` override on every
stateful logical
+ * and physical operator, used by the operator's `output` definition.
+ * - (c) [[WidenStatefulOperatorAttributeNullability]] (defined below in
this file): a
+ * custom optimizer rule that widens `AttributeReference`s inside
stateful ops'
+ * internal expressions and propagates upward to ancestor expressions.
+ */
+object WidenStatefulOpNullability {
+
+ def isEnabled: Boolean =
+ SQLConf.get.getConf(SQLConf.STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT)
+
+ /**
+ * Recursively widens an attribute to be fully nullable: outer `nullable =
true` plus
+ * every nested `StructField.nullable`, `ArrayType.containsNull`, and
+ * `MapType.valueContainsNull` flipped to `true` via
+ * [[org.apache.spark.sql.types.DataType#asNullable]].
+ */
+ def deepWidenAttribute(a: Attribute): Attribute = a match {
+ case ref: AttributeReference =>
+ AttributeReference(
+ ref.name, ref.dataType.asNullable, nullable = true, ref.metadata)(
+ ref.exprId, ref.qualifier)
+ case other => other.withNullability(true)
+ }
+
+ /**
+ * Component (a): widens a state schema to fully nullable. Stateful physical
execs apply
+ * this at every `validateAndMaybeEvolveStateSchema(...)` call site and every
+ * `mapPartitionsWith*StateStore(...)` call site. When the conf is off,
returns the
+ * schema unchanged.
+ */
+ def widenStateSchema(schema: StructType): StructType =
+ if (isEnabled) schema.asNullable else schema
+
+ /**
+ * Component (b): wraps a stateful operator's `output` to be fully nullable.
The caller
+ * is responsible for only calling this from within an `output` definition
on a stateful
+ * operator; gating is handled here via [[isEnabled]].
+ */
+ def widenOutputForStatefulOp(base: Seq[Attribute]): Seq[Attribute] =
+ if (isEnabled) base.map(deepWidenAttribute) else base
+
+ /**
+ * Recursively walks a schema and replaces any nested `StructType` that
+ * structurally matches `original` (by field names and base types, ignoring
+ * nullability) with `widened`. Used by TransformWithState execs to widen
+ * the grouping-key portion of col-family key schemas without touching
+ * user-defined key/value portions.
+ */
+ def widenGroupingKeyInSchema(
+ schema: StructType,
+ original: StructType,
+ widened: StructType): StructType = {
+ if (!isEnabled) return schema
+ if (DataType.equalsIgnoreNullability(schema, original)) {
+ widened
+ } else {
+ StructType(schema.fields.map { field =>
+ field.dataType match {
+ case st: StructType
+ if DataType.equalsIgnoreNullability(st, original) =>
+ field.copy(dataType = widened)
+ case st: StructType =>
+ field.copy(dataType =
+ widenGroupingKeyInSchema(st, original, widened))
+ case _ => field
+ }
+ })
+ }
+ }
+}
+
+/**
+ * Component (c) of the stateful-operator nullability fix: a custom optimizer
rule that
+ * widens `AttributeReference`s inside streaming-stateful operators' internal
expressions
+ * and propagates the widening upward to ancestor operators' expressions.
+ *
+ * The rule does NOT introduce any new logical or physical node. It is purely
an
+ * attribute-rewrite pass using `resolveOperatorsUp` (bottom-up): for every
node whose
+ * subtree contains a stateful operator, collect `exprId`s from children's
output, then
+ * deep-widen every `AttributeReference` in the node's expressions whose
`exprId` is in
+ * that set via [[WidenStatefulOpNullability#deepWidenAttribute]].
+ *
+ * At a stateful operator itself, all children's output attributes are
included because
+ * the operator's internal expressions (e.g. grouping keys) reference them
directly.
+ * At non-stateful ancestor operators, only children whose subtrees contain a
stateful
+ * operator are included, to avoid unnecessary widening of non-stateful
siblings.
+ * The node's own `p.output` is not needed for non-stateful ancestors because
the
+ * bottom-up traversal guarantees children are already transformed, so their
output
+ * attributes are already nullable and the ancestor's expressions reference
those
+ * children's `exprId`s.
+ *
+ * '''Scope.''' The walk only fires on nodes whose subtree contains a stateful
operator.
+ *
+ * '''Ordering constraint.''' This rule must run AFTER every
`UpdateAttributeNullability`
+ * invocation in both the main optimizer and AQE.
+ *
+ * '''Idempotence.''' [[WidenStatefulOpNullability#deepWidenAttribute]] is
idempotent.
+ */
+object WidenStatefulOperatorAttributeNullability extends Rule[LogicalPlan] {
+
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ if (!conf.getConf(SQLConf.STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT) ||
+ !plan.containsStatefulOperator) {
+ return plan
+ }
+ plan.resolveOperatorsUp {
+ case p if !p.resolved => p
+ case p: LeafNode => p
+ case p if !p.containsStatefulOperator => p
+ case p =>
+ val widenableAttrs = if (p.isStateful) {
+ p.output ++ p.children.flatMap(_.output)
+ } else {
+ p.children.filter(_.containsStatefulOperator).flatMap(_.output)
+ }
+ val widenableExprIds: Set[ExprId] = widenableAttrs
+ .iterator.collect { case ar: AttributeReference => ar.exprId }.toSet
+ if (widenableExprIds.isEmpty) {
+ p
+ } else {
+ p.transformExpressions {
+ case ar: AttributeReference if
widenableExprIds.contains(ar.exprId) =>
+ val widened = WidenStatefulOpNullability.deepWidenAttribute(ar)
+ if (ar.dataType == widened.dataType && ar.nullable ==
widened.nullable) {
+ ar
+ } else {
+ widened
+ }
+ }
+ }
+ }
+ }
+}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index a7ad11848c3f..9184c5ef412b 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.{AliasIdentifier, InternalRow,
SQLConfHelper}
-import org.apache.spark.sql.catalyst.analysis.{Analyzer, AnsiTypeCoercion,
MultiInstanceRelation, Resolver, TypeCoercion, TypeCoercionBase,
UnresolvedUnaryNode}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, AnsiTypeCoercion,
MultiInstanceRelation, Resolver, TypeCoercion, TypeCoercionBase,
UnresolvedUnaryNode, WidenStatefulOpNullability}
import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat,
CatalogTable}
import
org.apache.spark.sql.catalyst.catalog.CatalogTable.VIEW_STORING_ANALYZED_PLAN
import org.apache.spark.sql.catalyst.expressions._
@@ -746,7 +746,10 @@ case class Join(
}
}
- override def output: Seq[Attribute] = Join.computeOutput(joinType,
left.output, right.output)
+ override def output: Seq[Attribute] = {
+ val base = Join.computeOutput(joinType, left.output, right.output)
+ if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base)
else base
+ }
override def metadataOutput: Seq[Attribute] = {
joinType match {
@@ -1225,7 +1228,10 @@ case class Aggregate(
expressions.forall(_.resolved) && childrenResolved && !hasWindowExpressions
}
- override def output: Seq[Attribute] = aggregateExpressions.map(_.toAttribute)
+ override def output: Seq[Attribute] = {
+ val base = aggregateExpressions.map(_.toAttribute)
+ if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base)
else base
+ }
override def metadataOutput: Seq[Attribute] = Nil
override def maxRows: Option[Long] = {
if (groupingExpressions.isEmpty) {
@@ -1749,7 +1755,10 @@ object Limit {
* order.
*/
case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends
UnaryNode {
- override def output: Seq[Attribute] = child.output
+ override def output: Seq[Attribute] = {
+ val base = child.output
+ if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base)
else base
+ }
override def maxRows: Option[Long] = {
limitExpr match {
case IntegerLiteral(limit) => Some(limit)
@@ -2004,7 +2013,10 @@ case class Sample(
*/
case class Distinct(child: LogicalPlan) extends UnaryNode {
override def maxRows: Option[Long] = child.maxRows
- override def output: Seq[Attribute] = child.output
+ override def output: Seq[Attribute] = {
+ val base = child.output
+ if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base)
else base
+ }
final override val nodePatterns: Seq[TreePattern] = Seq(DISTINCT_LIKE)
override protected def withNewChildInternal(newChild: LogicalPlan): Distinct
=
copy(child = newChild)
@@ -2174,7 +2186,10 @@ case class Deduplicate(
keys: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def maxRows: Option[Long] = child.maxRows
- override def output: Seq[Attribute] = child.output
+ override def output: Seq[Attribute] = {
+ val base = child.output
+ if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base)
else base
+ }
final override val nodePatterns: Seq[TreePattern] = Seq(DISTINCT_LIKE)
override protected def withNewChildInternal(newChild: LogicalPlan):
Deduplicate =
copy(child = newChild)
@@ -2186,7 +2201,10 @@ case class DeduplicateWithinWatermark(keys:
Seq[Attribute], child: LogicalPlan)
override def references: AttributeSet = AttributeSet(keys) ++
AttributeSet(child.output.filter(_.metadata.contains(EventTimeWatermark.delayKey)))
override def maxRows: Option[Long] = child.maxRows
- override def output: Seq[Attribute] = child.output
+ override def output: Seq[Attribute] = {
+ val base = child.output
+ if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base)
else base
+ }
final override val nodePatterns: Seq[TreePattern] = Seq(DISTINCT_LIKE)
override protected def withNewChildInternal(newChild: LogicalPlan):
DeduplicateWithinWatermark =
copy(child = newChild)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 0c6f59073559..720b0dd640d0 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.api.java.function.FilterFunction
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{catalyst, Encoder, Row}
-import org.apache.spark.sql.catalyst.analysis.{Resolver,
UnresolvedDeserializer}
+import org.apache.spark.sql.catalyst.analysis.{Resolver,
UnresolvedDeserializer, WidenStatefulOpNullability}
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
@@ -568,6 +568,11 @@ case class FlatMapGroupsWithState(
newLeft: LogicalPlan, newRight: LogicalPlan): FlatMapGroupsWithState =
copy(child = newLeft, initialState = newRight)
override def isStateful: Boolean = child.isStreaming
+
+ override def output: Seq[Attribute] = {
+ val base = super.output
+ if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base)
else base
+ }
}
object TransformWithState {
@@ -657,6 +662,11 @@ case class TransformWithState(
newLeft: LogicalPlan, newRight: LogicalPlan): TransformWithState =
copy(child = newLeft, initialState = newRight)
override def isStateful: Boolean = child.isStreaming
+
+ override def output: Seq[Attribute] = {
+ val base = super.output
+ if (isStateful) WidenStatefulOpNullability.widenOutputForStatefulOp(base)
else base
+ }
}
/** Factory for constructing new `FlatMapGroupsInR` nodes. */
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index 56dc2f6de043..31e7d9402968 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.resource.ResourceProfile
import org.apache.spark.sql.catalyst.SQLConfHelper
-import org.apache.spark.sql.catalyst.analysis.{FunctionRegistryBase,
MultiInstanceRelation, UnresolvedAttribute, UnresolvedStar}
+import org.apache.spark.sql.catalyst.analysis.{FunctionRegistryBase,
MultiInstanceRelation, UnresolvedAttribute, UnresolvedStar,
WidenStatefulOpNullability}
import
org.apache.spark.sql.catalyst.analysis.TableFunctionRegistry.TableFunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute,
AttributeSet, Expression, ExpressionDescription, ExpressionInfo, JsonToStructs,
PythonUDF, PythonUDTF}
import org.apache.spark.sql.catalyst.trees.TreePattern._
@@ -159,7 +159,9 @@ case class FlatMapGroupsInPandasWithState(
timeout: GroupStateTimeout,
child: LogicalPlan) extends UnaryNode {
- override def output: Seq[Attribute] = outputAttrs
+ override def output: Seq[Attribute] =
+ if (isStateful)
WidenStatefulOpNullability.widenOutputForStatefulOp(outputAttrs)
+ else outputAttrs
override def producedAttributes: AttributeSet = AttributeSet(outputAttrs)
@@ -206,7 +208,9 @@ case class TransformWithStateInPySpark(
override def right: LogicalPlan = initialState
- override def output: Seq[Attribute] = outputAttrs
+ override def output: Seq[Attribute] =
+ if (isStateful)
WidenStatefulOpNullability.widenOutputForStatefulOp(outputAttrs)
+ else outputAttrs
override def producedAttributes: AttributeSet = AttributeSet(outputAttrs)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 328f434195f4..0aed28e92558 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -3444,6 +3444,24 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT =
+
buildConf("spark.sql.streaming.statefulOperator.alwaysNullableOutput.enabled")
+ .internal()
+ .withBindingPolicy(ConfigBindingPolicy.SESSION)
+ .doc("When true, every streaming stateful operator reports its output
schema with " +
+ "nullable=true on all columns (including nested struct fields, array
elements, and " +
+ "map values), and the state schema is widened at every construction
site, so the " +
+ "existing state schema " +
+ "compatibility check trivially passes regardless of input nullability.
" +
+ "This prevents query-optimizer decisions (e.g., PropagateEmptyRelation
dropping a " +
+ "Union branch) from flipping the state schema nullability across
microbatches or " +
+ "restarts. The effective value is pinned per query via the offset log
at batch 0, " +
+ "so pre-existing queries keep their original behavior; only newly
started queries " +
+ "pick this up.")
+ .version("4.3.0")
+ .booleanConf
+ .createWithDefault(true)
+
val FILESTREAM_SINK_METADATA_IGNORED =
buildConf("spark.sql.streaming.fileStreamSink.ignoreMetadata")
.internal()
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala
index c8a25652dacb..057e2fdc4775 100644
---
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/streaming/ClientStreamingQuerySuite.scala
@@ -86,7 +86,7 @@ class ClientStreamingQuerySuite extends QueryTest with
RemoteSparkSession with L
.count()
.selectExpr("window.start as timestamp", "count as num_events")
- assert(countsDF.schema.toDDL == "timestamp TIMESTAMP,num_events BIGINT
NOT NULL")
+ assert(countsDF.schema.toDDL == "timestamp TIMESTAMP,num_events BIGINT")
// Start the query
val queryName = "sparkConnectStreamingQuery"
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
index f16c6d9cfe6d..3c23930090ab 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AQEOptimizer.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.execution.adaptive
import org.apache.spark.internal.LogKeys.{BATCH_NAME, RULE_NAME}
-import org.apache.spark.sql.catalyst.analysis.UpdateAttributeNullability
+import org.apache.spark.sql.catalyst.analysis.{UpdateAttributeNullability,
WidenStatefulOperatorAttributeNullability}
import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation,
EliminateLimits, OptimizeOneRowPlan}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan,
LogicalPlanIntegrity}
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
@@ -44,7 +44,8 @@ class AQEOptimizer(conf: SQLConf,
extendedRuntimeOptimizerRules: Seq[Rule[Logica
Batch("Dynamic Join Selection", Once, DynamicJoinSelection),
Batch("Eliminate Limits", fixedPoint, EliminateLimits),
Batch("Optimize One Row Plan", fixedPoint, OptimizeOneRowPlan)) :+
- Batch("User Provided Runtime Optimizers", fixedPoint,
extendedRuntimeOptimizerRules: _*)
+ Batch("User Provided Runtime Optimizers", fixedPoint,
extendedRuntimeOptimizerRules: _*) :+
+ Batch("Widen Stateful Op Nullability", Once,
WidenStatefulOperatorAttributeNullability)
final override protected def batches: Seq[Batch] = {
val excludedRules = conf.getConf(SQLConf.ADAPTIVE_OPTIMIZER_EXCLUDED_RULES)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/FlatMapGroupsInPandasWithStateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/FlatMapGroupsInPandasWithStateExec.scala
index e9430ed9f9b7..a61f90515836 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/FlatMapGroupsInPandasWithStateExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/FlatMapGroupsInPandasWithStateExec.scala
@@ -20,6 +20,7 @@ import org.apache.spark.{JobArtifactSet, SparkException,
SparkUnsupportedOperati
import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{EventTimeTimeout,
ProcessingTimeTimeout}
@@ -81,7 +82,8 @@ case class FlatMapGroupsInPandasWithStateExec(
override protected val stateEncoder: ExpressionEncoder[Any] =
ExpressionEncoder(stateType).resolveAndBind().asInstanceOf[ExpressionEncoder[Any]]
- override def output: Seq[Attribute] = outAttributes
+ override def output: Seq[Attribute] =
+ WidenStatefulOpNullability.widenOutputForStatefulOp(outAttributes)
private val sessionLocalTimeZone = conf.sessionLocalTimeZone
private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
index 45f2af5c1dfe..d3fd757784e0 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
@@ -27,6 +27,7 @@ import org.apache.spark.api.python.{ChainedPythonFunctions,
PythonEvalType}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression,
PythonUDF}
import org.apache.spark.sql.catalyst.plans.logical.TransformWithStateInPySpark
@@ -39,7 +40,7 @@ import
org.apache.spark.sql.execution.streaming.operators.stateful.{StatefulOper
import
org.apache.spark.sql.execution.streaming.operators.stateful.join.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper
import
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.{TransformWithStateExecBase,
TransformWithStateVariableInfo}
import
org.apache.spark.sql.execution.streaming.operators.stateful.transformwithstate.statefulprocessor.{DriverStatefulProcessorHandleImpl,
StatefulProcessorHandleImpl}
-import
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec,
RocksDBStateStoreProvider, StateSchemaValidationResult, StateStore,
StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreOps,
StateStoreProvider, StateStoreProviderId}
+import
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec,
PrefixKeyScanStateEncoderSpec, RangeKeyScanStateEncoderSpec,
RocksDBStateStoreProvider, StateSchemaValidationResult, StateStore,
StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreOps,
StateStoreProvider, StateStoreProviderId,
TimestampAsPostfixKeyStateEncoderSpec, TimestampAsPrefixKeyStateEncoderSpec}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming.{OutputMode, TimeMode}
import org.apache.spark.sql.types.{BinaryType, StructField, StructType}
@@ -51,7 +52,7 @@ import org.apache.spark.util.{CompletionIterator,
SerializableConfiguration, Uti
*
* @param functionExpr function called on each group
* @param groupingAttributes used to group the data
- * @param output used to define the output rows
+ * @param outputAttrs used to define the output rows
* @param outputMode defines the output mode for the statefulProcessor
* @param timeMode The time mode semantics of the stateful processor for
timers and TTL.
* @param stateInfo Used to identify the state store for a given operator.
@@ -69,7 +70,7 @@ import org.apache.spark.util.{CompletionIterator,
SerializableConfiguration, Uti
case class TransformWithStateInPySparkExec(
functionExpr: Expression,
groupingAttributes: Seq[Attribute],
- output: Seq[Attribute],
+ outputAttrs: Seq[Attribute],
outputMode: OutputMode,
timeMode: TimeMode,
stateInfo: Option[StatefulOperatorStateInfo],
@@ -94,6 +95,9 @@ case class TransformWithStateInPySparkExec(
initialStateGroupingAttrs,
initialState) {
+ override def output: Seq[Attribute] =
+ WidenStatefulOpNullability.widenOutputForStatefulOp(outputAttrs)
+
// NOTE: This is needed to comply with existing release of
transformWithStateInPandas.
override def shortName: String = if (
userFacingDataType == TransformWithStateInPySpark.UserFacingDataType.PANDAS
@@ -127,16 +131,49 @@ case class TransformWithStateInPySparkExec(
// Each state variable has its own schema, this is a dummy one.
protected val schemaForValueRow: StructType = new StructType().add("value",
BinaryType)
+ private lazy val widenedGroupingKeySchema: StructType =
+ WidenStatefulOpNullability.widenStateSchema(groupingKeySchema)
+
override def getColFamilySchemas(
shouldBeNullable: Boolean): Map[String, StateStoreColFamilySchema] = {
// For Python, the user can explicitly set nullability on schema, so
// we need to throw an error if the schema is nullable
- driverProcessorHandle.getColumnFamilySchemas(
+ val schemas = driverProcessorHandle.getColumnFamilySchemas(
shouldCheckNullable = shouldBeNullable,
shouldSetNullable = shouldBeNullable
)
+ widenColFamilyGroupingKeys(schemas)
}
+ private def widenColFamilyGroupingKeys(
+ schemas: Map[String, StateStoreColFamilySchema])
+ : Map[String, StateStoreColFamilySchema] = {
+ if (!WidenStatefulOpNullability.isEnabled) return schemas
+ val original = groupingKeySchema
+ val widened = widenedGroupingKeySchema
+ def widenKey(ks: StructType): StructType =
+ WidenStatefulOpNullability.widenGroupingKeyInSchema(
+ ks, original, widened)
+ schemas.map { case (name, cf) =>
+ val widenedSpec = cf.keyStateEncoderSpec.map {
+ case NoPrefixKeyStateEncoderSpec(ks) =>
+ NoPrefixKeyStateEncoderSpec(widenKey(ks))
+ case PrefixKeyScanStateEncoderSpec(ks, n) =>
+ PrefixKeyScanStateEncoderSpec(widenKey(ks), n)
+ case RangeKeyScanStateEncoderSpec(ks, o) =>
+ RangeKeyScanStateEncoderSpec(widenKey(ks), o)
+ case TimestampAsPrefixKeyStateEncoderSpec(ks) =>
+ TimestampAsPrefixKeyStateEncoderSpec(widenKey(ks))
+ case TimestampAsPostfixKeyStateEncoderSpec(ks) =>
+ TimestampAsPostfixKeyStateEncoderSpec(widenKey(ks))
+ }
+ name -> cf.copy(
+ keySchema = widenKey(cf.keySchema),
+ keyStateEncoderSpec = widenedSpec)
+ }
+ }
+
+
override def getStateVariableInfos(): Map[String,
TransformWithStateVariableInfo] = {
driverProcessorHandle.getStateVariableInfos
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
index bf2278b81492..9ba99ac2c036 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/checkpointing/OffsetSeq.scala
@@ -204,7 +204,8 @@ object OffsetSeqMetadata extends Logging {
STATEFUL_OPERATOR_USE_STRICT_DISTRIBUTION,
PRUNE_FILTERS_CAN_PRUNE_STREAMING_SUBPLAN,
STREAMING_STATE_STORE_ENCODING_FORMAT,
STATE_STORE_ROW_CHECKSUM_ENABLED, PROTOBUF_EXTENSIONS_SUPPORT_ENABLED,
- ENABLE_STREAMING_SOURCE_EVOLUTION
+ ENABLE_STREAMING_SOURCE_EVOLUTION,
+ STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT
)
/**
@@ -254,7 +255,8 @@ object OffsetSeqMetadata extends Logging {
STATE_STORE_ROW_CHECKSUM_ENABLED.key -> "false",
STATE_STORE_ROCKSDB_MERGE_OPERATOR_VERSION.key -> "1",
PROTOBUF_EXTENSIONS_SUPPORT_ENABLED.key -> "false",
- ENABLE_STREAMING_SOURCE_EVOLUTION.key -> "false"
+ ENABLE_STREAMING_SOURCE_EVOLUTION.key -> "false",
+ STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT.key -> "false"
)
def readValue[T](metadataLog: OffsetSeqMetadataBase, confKey:
ConfigEntry[T]): String = {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/flatmapgroupswithstate/FlatMapGroupsWithStateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/flatmapgroupswithstate/FlatMapGroupsWithStateExec.scala
index 6b9f90a9ab5c..48d1dad70f5e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/flatmapgroupswithstate/FlatMapGroupsWithStateExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/flatmapgroupswithstate/FlatMapGroupsWithStateExec.scala
@@ -25,6 +25,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.{SparkException, SparkThrowable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute,
Expression, SortOrder, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical._
@@ -36,6 +37,7 @@ import
org.apache.spark.sql.execution.streaming.operators.stateful.join.Streamin
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
import org.apache.spark.sql.streaming.GroupStateTimeout.NoTimeout
+import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{CompletionIterator, SerializableConfiguration}
/**
@@ -72,6 +74,11 @@ trait FlatMapGroupsWithStateExecBase
lazy val stateManager: StateManager =
createStateManager(stateEncoder, isTimeoutEnabled, stateFormatVersion)
+ private lazy val stateKeySchema: StructType =
+
WidenStatefulOpNullability.widenStateSchema(groupingAttributes.toStructType)
+ private lazy val stateValueSchema: StructType =
+ WidenStatefulOpNullability.widenStateSchema(stateManager.stateSchema)
+
/**
* Distribute by grouping attributes - We need the underlying data and the
initial state data
* to have the same grouping so that the data are co-lacated on the same
task.
@@ -200,7 +207,7 @@ trait FlatMapGroupsWithStateExecBase
batchId: Long,
stateSchemaVersion: Int): List[StateSchemaValidationResult] = {
val newStateSchema =
List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0,
- groupingAttributes.toStructType, 0, stateManager.stateSchema))
+ stateKeySchema, 0, stateValueSchema))
List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo,
hadoopConf,
newStateSchema, session.sessionState, stateSchemaVersion))
}
@@ -243,9 +250,9 @@ trait FlatMapGroupsWithStateExecBase
val storeProviderId = StateStoreProviderId(stateStoreId,
stateInfo.get.queryRunId)
val store = StateStore.get(
storeProviderId,
- groupingAttributes.toStructType,
- stateManager.stateSchema,
- NoPrefixKeyStateEncoderSpec(groupingAttributes.toStructType),
+ stateKeySchema,
+ stateValueSchema,
+ NoPrefixKeyStateEncoderSpec(stateKeySchema),
stateInfo.get.storeVersion,
stateInfo.get.getStateStoreCkptId(partitionId).map(_.head),
None,
@@ -257,9 +264,9 @@ trait FlatMapGroupsWithStateExecBase
} else {
child.execute().mapPartitionsWithStateStore[InternalRow](
getStateInfo,
- groupingAttributes.toStructType,
- stateManager.stateSchema,
- NoPrefixKeyStateEncoderSpec(groupingAttributes.toStructType),
+ stateKeySchema,
+ stateValueSchema,
+ NoPrefixKeyStateEncoderSpec(stateKeySchema),
session.sessionState,
Some(session.streams.stateStoreCoordinator)
) { case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
@@ -425,6 +432,9 @@ case class FlatMapGroupsWithStateExec(
skipEmittingInitialStateKeys: Boolean,
child: SparkPlan)
extends FlatMapGroupsWithStateExecBase with BinaryExecNode with
ObjectProducerExec {
+ override def output: Seq[Attribute] =
+ WidenStatefulOpNullability.widenOutputForStatefulOp(super.output)
+
import GroupStateImpl._
import FlatMapGroupsWithStateExecHelper._
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
index 9eca04c98591..8f90a603c7ef 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/join/StreamingSymmetricHashJoinExec.scala
@@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper
+import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet,
Expression, GenericInternalRow, JoinedRow, Literal, Predicate,
UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
@@ -231,13 +232,16 @@ case class StreamingSymmetricHashJoinExec(
StatefulOpClusteredDistribution(leftKeys, getStateInfo.numPartitions) ::
StatefulOpClusteredDistribution(rightKeys, getStateInfo.numPartitions)
:: Nil
- override def output: Seq[Attribute] = joinType match {
- case _: InnerLike => left.output ++ right.output
- case LeftOuter => left.output ++ right.output.map(_.withNullability(true))
- case RightOuter => left.output.map(_.withNullability(true)) ++ right.output
- case FullOuter => (left.output ++
right.output).map(_.withNullability(true))
- case LeftSemi => left.output
- case _ => throwBadJoinTypeException()
+ override def output: Seq[Attribute] = {
+ val base = joinType match {
+ case _: InnerLike => left.output ++ right.output
+ case LeftOuter => left.output ++
right.output.map(_.withNullability(true))
+ case RightOuter => left.output.map(_.withNullability(true)) ++
right.output
+ case FullOuter => (left.output ++
right.output).map(_.withNullability(true))
+ case LeftSemi => left.output
+ case _ => throwBadJoinTypeException()
+ }
+ WidenStatefulOpNullability.widenOutputForStatefulOp(base)
}
override def outputPartitioning: Partitioning = joinType match {
@@ -279,11 +283,16 @@ case class StreamingSymmetricHashJoinExec(
override def getColFamilySchemas(
shouldBeNullable: Boolean): Map[String, StateStoreColFamilySchema] = {
assert(useVirtualColumnFamilies)
- // We only have one state store for the join, but there are four distinct
schemas
- SymmetricHashJoinStateManager
+ val raw = SymmetricHashJoinStateManager
.getSchemasForStateStoreWithColFamily(LeftSide, left.output, leftKeys,
stateFormatVersion) ++
- SymmetricHashJoinStateManager
- .getSchemasForStateStoreWithColFamily(RightSide, right.output,
rightKeys, stateFormatVersion)
+ SymmetricHashJoinStateManager
+ .getSchemasForStateStoreWithColFamily(
+ RightSide, right.output, rightKeys, stateFormatVersion)
+ raw.map { case (name, cf) =>
+ name -> cf.copy(
+ keySchema = WidenStatefulOpNullability.widenStateSchema(cf.keySchema),
+ valueSchema =
WidenStatefulOpNullability.widenStateSchema(cf.valueSchema))
+ }
}
override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = {
@@ -328,7 +337,8 @@ case class StreamingSymmetricHashJoinExec(
// we have to add the default column family schema because the
RocksDBStateEncoder
// expects this entry to be present in the stateSchemaProvider.
val newStateSchema =
List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0,
- keySchema, 0, valueSchema))
+ WidenStatefulOpNullability.widenStateSchema(keySchema), 0,
+ WidenStatefulOpNullability.widenStateSchema(valueSchema)))
StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo,
hadoopConf,
newStateSchema, session.sessionState, stateSchemaVersion, storeName
= stateStoreName)
}.toList
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala
index 59a2b9ee74f8..022fa3469eea 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/statefulOperators.scala
@@ -31,6 +31,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability
import org.apache.spark.sql.catalyst.expressions._
import
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers
@@ -767,11 +768,16 @@ case class StateStoreRestoreExec(
private[sql] val stateManager =
StreamingAggregationStateManager.createStateManager(
keyExpressions, child.output, stateFormatVersion)
+ private val stateKeySchema: StructType =
+ WidenStatefulOpNullability.widenStateSchema(keyExpressions.toStructType)
+ private val stateValueSchema: StructType =
+
WidenStatefulOpNullability.widenStateSchema(stateManager.getStateValueSchema)
+
override def validateAndMaybeEvolveStateSchema(
hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int):
List[StateSchemaValidationResult] = {
val newStateSchema =
List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME,
- 0, keyExpressions.toStructType, 0, stateManager.getStateValueSchema))
+ 0, stateKeySchema, 0, stateValueSchema))
List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo,
hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion))
}
@@ -781,9 +787,9 @@ case class StateStoreRestoreExec(
child.execute().mapPartitionsWithReadStateStore(
getStateInfo,
- keyExpressions.toStructType,
- stateManager.getStateValueSchema,
- NoPrefixKeyStateEncoderSpec(keyExpressions.toStructType),
+ stateKeySchema,
+ stateValueSchema,
+ NoPrefixKeyStateEncoderSpec(stateKeySchema),
session.sessionState,
Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
val hasInput = iter.hasNext
@@ -805,7 +811,8 @@ case class StateStoreRestoreExec(
}
}
- override def output: Seq[Attribute] = child.output
+ override def output: Seq[Attribute] =
+ WidenStatefulOpNullability.widenOutputForStatefulOp(child.output)
override def outputPartitioning: Partitioning = child.outputPartitioning
@@ -838,13 +845,18 @@ case class StateStoreSaveExec(
private[sql] val stateManager =
StreamingAggregationStateManager.createStateManager(
keyExpressions, child.output, stateFormatVersion)
+ private val stateKeySchema: StructType =
+ WidenStatefulOpNullability.widenStateSchema(keyExpressions.toStructType)
+ private val stateValueSchema: StructType =
+
WidenStatefulOpNullability.widenStateSchema(stateManager.getStateValueSchema)
+
override def validateAndMaybeEvolveStateSchema(
hadoopConf: Configuration,
batchId: Long,
stateSchemaVersion: Int): List[StateSchemaValidationResult] = {
val newStateSchema =
List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME,
- keySchemaId = 0, keyExpressions.toStructType, valueSchemaId = 0,
- stateManager.getStateValueSchema))
+ keySchemaId = 0, stateKeySchema, valueSchemaId = 0,
+ stateValueSchema))
List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo,
hadoopConf, newStateSchema, session.sessionState, stateSchemaVersion))
}
@@ -856,9 +868,9 @@ case class StateStoreSaveExec(
child.execute().mapPartitionsWithStateStore(
getStateInfo,
- keyExpressions.toStructType,
- stateManager.getStateValueSchema,
- NoPrefixKeyStateEncoderSpec(keyExpressions.toStructType),
+ stateKeySchema,
+ stateValueSchema,
+ NoPrefixKeyStateEncoderSpec(stateKeySchema),
session.sessionState,
Some(session.streams.stateStoreCoordinator)) { (store, iter) =>
val numOutputRows = longMetric("numOutputRows")
@@ -1000,7 +1012,8 @@ case class StateStoreSaveExec(
}
}
- override def output: Seq[Attribute] = child.output
+ override def output: Seq[Attribute] =
+ WidenStatefulOpNullability.widenOutputForStatefulOp(child.output)
override def outputPartitioning: Partitioning = child.outputPartitioning
@@ -1054,12 +1067,17 @@ case class SessionWindowStateStoreRestoreExec(
private val stateManager =
StreamingSessionWindowStateManager.createStateManager(
keyWithoutSessionExpressions, sessionExpression, child.output,
stateFormatVersion)
+ private val stateKeySchema: StructType =
+ WidenStatefulOpNullability.widenStateSchema(stateManager.getStateKeySchema)
+ private val stateValueSchema: StructType =
+
WidenStatefulOpNullability.widenStateSchema(stateManager.getStateValueSchema)
+
override def validateAndMaybeEvolveStateSchema(
hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int):
List[StateSchemaValidationResult] = {
val newStateSchema =
List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME,
- keySchemaId = 0, stateManager.getStateKeySchema, valueSchemaId = 0,
- stateManager.getStateValueSchema))
+ keySchemaId = 0, stateKeySchema, valueSchemaId = 0,
+ stateValueSchema))
List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo,
hadoopConf,
newStateSchema, session.sessionState, stateSchemaVersion))
}
@@ -1069,9 +1087,9 @@ case class SessionWindowStateStoreRestoreExec(
child.execute().mapPartitionsWithReadStateStore(
getStateInfo,
- stateManager.getStateKeySchema,
- stateManager.getStateValueSchema,
- PrefixKeyScanStateEncoderSpec(stateManager.getStateKeySchema,
+ stateKeySchema,
+ stateValueSchema,
+ PrefixKeyScanStateEncoderSpec(stateKeySchema,
stateManager.getNumColsForPrefixKey),
session.sessionState,
Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
@@ -1099,7 +1117,8 @@ case class SessionWindowStateStoreRestoreExec(
}
}
- override def output: Seq[Attribute] = child.output
+ override def output: Seq[Attribute] =
+ WidenStatefulOpNullability.widenOutputForStatefulOp(child.output)
override def outputPartitioning: Partitioning = child.outputPartitioning
@@ -1147,11 +1166,16 @@ case class SessionWindowStateStoreSaveExec(
private val stateManager =
StreamingSessionWindowStateManager.createStateManager(
keyWithoutSessionExpressions, sessionExpression, child.output,
stateFormatVersion)
+ private val stateKeySchema: StructType =
+ WidenStatefulOpNullability.widenStateSchema(stateManager.getStateKeySchema)
+ private val stateValueSchema: StructType =
+
WidenStatefulOpNullability.widenStateSchema(stateManager.getStateValueSchema)
+
override def validateAndMaybeEvolveStateSchema(
hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int):
List[StateSchemaValidationResult] = {
val newStateSchema =
List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0,
- stateManager.getStateKeySchema, 0, stateManager.getStateValueSchema))
+ stateKeySchema, 0, stateValueSchema))
List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo,
hadoopConf,
newStateSchema, session.sessionState, stateSchemaVersion))
}
@@ -1165,9 +1189,9 @@ case class SessionWindowStateStoreSaveExec(
child.execute().mapPartitionsWithStateStore(
getStateInfo,
- stateManager.getStateKeySchema,
- stateManager.getStateValueSchema,
- PrefixKeyScanStateEncoderSpec(stateManager.getStateKeySchema,
+ stateKeySchema,
+ stateValueSchema,
+ PrefixKeyScanStateEncoderSpec(stateKeySchema,
stateManager.getNumColsForPrefixKey),
session.sessionState,
Some(session.streams.stateStoreCoordinator)) { case (store, iter) =>
@@ -1251,7 +1275,8 @@ case class SessionWindowStateStoreSaveExec(
}
}
- override def output: Seq[Attribute] = child.output
+ override def output: Seq[Attribute] =
+ WidenStatefulOpNullability.widenOutputForStatefulOp(child.output)
override def outputPartitioning: Partitioning = child.outputPartitioning
@@ -1355,14 +1380,19 @@ abstract class BaseStreamingDeduplicateExec
protected val schemaForValueRow: StructType
protected val extraOptionOnStateStore: Map[String, String]
+ protected lazy val stateKeySchema: StructType =
+ WidenStatefulOpNullability.widenStateSchema(keyExpressions.toStructType)
+ protected lazy val stateValueSchema: StructType =
+ WidenStatefulOpNullability.widenStateSchema(schemaForValueRow)
+
override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver
child.execute().mapPartitionsWithStateStore(
getStateInfo,
- keyExpressions.toStructType,
- schemaForValueRow,
- NoPrefixKeyStateEncoderSpec(keyExpressions.toStructType),
+ stateKeySchema,
+ stateValueSchema,
+ NoPrefixKeyStateEncoderSpec(stateKeySchema),
session.sessionState,
Some(session.streams.stateStoreCoordinator),
extraOptions = extraOptionOnStateStore) { (store, iter) =>
@@ -1422,7 +1452,8 @@ abstract class BaseStreamingDeduplicateExec
protected def evictDupInfoFromState(store: StateStore): Unit
- override def output: Seq[Attribute] = child.output
+ override def output: Seq[Attribute] =
+ WidenStatefulOpNullability.widenOutputForStatefulOp(child.output)
override def outputPartitioning: Partitioning = child.outputPartitioning
@@ -1476,7 +1507,7 @@ case class StreamingDeduplicateExec(
hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int):
List[StateSchemaValidationResult] = {
val newStateSchema =
List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0,
- keyExpressions.toStructType, 0, schemaForValueRow))
+ stateKeySchema, 0, stateValueSchema))
List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo,
hadoopConf,
newStateSchema, session.sessionState, stateSchemaVersion,
extraOptions = extraOptionOnStateStore))
@@ -1562,7 +1593,7 @@ case class StreamingDeduplicateWithinWatermarkExec(
hadoopConf: Configuration, batchId: Long, stateSchemaVersion: Int):
List[StateSchemaValidationResult] = {
val newStateSchema =
List(StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME, 0,
- keyExpressions.toStructType, 0, schemaForValueRow))
+ stateKeySchema, 0, stateValueSchema))
List(StateSchemaCompatibilityChecker.validateAndMaybeEvolveStateSchema(getStateInfo,
hadoopConf,
newStateSchema, session.sessionState, stateSchemaVersion,
extraOptions = extraOptionOnStateStore))
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/streamingLimits.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/streamingLimits.scala
index 6816be103f6e..da54c0ce0fe6 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/streamingLimits.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/streamingLimits.scala
@@ -22,6 +22,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability
import org.apache.spark.sql.catalyst.expressions.{Attribute,
GenericInternalRow, SortOrder, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, Distribution,
Partitioning}
import org.apache.spark.sql.execution.{LimitExec, SparkPlan, UnaryExecNode}
@@ -98,7 +99,8 @@ case class StreamingGlobalLimitExec(
}
}
- override def output: Seq[Attribute] = child.output
+ override def output: Seq[Attribute] =
+ WidenStatefulOpNullability.widenOutputForStatefulOp(child.output)
override def outputPartitioning: Partitioning = child.outputPartitioning
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala
index b200bde96cbc..f0e3003b2b71 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/operators/stateful/transformwithstate/TransformWithStateExec.scala
@@ -24,6 +24,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression,
UnsafeRow}
import org.apache.spark.sql.catalyst.plans.logical._
@@ -35,6 +36,7 @@ import
org.apache.spark.sql.execution.streaming.operators.stateful.transformwith
import org.apache.spark.sql.execution.streaming.state._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming._
+import org.apache.spark.sql.types.StructType
import org.apache.spark.util.{CompletionIterator, SerializableConfiguration,
Utils}
/**
@@ -88,6 +90,12 @@ case class TransformWithStateExec(
initialState)
with ObjectProducerExec {
+ override def output: Seq[Attribute] =
+ WidenStatefulOpNullability.widenOutputForStatefulOp(super.output)
+
+ private lazy val stateKeySchema: StructType =
+ WidenStatefulOpNullability.widenStateSchema(keyExpressions.toStructType)
+
override def shortName: String =
StatefulOperatorsUtils.TRANSFORM_WITH_STATE_EXEC_OP_NAME
// We need to just initialize key and value deserializer once per partition.
@@ -133,12 +141,11 @@ case class TransformWithStateExec(
override def getColFamilySchemas(
shouldBeNullable: Boolean
): Map[String, StateStoreColFamilySchema] = {
- val keySchema = keyExpressions.toStructType
// we have to add the default column family schema because the
RocksDBStateEncoder
// expects this entry to be present in the stateSchemaProvider.
val defaultSchema =
StateStoreColFamilySchema(StateStore.DEFAULT_COL_FAMILY_NAME,
- 0, keyExpressions.toStructType, 0, DUMMY_VALUE_ROW_SCHEMA,
- Some(NoPrefixKeyStateEncoderSpec(keySchema)))
+ 0, stateKeySchema, 0, DUMMY_VALUE_ROW_SCHEMA,
+ Some(NoPrefixKeyStateEncoderSpec(stateKeySchema)))
// For Scala, the user can't explicitly set nullability on schema, so
there is
// no reason to throw an error, and we can simply set the schema to
nullable.
@@ -147,9 +154,37 @@ case class TransformWithStateExec(
shouldCheckNullable = false, shouldSetNullable = shouldBeNullable) ++
Map(StateStore.DEFAULT_COL_FAMILY_NAME -> defaultSchema)
closeProcessorHandle()
- columnFamilySchemas
+ widenColFamilyGroupingKeys(columnFamilySchemas)
}
+ private def widenColFamilyGroupingKeys(
+ schemas: Map[String, StateStoreColFamilySchema])
+ : Map[String, StateStoreColFamilySchema] = {
+ if (!WidenStatefulOpNullability.isEnabled) return schemas
+ val original = keyEncoder.schema
+ val widened = stateKeySchema
+ def widenKey(ks: StructType): StructType =
+ WidenStatefulOpNullability.widenGroupingKeyInSchema(ks, original,
widened)
+ schemas.map { case (name, cf) =>
+ val widenedSpec = cf.keyStateEncoderSpec.map {
+ case NoPrefixKeyStateEncoderSpec(ks) =>
+ NoPrefixKeyStateEncoderSpec(widenKey(ks))
+ case PrefixKeyScanStateEncoderSpec(ks, n) =>
+ PrefixKeyScanStateEncoderSpec(widenKey(ks), n)
+ case RangeKeyScanStateEncoderSpec(ks, o) =>
+ RangeKeyScanStateEncoderSpec(widenKey(ks), o)
+ case TimestampAsPrefixKeyStateEncoderSpec(ks) =>
+ TimestampAsPrefixKeyStateEncoderSpec(widenKey(ks))
+ case TimestampAsPostfixKeyStateEncoderSpec(ks) =>
+ TimestampAsPostfixKeyStateEncoderSpec(widenKey(ks))
+ }
+ name -> cf.copy(
+ keySchema = widenKey(cf.keySchema),
+ keyStateEncoderSpec = widenedSpec)
+ }
+ }
+
+
override def getStateVariableInfos(): Map[String,
TransformWithStateVariableInfo] = {
val stateVariableInfos = getDriverProcessorHandle().getStateVariableInfos
closeProcessorHandle()
@@ -401,9 +436,9 @@ case class TransformWithStateExec(
val storeProviderId = StateStoreProviderId(stateStoreId,
stateInfo.get.queryRunId)
val store = StateStore.get(
storeProviderId = storeProviderId,
- keyEncoder.schema,
+ stateKeySchema,
DUMMY_VALUE_ROW_SCHEMA,
- NoPrefixKeyStateEncoderSpec(keyEncoder.schema),
+ NoPrefixKeyStateEncoderSpec(stateKeySchema),
version = stateInfo.get.storeVersion,
stateStoreCkptId =
stateInfo.get.getStateStoreCkptId(partitionId).map(_.head),
stateSchemaBroadcast = stateInfo.get.stateSchemaMetadata,
@@ -423,9 +458,9 @@ case class TransformWithStateExec(
if (isStreaming) {
child.execute().mapPartitionsWithStateStore[InternalRow](
getStateInfo,
- keyEncoder.schema,
+ stateKeySchema,
DUMMY_VALUE_ROW_SCHEMA,
- NoPrefixKeyStateEncoderSpec(keyEncoder.schema),
+ NoPrefixKeyStateEncoderSpec(stateKeySchema),
session.sessionState,
Some(session.streams.stateStoreCoordinator),
useColumnFamilies = true
@@ -473,9 +508,9 @@ case class TransformWithStateExec(
// Create StateStoreProvider for this partition
val stateStoreProvider = StateStoreProvider.createAndInit(
providerId,
- keyEncoder.schema,
+ stateKeySchema,
DUMMY_VALUE_ROW_SCHEMA,
- NoPrefixKeyStateEncoderSpec(keyEncoder.schema),
+ NoPrefixKeyStateEncoderSpec(stateKeySchema),
useColumnFamilies = true,
storeConf = storeConf,
hadoopConf = hadoopConfBroadcast.value.value,
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala
index 9fc72241e83b..0d2e4a6941a0 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/IncrementalExecution.scala
@@ -27,6 +27,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys.{BATCH_TIMESTAMP, ERROR}
import org.apache.spark.sql.catalyst.QueryPlanningTracker
+import
org.apache.spark.sql.catalyst.analysis.WidenStatefulOperatorAttributeNullability
import org.apache.spark.sql.catalyst.expressions.{CurrentBatchTimestamp,
ExpressionWithRandomSeed}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
@@ -133,7 +134,7 @@ class IncrementalExecution(
// of sink information.
case w: WriteToMicroBatchDataSourceV1 => w.child
}
- sparkSession.sessionState.optimizer.executeAndTrack(preOptimized,
+ val optimized =
sparkSession.sessionState.optimizer.executeAndTrack(preOptimized,
tracker).transformAllExpressionsWithPruning(
_.containsAnyPattern(CURRENT_LIKE, EXPRESSION_WITH_RANDOM_SEED)) {
case ts @ CurrentBatchTimestamp(timestamp, _, _) =>
@@ -141,6 +142,7 @@ class IncrementalExecution(
ts.toLiteral
case e: ExpressionWithRandomSeed =>
e.withNewSeed(Utils.random.nextLong())
}
+ WidenStatefulOperatorAttributeNullability(optimized)
}
// Use `this` for explain so the already-open transaction and executedPlan
are reused.
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
index 1e1aa451a0ae..c46f0076721b 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala
@@ -1181,16 +1181,16 @@ abstract class StreamingInnerJoinSuite extends
StreamingInnerJoinBase {
val hadoopConf = spark.sessionState.newHadoopConf()
val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf)
- val keySchemaForNums = new StructType().add("field0", IntegerType,
nullable = false)
+ val keySchemaForNums = new StructType().add("field0", IntegerType)
val keySchemaForIndex = keySchemaForNums.add("index", LongType)
val numSchema: StructType = new StructType().add("value", LongType)
val leftIndexSchema: StructType = new StructType()
- .add("key", IntegerType, nullable = false)
- .add("leftValue", IntegerType, nullable = false)
+ .add("key", IntegerType)
+ .add("leftValue", IntegerType)
.add("matched", BooleanType)
val rightIndexSchema: StructType = new StructType()
- .add("key", IntegerType, nullable = false)
- .add("rightValue", IntegerType, nullable = false)
+ .add("key", IntegerType)
+ .add("rightValue", IntegerType)
.add("matched", BooleanType)
val schemaLeftIndex = StateStoreColFamilySchema(
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala
index e58af3b2bf65..6d4a97861efe 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinV4Suite.scala
@@ -112,16 +112,16 @@ class StreamingInnerJoinV4Suite
CheckpointFileManager.create(stateSchemaPath, hadoopConf)
val keySchemaWithTimestamp = new StructType()
- .add("field0", IntegerType, nullable = false)
- .add("__event_time", LongType, nullable = false)
+ .add("field0", IntegerType)
+ .add("__event_time", LongType)
val leftValueSchema: StructType = new StructType()
- .add("key", IntegerType, nullable = false)
- .add("leftValue", IntegerType, nullable = false)
+ .add("key", IntegerType)
+ .add("leftValue", IntegerType)
.add("matched", BooleanType)
val rightValueSchema: StructType = new StructType()
- .add("key", IntegerType, nullable = false)
- .add("rightValue", IntegerType, nullable = false)
+ .add("key", IntegerType)
+ .add("rightValue", IntegerType)
.add("matched", BooleanType)
val dummyValueSchema =
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStatefulOperatorNullabilityDriftSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStatefulOperatorNullabilityDriftSuite.scala
new file mode 100644
index 000000000000..5278f68fbb0a
--- /dev/null
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingStatefulOperatorNullabilityDriftSuite.scala
@@ -0,0 +1,534 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.streaming
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.SparkUnsupportedOperationException
+import org.apache.spark.sql.{DataFrame, Encoders}
+import org.apache.spark.sql.catalyst.analysis.WidenStatefulOpNullability
+import
org.apache.spark.sql.execution.streaming.checkpointing.CheckpointFileManager
+import org.apache.spark.sql.execution.streaming.runtime.MemoryStream
+import
org.apache.spark.sql.execution.streaming.state.{RocksDBStateStoreProvider,
StateSchemaCompatibilityChecker, StateStore}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
+
+/**
+ * Regression suite for stateful-operator nullability drift.
+ *
+ * Driver: `PropagateEmptyRelation` drops empty `Union` branches without a
streaming
+ * guard, so the surviving branch's per-column nullability becomes the Union's
+ * nullability and propagates into a stateful operator above -- across
microbatches or
+ * restarts.
+ *
+ * Coverage:
+ * - New-query (default conf): originally-failing scenarios now complete
cleanly.
+ * - Existing-query (conf forced false): pre-fix behavior preserved verbatim.
+ * - Helper invariant: `WidenStatefulOpNullability.deepWidenAttribute`
recurses into
+ * nested types.
+ */
+class StreamingStatefulOperatorNullabilityDriftSuite extends StreamTest {
+
+ import testImplicits._
+
+ private def buildTwoSources(): (MemoryStream[Int], MemoryStream[Int],
DataFrame, DataFrame) = {
+ val inputA = MemoryStream[Int]
+ val inputB = MemoryStream[Int]
+
+ val dfA = inputA.toDF().select($"value".as("key"))
+ val dfB = inputB.toDF()
+ .select(when($"value" > Int.MinValue, $"value")
+ .otherwise(lit(null).cast("int"))
+ .as("key"))
+
+ (inputA, inputB, dfA, dfB)
+ }
+
+ private def buildTwoSourcesWithWatermark()
+ : (MemoryStream[Int], MemoryStream[Int], DataFrame, DataFrame) = {
+ val inputA = MemoryStream[Int]
+ val inputB = MemoryStream[Int]
+
+ val dfA = inputA.toDF()
+ .select($"value".as("key"),
+ timestamp_seconds($"value").as("ts"))
+ .withWatermark("ts", "1 minute")
+ val dfB = inputB.toDF()
+ .select(when($"value" > Int.MinValue, $"value")
+ .otherwise(lit(null).cast("int")).as("key"),
+ timestamp_seconds($"value").as("ts"))
+ .withWatermark("ts", "1 minute")
+
+ (inputA, inputB, dfA, dfB)
+ }
+
+ private def runUnionBranchDropRestart(
+ buildSources: () => (MemoryStream[Int], MemoryStream[Int], DataFrame,
DataFrame),
+ buildQuery: (DataFrame, DataFrame) => DataFrame,
+ outputMode: OutputMode,
+ nullableToNonNullable: Boolean): Unit = {
+ withTempDir { checkpointDir =>
+ val checkpointPath = checkpointDir.getAbsolutePath
+
+ val (inputA, inputB, dfA, dfB) = buildSources()
+ val q = buildQuery(dfA, dfB)
+
+ if (nullableToNonNullable) {
+ testStream(q, outputMode)(
+ StartStream(checkpointLocation = checkpointPath),
+ MultiAddData(inputA, 1, 2, 3)(inputB, 4, 5),
+ ProcessAllAvailable(),
+ StopStream
+ )
+ } else {
+ testStream(q, outputMode)(
+ StartStream(checkpointLocation = checkpointPath),
+ AddData(inputA, 1, 2, 3),
+ ProcessAllAvailable(),
+ StopStream
+ )
+ }
+
+ assertJournaledStateSchemaAllNullable(checkpointPath)
+
+ if (nullableToNonNullable) {
+ testStream(q, outputMode)(
+ StartStream(checkpointLocation = checkpointPath),
+ AddData(inputA, 6),
+ ProcessAllAvailable()
+ )
+ } else {
+ testStream(q, outputMode)(
+ StartStream(checkpointLocation = checkpointPath),
+ MultiAddData(inputA, 6)(inputB, 7),
+ ProcessAllAvailable()
+ )
+ }
+ }
+ }
+
+ private def assertJournaledStateSchemaAllNullable(checkpointPath: String):
Unit = {
+ val partId = StateStore.PARTITION_ID_TO_CHECK_SCHEMA
+ val operatorRoot = new Path(checkpointPath, "state/0")
+ val partitionRoot = new Path(operatorRoot, s"$partId")
+ val hadoopConf = spark.sessionState.newHadoopConf()
+ val fm = CheckpointFileManager.create(operatorRoot, hadoopConf)
+ val fs = operatorRoot.getFileSystem(hadoopConf)
+
+ def collectSchemaFiles(dir: Path): Seq[Path] = {
+ if (!fm.exists(dir)) return Seq.empty
+ if (fs.getFileStatus(dir).isDirectory) {
+ fs.listStatus(dir).filter(_.isFile).map(_.getPath).toSeq
+ } else {
+ Seq(dir)
+ }
+ }
+
+ val schemaFiles = scala.collection.mutable.ArrayBuffer.empty[Path]
+
+ val storeDirs = scala.collection.mutable.ArrayBuffer(partitionRoot)
+ if (fs.exists(partitionRoot)) {
+ fs.listStatus(partitionRoot)
+ .filter(_.isDirectory)
+ .filterNot(_.getPath.getName.startsWith("_"))
+ .foreach(d => storeDirs += d.getPath)
+ }
+ storeDirs.foreach { storeDir =>
+ schemaFiles ++= collectSchemaFiles(
+ new Path(storeDir, "_metadata/schema"))
+ }
+
+ val stateSchemaRoot = new Path(operatorRoot, "_stateSchema")
+ if (fs.exists(stateSchemaRoot)) {
+ fs.listStatus(stateSchemaRoot)
+ .filter(_.isDirectory)
+ .foreach { storeDir =>
+ schemaFiles ++= collectSchemaFiles(storeDir.getPath)
+ }
+ }
+
+ assert(schemaFiles.nonEmpty,
+ s"expected at least one schema file under $operatorRoot")
+ schemaFiles.foreach { schemaFile =>
+ val inStream = fm.open(schemaFile)
+ try {
+ val schemas = StateSchemaCompatibilityChecker.readSchemaFile(inStream)
+ schemas.foreach { s =>
+ assertSchemaAllNullable(s.keySchema,
+ s"$schemaFile: key schema for col family ${s.colFamilyName}")
+ }
+ } finally inStream.close()
+ }
+ }
+
+ private def assertSchemaAllNullable(schema: StructType, label: String): Unit
= {
+ schema.fields.foreach { f =>
+ assert(f.nullable, s"$label: field ${f.name} should be nullable")
+ assertDataTypeAllNullable(f.dataType, s"$label.${f.name}")
+ }
+ }
+
+ private def assertDataTypeAllNullable(dataType: DataType, label: String):
Unit = dataType match {
+ case s: StructType => assertSchemaAllNullable(s, label)
+ case ArrayType(elementType, containsNull) =>
+ assert(containsNull, s"$label: array element should be nullable")
+ assertDataTypeAllNullable(elementType, s"$label[]")
+ case MapType(keyType, valueType, valueContainsNull) =>
+ assert(valueContainsNull, s"$label: map value should be nullable")
+ assertDataTypeAllNullable(keyType, s"$label.key")
+ assertDataTypeAllNullable(valueType, s"$label.value")
+ case _ =>
+ }
+
+ test("streaming aggregate: non-nullable -> nullable widening remains
restart-compatible") {
+ runUnionBranchDropRestart(
+ buildSources = () => buildTwoSources(),
+ buildQuery = (dfA, dfB) => dfA.union(dfB).groupBy($"key").count(),
+ outputMode = OutputMode.Update(),
+ nullableToNonNullable = false)
+ }
+
+ test("streaming aggregate: nullable -> non-nullable narrowing remains
restart-compatible") {
+ runUnionBranchDropRestart(
+ buildSources = () => buildTwoSources(),
+ buildQuery = (dfA, dfB) => dfA.union(dfB).groupBy($"key").count(),
+ outputMode = OutputMode.Update(),
+ nullableToNonNullable = true)
+ }
+
+ test("streaming dropDuplicates: non-nullable -> nullable widening remains
restart-compatible") {
+ runUnionBranchDropRestart(
+ buildSources = () => buildTwoSources(),
+ buildQuery = (dfA, dfB) => dfA.union(dfB).dropDuplicates(Seq("key")),
+ outputMode = OutputMode.Append(),
+ nullableToNonNullable = false)
+ }
+
+ test("streaming dropDuplicatesWithinWatermark: " +
+ "non-nullable -> nullable widening remains restart-compatible") {
+ runUnionBranchDropRestart(
+ buildSources = () => buildTwoSourcesWithWatermark(),
+ buildQuery = (dfA, dfB) =>
dfA.union(dfB).dropDuplicatesWithinWatermark(Seq("key")),
+ outputMode = OutputMode.Append(),
+ nullableToNonNullable = false)
+ }
+
+ test("streaming aggregate (Complete mode): no codegen NPE on state-restored
null " +
+ "struct grouping key after fix") {
+ import org.apache.spark.sql.functions.struct
+
+ def mkQuery(inNullableK: MemoryStream[Int], inNonNullK:
MemoryStream[Int]): DataFrame = {
+ val dfNullable = inNullableK.toDF()
+ .select(
+ when($"value" > 0, struct($"value".as("v")))
+ .otherwise(lit(null).cast("struct<v:int>"))
+ .as("key"),
+ lit(1).as("metric"))
+
+ val dfNonNull = inNonNullK.toDF()
+ .select(
+ struct($"value".as("v")).as("key"),
+ lit(1).as("metric"))
+
+ dfNullable.union(dfNonNull)
+ .groupBy($"key")
+ .agg(sum($"metric").as("c"))
+ .select($"key.v".as("v"), $"c")
+ }
+
+ withTempDir { checkpointDir =>
+ withSQLConf(
+ SQLConf.STATE_SCHEMA_CHECK_ENABLED.key -> "false",
+ SQLConf.STATE_STORE_FORMAT_VALIDATION_ENABLED.key -> "false",
+ SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ val inNullable = MemoryStream[Int]
+ val inNonNull = MemoryStream[Int]
+ val q = mkQuery(inNullable, inNonNull)
+ testStream(q, OutputMode.Complete())(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inNullable, 0),
+ ProcessAllAvailable(),
+ StopStream
+ )
+
+ testStream(q, OutputMode.Complete())(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inNonNull, 1),
+ ProcessAllAvailable()
+ )
+ }
+ }
+ }
+
+ test("streaming aggregate: with widening forced off (existing-query path), "
+
+ "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE still triggers on restart") {
+ withTempDir { checkpointDir =>
+ withSQLConf(
+ SQLConf.STATEFUL_OPERATOR_ALWAYS_NULLABLE_OUTPUT.key -> "false") {
+ val (inputA, inputB, dfA, dfB) = buildTwoSources()
+ val aggregated = dfA.union(dfB).groupBy($"key").count()
+ testStream(aggregated, OutputMode.Update())(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ AddData(inputA, 1, 2, 3),
+ ProcessAllAvailable(),
+ StopStream
+ )
+
+ inputA.addData(4)
+ inputB.addData(5)
+
+ val ex = intercept[SparkUnsupportedOperationException] {
+ testStream(aggregated, OutputMode.Update())(
+ StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+ ProcessAllAvailable()
+ )
+ }
+
+ checkError(
+ ex,
+ condition = "STATE_STORE_KEY_SCHEMA_NOT_COMPATIBLE",
+ parameters = Map(
+ "storedKeySchema" -> ".*",
+ "newKeySchema" -> ".*"),
+ matchPVals = true
+ )
+ }
+ }
+ }
+
+ test("stream-stream join: non-nullable -> nullable widening remains
restart-compatible") {
+ withTempDir { checkpointDir =>
+ val checkpointPath = checkpointDir.getAbsolutePath
+
+ def buildJoinQuery(): (MemoryStream[Int], MemoryStream[Int], DataFrame)
= {
+ val leftInput = MemoryStream[Int]
+ val rightInput = MemoryStream[Int]
+
+ val left = leftInput.toDF()
+ .select($"value".as("key"),
+ timestamp_seconds($"value").as("leftTime"))
+ .withWatermark("leftTime", "10 seconds")
+ val right = rightInput.toDF()
+ .select(
+ when($"value" > Int.MinValue, $"value")
+ .otherwise(lit(null).cast("int")).as("key"),
+ timestamp_seconds($"value").as("rightTime"))
+ .withWatermark("rightTime", "10 seconds")
+
+ val joined = left.join(right,
+ left("key") === right("key") &&
+ left("leftTime") > right("rightTime") - expr("INTERVAL 10
SECONDS") &&
+ left("leftTime") < right("rightTime") + expr("INTERVAL 10
SECONDS"),
+ "inner")
+ (leftInput, rightInput, joined)
+ }
+
+ val (leftInput1, rightInput1, joined1) = buildJoinQuery()
+ testStream(joined1, OutputMode.Append())(
+ StartStream(checkpointLocation = checkpointPath),
+ MultiAddData(leftInput1, 1, 2, 3)(rightInput1, 1, 2),
+ ProcessAllAvailable(),
+ StopStream
+ )
+
+ assertJournaledStateSchemaAllNullable(checkpointPath)
+
+ val (leftInput2, rightInput2, joined2) = buildJoinQuery()
+ testStream(joined2, OutputMode.Append())(
+ StartStream(checkpointLocation = checkpointPath),
+ MultiAddData(leftInput2, 4)(rightInput2, 5),
+ ProcessAllAvailable()
+ )
+ }
+ }
+
+ test("streaming flatMapGroupsWithState: " +
+ "non-nullable -> nullable widening remains restart-compatible") {
+ val stateFunc = (key: Int, values: Iterator[Int], state: GroupState[Int])
=> {
+ val sum = values.sum + state.getOption.getOrElse(0)
+ state.update(sum)
+ Iterator((key, sum))
+ }
+
+ withTempDir { checkpointDir =>
+ val checkpointPath = checkpointDir.getAbsolutePath
+
+ def buildFmgwsQuery()
+ : (MemoryStream[Int], MemoryStream[Int], DataFrame) = {
+ val (inputA, inputB, dfA, dfB) = buildTwoSources()
+ val result = dfA.union(dfB)
+ .as[Int]
+ .groupByKey(identity)
+ .flatMapGroupsWithState(
+ OutputMode.Update(), GroupStateTimeout.NoTimeout())(stateFunc)
+ .toDF("key", "sum")
+ (inputA, inputB, result)
+ }
+
+ val (inputA1, inputB1, q1) = buildFmgwsQuery()
+ testStream(q1, OutputMode.Update())(
+ StartStream(checkpointLocation = checkpointPath),
+ AddData(inputA1, 1, 2, 3),
+ ProcessAllAvailable(),
+ StopStream
+ )
+
+ assertJournaledStateSchemaAllNullable(checkpointPath)
+
+ val (inputA2, inputB2, q2) = buildFmgwsQuery()
+ testStream(q2, OutputMode.Update())(
+ StartStream(checkpointLocation = checkpointPath),
+ MultiAddData(inputA2, 4)(inputB2, 5),
+ ProcessAllAvailable()
+ )
+ }
+ }
+
+ test("streaming transformWithState: " +
+ "non-nullable -> nullable widening remains restart-compatible") {
+ withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+ classOf[RocksDBStateStoreProvider].getName) {
+ withTempDir { checkpointDir =>
+ val checkpointPath = checkpointDir.getAbsolutePath
+
+ def buildTwsQuery()
+ : (MemoryStream[Int], MemoryStream[Int], DataFrame) = {
+ val (inputA, inputB, dfA, dfB) = buildTwoSources()
+ val result = dfA.union(dfB)
+ .as[Int]
+ .groupByKey(identity)
+ .transformWithState(
+ new NullabilityDriftCountProcessor(),
+ TimeMode.None(),
+ OutputMode.Update())
+ (inputA, inputB, result.toDF())
+ }
+
+ val (inputA1, inputB1, q1) = buildTwsQuery()
+ testStream(q1, OutputMode.Update())(
+ StartStream(checkpointLocation = checkpointPath),
+ AddData(inputA1, 1, 2, 3),
+ ProcessAllAvailable(),
+ StopStream
+ )
+
+ assertJournaledStateSchemaAllNullable(checkpointPath)
+
+ val (inputA2, inputB2, q2) = buildTwsQuery()
+ testStream(q2, OutputMode.Update())(
+ StartStream(checkpointLocation = checkpointPath),
+ MultiAddData(inputA2, 4)(inputB2, 5),
+ ProcessAllAvailable()
+ )
+ }
+ }
+ }
+
+ test("rule skips non-stateful nodes whose subtree has no stateful operator")
{
+ import
org.apache.spark.sql.catalyst.analysis.WidenStatefulOperatorAttributeNullability
+ import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
NamedExpression}
+ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate,
LocalRelation, Project}
+ import org.apache.spark.sql.types.IntegerType
+
+ val key = AttributeReference("key", IntegerType, nullable = false)()
+ val payload = AttributeReference("payload", IntegerType, nullable =
false)()
+ val source = LocalRelation(Seq(key, payload), isStreaming = true)
+ val project = Project(Seq(key, payload), source)
+ val agg = Aggregate(
+ groupingExpressions = Seq(key),
+ aggregateExpressions = Seq(key.asInstanceOf[NamedExpression]),
+ child = project)
+
+ val widened = WidenStatefulOperatorAttributeNullability(agg)
+
+ val projectAfter = widened.collectFirst { case p: Project => p }.getOrElse(
+ fail(s"expected to find a Project node in the rewritten plan: $widened"))
+ assert(projectAfter.projectList.forall {
+ case ar: AttributeReference => !ar.nullable
+ case _ => true
+ }, s"Project.projectList below a stateful op should remain non-nullable: "
+
+ s"${projectAfter.projectList}")
+
+ val aggAfter = widened.asInstanceOf[Aggregate]
+ assert(aggAfter.aggregateExpressions.forall {
+ case ar: AttributeReference => ar.nullable
+ case _ => true
+ }, s"Aggregate.aggregateExpressions should be widened to nullable: " +
+ s"${aggAfter.aggregateExpressions}")
+ assert(aggAfter.groupingExpressions.forall {
+ case ar: AttributeReference => ar.nullable
+ case _ => true
+ }, s"Aggregate.groupingExpressions should be widened to nullable: " +
+ s"${aggAfter.groupingExpressions}")
+ }
+
+ test("deepWidenAttribute recurses into struct fields, array elements, map
values") {
+ import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
ExprId}
+ import org.apache.spark.sql.types._
+
+ val nestedStruct = StructType(Seq(
+ StructField("inner_nn", IntegerType, nullable = false),
+ StructField("inner_nl", StringType, nullable = true)))
+ val arrayOfNonNull = ArrayType(IntegerType, containsNull = false)
+ val mapWithNonNullValue = MapType(StringType, IntegerType,
valueContainsNull = false)
+ val combined = StructType(Seq(
+ StructField("s", nestedStruct, nullable = false),
+ StructField("a", arrayOfNonNull, nullable = false),
+ StructField("m", mapWithNonNullValue, nullable = false)))
+
+ val attr = AttributeReference("complex", combined, nullable =
false)(ExprId(42L))
+ val widened = WidenStatefulOpNullability.deepWidenAttribute(attr)
+
+ assert(widened.nullable, "outer attribute should be widened to nullable")
+ val widenedStruct = widened.dataType.asInstanceOf[StructType]
+ val widenedNested = widenedStruct("s").dataType.asInstanceOf[StructType]
+ assert(
+ widenedStruct("s").nullable && widenedStruct("a").nullable &&
widenedStruct("m").nullable,
+ "all top-level fields should be widened to nullable")
+ assert(widenedNested("inner_nn").nullable &&
widenedNested("inner_nl").nullable,
+ "nested struct fields should be widened to nullable")
+ val widenedArray = widenedStruct("a").dataType.asInstanceOf[ArrayType]
+ assert(widenedArray.containsNull, "array element nullability should be
widened")
+ val widenedMap = widenedStruct("m").dataType.asInstanceOf[MapType]
+ assert(widenedMap.valueContainsNull, "map value nullability should be
widened")
+
+ assert(widened.exprId == attr.exprId)
+ assert(widened.name == attr.name)
+ assert(widened.qualifier == attr.qualifier)
+ }
+}
+
+class NullabilityDriftCountProcessor
+ extends StatefulProcessor[Int, Int, (Int, Long)] {
+ @transient private var countState: ValueState[Long] = _
+
+ override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
+ countState = getHandle.getValueState[Long](
+ "count", Encoders.scalaLong, TTLConfig.NONE)
+ }
+
+ override def handleInputRows(
+ key: Int,
+ rows: Iterator[Int],
+ timerValues: TimerValues): Iterator[(Int, Long)] = {
+ val count = (if (countState.exists()) countState.get() else 0L) + rows.size
+ countState.update(count)
+ Iterator((key, count))
+ }
+}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
index 0454c67f6a61..de1bc0d9c3d7 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala
@@ -1787,14 +1787,14 @@ abstract class TransformWithStateSuite extends
StateStoreMetricsTest
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString,
SQLConf.STREAMING_STATE_STORE_ENCODING_FORMAT.key -> encoding) {
withTempDir { checkpointDir =>
- // When Avro is used, we want to set the StructFields to nullable
- val shouldBeNullable = encoding == "avro"
val metadataPathPostfix = "state/0/_stateSchema/default"
val stateSchemaPath = new Path(checkpointDir.toString,
s"$metadataPathPostfix")
val hadoopConf = spark.sessionState.newHadoopConf()
val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf)
+ // When Avro is used, we want to set the StructFields to nullable
+ val shouldBeNullable = encoding == "avro"
val keySchema = new StructType().add("value", StringType)
val schema0 = StateStoreColFamilySchema(
"countState", 0,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]