Repository: spark
Updated Branches:
  refs/heads/master 12b1e91e6 -> a71f6a175


[SPARK-25414][SS][TEST] make it clear that the numRows metrics should be 
counted for each scan of the source

## What changes were proposed in this pull request?

For self-join/self-union, Spark will produce a physical plan which has multiple 
`DataSourceV2ScanExec` instances referring to the same `ReadSupport` instance. 
In this case, the streaming source is indeed scanned multiple times, and the 
`numInputRows` metrics should be counted for each scan.

Actually we already have 2 test cases to verify the behavior:
1. `StreamingQuerySuite.input row calculation with same V2 source used twice in 
self-join`
2. `KafkaMicroBatchSourceSuiteBase.ensure stream-stream self-join generates 
only one offset in log and correct metrics`.

However, in these 2 tests, the expected result is different, which is super 
confusing. It turns out that, the first test doesn't trigger exchange reuse, so 
the source is scanned twice. The second test triggers exchange reuse, and the 
source is scanned only once.

This PR proposes to improve these 2 tests, to test with/without exchange reuse.

## How was this patch tested?

test only change

Closes #22402 from cloud-fan/bug.

Authored-by: Wenchen Fan <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a71f6a17
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a71f6a17
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a71f6a17

Branch: refs/heads/master
Commit: a71f6a1750fd0a29ecae6b98673ee15840da1c62
Parents: 12b1e91
Author: Wenchen Fan <[email protected]>
Authored: Thu Sep 20 00:29:48 2018 +0800
Committer: Wenchen Fan <[email protected]>
Committed: Thu Sep 20 00:29:48 2018 +0800

----------------------------------------------------------------------
 .../kafka010/KafkaMicroBatchSourceSuite.scala   | 39 +++++++++++++----
 .../execution/streaming/ProgressReporter.scala  |  6 +--
 .../sql/streaming/StreamingQuerySuite.scala     | 46 +++++++++++++++-----
 3 files changed, 68 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a71f6a17/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
----------------------------------------------------------------------
diff --git 
a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
 
b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
index 8e246db..e5f0088 100644
--- 
a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
+++ 
b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala
@@ -35,9 +35,11 @@ import org.scalatest.time.SpanSugar._
 
 import org.apache.spark.sql.{ForeachWriter, SparkSession}
 import 
org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation
+import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution
 import org.apache.spark.sql.functions.{count, window}
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.kafka010.KafkaSourceProvider._
 import org.apache.spark.sql.sources.v2.DataSourceOptions
 import org.apache.spark.sql.streaming.{ProcessingTime, StreamTest}
@@ -598,18 +600,37 @@ abstract class KafkaMicroBatchSourceSuiteBase extends 
KafkaSourceSuiteBase {
 
     val join = values.join(values, "key")
 
-    testStream(join)(
-      makeSureGetOffsetCalled,
-      AddKafkaData(Set(topic), 1, 2),
-      CheckAnswer((1, 1, 1), (2, 2, 2)),
-      AddKafkaData(Set(topic), 6, 3),
-      CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 
6, 6)),
-      AssertOnQuery { q =>
+    def checkQuery(check: AssertOnQuery): Unit = {
+      testStream(join)(
+        makeSureGetOffsetCalled,
+        AddKafkaData(Set(topic), 1, 2),
+        CheckAnswer((1, 1, 1), (2, 2, 2)),
+        AddKafkaData(Set(topic), 6, 3),
+        CheckAnswer((1, 1, 1), (2, 2, 2), (3, 3, 3), (1, 6, 1), (1, 1, 6), (1, 
6, 6)),
+        check
+      )
+    }
+
+    withSQLConf(SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") {
+      checkQuery(AssertOnQuery { q =>
         assert(q.availableOffsets.iterator.size == 1)
+        // The kafka source is scanned twice because of self-join
+        assert(q.recentProgress.map(_.numInputRows).sum == 8)
+        true
+      })
+    }
+
+    withSQLConf(SQLConf.EXCHANGE_REUSE_ENABLED.key -> "true") {
+      checkQuery(AssertOnQuery { q =>
+        assert(q.availableOffsets.iterator.size == 1)
+        assert(q.lastExecution.executedPlan.collect {
+          case r: ReusedExchangeExec => r
+        }.length == 1)
+        // The kafka source is scanned only once because of exchange reuse.
         assert(q.recentProgress.map(_.numInputRows).sum == 4)
         true
-      }
-    )
+      })
+    }
   }
 
   test("read Kafka transactional messages: read_committed") {

http://git-wip-us.apache.org/repos/asf/spark/blob/a71f6a17/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
index 73b1804..392229b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala
@@ -240,9 +240,6 @@ trait ProgressReporter extends Logging {
   /** Extract number of input sources for each streaming source in plan */
   private def extractSourceToNumInputRows(): Map[BaseStreamingSource, Long] = {
 
-    import java.util.IdentityHashMap
-    import scala.collection.JavaConverters._
-
     def sumRows(tuples: Seq[(BaseStreamingSource, Long)]): 
Map[BaseStreamingSource, Long] = {
       tuples.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each 
source
     }
@@ -255,6 +252,9 @@ trait ProgressReporter extends Logging {
     }
 
     if (onlyDataSourceV2Sources) {
+      // It's possible that multiple DataSourceV2ScanExec instances may refer 
to the same source
+      // (can happen with self-unions or self-joins). This means the source is 
scanned multiple
+      // times in the query, we should count the numRows for each scan.
       val sourceToInputRowsTuples = lastExecution.executedPlan.collect {
         case s: DataSourceV2ScanExec if 
s.readSupport.isInstanceOf[BaseStreamingSource] =>
           val numRows = 
s.metrics.get("numOutputRows").map(_.value).getOrElse(0L)

http://git-wip-us.apache.org/repos/asf/spark/blob/a71f6a17/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
index 1dd8175..c170641 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.SparkException
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{Column, DataFrame, Dataset, Row}
 import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, 
Shuffle, Uuid}
+import org.apache.spark.sql.execution.exchange.ReusedExchangeExec
 import org.apache.spark.sql.execution.streaming._
 import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter
 import org.apache.spark.sql.functions._
@@ -500,29 +501,52 @@ class StreamingQuerySuite extends StreamTest with 
BeforeAndAfter with Logging wi
       AssertOnQuery { q =>
         val lastProgress = getLastProgressWithData(q)
         assert(lastProgress.nonEmpty)
-        assert(lastProgress.get.numInputRows == 6)
         assert(lastProgress.get.sources.length == 1)
-        assert(lastProgress.get.sources(0).numInputRows == 6)
+        // The source is scanned twice because of self-union
+        assert(lastProgress.get.numInputRows == 6)
         true
       }
     )
   }
 
   test("input row calculation with same V2 source used twice in self-join") {
-    val streamInput = MemoryStream[Int]
-    val df = streamInput.toDF()
-    testStream(df.join(df, "value"), useV2Sink = true)(
-      AddData(streamInput, 1, 2, 3),
-      CheckAnswer(1, 2, 3),
-      AssertOnQuery { q =>
+    def checkQuery(check: AssertOnQuery): Unit = {
+      val memoryStream = MemoryStream[Int]
+      // TODO: currently the streaming framework always add a dummy Project 
above streaming source
+      // relation, which breaks exchange reuse, as the optimizer will remove 
Project from one side.
+      // Here we manually add a useful Project, to trigger exchange reuse.
+      val streamDF = memoryStream.toDF().select('value + 0 as "v")
+      testStream(streamDF.join(streamDF, "v"), useV2Sink = true)(
+        AddData(memoryStream, 1, 2, 3),
+        CheckAnswer(1, 2, 3),
+        check
+      )
+    }
+
+    withSQLConf(SQLConf.EXCHANGE_REUSE_ENABLED.key -> "false") {
+      checkQuery(AssertOnQuery { q =>
         val lastProgress = getLastProgressWithData(q)
         assert(lastProgress.nonEmpty)
+        assert(lastProgress.get.sources.length == 1)
+        // The source is scanned twice because of self-join
         assert(lastProgress.get.numInputRows == 6)
+        true
+      })
+    }
+
+    withSQLConf(SQLConf.EXCHANGE_REUSE_ENABLED.key -> "true") {
+      checkQuery(AssertOnQuery { q =>
+        val lastProgress = getLastProgressWithData(q)
+        assert(lastProgress.nonEmpty)
         assert(lastProgress.get.sources.length == 1)
-        assert(lastProgress.get.sources(0).numInputRows == 6)
+        assert(q.lastExecution.executedPlan.collect {
+          case r: ReusedExchangeExec => r
+        }.length == 1)
+        // The source is scanned only once because of exchange reuse
+        assert(lastProgress.get.numInputRows == 3)
         true
-      }
-    )
+      })
+    }
   }
 
   test("input row calculation with trigger having data for only one of two V2 
sources") {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to