Repository: spark Updated Branches: refs/heads/master 89e8d556b -> ffbca8451
[SPARK-23202][SQL] Add new API in DataSourceWriter: onDataWriterCommit ## What changes were proposed in this pull request? The current DataSourceWriter API makes it hard to implement `onTaskCommit(taskCommit: TaskCommitMessage)` in `FileCommitProtocol`. In general, on receiving commit message, driver can start processing messages(e.g. persist messages into files) before all the messages are collected. The proposal to add a new API: `add(WriterCommitMessage message)`: Handles a commit message on receiving from a successful data writer. This should make the whole API of DataSourceWriter compatible with `FileCommitProtocol`, and more flexible. There was another radical attempt in #20386. This one should be more reasonable. ## How was this patch tested? Unit test Author: Wang Gengliang <[email protected]> Closes #20454 from gengliangwang/write_api. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ffbca845 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ffbca845 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ffbca845 Branch: refs/heads/master Commit: ffbca84519011a747e0552632e88f5e4956e493d Parents: 89e8d55 Author: Wang Gengliang <[email protected]> Authored: Thu Feb 1 20:39:15 2018 +0800 Committer: Wenchen Fan <[email protected]> Committed: Thu Feb 1 20:39:15 2018 +0800 ---------------------------------------------------------------------- .../sql/sources/v2/writer/DataSourceWriter.java | 14 +++++++++++-- .../datasources/v2/WriteToDataSourceV2.scala | 5 ++++- .../sql/sources/v2/DataSourceV2Suite.scala | 21 +++++++++++++++++++- .../sources/v2/SimpleWritableDataSource.scala | 21 ++++++++++++++++++++ 4 files changed, 57 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ffbca845/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java index 7096aec..52324b3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceWriter.java @@ -63,6 +63,14 @@ public interface DataSourceWriter { DataWriterFactory<Row> createWriterFactory(); /** + * Handles a commit message on receiving from a successful data writer. + * + * If this method fails (by throwing an exception), this writing job is considered to to have been + * failed, and {@link #abort(WriterCommitMessage[])} would be called. + */ + default void onDataWriterCommit(WriterCommitMessage message) {} + + /** * Commits this writing job with a list of commit messages. The commit messages are collected from * successful data writers and are produced by {@link DataWriter#commit()}. * @@ -78,8 +86,10 @@ public interface DataSourceWriter { void commit(WriterCommitMessage[] messages); /** - * Aborts this writing job because some data writers are failed and keep failing when retry, or - * the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} fails. + * Aborts this writing job because some data writers are failed and keep failing when retry, + * or the Spark job fails with some unknown reasons, + * or {@link #onDataWriterCommit(WriterCommitMessage)} fails, + * or {@link #commit(WriterCommitMessage[])} fails. * * If this method fails (by throwing an exception), the underlying data source may require manual * cleanup. http://git-wip-us.apache.org/repos/asf/spark/blob/ffbca845/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala index 6592bd7..eefbcf4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -80,7 +80,10 @@ case class WriteToDataSourceV2Exec(writer: DataSourceWriter, query: SparkPlan) e rdd, runTask, rdd.partitions.indices, - (index, message: WriterCommitMessage) => messages(index) = message + (index, message: WriterCommitMessage) => { + messages(index) = message + writer.onDataWriterCommit(message) + } ) if (!writer.isInstanceOf[StreamWriter]) { http://git-wip-us.apache.org/repos/asf/spark/blob/ffbca845/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 2f49b07..1c3ba78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -21,7 +21,7 @@ import java.util.{ArrayList, List => JList} import test.org.apache.spark.sql.sources.v2._ -import org.apache.spark.SparkException +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec @@ -198,6 +198,25 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } } + + test("simple counter in writer with onDataWriterCommit") { + Seq(classOf[SimpleWritableDataSource]).foreach { cls => + withTempPath { file => + val path = file.getCanonicalPath + assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) + + val numPartition = 6 + spark.range(0, 10, 1, numPartition).select('id, -'id).write.format(cls.getName) + .option("path", path).save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(10).select('id, -'id)) + + assert(SimpleCounter.getCounter == numPartition, + "method onDataWriterCommit should be called as many as the number of partitions") + } + } + } } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { http://git-wip-us.apache.org/repos/asf/spark/blob/ffbca845/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index a131b16..36dd2a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -66,9 +66,14 @@ class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteS class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceWriter { override def createWriterFactory(): DataWriterFactory[Row] = { + SimpleCounter.resetCounter new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) } + override def onDataWriterCommit(message: WriterCommitMessage): Unit = { + SimpleCounter.increaseCounter + } + override def commit(messages: Array[WriterCommitMessage]): Unit = { val finalPath = new Path(path) val jobPath = new Path(new Path(finalPath, "_temporary"), jobId) @@ -183,6 +188,22 @@ class SimpleCSVDataReaderFactory(path: String, conf: SerializableConfiguration) } } +private[v2] object SimpleCounter { + private var count: Int = 0 + + def increaseCounter: Unit = { + count += 1 + } + + def getCounter: Int = { + count + } + + def resetCounter: Unit = { + count = 0 + } +} + class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) extends DataWriterFactory[Row] { --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
