This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new a9bfacb084e6 [SPARK-52391][SS] Refactor TransformWithStateExec to extract shared functions and variables into an abstract base class for Scala and Python a9bfacb084e6 is described below commit a9bfacb084e696265a9d1473efe5001d03700ee3 Author: huanliwang-db <huanli.w...@databricks.com> AuthorDate: Fri Jun 6 16:22:53 2025 +0900 [SPARK-52391][SS] Refactor TransformWithStateExec to extract shared functions and variables into an abstract base class for Scala and Python ### What changes were proposed in this pull request? Refactor the TWS Exec code to extract the common functions/variables and move them to a base abstract class such that it can be shared by both scala exec and python exec. ### Why are the changes needed? code elegant - less duplicate code ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? no functionalities change - existing UTs should be able to provide test coverage ### Was this patch authored or co-authored using generative AI tooling? No Closes #51077 from huanliwang-db/refactor-tws. Authored-by: huanliwang-db <huanli.w...@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../TransformWithStateInPySparkExec.scala | 119 ++-------- .../streaming/TransformWithStateExec.scala | 198 ++--------------- .../streaming/TransformWithStateExecBase.scala | 239 +++++++++++++++++++++ 3 files changed, 264 insertions(+), 292 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala index d0a4d8f6b284..b65d46fb1632 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala @@ -28,18 +28,16 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, PythonUDF, SortOrder} -import org.apache.spark.sql.catalyst.plans.logical.ProcessingTime +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PythonUDF} import org.apache.spark.sql.catalyst.plans.logical.TransformWithStateInPySpark -import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.catalyst.types.DataTypeUtils -import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, SparkPlan} +import org.apache.spark.sql.execution.{CoGroupedIterator, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.ArrowPythonRunner import org.apache.spark.sql.execution.python.PandasGroupUtils.{executePython, groupAndProject, resolveArgOffsets} -import org.apache.spark.sql.execution.streaming.{DriverStatefulProcessorHandleImpl, StatefulOperatorCustomMetric, StatefulOperatorCustomSumMetric, StatefulOperatorPartitioning, StatefulOperatorStateInfo, StatefulProcessorHandleImpl, StateStoreWriter, TransformWithStateMetadataUtils, TransformWithStateVariableInfo, WatermarkSupport} +import org.apache.spark.sql.execution.streaming.{DriverStatefulProcessorHandleImpl, StatefulOperatorStateInfo, StatefulProcessorHandleImpl, TransformWithStateExecBase, TransformWithStateVariableInfo} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, OperatorStateMetadata, RocksDBStateStoreProvider, StateSchemaValidationResult, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreOps, StateStoreProvider, StateStoreProviderId} +import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, RocksDBStateStoreProvider, StateSchemaValidationResult, StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreOps, StateStoreProvider, StateStoreProviderId} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{OutputMode, TimeMode} import org.apache.spark.sql.types.{BinaryType, StructField, StructType} @@ -83,10 +81,15 @@ case class TransformWithStateInPySparkExec( initialState: SparkPlan, initialStateGroupingAttrs: Seq[Attribute], initialStateSchema: StructType) - extends BinaryExecNode - with StateStoreWriter - with WatermarkSupport - with TransformWithStateMetadataUtils { + extends TransformWithStateExecBase( + groupingAttributes, + timeMode, + outputMode, + batchTimestampMs, + eventTimeWatermarkForEviction, + child, + initialStateGroupingAttrs, + initialState) { // NOTE: This is needed to comply with existing release of transformWithStateInPandas. override def shortName: String = if ( @@ -115,17 +118,12 @@ case class TransformWithStateInPySparkExec( private val numOutputRows: SQLMetric = longMetric("numOutputRows") - // The keys that may have a watermark attribute. - override def keyExpressions: Seq[Attribute] = groupingAttributes - // Each state variable has its own schema, this is a dummy one. protected val schemaForKeyRow: StructType = new StructType().add("key", BinaryType) // Each state variable has its own schema, this is a dummy one. protected val schemaForValueRow: StructType = new StructType().add("value", BinaryType) - override def operatorStateMetadataVersion: Int = 2 - override def getColFamilySchemas( shouldBeNullable: Boolean): Map[String, StateStoreColFamilySchema] = { // For Python, the user can explicitly set nullability on schema, so @@ -146,37 +144,6 @@ case class TransformWithStateInPySparkExec( private val driverProcessorHandle: DriverStatefulProcessorHandleImpl = new DriverStatefulProcessorHandleImpl(timeMode, groupingKeyExprEncoder) - /** - * Distribute by grouping attributes - We need the underlying data and the initial state data - * to have the same grouping so that the data are co-located on the same task. - */ - override def requiredChildDistribution: Seq[Distribution] = { - StatefulOperatorPartitioning.getCompatibleDistribution(groupingAttributes, - getStateInfo, conf) :: - StatefulOperatorPartitioning.getCompatibleDistribution( - initialStateGroupingAttrs, getStateInfo, conf) :: - Nil - } - - /** - * We need the initial state to also use the ordering as the data so that we can co-locate the - * keys from the underlying data and the initial state. - */ - override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( - groupingAttributes.map(SortOrder(_, Ascending)), - initialStateGroupingAttrs.map(SortOrder(_, Ascending))) - - override def operatorStateMetadata( - stateSchemaPaths: List[List[String]]): OperatorStateMetadata = { - getOperatorStateMetadata(stateSchemaPaths, getStateInfo, shortName, timeMode, outputMode) - } - - override def validateNewMetadata( - oldOperatorMetadata: OperatorStateMetadata, - newOperatorMetadata: OperatorStateMetadata): Unit = { - validateNewMetadataForTWS(oldOperatorMetadata, newOperatorMetadata) - } - override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, @@ -208,60 +175,6 @@ case class TransformWithStateInPySparkExec( conf.stateStoreEncodingFormat) } - override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = { - if (timeMode == ProcessingTime) { - // TODO SPARK-50180: check if we can return true only if actual timers are registered, - // or there is expired state - true - } else if (outputMode == OutputMode.Append || outputMode == OutputMode.Update) { - eventTimeWatermarkForEviction.isDefined && - newInputWatermark > eventTimeWatermarkForEviction.get - } else { - false - } - } - - /** - * Controls watermark propagation to downstream modes. If timeMode is - * ProcessingTime, the output rows cannot be interpreted in eventTime, hence - * this node will not propagate watermark in this timeMode. - * - * For timeMode EventTime, output watermark is same as input Watermark because - * transformWithState does not allow users to set the event time column to be - * earlier than the watermark. - */ - override def produceOutputWatermark(inputWatermarkMs: Long): Option[Long] = { - timeMode match { - case ProcessingTime => - None - case _ => - Some(inputWatermarkMs) - } - } - - override def customStatefulOperatorMetrics: Seq[StatefulOperatorCustomMetric] = { - Seq( - // metrics around state variables - StatefulOperatorCustomSumMetric("numValueStateVars", "Number of value state variables"), - StatefulOperatorCustomSumMetric("numListStateVars", "Number of list state variables"), - StatefulOperatorCustomSumMetric("numMapStateVars", "Number of map state variables"), - StatefulOperatorCustomSumMetric("numDeletedStateVars", "Number of deleted state variables"), - // metrics around timers - StatefulOperatorCustomSumMetric("numRegisteredTimers", "Number of registered timers"), - StatefulOperatorCustomSumMetric("numDeletedTimers", "Number of deleted timers"), - StatefulOperatorCustomSumMetric("numExpiredTimers", "Number of expired timers"), - // metrics around TTL - StatefulOperatorCustomSumMetric("numValueStateWithTTLVars", - "Number of value state variables with TTL"), - StatefulOperatorCustomSumMetric("numListStateWithTTLVars", - "Number of list state variables with TTL"), - StatefulOperatorCustomSumMetric("numMapStateWithTTLVars", - "Number of map state variables with TTL"), - StatefulOperatorCustomSumMetric("numValuesRemovedDueToTTLExpiry", - "Number of values removed due to TTL expiry") - ) - } - /** * Produces the result of the query as an `RDD[InternalRow]` */ @@ -376,8 +289,6 @@ case class TransformWithStateInPySparkExec( } } - override def supportsSchemaEvolution: Boolean = true - private def processDataWithPartition( store: StateStore, dataIterator: Iterator[InternalRow], @@ -491,10 +402,6 @@ case class TransformWithStateInPySparkExec( } else { copy(child = newLeft) } - - override def left: SparkPlan = child - - override def right: SparkPlan = initialState } // scalastyle:off argcount diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 8943e8898c6d..80fdaa1e71e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -21,21 +21,18 @@ import java.util.concurrent.TimeUnit.NANOSECONDS import org.apache.hadoop.conf.Configuration -import org.apache.spark.SparkThrowable import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, UnsafeRow} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ -import org.apache.spark.sql.types.{BinaryType, StructType} -import org.apache.spark.util.{CompletionIterator, NextIterator, SerializableConfiguration, Utils} +import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Utils} /** * Physical operator for executing `TransformWithState` @@ -76,17 +73,19 @@ case class TransformWithStateExec( initialStateDataAttrs: Seq[Attribute], initialStateDeserializer: Expression, initialState: SparkPlan) - extends BinaryExecNode - with StateStoreWriter - with WatermarkSupport - with ObjectProducerExec - with TransformWithStateMetadataUtils { + extends TransformWithStateExecBase( + groupingAttributes, + timeMode, + outputMode, + batchTimestampMs, + eventTimeWatermarkForEviction, + child, + initialStateGroupingAttrs, + initialState) + with ObjectProducerExec { override def shortName: String = "transformWithStateExec" - // dummy value schema, the real schema will get during state variable init time - private val DUMMY_VALUE_ROW_SCHEMA = new StructType().add("value", BinaryType) - // We need to just initialize key and value deserializer once per partition. // The deserializers need to be lazily created on the executor since they // are not serializable. @@ -98,21 +97,6 @@ case class TransformWithStateExec( private lazy val getValueObj = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) - override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = { - if (timeMode == ProcessingTime) { - // TODO SPARK-50180: check if we can return true only if actual timers are registered, - // or there is expired state - true - } else if (outputMode == OutputMode.Append || outputMode == OutputMode.Update) { - eventTimeWatermarkForEviction.isDefined && - newInputWatermark > eventTimeWatermarkForEviction.get - } else { - false - } - } - - override def operatorStateMetadataVersion: Int = 2 - /** * We initialize this processor handle in the driver to run the init function * and fetch the schemas of the state variables initialized in this processor. @@ -168,28 +152,6 @@ case class TransformWithStateExec( stateVariableInfos } - /** - * Controls watermark propagation to downstream modes. If timeMode is - * ProcessingTime, the output rows cannot be interpreted in eventTime, hence - * this node will not propagate watermark in this timeMode. - * - * For timeMode EventTime, output watermark is same as input Watermark because - * transformWithState does not allow users to set the event time column to be - * earlier than the watermark. - */ - override def produceOutputWatermark(inputWatermarkMs: Long): Option[Long] = { - timeMode match { - case ProcessingTime => - None - case _ => - Some(inputWatermarkMs) - } - } - - override def left: SparkPlan = child - - override def right: SparkPlan = initialState - override protected def withNewChildrenInternal( newLeft: SparkPlan, newRight: SparkPlan): TransformWithStateExec = { if (hasInitialState) { @@ -199,66 +161,6 @@ case class TransformWithStateExec( } } - override def keyExpressions: Seq[Attribute] = groupingAttributes - - /** - * Distribute by grouping attributes - We need the underlying data and the initial state data - * to have the same grouping so that the data are co-located on the same task. - */ - override def requiredChildDistribution: Seq[Distribution] = { - StatefulOperatorPartitioning.getCompatibleDistribution( - groupingAttributes, getStateInfo, conf) :: - StatefulOperatorPartitioning.getCompatibleDistribution( - initialStateGroupingAttrs, getStateInfo, conf) :: - Nil - } - - /** - * We need the initial state to also use the ordering as the data so that we can co-locate the - * keys from the underlying data and the initial state. - */ - override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( - groupingAttributes.map(SortOrder(_, Ascending)), - initialStateGroupingAttrs.map(SortOrder(_, Ascending))) - - // Wrapper to ensure that the implicit key is set when the methods on the iterator - // are called. We process all the values for a particular key at a time, so we - // only have to set the implicit key when the first call to the iterator is made, and - // we have to remove it when the iterator is closed. - // - // Note: if we ever start to interleave the processing of the iterators we get back - // from handleInputRows (i.e. we don't process each iterator all at once), then this - // iterator will need to set/unset the implicit key every time hasNext/next is called, - // not just at the first and last calls to hasNext. - private def iteratorWithImplicitKeySet( - key: Any, - iter: Iterator[InternalRow], - onClose: () => Unit = () => {} - ): Iterator[InternalRow] = { - new NextIterator[InternalRow] { - var hasStarted = false - - override protected def getNext(): InternalRow = { - if (!hasStarted) { - hasStarted = true - ImplicitGroupingKeyTracker.setImplicitKey(key) - } - - if (!iter.hasNext) { - finished = true - null - } else { - iter.next() - } - } - - override protected def close(): Unit = { - onClose() - ImplicitGroupingKeyTracker.removeImplicitKey() - } - } - } - private def handleInputRows(keyRow: UnsafeRow, valueRowIter: Iterator[InternalRow]): Iterator[InternalRow] = { @@ -455,35 +357,6 @@ case class TransformWithStateExec( } } - // operator specific metrics - override def customStatefulOperatorMetrics: Seq[StatefulOperatorCustomMetric] = { - Seq( - // metrics around initial state - StatefulOperatorCustomSumMetric("initialStateProcessingTimeMs", - "Number of milliseconds taken to process all initial state"), - // metrics around state variables - StatefulOperatorCustomSumMetric("numValueStateVars", "Number of value state variables"), - StatefulOperatorCustomSumMetric("numListStateVars", "Number of list state variables"), - StatefulOperatorCustomSumMetric("numMapStateVars", "Number of map state variables"), - StatefulOperatorCustomSumMetric("numDeletedStateVars", "Number of deleted state variables"), - // metrics around timers - StatefulOperatorCustomSumMetric("timerProcessingTimeMs", - "Number of milliseconds taken to process all timers"), - StatefulOperatorCustomSumMetric("numRegisteredTimers", "Number of registered timers"), - StatefulOperatorCustomSumMetric("numDeletedTimers", "Number of deleted timers"), - StatefulOperatorCustomSumMetric("numExpiredTimers", "Number of expired timers"), - // metrics around TTL - StatefulOperatorCustomSumMetric("numValueStateWithTTLVars", - "Number of value state variables with TTL"), - StatefulOperatorCustomSumMetric("numListStateWithTTLVars", - "Number of list state variables with TTL"), - StatefulOperatorCustomSumMetric("numMapStateWithTTLVars", - "Number of map state variables with TTL"), - StatefulOperatorCustomSumMetric("numValuesRemovedDueToTTLExpiry", - "Number of values removed due to TTL expiry") - ) - } - override def validateAndMaybeEvolveStateSchema( hadoopConf: Configuration, batchId: Long, @@ -494,19 +367,6 @@ case class TransformWithStateExec( info, stateSchemaDir, session, operatorStateMetadataVersion, conf.stateStoreEncodingFormat) } - /** Metadata of this stateful operator and its states stores. */ - override def operatorStateMetadata( - stateSchemaPaths: List[List[String]]): OperatorStateMetadata = { - val info = getStateInfo - getOperatorStateMetadata(stateSchemaPaths, info, shortName, timeMode, outputMode) - } - - override def validateNewMetadata( - oldOperatorMetadata: OperatorStateMetadata, - newOperatorMetadata: OperatorStateMetadata): Unit = { - validateNewMetadataForTWS(oldOperatorMetadata, newOperatorMetadata) - } - override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver @@ -578,28 +438,6 @@ case class TransformWithStateExec( } } - /** - * Executes a block of code with standardized error handling for StatefulProcessor - * operations. Rethrows SparkThrowables directly and wraps other exceptions in - * TransformWithStateUserFunctionException with the provided function name. - * - * @param functionName The name of the function being executed (for error reporting) - * @param block The code block to execute with error handling - * @return The result of the block execution - */ - private def withStatefulProcessorErrorHandling[R](functionName: String)(block: => R): R = { - try { - block - } catch { - case st: Exception with SparkThrowable if st.getCondition != null => - throw st - case e: Exception => - throw TransformWithStateUserFunctionException(e, functionName) - } - } - - override def supportsSchemaEvolution: Boolean = true - /** * Create a new StateStore for given partitionId and instantiate a temp directory * on the executors. Process data and close the stateStore provider afterwards. @@ -693,18 +531,6 @@ case class TransformWithStateExec( processDataWithPartition(childDataIterator, store, processorHandle) } - - private def validateTimeMode(): Unit = { - timeMode match { - case ProcessingTime => - TransformWithStateVariableUtils.validateTimeMode(timeMode, batchTimestampMs) - - case EventTime => - TransformWithStateVariableUtils.validateTimeMode(timeMode, eventTimeWatermarkForEviction) - - case _ => - } - } } // scalastyle:off argcount diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExecBase.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExecBase.scala new file mode 100644 index 000000000000..df68b21e0bb9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExecBase.scala @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.SparkThrowable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.{EventTime, ProcessingTime} +import org.apache.spark.sql.catalyst.plans.physical.Distribution +import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} +import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, TransformWithStateUserFunctionException} +import org.apache.spark.sql.streaming.{OutputMode, TimeMode} +import org.apache.spark.sql.types.{BinaryType, StructType} +import org.apache.spark.util.NextIterator + +/** + * This is the base class for physical node that execute `TransformWithState`. + * + * It contains some common logics like state store metrics handling, co-locate + * initial state with the incoming data, and etc. Concrete physical node like + * `TransformWithStateInPySparkExec` and `TransformWithStateExec` should extend + * this class. + */ +abstract class TransformWithStateExecBase( + groupingAttributes: Seq[Attribute], + timeMode: TimeMode, + outputMode: OutputMode, + batchTimestampMs: Option[Long], + eventTimeWatermarkForEviction: Option[Long], + child: SparkPlan, + initialStateGroupingAttrs: Seq[Attribute], + initialState: SparkPlan) + extends BinaryExecNode + with StateStoreWriter + with WatermarkSupport + with TransformWithStateMetadataUtils { + + override def operatorStateMetadataVersion: Int = 2 + + override def supportsSchemaEvolution: Boolean = true + + override def left: SparkPlan = child + + override def right: SparkPlan = initialState + + // The keys that may have a watermark attribute. + override def keyExpressions: Seq[Attribute] = groupingAttributes + + /** + * Distribute by grouping attributes - We need the underlying data and the initial state data to + * have the same grouping so that the data are co-located on the same task. + */ + override def requiredChildDistribution: Seq[Distribution] = { + StatefulOperatorPartitioning.getCompatibleDistribution( + groupingAttributes, + getStateInfo, + conf) :: + StatefulOperatorPartitioning.getCompatibleDistribution( + initialStateGroupingAttrs, + getStateInfo, + conf) :: + Nil + } + + /** + * We need the initial state to also use the ordering as the data so that we can co-locate the + * keys from the underlying data and the initial state. + */ + override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq( + groupingAttributes.map(SortOrder(_, Ascending)), + initialStateGroupingAttrs.map(SortOrder(_, Ascending))) + + override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = { + if (timeMode == ProcessingTime) { + // TODO SPARK-50180: check if we can return true only if actual timers are registered, + // or there is expired state + true + } else if (outputMode == OutputMode.Append || outputMode == OutputMode.Update) { + eventTimeWatermarkForEviction.isDefined && + newInputWatermark > eventTimeWatermarkForEviction.get + } else { + false + } + } + + /** + * Controls watermark propagation to downstream modes. If timeMode is ProcessingTime, the output + * rows cannot be interpreted in eventTime, hence this node will not propagate watermark in this + * timeMode. + * + * For timeMode EventTime, output watermark is same as input Watermark because + * transformWithState does not allow users to set the event time column to be earlier than the + * watermark. + */ + override def produceOutputWatermark(inputWatermarkMs: Long): Option[Long] = { + timeMode match { + case ProcessingTime => + None + case _ => + Some(inputWatermarkMs) + } + } + + // operator specific metrics + override def customStatefulOperatorMetrics: Seq[StatefulOperatorCustomMetric] = { + Seq( + // metrics around initial state + StatefulOperatorCustomSumMetric( + "initialStateProcessingTimeMs", + "Number of milliseconds taken to process all initial state"), + // metrics around state variables + StatefulOperatorCustomSumMetric("numValueStateVars", "Number of value state variables"), + StatefulOperatorCustomSumMetric("numListStateVars", "Number of list state variables"), + StatefulOperatorCustomSumMetric("numMapStateVars", "Number of map state variables"), + StatefulOperatorCustomSumMetric("numDeletedStateVars", "Number of deleted state variables"), + // metrics around timers + StatefulOperatorCustomSumMetric( + "timerProcessingTimeMs", + "Number of milliseconds taken to process all timers"), + StatefulOperatorCustomSumMetric("numRegisteredTimers", "Number of registered timers"), + StatefulOperatorCustomSumMetric("numDeletedTimers", "Number of deleted timers"), + StatefulOperatorCustomSumMetric("numExpiredTimers", "Number of expired timers"), + // metrics around TTL + StatefulOperatorCustomSumMetric( + "numValueStateWithTTLVars", + "Number of value state variables with TTL"), + StatefulOperatorCustomSumMetric( + "numListStateWithTTLVars", + "Number of list state variables with TTL"), + StatefulOperatorCustomSumMetric( + "numMapStateWithTTLVars", + "Number of map state variables with TTL"), + StatefulOperatorCustomSumMetric( + "numValuesRemovedDueToTTLExpiry", + "Number of values removed due to TTL expiry")) + } + + /** Metadata of this stateful operator and its states stores. */ + override def operatorStateMetadata( + stateSchemaPaths: List[List[String]]): OperatorStateMetadata = { + val info = getStateInfo + getOperatorStateMetadata(stateSchemaPaths, info, shortName, timeMode, outputMode) + } + + override def validateNewMetadata( + oldOperatorMetadata: OperatorStateMetadata, + newOperatorMetadata: OperatorStateMetadata): Unit = { + validateNewMetadataForTWS(oldOperatorMetadata, newOperatorMetadata) + } + + // dummy value schema, the real schema will get during state variable init time + protected val DUMMY_VALUE_ROW_SCHEMA = new StructType().add("value", BinaryType) + + // Wrapper to ensure that the implicit key is set when the methods on the iterator + // are called. We process all the values for a particular key at a time, so we + // only have to set the implicit key when the first call to the iterator is made, and + // we have to remove it when the iterator is closed. + // + // Note: if we ever start to interleave the processing of the iterators we get back + // from handleInputRows (i.e. we don't process each iterator all at once), then this + // iterator will need to set/unset the implicit key every time hasNext/next is called, + // not just at the first and last calls to hasNext. + protected def iteratorWithImplicitKeySet( + key: Any, + iter: Iterator[InternalRow], + onClose: () => Unit = () => {}): Iterator[InternalRow] = { + new NextIterator[InternalRow] { + var hasStarted = false + + override protected def getNext(): InternalRow = { + if (!hasStarted) { + hasStarted = true + ImplicitGroupingKeyTracker.setImplicitKey(key) + } + + if (!iter.hasNext) { + finished = true + null + } else { + iter.next() + } + } + + override protected def close(): Unit = { + onClose() + ImplicitGroupingKeyTracker.removeImplicitKey() + } + } + } + + protected def validateTimeMode(): Unit = { + timeMode match { + case ProcessingTime => + TransformWithStateVariableUtils.validateTimeMode(timeMode, batchTimestampMs) + + case EventTime => + TransformWithStateVariableUtils.validateTimeMode(timeMode, eventTimeWatermarkForEviction) + + case _ => + } + } + + /** + * Executes a block of code with standardized error handling for StatefulProcessor operations. + * Rethrows SparkThrowables directly and wraps other exceptions in + * TransformWithStateUserFunctionException with the provided function name. + * + * @param functionName + * The name of the function being executed (for error reporting) + * @param block + * The code block to execute with error handling + * @return + * The result of the block execution + */ + protected def withStatefulProcessorErrorHandling[R](functionName: String)(block: => R): R = { + try { + block + } catch { + case st: Exception with SparkThrowable if st.getCondition != null => + throw st + case e: Exception => + throw TransformWithStateUserFunctionException(e, functionName) + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org