jkolash commented on issue #13297: URL: https://github.com/apache/iceberg/issues/13297#issuecomment-2978695785
On the v2 parquet reader side quest. This total set of spark changes allows the v2 parquet reader to work <img width="1351" alt="Image" src="https://github.com/user-attachments/assets/f4307624-8168-4c2f-b91a-fb948fba5a65" /> I introduce a GarbageCollectableRecordReader that nulls out the delegate once close() has been called. The close() was called correctly by callers but the callback still prevented GC even if close had been called. ```diff diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala index 67e77a9786..288e9de16c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.execution.datasources.v2 -import scala.language.existentials +import java.util.concurrent.atomic.AtomicReference +import scala.language.existentials import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging @@ -29,6 +30,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric} import org.apache.spark.sql.vectorized.ColumnarBatch + class DataSourceRDDPartition(val index: Int, val inputPartitions: Seq[InputPartition]) extends Partition with Serializable @@ -74,24 +76,33 @@ class DataSourceRDD( val inputPartition = inputPartitions(currentIndex) currentIndex += 1 + val exhaustCallback = new InvokeOnceCallback() + // TODO: SPARK-25083 remove the type erasure hack in data source scan val (iter, reader) = if (columnarReads) { val batchReader = partitionReaderFactory.createColumnarReader(inputPartition) val iter = new MetricsBatchIterator( - new PartitionIterator[ColumnarBatch](batchReader, customMetrics)) + new PartitionIterator[ColumnarBatch](batchReader, customMetrics), exhaustCallback) (iter, batchReader) } else { val rowReader = partitionReaderFactory.createReader(inputPartition) val iter = new MetricsRowIterator( - new PartitionIterator[InternalRow](rowReader, customMetrics)) + new PartitionIterator[InternalRow](rowReader, customMetrics), exhaustCallback) (iter, rowReader) } + + exhaustCallback.setCallback(callback = new Runnable() { + override def run(): Unit = { + // In case of early stopping before consuming the entire iterator, + // we need to do one more metric update at the end of the task. + CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics) + iter.forceUpdateMetrics() + reader.close() + } + }) + context.addTaskCompletionListener[Unit] { _ => - // In case of early stopping before consuming the entire iterator, - // we need to do one more metric update at the end of the task. - CustomMetrics.updateMetrics(reader.currentMetricsValues, customMetrics) - iter.forceUpdateMetrics() - reader.close() + exhaustCallback.run() } currentIter = Some(iter) hasNext @@ -107,6 +118,21 @@ class DataSourceRDD( } } +private class InvokeOnceCallback extends Runnable { + val originalCallback = new AtomicReference[Runnable](null) + + override def run(): Unit = { + if (originalCallback.get() != null) { + originalCallback.get().run() + originalCallback.set(null); + } + } + + def setCallback(callback: Runnable): Unit = { + originalCallback.set(callback); + } +} + private class PartitionIterator[T]( reader: PartitionReader[T], customMetrics: Map[String, SQLMetric]) extends Iterator[T] { @@ -151,14 +177,16 @@ private class MetricsHandler extends Logging with Serializable { } } -private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I] { +private abstract class MetricsIterator[I]( + iter: Iterator[I], + exhaustionCallback: InvokeOnceCallback) extends Iterator[I] { protected val metricsHandler = new MetricsHandler override def hasNext: Boolean = { if (iter.hasNext) { true } else { - forceUpdateMetrics() + exhaustionCallback.run() false } } @@ -167,7 +195,8 @@ private abstract class MetricsIterator[I](iter: Iterator[I]) extends Iterator[I] } private class MetricsRowIterator( - iter: Iterator[InternalRow]) extends MetricsIterator[InternalRow](iter) { + iter: Iterator[InternalRow], + callback: InvokeOnceCallback) extends MetricsIterator[InternalRow](iter, callback) { override def next(): InternalRow = { val item = iter.next metricsHandler.updateMetrics(1) @@ -176,7 +205,8 @@ private class MetricsRowIterator( } private class MetricsBatchIterator( - iter: Iterator[ColumnarBatch]) extends MetricsIterator[ColumnarBatch](iter) { + iter: Iterator[ColumnarBatch], + callback: InvokeOnceCallback) extends MetricsIterator[ColumnarBatch](iter, callback) { override def next(): ColumnarBatch = { val batch: ColumnarBatch = iter.next metricsHandler.updateMetrics(batch.numRows) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala index 5951c1d8dd..165fe88bef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.datasources.v2.parquet import java.time.ZoneId - import org.apache.hadoop.mapred.FileSplit import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl @@ -26,7 +25,6 @@ import org.apache.parquet.filter2.predicate.{FilterApi, FilterPredicate} import org.apache.parquet.format.converter.ParquetMetadataConverter.{NO_FILTER, SKIP_ROW_GROUPS} import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetRecordReader} import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata} - import org.apache.spark.TaskContext import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging @@ -36,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.connector.expressions.aggregate.Aggregation import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} import org.apache.spark.sql.execution.WholeStageCodegenExec -import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, DataSourceUtils, PartitionedFile, RecordReaderIterator} +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, DataSourceUtils, PartitionedFile} import org.apache.spark.sql.execution.datasources.parquet._ import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf @@ -45,6 +43,8 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.SerializableConfiguration +import java.util.concurrent.atomic.AtomicReference + /** * A factory used to create Parquet readers. * @@ -158,7 +158,7 @@ case class ParquetPartitionReaderFactory( override def buildColumnarReader(file: PartitionedFile): PartitionReader[ColumnarBatch] = { val fileReader = if (aggregation.isEmpty) { val vectorizedReader = createVectorizedReader(file) - vectorizedReader.enableReturningBatches() + vectorizedReader.delegate.asInstanceOf[VectorizedParquetRecordReader].enableReturningBatches() new PartitionReader[ColumnarBatch] { override def next(): Boolean = vectorizedReader.nextKeyValue() @@ -205,7 +205,8 @@ case class ParquetPartitionReaderFactory( InternalRow, Option[FilterPredicate], Option[ZoneId], RebaseSpec, - RebaseSpec) => RecordReader[Void, T]): RecordReader[Void, T] = { + RebaseSpec) => + GarbageCollectableRecordReader[Void, T]): GarbageCollectableRecordReader[Void, T] = { val conf = broadcastedConf.value.value val filePath = file.toPath @@ -279,7 +280,7 @@ case class ParquetPartitionReaderFactory( pushed: Option[FilterPredicate], convertTz: Option[ZoneId], datetimeRebaseSpec: RebaseSpec, - int96RebaseSpec: RebaseSpec): RecordReader[Void, InternalRow] = { + int96RebaseSpec: RebaseSpec): GarbageCollectableRecordReader[Void, InternalRow] = { logDebug(s"Falling back to parquet-mr") val taskContext = Option(TaskContext.get()) // ParquetRecordReader returns InternalRow @@ -296,17 +297,57 @@ case class ParquetPartitionReaderFactory( } val readerWithRowIndexes = ParquetRowIndexUtil.addRowIndexToRecordReaderIfNeeded( reader, readDataSchema) - val iter = new RecordReaderIterator(readerWithRowIndexes) + val delegatingRecordReader = + new GarbageCollectableRecordReader[Void, InternalRow](readerWithRowIndexes) // SPARK-23457 Register a task completion listener before `initialization`. - taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) - readerWithRowIndexes + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => delegatingRecordReader.close())) + delegatingRecordReader + } + + private class GarbageCollectableRecordReader[K, V](reader: RecordReader[K, V]) + extends RecordReader[K, V] { + val delegate = new AtomicReference[RecordReader[K, V]](reader) + + override def initialize(inputSplit: InputSplit, + taskAttemptContext: TaskAttemptContext): Unit = { + delegate.get().initialize(inputSplit, taskAttemptContext) + } + + override def nextKeyValue(): Boolean = { + delegate.get().nextKeyValue() + } + + override def getCurrentKey: K = { + delegate.get().getCurrentKey + } + + override def getCurrentValue: V = { + delegate.get().getCurrentValue + } + + override def getProgress: Float = { + if (delegate.get() == null) { + 1.0f + } else { + delegate.get().getProgress + } + } + + override def close(): Unit = { + if (delegate.get() != null) { + delegate.get().close() + delegate.set(null) + } + } } - private def createVectorizedReader(file: PartitionedFile): VectorizedParquetRecordReader = { - val vectorizedReader = buildReaderBase(file, createParquetVectorizedReader) + private def createVectorizedReader(file: PartitionedFile): + GarbageCollectableRecordReader[Void, InternalRow] = { + val gcReader = buildReaderBase(file, createParquetVectorizedReader) + val vectorizedReader = gcReader.delegate.get() .asInstanceOf[VectorizedParquetRecordReader] vectorizedReader.initBatch(partitionSchema, file.partitionValues) - vectorizedReader + gcReader } private def createParquetVectorizedReader( @@ -314,7 +355,7 @@ case class ParquetPartitionReaderFactory( pushed: Option[FilterPredicate], convertTz: Option[ZoneId], datetimeRebaseSpec: RebaseSpec, - int96RebaseSpec: RebaseSpec): VectorizedParquetRecordReader = { + int96RebaseSpec: RebaseSpec): GarbageCollectableRecordReader[Void, InternalRow] = { val taskContext = Option(TaskContext.get()) val vectorizedReader = new VectorizedParquetRecordReader( convertTz.orNull, @@ -323,11 +364,14 @@ case class ParquetPartitionReaderFactory( int96RebaseSpec.mode.toString, int96RebaseSpec.timeZone, enableOffHeapColumnVector && taskContext.isDefined, - capacity) - val iter = new RecordReaderIterator(vectorizedReader) + capacity).asInstanceOf[RecordReader[Void, InternalRow]] + + val delegatingRecordReader = + new GarbageCollectableRecordReader[Void, InternalRow](vectorizedReader) + // SPARK-23457 Register a task completion listener before `initialization`. - taskContext.foreach(_.addTaskCompletionListener[Unit](_ => iter.close())) + taskContext.foreach(_.addTaskCompletionListener[Unit](_ => delegatingRecordReader.close())) logDebug(s"Appending $partitionSchema $partitionValues") - vectorizedReader + delegatingRecordReader } } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For additional commands, e-mail: issues-h...@iceberg.apache.org