jkolash commented on issue #13297: URL: https://github.com/apache/iceberg/issues/13297#issuecomment-2977406636
So I made the following changes to just spark, without my iceberg changes. and I was able to not OOM. <img width="679" alt="Image" src="https://github.com/user-attachments/assets/1158512c-6157-455e-aebe-2a01fa8589e3" /> ```diff diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 67e77a9786..288e9de16c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.execution.datasources.v2 -import scala.language.existentials +import java.util.concurrent.atomic.AtomicReference +import scala.language.existentials import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging @@ -29,6 +30,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} import org.apache.spark.sql.vectorized.ColumnarBatch + class DataSourceRDDPartition(val index: Int, val inputPartitions: Seq[InputPartition]) extends Partition with Serializable @@ -74,24 +76,33 @@ class DataSourceRDD( val inputPartition = inputPartitions(currentIndex) currentIndex += 1 + val exhaustCallback = new InvokeOnceCallback() + // TODO: SPARK-25083 remove the type erasure hack in data source scan val (iter, reader) = if (columnarReads) { val batchReader = partitionReaderFactory.createColumnarReader(inputPartition) val iter = new MetricsBatchIterator( - new PartitionIterator[ColumnarBatch](batchReader, customMetrics)) + new PartitionIterator[ColumnarBatch](batchReader, customMetrics), exhaustCallback) (iter, batchReader) } else { val rowReader = partitionReaderFactory.createReader(inputPartition) val iter = new MetricsRowIterator( - new PartitionIterator[InternalRow](rowReader, customMetrics)) + new PartitionIterator[InternalRow](rowReader, customMetrics), exhaustCallback) (iter, rowReader) } + + exhaustCallback.setCallback(callback = new Runnable() { + override def run(): Unit = { + // In case of early stopping before consuming the entire iterator, + // we need to do one more metric update at the end of the task. + CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics) + iter.forceUpdateMetrics() + reader.close() + } + }) + context.addTaskCompletionListener[Unit] { _ => - // In case of early stopping before consuming the entire iterator, - // we need to do one more metric update at the end of the task. - CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics) - iter.forceUpdateMetrics() - reader.close() + exhaustCallback.run() } currentIter = Some(iter) hasNext @@ -107,6 +118,21 @@ class DataSourceRDD( } } +private class InvokeOnceCallback extends Runnable { + val originalCallback = new AtomicReference[Runnable](null) + + override def run(): Unit = { + if (originalCallback.get() != null) { + originalCallback.get().run() + originalCallback.set(null); + } + } + + def setCallback(callback: Runnable): Unit = { + originalCallback.set(callback); + } +} + private class PartitionIterator[T]( reader: PartitionReader[T], customMetrics: Map[String, SQLMetric]) extends Iterator[T] { @@ -151,14 +177,16 @@ private class MetricsHandler extends Logging with Serializable { } } -private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I] { +private abstract class MetricsIterator[I]( + iter: Iterator[I], + exhaustionCallback: InvokeOnceCallback) extends Iterator[I] { protected val metricsHandler = new MetricsHandler override def hasNext: Boolean = { if (iter.hasNext) { true } else { - forceUpdateMetrics() + exhaustionCallback.run() false } } @@ -167,7 +195,8 @@ private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I] } private class MetricsRowIterator( - iter: Iterator[InternalRow]) extends MetricsIterator[InternalRow](iter) { + iter: Iterator[InternalRow], + callback: InvokeOnceCallback) extends MetricsIterator[InternalRow](iter, callback) { override def next(): InternalRow = { val item = iter.next metricsHandler.updateMetrics(1) @@ -176,7 +205,8 @@ private class MetricsRowIterator( } private class MetricsBatchIterator( - iter: Iterator[ColumnarBatch]) extends MetricsIterator[ColumnarBatch](iter) { + iter: Iterator[ColumnarBatch], + callback: InvokeOnceCallback) extends MetricsIterator[ColumnarBatch](iter, callback) { override def next(): ColumnarBatch = { val batch: ColumnarBatch = iter.next metricsHandler.updateMetrics(batch.numRows) ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For additional commands, e-mail: issues-h...@iceberg.apache.org