This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new c9cfaac90fd4 [SPARK-46452][SQL] Add a new API in DataWriter to write
an iterator of records
c9cfaac90fd4 is described below
commit c9cfaac90fd423c3a38e295234e24744b946cb02
Author: allisonwang-db <[email protected]>
AuthorDate: Wed Dec 20 19:17:21 2023 +0800
[SPARK-46452][SQL] Add a new API in DataWriter to write an iterator of
records
### What changes were proposed in this pull request?
This PR proposes to add a new method in `DataWriter` that supports writing
an iterator of records:
```java
void writeAll(Iterator<T> records) throws IOException
```
### Why are the changes needed?
To make the API more flexible and support more use cases (e.g Python data
sources). See https://github.com/apache/spark/pull/43791
### Does this PR introduce _any_ user-facing change?
Yes. This PR introduces a new method in `DataWriter`.
### How was this patch tested?
Existing unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #44410 from allisonwang-db/spark-46452-dsv2-write-all.
Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/connector/write/DataWriter.java | 18 +++
.../datasources/v2/WriteToDataSourceV2Exec.scala | 121 ++++++++++++---------
2 files changed, 88 insertions(+), 51 deletions(-)
diff --git
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java
index 6a1cee181bc2..d6e94fe2ca8b 100644
---
a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java
+++
b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/write/DataWriter.java
@@ -19,6 +19,7 @@ package org.apache.spark.sql.connector.write;
import java.io.Closeable;
import java.io.IOException;
+import java.util.Iterator;
import org.apache.spark.annotation.Evolving;
import org.apache.spark.sql.connector.metric.CustomTaskMetric;
@@ -74,6 +75,23 @@ public interface DataWriter<T> extends Closeable {
*/
void write(T record) throws IOException;
+ /**
+ * Writes all records provided by the given iterator. By default, it calls
the {@link #write}
+ * method for each record in the iterator.
+ * <p>
+ * If this method fails (by throwing an exception), {@link #abort()} will be
called and this
+ * data writer is considered to have been failed.
+ *
+ * @throws IOException if failure happens during disk/network IO like
writing files.
+ *
+ * @since 4.0.0
+ */
+ default void writeAll(Iterator<T> records) throws IOException {
+ while (records.hasNext()) {
+ write(records.next());
+ }
+ }
+
/**
* Commits this writer after all records are written successfully, returns a
commit message which
* will be sent back to driver side and passed to
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
index 2527f201f3a8..97c1f7ced508 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala
@@ -421,7 +421,7 @@ trait V2TableWriteExec extends V2CommandExec with
UnaryExecNode {
trait WritingSparkTask[W <: DataWriter[InternalRow]] extends Logging with
Serializable {
- protected def write(writer: W, row: InternalRow): Unit
+ protected def write(writer: W, iter: java.util.Iterator[InternalRow]): Unit
def run(
writerFactory: DataWriterFactory,
@@ -436,19 +436,11 @@ trait WritingSparkTask[W <: DataWriter[InternalRow]]
extends Logging with Serial
val attemptId = context.attemptNumber()
val dataWriter = writerFactory.createWriter(partId, taskId).asInstanceOf[W]
- var count = 0L
+ val iterWithMetrics = IteratorWithMetrics(iter, dataWriter, customMetrics)
+
// write the data and commit this writer.
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
- while (iter.hasNext) {
- if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) {
- CustomMetrics.updateMetrics(
- dataWriter.currentMetricsValues.toImmutableArraySeq, customMetrics)
- }
-
- // Count is here.
- count += 1
- write(dataWriter, iter.next())
- }
+ write(dataWriter, iterWithMetrics)
CustomMetrics.updateMetrics(
dataWriter.currentMetricsValues.toImmutableArraySeq, customMetrics)
@@ -476,7 +468,7 @@ trait WritingSparkTask[W <: DataWriter[InternalRow]]
extends Logging with Serial
logInfo(s"Committed partition $partId (task $taskId, attempt $attemptId,
" +
s"stage $stageId.$stageAttempt)")
- DataWritingSparkTaskResult(count, msg)
+ DataWritingSparkTaskResult(iterWithMetrics.count, msg)
})(catchBlock = {
// If there is an error, abort this writer
@@ -489,11 +481,30 @@ trait WritingSparkTask[W <: DataWriter[InternalRow]]
extends Logging with Serial
dataWriter.close()
})
}
+
+ private case class IteratorWithMetrics(
+ iter: Iterator[InternalRow],
+ dataWriter: W,
+ customMetrics: Map[String, SQLMetric]) extends
java.util.Iterator[InternalRow] {
+ var count = 0L
+
+ override def hasNext: Boolean = iter.hasNext
+
+ override def next(): InternalRow = {
+ if (count % CustomMetrics.NUM_ROWS_PER_UPDATE == 0) {
+ CustomMetrics.updateMetrics(
+ dataWriter.currentMetricsValues.toImmutableArraySeq, customMetrics)
+ }
+ count += 1
+ iter.next()
+ }
+ }
}
object DataWritingSparkTask extends WritingSparkTask[DataWriter[InternalRow]] {
- override protected def write(writer: DataWriter[InternalRow], row:
InternalRow): Unit = {
- writer.write(row)
+ override protected def write(
+ writer: DataWriter[InternalRow], iter: java.util.Iterator[InternalRow]):
Unit = {
+ writer.writeAll(iter)
}
}
@@ -503,25 +514,29 @@ case class DeltaWritingSparkTask(
private lazy val rowProjection = projections.rowProjection.orNull
private lazy val rowIdProjection = projections.rowIdProjection
- override protected def write(writer: DeltaWriter[InternalRow], row:
InternalRow): Unit = {
- val operation = row.getInt(0)
+ override protected def write(
+ writer: DeltaWriter[InternalRow], iter:
java.util.Iterator[InternalRow]): Unit = {
+ while (iter.hasNext) {
+ val row = iter.next()
+ val operation = row.getInt(0)
- operation match {
- case DELETE_OPERATION =>
- rowIdProjection.project(row)
- writer.delete(null, rowIdProjection)
+ operation match {
+ case DELETE_OPERATION =>
+ rowIdProjection.project(row)
+ writer.delete(null, rowIdProjection)
- case UPDATE_OPERATION =>
- rowProjection.project(row)
- rowIdProjection.project(row)
- writer.update(null, rowIdProjection, rowProjection)
+ case UPDATE_OPERATION =>
+ rowProjection.project(row)
+ rowIdProjection.project(row)
+ writer.update(null, rowIdProjection, rowProjection)
- case INSERT_OPERATION =>
- rowProjection.project(row)
- writer.insert(rowProjection)
+ case INSERT_OPERATION =>
+ rowProjection.project(row)
+ writer.insert(rowProjection)
- case other =>
- throw new SparkException(s"Unexpected operation ID: $other")
+ case other =>
+ throw new SparkException(s"Unexpected operation ID: $other")
+ }
}
}
}
@@ -533,27 +548,31 @@ case class DeltaWithMetadataWritingSparkTask(
private lazy val rowIdProjection = projections.rowIdProjection
private lazy val metadataProjection = projections.metadataProjection.orNull
- override protected def write(writer: DeltaWriter[InternalRow], row:
InternalRow): Unit = {
- val operation = row.getInt(0)
-
- operation match {
- case DELETE_OPERATION =>
- rowIdProjection.project(row)
- metadataProjection.project(row)
- writer.delete(metadataProjection, rowIdProjection)
-
- case UPDATE_OPERATION =>
- rowProjection.project(row)
- rowIdProjection.project(row)
- metadataProjection.project(row)
- writer.update(metadataProjection, rowIdProjection, rowProjection)
-
- case INSERT_OPERATION =>
- rowProjection.project(row)
- writer.insert(rowProjection)
-
- case other =>
- throw new SparkException(s"Unexpected operation ID: $other")
+ override protected def write(
+ writer: DeltaWriter[InternalRow], iter:
java.util.Iterator[InternalRow]): Unit = {
+ while (iter.hasNext) {
+ val row = iter.next()
+ val operation = row.getInt(0)
+
+ operation match {
+ case DELETE_OPERATION =>
+ rowIdProjection.project(row)
+ metadataProjection.project(row)
+ writer.delete(metadataProjection, rowIdProjection)
+
+ case UPDATE_OPERATION =>
+ rowProjection.project(row)
+ rowIdProjection.project(row)
+ metadataProjection.project(row)
+ writer.update(metadataProjection, rowIdProjection, rowProjection)
+
+ case INSERT_OPERATION =>
+ rowProjection.project(row)
+ writer.insert(rowProjection)
+
+ case other =>
+ throw new SparkException(s"Unexpected operation ID: $other")
+ }
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]