This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 07ec264da9ed [SPARK-45813][CONNECT][PYTHON] Return the observed
metrics from commands
07ec264da9ed is described below
commit 07ec264da9ed56c0de21ab60fff95bab64d3579e
Author: Takuya UESHIN <[email protected]>
AuthorDate: Wed Nov 15 10:49:22 2023 +0900
[SPARK-45813][CONNECT][PYTHON] Return the observed metrics from commands
### What changes were proposed in this pull request?
Returns the observed metrics from commands.
### Why are the changes needed?
Currently the observed metrics on commands are not available.
For example:
```py
>>> df = spark.range(10)
>>>
>>> observation = Observation()
>>> observed_df = df.observe(observation, count(lit(1)).alias("cnt"))
>>>
>>> observed_df.show()
...
>>> observation.get
{}
```
it should be:
```py
>>> observation.get
{'cnt': 10}
```
### Does this PR introduce _any_ user-facing change?
Yes, the observed metrics on commands will be available.
### How was this patch tested?
Added the related tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43690 from ueshin/issues/SPARK-45813/observed_metrics.
Authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../connect/execution/ExecuteThreadRunner.scala | 15 +++++++++
.../execution/SparkConnectPlanExecution.scala | 37 ++++++++++++++++------
.../sql/connect/planner/SparkConnectPlanner.scala | 13 ++++++--
.../spark/sql/connect/service/ExecuteHolder.scala | 7 ++++
python/pyspark/sql/connect/client/core.py | 14 +++++---
.../sql/tests/connect/test_connect_basic.py | 2 +-
python/pyspark/sql/tests/test_dataframe.py | 21 ++++++++++++
.../scala/org/apache/spark/sql/Observation.scala | 31 +++++++++++++-----
8 files changed, 115 insertions(+), 25 deletions(-)
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
index ea2bbe0093fc..24b3c302b759 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/ExecuteThreadRunner.scala
@@ -162,6 +162,21 @@ private[connect] class ExecuteThreadRunner(executeHolder:
ExecuteHolder) extends
s"${executeHolder.request.getPlan.getOpTypeCase} not supported.")
}
+ if (executeHolder.observations.nonEmpty) {
+ val observedMetrics = executeHolder.observations.map { case (name,
observation) =>
+ val values = observation.getOrEmpty.map { case (key, value) =>
+ (Some(key), value)
+ }.toSeq
+ name -> values
+ }.toMap
+ executeHolder.responseObserver.onNext(
+ SparkConnectPlanExecution
+ .createObservedMetricsResponse(
+ executeHolder.sessionHolder.sessionId,
+ executeHolder.sessionHolder.serverSessionId,
+ observedMetrics))
+ }
+
lock.synchronized {
// Synchronized before sending ResultComplete, and up until completing
the result stream
// to prevent a situation in which a client of reattachable execution
receives
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
index 002239aba96e..23390bf7aba8 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala
@@ -66,9 +66,8 @@ private[execution] class
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
responseObserver.onNext(createSchemaResponse(request.getSessionId,
dataframe.schema))
processAsArrowBatches(dataframe, responseObserver, executeHolder)
responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder,
dataframe))
- if (dataframe.queryExecution.observedMetrics.nonEmpty) {
-
responseObserver.onNext(createObservedMetricsResponse(request.getSessionId,
dataframe))
- }
+ createObservedMetricsResponse(request.getSessionId, dataframe).foreach(
+ responseObserver.onNext)
}
type Batch = (Array[Byte], Long)
@@ -245,15 +244,33 @@ private[execution] class
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
private def createObservedMetricsResponse(
sessionId: String,
- dataframe: DataFrame): ExecutePlanResponse = {
- val observedMetrics = dataframe.queryExecution.observedMetrics.map { case
(name, row) =>
- val cols = (0 until row.length).map(i => toLiteralProto(row(i)))
+ dataframe: DataFrame): Option[ExecutePlanResponse] = {
+ val observedMetrics = dataframe.queryExecution.observedMetrics.collect {
+ case (name, row) if !executeHolder.observations.contains(name) =>
+ val values = (0 until row.length).map { i =>
+ (if (row.schema != null) Some(row.schema.fieldNames(i)) else None,
row(i))
+ }
+ name -> values
+ }
+ if (observedMetrics.nonEmpty) {
+ Some(SparkConnectPlanExecution
+ .createObservedMetricsResponse(sessionId,
sessionHolder.serverSessionId, observedMetrics))
+ } else None
+ }
+}
+
+object SparkConnectPlanExecution {
+ def createObservedMetricsResponse(
+ sessionId: String,
+ serverSessionId: String,
+ metrics: Map[String, Seq[(Option[String], Any)]]): ExecutePlanResponse =
{
+ val observedMetrics = metrics.map { case (name, values) =>
val metrics = ExecutePlanResponse.ObservedMetrics
.newBuilder()
.setName(name)
- .addAllValues(cols.asJava)
- if (row.schema != null) {
- metrics.addAllKeys(row.schema.fieldNames.toList.asJava)
+ values.foreach { case (key, value) =>
+ metrics.addValues(toLiteralProto(value))
+ key.foreach(metrics.addKeys)
}
metrics.build()
}
@@ -261,7 +278,7 @@ private[execution] class
SparkConnectPlanExecution(executeHolder: ExecuteHolder)
ExecutePlanResponse
.newBuilder()
.setSessionId(sessionId)
- .setServerSideSessionId(sessionHolder.serverSessionId)
+ .setServerSideSessionId(serverSessionId)
.addAllObservedMetrics(observedMetrics.asJava)
.build()
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 654513857824..637ed09798a5 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -42,7 +42,7 @@ import
org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.Streami
import org.apache.spark.connect.proto.WriteStreamOperationStart.TriggerCase
import org.apache.spark.internal.Logging
import org.apache.spark.ml.{functions => MLFunctions}
-import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter,
RelationalGroupedDataset, SparkSession}
+import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter,
Observation, RelationalGroupedDataset, SparkSession}
import org.apache.spark.sql.avro.{AvroDataToCatalyst, CatalystDataToAvro}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier,
FunctionIdentifier}
import org.apache.spark.sql.catalyst.analysis.{GlobalTempView, LocalTempView,
MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias,
UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue,
UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar}
@@ -1069,8 +1069,17 @@ class SparkConnectPlanner(
val metrics = rel.getMetricsList.asScala.toSeq.map { expr =>
Column(transformExpression(expr))
}
+ val name = rel.getName
+ val input = transformRelation(rel.getInput)
- CollectMetrics(rel.getName, metrics.map(_.named),
transformRelation(rel.getInput), planId)
+ if (input.isStreaming || executeHolderOpt.isEmpty) {
+ CollectMetrics(name, metrics.map(_.named),
transformRelation(rel.getInput), planId)
+ } else {
+ val observation = Observation(name)
+ observation.register(session, planId)
+ executeHolderOpt.get.addObservation(name, observation)
+ CollectMetrics(name, metrics.map(_.named),
transformRelation(rel.getInput), planId)
+ }
}
private def transformDeduplicate(rel: proto.Deduplicate): LogicalPlan = {
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
index eed8cc01f7c6..8b910154d2f4 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/ExecuteHolder.scala
@@ -25,6 +25,7 @@ import scala.jdk.CollectionConverters._
import org.apache.spark.{SparkEnv, SparkSQLException}
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
+import org.apache.spark.sql.Observation
import org.apache.spark.sql.connect.common.ProtoUtils
import
org.apache.spark.sql.connect.config.Connect.CONNECT_EXECUTE_REATTACHABLE_ENABLED
import org.apache.spark.sql.connect.execution.{ExecuteGrpcResponseSender,
ExecuteResponseObserver, ExecuteThreadRunner}
@@ -89,6 +90,8 @@ private[connect] class ExecuteHolder(
val eventsManager: ExecuteEventsManager = ExecuteEventsManager(this, new
SystemClock())
+ val observations: mutable.Map[String, Observation] = mutable.Map.empty
+
private val runner: ExecuteThreadRunner = new ExecuteThreadRunner(this)
/** System.currentTimeMillis when this ExecuteHolder was created. */
@@ -132,6 +135,10 @@ private[connect] class ExecuteHolder(
runner.join()
}
+ def addObservation(name: String, observation: Observation): Unit =
synchronized {
+ observations += (name -> observation)
+ }
+
/**
* Attach an ExecuteGrpcResponseSender that will consume responses from the
query and send them
* out on the Grpc response stream. The sender will start from the start of
the response stream.
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index b98de0f9ceea..a2590dec960d 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1176,10 +1176,16 @@ class SparkConnectClient(object):
logger.debug("Received observed metric batch.")
for observed_metrics in
self._build_observed_metrics(b.observed_metrics):
if observed_metrics.name in observations:
- observations[observed_metrics.name]._result = {
- key: LiteralExpression._to_value(metric)
- for key, metric in zip(observed_metrics.keys,
observed_metrics.metrics)
- }
+ observation_result =
observations[observed_metrics.name]._result
+ assert observation_result is not None
+ observation_result.update(
+ {
+ key: LiteralExpression._to_value(metric)
+ for key, metric in zip(
+ observed_metrics.keys,
observed_metrics.metrics
+ )
+ }
+ )
yield observed_metrics
if b.HasField("schema"):
logger.debug("Received the schema.")
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index e926eb835a80..d2febcd6b089 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -1824,7 +1824,7 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
self.assert_eq(cdf, df)
- self.assert_eq(cobservation.get, observation.get)
+ self.assertEquals(cobservation.get, observation.get)
observed_metrics = cdf.attrs["observed_metrics"]
self.assert_eq(len(observed_metrics), 1)
diff --git a/python/pyspark/sql/tests/test_dataframe.py
b/python/pyspark/sql/tests/test_dataframe.py
index 527cf702bce9..3b2fb87123eb 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -1078,6 +1078,27 @@ class DataFrameTestsMixin:
self.assertEqual(observation1.get, dict(cnt=50))
self.assertEqual(observation2.get, dict(cnt=100))
+ def test_observe_on_commands(self):
+ from pyspark.sql import Observation
+
+ df = self.spark.range(50)
+
+ test_table = "test_table"
+
+ # DataFrameWriter
+ with self.table(test_table):
+ for command, action in [
+ ("collect", lambda df: df.collect()),
+ ("show", lambda df: df.show(50)),
+ ("save", lambda df:
df.write.format("noop").mode("overwrite").save()),
+ ("create", lambda df:
df.writeTo(test_table).using("parquet").create()),
+ ]:
+ with self.subTest(command=command):
+ observation = Observation()
+ observed_df = df.observe(observation,
count(lit(1)).alias("cnt"))
+ action(observed_df)
+ self.assertEqual(observation.get, dict(cnt=50))
+
def test_sample(self):
with self.assertRaises(PySparkTypeError) as pe:
self.spark.range(1).sample()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
index f4b518c1e9fb..104e7c101fd1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Observation.scala
@@ -58,7 +58,7 @@ class Observation(val name: String) {
private val listener: ObservationListener = ObservationListener(this)
- @volatile private var ds: Option[Dataset[_]] = None
+ @volatile private var dataframeId: Option[(SparkSession, Long)] = None
@volatile private var metrics: Option[Map[String, Any]] = None
@@ -79,7 +79,7 @@ class Observation(val name: String) {
". Please register a StreamingQueryListener and get the metric for
each microbatch in " +
"QueryProgressEvent.progress, or use query.lastProgress or
query.recentProgress.")
}
- register(ds)
+ register(ds.sparkSession, ds.id)
ds.observe(name, expr, exprs: _*)
}
@@ -117,29 +117,44 @@ class Observation(val name: String) {
get.map { case (key, value) => (key, value.asInstanceOf[Object])}.asJava
}
- private def register(ds: Dataset[_]): Unit = {
+ /**
+ * Get the observed metrics. This returns the metrics if they are available,
otherwise an empty.
+ *
+ * @return the observed metrics as a `Map[String, Any]`
+ */
+ @throws[InterruptedException]
+ private[sql] def getOrEmpty: Map[String, _] = {
+ synchronized {
+ if (metrics.isEmpty) {
+ wait(100) // Wait for 100ms to see if metrics are available
+ }
+ metrics.getOrElse(Map.empty)
+ }
+ }
+
+ private[sql] def register(sparkSession: SparkSession, dataframeId: Long):
Unit = {
// makes this class thread-safe:
// only the first thread entering this block can set sparkSession
// all other threads will see the exception, as it is only allowed to do
this once
synchronized {
- if (this.ds.isDefined) {
+ if (this.dataframeId.isDefined) {
throw new IllegalArgumentException("An Observation can be used with a
Dataset only once")
}
- this.ds = Some(ds)
+ this.dataframeId = Some((sparkSession, dataframeId))
}
- ds.sparkSession.listenerManager.register(this.listener)
+ sparkSession.listenerManager.register(this.listener)
}
private def unregister(): Unit = {
- this.ds.foreach(_.sparkSession.listenerManager.unregister(this.listener))
+ this.dataframeId.foreach(_._1.listenerManager.unregister(this.listener))
}
private[spark] def onFinish(qe: QueryExecution): Unit = {
synchronized {
if (this.metrics.isEmpty && qe.logical.exists {
case CollectMetrics(name, _, _, dataframeId) =>
- name == this.name && dataframeId == ds.get.id
+ name == this.name && dataframeId == this.dataframeId.get._2
case _ => false
}) {
val row = qe.observedMetrics.get(name)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]