jkolash commented on issue #13297:
URL: https://github.com/apache/iceberg/issues/13297#issuecomment-2978695785

   On the v2 parquet reader side quest. This total set of spark changes allows 
the v2 parquet reader to work
   
   <img width="1351" alt="Image" 
src="https://github.com/user-attachments/assets/f4307624-8168-4c2f-b91a-fb948fba5a65";
 />
   
   I introduce a GarbageCollectableRecordReader that nulls out the delegate 
once close() has been called. The close() was called correctly by callers but 
the callback still prevented GC even if close had been called.
   
   ```diff
   diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
   index 67e77a9786..288e9de16c 100644
   --- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
   +++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceRDD.scala
   @@ -17,8 +17,9 @@
    
    package org.apache.spark.sql.execution.datasources.v2
    
   -import scala.language.existentials
   +import java.util.concurrent.atomic.AtomicReference
    
   +import scala.language.existentials
    import org.apache.spark._
    import org.apache.spark.deploy.SparkHadoopUtil
    import org.apache.spark.internal.Logging
   @@ -29,6 +30,7 @@ import org.apache.spark.sql.errors.QueryExecutionErrors
    import org.apache.spark.sql.execution.metric.{CustomMetrics, SQLMetric}
    import org.apache.spark.sql.vectorized.ColumnarBatch
    
   +
    class DataSourceRDDPartition(val index: Int, val inputPartitions: 
Seq[InputPartition])
      extends Partition with Serializable
    
   @@ -74,24 +76,33 @@ class DataSourceRDD(
              val inputPartition = inputPartitions(currentIndex)
              currentIndex += 1
    
   +          val exhaustCallback = new InvokeOnceCallback()
   +
              // TODO: SPARK-25083 remove the type erasure hack in data source 
scan
              val (iter, reader) = if (columnarReads) {
                val batchReader = 
partitionReaderFactory.createColumnarReader(inputPartition)
                val iter = new MetricsBatchIterator(
   -              new PartitionIterator[ColumnarBatch](batchReader, 
customMetrics))
   +              new PartitionIterator[ColumnarBatch](batchReader, 
customMetrics), exhaustCallback)
                (iter, batchReader)
              } else {
                val rowReader = 
partitionReaderFactory.createReader(inputPartition)
                val iter = new MetricsRowIterator(
   -              new PartitionIterator[InternalRow](rowReader, customMetrics))
   +              new PartitionIterator[InternalRow](rowReader, customMetrics), 
exhaustCallback)
                (iter, rowReader)
              }
   +
   +          exhaustCallback.setCallback(callback = new Runnable() {
   +            override def run(): Unit = {
   +              // In case of early stopping before consuming the entire 
iterator,
   +              // we need to do one more metric update at the end of the 
task.
   +              CustomMetrics.updateMetrics(reader.currentMetricsValues, 
customMetrics)
   +              iter.forceUpdateMetrics()
   +              reader.close()
   +            }
   +          })
   +
              context.addTaskCompletionListener[Unit] { _ =>
   -            // In case of early stopping before consuming the entire 
iterator,
   -            // we need to do one more metric update at the end of the task.
   -            CustomMetrics.updateMetrics(reader.currentMetricsValues, 
customMetrics)
   -            iter.forceUpdateMetrics()
   -            reader.close()
   +            exhaustCallback.run()
              }
              currentIter = Some(iter)
              hasNext
   @@ -107,6 +118,21 @@ class DataSourceRDD(
      }
    }
    
   +private class InvokeOnceCallback extends Runnable {
   +  val originalCallback = new AtomicReference[Runnable](null)
   +
   +  override def run(): Unit = {
   +    if (originalCallback.get() != null) {
   +      originalCallback.get().run()
   +      originalCallback.set(null);
   +    }
   +  }
   +
   +  def setCallback(callback: Runnable): Unit = {
   +    originalCallback.set(callback);
   +  }
   +}
   +
    private class PartitionIterator[T](
        reader: PartitionReader[T],
        customMetrics: Map[String, SQLMetric]) extends Iterator[T] {
   @@ -151,14 +177,16 @@ private class MetricsHandler extends Logging with 
Serializable {
      }
    }
    
   -private abstract class MetricsIterator[I](iter: Iterator[I]) extends 
Iterator[I] {
   +private abstract class MetricsIterator[I](
   +    iter: Iterator[I],
   +    exhaustionCallback: InvokeOnceCallback) extends Iterator[I] {
      protected val metricsHandler = new MetricsHandler
    
      override def hasNext: Boolean = {
        if (iter.hasNext) {
          true
        } else {
   -      forceUpdateMetrics()
   +      exhaustionCallback.run()
          false
        }
      }
   @@ -167,7 +195,8 @@ private abstract class MetricsIterator[I](iter: 
Iterator[I]) extends Iterator[I]
    }
    
    private class MetricsRowIterator(
   -    iter: Iterator[InternalRow]) extends MetricsIterator[InternalRow](iter) 
{
   +    iter: Iterator[InternalRow],
   +    callback: InvokeOnceCallback) extends 
MetricsIterator[InternalRow](iter, callback) {
      override def next(): InternalRow = {
        val item = iter.next
        metricsHandler.updateMetrics(1)
   @@ -176,7 +205,8 @@ private class MetricsRowIterator(
    }
    
    private class MetricsBatchIterator(
   -    iter: Iterator[ColumnarBatch]) extends 
MetricsIterator[ColumnarBatch](iter) {
   +    iter: Iterator[ColumnarBatch],
   +    callback: InvokeOnceCallback) extends 
MetricsIterator[ColumnarBatch](iter, callback) {
      override def next(): ColumnarBatch = {
        val batch: ColumnarBatch = iter.next
        metricsHandler.updateMetrics(batch.numRows)
   diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
   index 5951c1d8dd..165fe88bef 100644
   --- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
   +++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/parquet/ParquetPartitionReaderFactory.scala
   @@ -17,7 +17,6 @@
    package org.apache.spark.sql.execution.datasources.v2.parquet
    
    import java.time.ZoneId
   -
    import org.apache.hadoop.mapred.FileSplit
    import org.apache.hadoop.mapreduce._
    import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl
   @@ -26,7 +25,6 @@ import org.apache.parquet.filter2.predicate.{FilterApi, 
FilterPredicate}
    import 
org.apache.parquet.format.converter.ParquetMetadataConverter.{NO_FILTER, 
SKIP_ROW_GROUPS}
    import org.apache.parquet.hadoop.{ParquetInputFormat, ParquetRecordReader}
    import org.apache.parquet.hadoop.metadata.{FileMetaData, ParquetMetadata}
   -
    import org.apache.spark.TaskContext
    import org.apache.spark.broadcast.Broadcast
    import org.apache.spark.internal.Logging
   @@ -36,7 +34,7 @@ import 
org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec
    import org.apache.spark.sql.connector.expressions.aggregate.Aggregation
    import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader}
    import org.apache.spark.sql.execution.WholeStageCodegenExec
   -import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, 
DataSourceUtils, PartitionedFile, RecordReaderIterator}
   +import org.apache.spark.sql.execution.datasources.{AggregatePushDownUtils, 
DataSourceUtils, PartitionedFile}
    import org.apache.spark.sql.execution.datasources.parquet._
    import org.apache.spark.sql.execution.datasources.v2._
    import org.apache.spark.sql.internal.SQLConf
   @@ -45,6 +43,8 @@ import org.apache.spark.sql.types.StructType
    import org.apache.spark.sql.vectorized.ColumnarBatch
    import org.apache.spark.util.SerializableConfiguration
    
   +import java.util.concurrent.atomic.AtomicReference
   +
    /**
     * A factory used to create Parquet readers.
     *
   @@ -158,7 +158,7 @@ case class ParquetPartitionReaderFactory(
      override def buildColumnarReader(file: PartitionedFile): 
PartitionReader[ColumnarBatch] = {
        val fileReader = if (aggregation.isEmpty) {
          val vectorizedReader = createVectorizedReader(file)
   -      vectorizedReader.enableReturningBatches()
   +      
vectorizedReader.delegate.asInstanceOf[VectorizedParquetRecordReader].enableReturningBatches()
    
          new PartitionReader[ColumnarBatch] {
            override def next(): Boolean = vectorizedReader.nextKeyValue()
   @@ -205,7 +205,8 @@ case class ParquetPartitionReaderFactory(
            InternalRow,
              Option[FilterPredicate], Option[ZoneId],
              RebaseSpec,
   -          RebaseSpec) => RecordReader[Void, T]): RecordReader[Void, T] = {
   +          RebaseSpec) =>
   +        GarbageCollectableRecordReader[Void, T]): 
GarbageCollectableRecordReader[Void, T] = {
        val conf = broadcastedConf.value.value
    
        val filePath = file.toPath
   @@ -279,7 +280,7 @@ case class ParquetPartitionReaderFactory(
          pushed: Option[FilterPredicate],
          convertTz: Option[ZoneId],
          datetimeRebaseSpec: RebaseSpec,
   -      int96RebaseSpec: RebaseSpec): RecordReader[Void, InternalRow] = {
   +      int96RebaseSpec: RebaseSpec): GarbageCollectableRecordReader[Void, 
InternalRow] = {
        logDebug(s"Falling back to parquet-mr")
        val taskContext = Option(TaskContext.get())
        // ParquetRecordReader returns InternalRow
   @@ -296,17 +297,57 @@ case class ParquetPartitionReaderFactory(
        }
        val readerWithRowIndexes = 
ParquetRowIndexUtil.addRowIndexToRecordReaderIfNeeded(
          reader, readDataSchema)
   -    val iter = new RecordReaderIterator(readerWithRowIndexes)
   +    val delegatingRecordReader =
   +      new GarbageCollectableRecordReader[Void, 
InternalRow](readerWithRowIndexes)
        // SPARK-23457 Register a task completion listener before 
`initialization`.
   -    taskContext.foreach(_.addTaskCompletionListener[Unit](_ => 
iter.close()))
   -    readerWithRowIndexes
   +    taskContext.foreach(_.addTaskCompletionListener[Unit](_ => 
delegatingRecordReader.close()))
   +    delegatingRecordReader
   +  }
   +
   +  private class GarbageCollectableRecordReader[K, V](reader: 
RecordReader[K, V])
   +    extends RecordReader[K, V] {
   +    val delegate = new AtomicReference[RecordReader[K, V]](reader)
   +
   +    override def initialize(inputSplit: InputSplit,
   +                            taskAttemptContext: TaskAttemptContext): Unit = 
{
   +      delegate.get().initialize(inputSplit, taskAttemptContext)
   +    }
   +
   +    override def nextKeyValue(): Boolean = {
   +      delegate.get().nextKeyValue()
   +    }
   +
   +    override def getCurrentKey: K = {
   +      delegate.get().getCurrentKey
   +    }
   +
   +    override def getCurrentValue: V = {
   +      delegate.get().getCurrentValue
   +    }
   +
   +    override def getProgress: Float = {
   +      if (delegate.get() == null) {
   +        1.0f
   +      } else {
   +        delegate.get().getProgress
   +      }
   +    }
   +
   +    override def close(): Unit = {
   +      if (delegate.get() != null) {
   +        delegate.get().close()
   +        delegate.set(null)
   +      }
   +    }
      }
    
   -  private def createVectorizedReader(file: PartitionedFile): 
VectorizedParquetRecordReader = {
   -    val vectorizedReader = buildReaderBase(file, 
createParquetVectorizedReader)
   +  private def createVectorizedReader(file: PartitionedFile):
   +    GarbageCollectableRecordReader[Void, InternalRow] = {
   +    val gcReader = buildReaderBase(file, createParquetVectorizedReader)
   +    val vectorizedReader = gcReader.delegate.get()
          .asInstanceOf[VectorizedParquetRecordReader]
        vectorizedReader.initBatch(partitionSchema, file.partitionValues)
   -    vectorizedReader
   +    gcReader
      }
    
      private def createParquetVectorizedReader(
   @@ -314,7 +355,7 @@ case class ParquetPartitionReaderFactory(
          pushed: Option[FilterPredicate],
          convertTz: Option[ZoneId],
          datetimeRebaseSpec: RebaseSpec,
   -      int96RebaseSpec: RebaseSpec): VectorizedParquetRecordReader = {
   +      int96RebaseSpec: RebaseSpec): GarbageCollectableRecordReader[Void, 
InternalRow] = {
        val taskContext = Option(TaskContext.get())
        val vectorizedReader = new VectorizedParquetRecordReader(
          convertTz.orNull,
   @@ -323,11 +364,14 @@ case class ParquetPartitionReaderFactory(
          int96RebaseSpec.mode.toString,
          int96RebaseSpec.timeZone,
          enableOffHeapColumnVector && taskContext.isDefined,
   -      capacity)
   -    val iter = new RecordReaderIterator(vectorizedReader)
   +      capacity).asInstanceOf[RecordReader[Void, InternalRow]]
   +
   +    val delegatingRecordReader =
   +      new GarbageCollectableRecordReader[Void, 
InternalRow](vectorizedReader)
   +
        // SPARK-23457 Register a task completion listener before 
`initialization`.
   -    taskContext.foreach(_.addTaskCompletionListener[Unit](_ => 
iter.close()))
   +    taskContext.foreach(_.addTaskCompletionListener[Unit](_ => 
delegatingRecordReader.close()))
        logDebug(s"Appending $partitionSchema $partitionValues")
   -    vectorizedReader
   +    delegatingRecordReader
      }
    }
   
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org
For additional commands, e-mail: issues-h...@iceberg.apache.org

Reply via email to