This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new a9bfacb084e6 [SPARK-52391][SS] Refactor TransformWithStateExec to 
extract shared functions and variables into an abstract base class for Scala 
and Python
a9bfacb084e6 is described below

commit a9bfacb084e696265a9d1473efe5001d03700ee3
Author: huanliwang-db <huanli.w...@databricks.com>
AuthorDate: Fri Jun 6 16:22:53 2025 +0900

    [SPARK-52391][SS] Refactor TransformWithStateExec to extract shared 
functions and variables into an abstract base class for Scala and Python
    
    ### What changes were proposed in this pull request?
    
    Refactor the TWS Exec code to extract the common functions/variables and 
move them to a base abstract class such that it can be shared by both scala 
exec and python exec.
    
    ### Why are the changes needed?
    
    code elegant - less duplicate code
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    no functionalities change - existing UTs should be able to provide test 
coverage
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #51077 from huanliwang-db/refactor-tws.
    
    Authored-by: huanliwang-db <huanli.w...@databricks.com>
    Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com>
---
 .../TransformWithStateInPySparkExec.scala          | 119 ++--------
 .../streaming/TransformWithStateExec.scala         | 198 ++---------------
 .../streaming/TransformWithStateExecBase.scala     | 239 +++++++++++++++++++++
 3 files changed, 264 insertions(+), 292 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
index d0a4d8f6b284..b65d46fb1632 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPySparkExec.scala
@@ -28,18 +28,16 @@ import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, 
Expression, PythonUDF, SortOrder}
-import org.apache.spark.sql.catalyst.plans.logical.ProcessingTime
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, 
PythonUDF}
 import org.apache.spark.sql.catalyst.plans.logical.TransformWithStateInPySpark
-import org.apache.spark.sql.catalyst.plans.physical.Distribution
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
-import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, 
SparkPlan}
+import org.apache.spark.sql.execution.{CoGroupedIterator, SparkPlan}
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.execution.python.ArrowPythonRunner
 import org.apache.spark.sql.execution.python.PandasGroupUtils.{executePython, 
groupAndProject, resolveArgOffsets}
-import 
org.apache.spark.sql.execution.streaming.{DriverStatefulProcessorHandleImpl, 
StatefulOperatorCustomMetric, StatefulOperatorCustomSumMetric, 
StatefulOperatorPartitioning, StatefulOperatorStateInfo, 
StatefulProcessorHandleImpl, StateStoreWriter, TransformWithStateMetadataUtils, 
TransformWithStateVariableInfo, WatermarkSupport}
+import 
org.apache.spark.sql.execution.streaming.{DriverStatefulProcessorHandleImpl, 
StatefulOperatorStateInfo, StatefulProcessorHandleImpl, 
TransformWithStateExecBase, TransformWithStateVariableInfo}
 import 
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper
-import 
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, 
OperatorStateMetadata, RocksDBStateStoreProvider, StateSchemaValidationResult, 
StateStore, StateStoreColFamilySchema, StateStoreConf, StateStoreId, 
StateStoreOps, StateStoreProvider, StateStoreProviderId}
+import 
org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, 
RocksDBStateStoreProvider, StateSchemaValidationResult, StateStore, 
StateStoreColFamilySchema, StateStoreConf, StateStoreId, StateStoreOps, 
StateStoreProvider, StateStoreProviderId}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming.{OutputMode, TimeMode}
 import org.apache.spark.sql.types.{BinaryType, StructField, StructType}
@@ -83,10 +81,15 @@ case class TransformWithStateInPySparkExec(
     initialState: SparkPlan,
     initialStateGroupingAttrs: Seq[Attribute],
     initialStateSchema: StructType)
-  extends BinaryExecNode
-  with StateStoreWriter
-  with WatermarkSupport
-  with TransformWithStateMetadataUtils {
+  extends TransformWithStateExecBase(
+    groupingAttributes,
+    timeMode,
+    outputMode,
+    batchTimestampMs,
+    eventTimeWatermarkForEviction,
+    child,
+    initialStateGroupingAttrs,
+    initialState) {
 
   // NOTE: This is needed to comply with existing release of 
transformWithStateInPandas.
   override def shortName: String = if (
@@ -115,17 +118,12 @@ case class TransformWithStateInPySparkExec(
 
   private val numOutputRows: SQLMetric = longMetric("numOutputRows")
 
-  // The keys that may have a watermark attribute.
-  override def keyExpressions: Seq[Attribute] = groupingAttributes
-
   // Each state variable has its own schema, this is a dummy one.
   protected val schemaForKeyRow: StructType = new StructType().add("key", 
BinaryType)
 
   // Each state variable has its own schema, this is a dummy one.
   protected val schemaForValueRow: StructType = new StructType().add("value", 
BinaryType)
 
-  override def operatorStateMetadataVersion: Int = 2
-
   override def getColFamilySchemas(
       shouldBeNullable: Boolean): Map[String, StateStoreColFamilySchema] = {
     // For Python, the user can explicitly set nullability on schema, so
@@ -146,37 +144,6 @@ case class TransformWithStateInPySparkExec(
   private val driverProcessorHandle: DriverStatefulProcessorHandleImpl =
     new DriverStatefulProcessorHandleImpl(timeMode, groupingKeyExprEncoder)
 
-  /**
-   * Distribute by grouping attributes - We need the underlying data and the 
initial state data
-   * to have the same grouping so that the data are co-located on the same 
task.
-   */
-  override def requiredChildDistribution: Seq[Distribution] = {
-    StatefulOperatorPartitioning.getCompatibleDistribution(groupingAttributes,
-      getStateInfo, conf) ::
-      StatefulOperatorPartitioning.getCompatibleDistribution(
-        initialStateGroupingAttrs, getStateInfo, conf) ::
-      Nil
-  }
-
-  /**
-   * We need the initial state to also use the ordering as the data so that we 
can co-locate the
-   * keys from the underlying data and the initial state.
-   */
-  override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(
-    groupingAttributes.map(SortOrder(_, Ascending)),
-    initialStateGroupingAttrs.map(SortOrder(_, Ascending)))
-
-  override def operatorStateMetadata(
-      stateSchemaPaths: List[List[String]]): OperatorStateMetadata = {
-    getOperatorStateMetadata(stateSchemaPaths, getStateInfo, shortName, 
timeMode, outputMode)
-  }
-
-  override def validateNewMetadata(
-      oldOperatorMetadata: OperatorStateMetadata,
-      newOperatorMetadata: OperatorStateMetadata): Unit = {
-    validateNewMetadataForTWS(oldOperatorMetadata, newOperatorMetadata)
-  }
-
   override def validateAndMaybeEvolveStateSchema(
       hadoopConf: Configuration,
       batchId: Long,
@@ -208,60 +175,6 @@ case class TransformWithStateInPySparkExec(
         conf.stateStoreEncodingFormat)
   }
 
-  override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = {
-    if (timeMode == ProcessingTime) {
-      // TODO SPARK-50180: check if we can return true only if actual timers 
are registered,
-      //  or there is expired state
-      true
-    } else if (outputMode == OutputMode.Append || outputMode == 
OutputMode.Update) {
-      eventTimeWatermarkForEviction.isDefined &&
-        newInputWatermark > eventTimeWatermarkForEviction.get
-    } else {
-      false
-    }
-  }
-
-  /**
-   * Controls watermark propagation to downstream modes. If timeMode is
-   * ProcessingTime, the output rows cannot be interpreted in eventTime, hence
-   * this node will not propagate watermark in this timeMode.
-   *
-   * For timeMode EventTime, output watermark is same as input Watermark 
because
-   * transformWithState does not allow users to set the event time column to be
-   * earlier than the watermark.
-   */
-  override def produceOutputWatermark(inputWatermarkMs: Long): Option[Long] = {
-    timeMode match {
-      case ProcessingTime =>
-        None
-      case _ =>
-        Some(inputWatermarkMs)
-    }
-  }
-
-  override def customStatefulOperatorMetrics: 
Seq[StatefulOperatorCustomMetric] = {
-    Seq(
-      // metrics around state variables
-      StatefulOperatorCustomSumMetric("numValueStateVars", "Number of value 
state variables"),
-      StatefulOperatorCustomSumMetric("numListStateVars", "Number of list 
state variables"),
-      StatefulOperatorCustomSumMetric("numMapStateVars", "Number of map state 
variables"),
-      StatefulOperatorCustomSumMetric("numDeletedStateVars", "Number of 
deleted state variables"),
-      // metrics around timers
-      StatefulOperatorCustomSumMetric("numRegisteredTimers", "Number of 
registered timers"),
-      StatefulOperatorCustomSumMetric("numDeletedTimers", "Number of deleted 
timers"),
-      StatefulOperatorCustomSumMetric("numExpiredTimers", "Number of expired 
timers"),
-      // metrics around TTL
-      StatefulOperatorCustomSumMetric("numValueStateWithTTLVars",
-        "Number of value state variables with TTL"),
-      StatefulOperatorCustomSumMetric("numListStateWithTTLVars",
-        "Number of list state variables with TTL"),
-      StatefulOperatorCustomSumMetric("numMapStateWithTTLVars",
-        "Number of map state variables with TTL"),
-      StatefulOperatorCustomSumMetric("numValuesRemovedDueToTTLExpiry",
-        "Number of values removed due to TTL expiry")
-    )
-  }
-
   /**
    * Produces the result of the query as an `RDD[InternalRow]`
    */
@@ -376,8 +289,6 @@ case class TransformWithStateInPySparkExec(
     }
   }
 
-  override def supportsSchemaEvolution: Boolean = true
-
   private def processDataWithPartition(
       store: StateStore,
       dataIterator: Iterator[InternalRow],
@@ -491,10 +402,6 @@ case class TransformWithStateInPySparkExec(
     } else {
       copy(child = newLeft)
     }
-
-  override def left: SparkPlan = child
-
-  override def right: SparkPlan = initialState
 }
 
 // scalastyle:off argcount
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
index 8943e8898c6d..80fdaa1e71e2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala
@@ -21,21 +21,18 @@ import java.util.concurrent.TimeUnit.NANOSECONDS
 
 import org.apache.hadoop.conf.Configuration
 
-import org.apache.spark.SparkThrowable
 import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, 
Expression, SortOrder, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, 
UnsafeRow}
 import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.plans.physical.Distribution
 import org.apache.spark.sql.execution._
 import 
org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper
 import org.apache.spark.sql.execution.streaming.state._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming._
-import org.apache.spark.sql.types.{BinaryType, StructType}
-import org.apache.spark.util.{CompletionIterator, NextIterator, 
SerializableConfiguration, Utils}
+import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, 
Utils}
 
 /**
  * Physical operator for executing `TransformWithState`
@@ -76,17 +73,19 @@ case class TransformWithStateExec(
     initialStateDataAttrs: Seq[Attribute],
     initialStateDeserializer: Expression,
     initialState: SparkPlan)
-  extends BinaryExecNode
-  with StateStoreWriter
-  with WatermarkSupport
-  with ObjectProducerExec
-  with TransformWithStateMetadataUtils {
+  extends TransformWithStateExecBase(
+    groupingAttributes,
+    timeMode,
+    outputMode,
+    batchTimestampMs,
+    eventTimeWatermarkForEviction,
+    child,
+    initialStateGroupingAttrs,
+    initialState)
+  with ObjectProducerExec {
 
   override def shortName: String = "transformWithStateExec"
 
-  // dummy value schema, the real schema will get during state variable init 
time
-  private val DUMMY_VALUE_ROW_SCHEMA = new StructType().add("value", 
BinaryType)
-
   // We need to just initialize key and value deserializer once per partition.
   // The deserializers need to be lazily created on the executor since they
   // are not serializable.
@@ -98,21 +97,6 @@ case class TransformWithStateExec(
   private lazy val getValueObj =
     ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
 
-  override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = {
-    if (timeMode == ProcessingTime) {
-      // TODO SPARK-50180: check if we can return true only if actual timers 
are registered,
-      //  or there is expired state
-      true
-    } else if (outputMode == OutputMode.Append || outputMode == 
OutputMode.Update) {
-      eventTimeWatermarkForEviction.isDefined &&
-      newInputWatermark > eventTimeWatermarkForEviction.get
-    } else {
-      false
-    }
-  }
-
-  override def operatorStateMetadataVersion: Int = 2
-
   /**
    * We initialize this processor handle in the driver to run the init function
    * and fetch the schemas of the state variables initialized in this 
processor.
@@ -168,28 +152,6 @@ case class TransformWithStateExec(
     stateVariableInfos
   }
 
-  /**
-   * Controls watermark propagation to downstream modes. If timeMode is
-   * ProcessingTime, the output rows cannot be interpreted in eventTime, hence
-   * this node will not propagate watermark in this timeMode.
-   *
-   * For timeMode EventTime, output watermark is same as input Watermark 
because
-   * transformWithState does not allow users to set the event time column to be
-   * earlier than the watermark.
-   */
-  override def produceOutputWatermark(inputWatermarkMs: Long): Option[Long] = {
-    timeMode match {
-      case ProcessingTime =>
-        None
-      case _ =>
-        Some(inputWatermarkMs)
-    }
-  }
-
-  override def left: SparkPlan = child
-
-  override def right: SparkPlan = initialState
-
   override protected def withNewChildrenInternal(
       newLeft: SparkPlan, newRight: SparkPlan): TransformWithStateExec = {
     if (hasInitialState) {
@@ -199,66 +161,6 @@ case class TransformWithStateExec(
     }
   }
 
-  override def keyExpressions: Seq[Attribute] = groupingAttributes
-
-  /**
-   * Distribute by grouping attributes - We need the underlying data and the 
initial state data
-   * to have the same grouping so that the data are co-located on the same 
task.
-   */
-  override def requiredChildDistribution: Seq[Distribution] = {
-    StatefulOperatorPartitioning.getCompatibleDistribution(
-      groupingAttributes, getStateInfo, conf) ::
-    StatefulOperatorPartitioning.getCompatibleDistribution(
-        initialStateGroupingAttrs, getStateInfo, conf) ::
-    Nil
-  }
-
-  /**
-   * We need the initial state to also use the ordering as the data so that we 
can co-locate the
-   * keys from the underlying data and the initial state.
-   */
-  override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(
-    groupingAttributes.map(SortOrder(_, Ascending)),
-    initialStateGroupingAttrs.map(SortOrder(_, Ascending)))
-
-  // Wrapper to ensure that the implicit key is set when the methods on the 
iterator
-  // are called. We process all the values for a particular key at a time, so 
we
-  // only have to set the implicit key when the first call to the iterator is 
made, and
-  // we have to remove it when the iterator is closed.
-  //
-  // Note: if we ever start to interleave the processing of the iterators we 
get back
-  // from handleInputRows (i.e. we don't process each iterator all at once), 
then this
-  // iterator will need to set/unset the implicit key every time hasNext/next 
is called,
-  // not just at the first and last calls to hasNext.
-  private def iteratorWithImplicitKeySet(
-      key: Any,
-      iter: Iterator[InternalRow],
-      onClose: () => Unit = () => {}
-  ): Iterator[InternalRow] = {
-    new NextIterator[InternalRow] {
-      var hasStarted = false
-
-      override protected def getNext(): InternalRow = {
-        if (!hasStarted) {
-          hasStarted = true
-          ImplicitGroupingKeyTracker.setImplicitKey(key)
-        }
-
-        if (!iter.hasNext) {
-          finished = true
-          null
-        } else {
-          iter.next()
-        }
-      }
-
-      override protected def close(): Unit = {
-        onClose()
-        ImplicitGroupingKeyTracker.removeImplicitKey()
-      }
-    }
-  }
-
   private def handleInputRows(keyRow: UnsafeRow, valueRowIter: 
Iterator[InternalRow]):
     Iterator[InternalRow] = {
 
@@ -455,35 +357,6 @@ case class TransformWithStateExec(
     }
   }
 
-  // operator specific metrics
-  override def customStatefulOperatorMetrics: 
Seq[StatefulOperatorCustomMetric] = {
-    Seq(
-      // metrics around initial state
-      StatefulOperatorCustomSumMetric("initialStateProcessingTimeMs",
-        "Number of milliseconds taken to process all initial state"),
-      // metrics around state variables
-      StatefulOperatorCustomSumMetric("numValueStateVars", "Number of value 
state variables"),
-      StatefulOperatorCustomSumMetric("numListStateVars", "Number of list 
state variables"),
-      StatefulOperatorCustomSumMetric("numMapStateVars", "Number of map state 
variables"),
-      StatefulOperatorCustomSumMetric("numDeletedStateVars", "Number of 
deleted state variables"),
-      // metrics around timers
-      StatefulOperatorCustomSumMetric("timerProcessingTimeMs",
-        "Number of milliseconds taken to process all timers"),
-      StatefulOperatorCustomSumMetric("numRegisteredTimers", "Number of 
registered timers"),
-      StatefulOperatorCustomSumMetric("numDeletedTimers", "Number of deleted 
timers"),
-      StatefulOperatorCustomSumMetric("numExpiredTimers", "Number of expired 
timers"),
-      // metrics around TTL
-      StatefulOperatorCustomSumMetric("numValueStateWithTTLVars",
-        "Number of value state variables with TTL"),
-      StatefulOperatorCustomSumMetric("numListStateWithTTLVars",
-        "Number of list state variables with TTL"),
-      StatefulOperatorCustomSumMetric("numMapStateWithTTLVars",
-        "Number of map state variables with TTL"),
-      StatefulOperatorCustomSumMetric("numValuesRemovedDueToTTLExpiry",
-        "Number of values removed due to TTL expiry")
-    )
-  }
-
   override def validateAndMaybeEvolveStateSchema(
       hadoopConf: Configuration,
       batchId: Long,
@@ -494,19 +367,6 @@ case class TransformWithStateExec(
       info, stateSchemaDir, session, operatorStateMetadataVersion, 
conf.stateStoreEncodingFormat)
   }
 
-  /** Metadata of this stateful operator and its states stores. */
-  override def operatorStateMetadata(
-      stateSchemaPaths: List[List[String]]): OperatorStateMetadata = {
-    val info = getStateInfo
-    getOperatorStateMetadata(stateSchemaPaths, info, shortName, timeMode, 
outputMode)
-  }
-
-  override def validateNewMetadata(
-      oldOperatorMetadata: OperatorStateMetadata,
-      newOperatorMetadata: OperatorStateMetadata): Unit = {
-    validateNewMetadataForTWS(oldOperatorMetadata, newOperatorMetadata)
-  }
-
   override protected def doExecute(): RDD[InternalRow] = {
     metrics // force lazy init at driver
 
@@ -578,28 +438,6 @@ case class TransformWithStateExec(
     }
   }
 
-  /**
-   * Executes a block of code with standardized error handling for 
StatefulProcessor
-   * operations. Rethrows SparkThrowables directly and wraps other exceptions 
in
-   * TransformWithStateUserFunctionException with the provided function name.
-   *
-   * @param functionName The name of the function being executed (for error 
reporting)
-   * @param block The code block to execute with error handling
-   * @return The result of the block execution
-   */
-  private def withStatefulProcessorErrorHandling[R](functionName: 
String)(block: => R): R = {
-    try {
-      block
-    } catch {
-      case st: Exception with SparkThrowable if st.getCondition != null =>
-        throw st
-      case e: Exception =>
-        throw TransformWithStateUserFunctionException(e, functionName)
-    }
-  }
-
-  override def supportsSchemaEvolution: Boolean = true
-
   /**
    * Create a new StateStore for given partitionId and instantiate a temp 
directory
    * on the executors. Process data and close the stateStore provider 
afterwards.
@@ -693,18 +531,6 @@ case class TransformWithStateExec(
 
     processDataWithPartition(childDataIterator, store, processorHandle)
   }
-
-  private def validateTimeMode(): Unit = {
-    timeMode match {
-      case ProcessingTime =>
-        TransformWithStateVariableUtils.validateTimeMode(timeMode, 
batchTimestampMs)
-
-      case EventTime =>
-        TransformWithStateVariableUtils.validateTimeMode(timeMode, 
eventTimeWatermarkForEviction)
-
-      case _ =>
-    }
-  }
 }
 
 // scalastyle:off argcount
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExecBase.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExecBase.scala
new file mode 100644
index 000000000000..df68b21e0bb9
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExecBase.scala
@@ -0,0 +1,239 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.streaming
+
+import org.apache.spark.SparkThrowable
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, 
SortOrder}
+import org.apache.spark.sql.catalyst.plans.logical.{EventTime, ProcessingTime}
+import org.apache.spark.sql.catalyst.plans.physical.Distribution
+import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan}
+import org.apache.spark.sql.execution.streaming.state.{OperatorStateMetadata, 
TransformWithStateUserFunctionException}
+import org.apache.spark.sql.streaming.{OutputMode, TimeMode}
+import org.apache.spark.sql.types.{BinaryType, StructType}
+import org.apache.spark.util.NextIterator
+
+/**
+ * This is the base class for physical node that execute `TransformWithState`.
+ *
+ * It contains some common logics like state store metrics handling, co-locate
+ * initial state with the incoming data, and etc. Concrete physical node like
+ * `TransformWithStateInPySparkExec` and `TransformWithStateExec` should extend
+ * this class.
+ */
+abstract class TransformWithStateExecBase(
+    groupingAttributes: Seq[Attribute],
+    timeMode: TimeMode,
+    outputMode: OutputMode,
+    batchTimestampMs: Option[Long],
+    eventTimeWatermarkForEviction: Option[Long],
+    child: SparkPlan,
+    initialStateGroupingAttrs: Seq[Attribute],
+    initialState: SparkPlan)
+  extends BinaryExecNode
+  with StateStoreWriter
+  with WatermarkSupport
+  with TransformWithStateMetadataUtils {
+
+  override def operatorStateMetadataVersion: Int = 2
+
+  override def supportsSchemaEvolution: Boolean = true
+
+  override def left: SparkPlan = child
+
+  override def right: SparkPlan = initialState
+
+  // The keys that may have a watermark attribute.
+  override def keyExpressions: Seq[Attribute] = groupingAttributes
+
+  /**
+   * Distribute by grouping attributes - We need the underlying data and the 
initial state data to
+   * have the same grouping so that the data are co-located on the same task.
+   */
+  override def requiredChildDistribution: Seq[Distribution] = {
+    StatefulOperatorPartitioning.getCompatibleDistribution(
+      groupingAttributes,
+      getStateInfo,
+      conf) ::
+      StatefulOperatorPartitioning.getCompatibleDistribution(
+        initialStateGroupingAttrs,
+        getStateInfo,
+        conf) ::
+      Nil
+  }
+
+  /**
+   * We need the initial state to also use the ordering as the data so that we 
can co-locate the
+   * keys from the underlying data and the initial state.
+   */
+  override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(
+    groupingAttributes.map(SortOrder(_, Ascending)),
+    initialStateGroupingAttrs.map(SortOrder(_, Ascending)))
+
+  override def shouldRunAnotherBatch(newInputWatermark: Long): Boolean = {
+    if (timeMode == ProcessingTime) {
+      // TODO SPARK-50180: check if we can return true only if actual timers 
are registered,
+      //  or there is expired state
+      true
+    } else if (outputMode == OutputMode.Append || outputMode == 
OutputMode.Update) {
+      eventTimeWatermarkForEviction.isDefined &&
+      newInputWatermark > eventTimeWatermarkForEviction.get
+    } else {
+      false
+    }
+  }
+
+  /**
+   * Controls watermark propagation to downstream modes. If timeMode is 
ProcessingTime, the output
+   * rows cannot be interpreted in eventTime, hence this node will not 
propagate watermark in this
+   * timeMode.
+   *
+   * For timeMode EventTime, output watermark is same as input Watermark 
because
+   * transformWithState does not allow users to set the event time column to 
be earlier than the
+   * watermark.
+   */
+  override def produceOutputWatermark(inputWatermarkMs: Long): Option[Long] = {
+    timeMode match {
+      case ProcessingTime =>
+        None
+      case _ =>
+        Some(inputWatermarkMs)
+    }
+  }
+
+  // operator specific metrics
+  override def customStatefulOperatorMetrics: 
Seq[StatefulOperatorCustomMetric] = {
+    Seq(
+      // metrics around initial state
+      StatefulOperatorCustomSumMetric(
+        "initialStateProcessingTimeMs",
+        "Number of milliseconds taken to process all initial state"),
+      // metrics around state variables
+      StatefulOperatorCustomSumMetric("numValueStateVars", "Number of value 
state variables"),
+      StatefulOperatorCustomSumMetric("numListStateVars", "Number of list 
state variables"),
+      StatefulOperatorCustomSumMetric("numMapStateVars", "Number of map state 
variables"),
+      StatefulOperatorCustomSumMetric("numDeletedStateVars", "Number of 
deleted state variables"),
+      // metrics around timers
+      StatefulOperatorCustomSumMetric(
+        "timerProcessingTimeMs",
+        "Number of milliseconds taken to process all timers"),
+      StatefulOperatorCustomSumMetric("numRegisteredTimers", "Number of 
registered timers"),
+      StatefulOperatorCustomSumMetric("numDeletedTimers", "Number of deleted 
timers"),
+      StatefulOperatorCustomSumMetric("numExpiredTimers", "Number of expired 
timers"),
+      // metrics around TTL
+      StatefulOperatorCustomSumMetric(
+        "numValueStateWithTTLVars",
+        "Number of value state variables with TTL"),
+      StatefulOperatorCustomSumMetric(
+        "numListStateWithTTLVars",
+        "Number of list state variables with TTL"),
+      StatefulOperatorCustomSumMetric(
+        "numMapStateWithTTLVars",
+        "Number of map state variables with TTL"),
+      StatefulOperatorCustomSumMetric(
+        "numValuesRemovedDueToTTLExpiry",
+        "Number of values removed due to TTL expiry"))
+  }
+
+  /** Metadata of this stateful operator and its states stores. */
+  override def operatorStateMetadata(
+      stateSchemaPaths: List[List[String]]): OperatorStateMetadata = {
+    val info = getStateInfo
+    getOperatorStateMetadata(stateSchemaPaths, info, shortName, timeMode, 
outputMode)
+  }
+
+  override def validateNewMetadata(
+      oldOperatorMetadata: OperatorStateMetadata,
+      newOperatorMetadata: OperatorStateMetadata): Unit = {
+    validateNewMetadataForTWS(oldOperatorMetadata, newOperatorMetadata)
+  }
+
+  // dummy value schema, the real schema will get during state variable init 
time
+  protected val DUMMY_VALUE_ROW_SCHEMA = new StructType().add("value", 
BinaryType)
+
+  // Wrapper to ensure that the implicit key is set when the methods on the 
iterator
+  // are called. We process all the values for a particular key at a time, so 
we
+  // only have to set the implicit key when the first call to the iterator is 
made, and
+  // we have to remove it when the iterator is closed.
+  //
+  // Note: if we ever start to interleave the processing of the iterators we 
get back
+  // from handleInputRows (i.e. we don't process each iterator all at once), 
then this
+  // iterator will need to set/unset the implicit key every time hasNext/next 
is called,
+  // not just at the first and last calls to hasNext.
+  protected def iteratorWithImplicitKeySet(
+      key: Any,
+      iter: Iterator[InternalRow],
+      onClose: () => Unit = () => {}): Iterator[InternalRow] = {
+    new NextIterator[InternalRow] {
+      var hasStarted = false
+
+      override protected def getNext(): InternalRow = {
+        if (!hasStarted) {
+          hasStarted = true
+          ImplicitGroupingKeyTracker.setImplicitKey(key)
+        }
+
+        if (!iter.hasNext) {
+          finished = true
+          null
+        } else {
+          iter.next()
+        }
+      }
+
+      override protected def close(): Unit = {
+        onClose()
+        ImplicitGroupingKeyTracker.removeImplicitKey()
+      }
+    }
+  }
+
+  protected def validateTimeMode(): Unit = {
+    timeMode match {
+      case ProcessingTime =>
+        TransformWithStateVariableUtils.validateTimeMode(timeMode, 
batchTimestampMs)
+
+      case EventTime =>
+        TransformWithStateVariableUtils.validateTimeMode(timeMode, 
eventTimeWatermarkForEviction)
+
+      case _ =>
+    }
+  }
+
+  /**
+   * Executes a block of code with standardized error handling for 
StatefulProcessor operations.
+   * Rethrows SparkThrowables directly and wraps other exceptions in
+   * TransformWithStateUserFunctionException with the provided function name.
+   *
+   * @param functionName
+   *   The name of the function being executed (for error reporting)
+   * @param block
+   *   The code block to execute with error handling
+   * @return
+   *   The result of the block execution
+   */
+  protected def withStatefulProcessorErrorHandling[R](functionName: 
String)(block: => R): R = {
+    try {
+      block
+    } catch {
+      case st: Exception with SparkThrowable if st.getCondition != null =>
+        throw st
+      case e: Exception =>
+        throw TransformWithStateUserFunctionException(e, functionName)
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to