rahil-c commented on code in PR #18403:
URL: https://github.com/apache/hudi/pull/18403#discussion_r3031682791


##########
hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/execution/datasources/lance/SparkLanceReaderBase.scala:
##########
@@ -111,65 +112,285 @@ class SparkLanceReaderBase(enableVectorizedReader: 
Boolean) extends SparkColumna
         // Read data with column projection (filters not supported yet)
         val arrowReader = lanceReader.readAll(columnNames, null, 
DEFAULT_BATCH_SIZE)
 
-        // Create iterator using shared LanceRecordIterator
-        lanceIterator = new LanceRecordIterator(
-          allocator,
-          lanceReader,
-          arrowReader,
-          requestSchema,
-          filePath
-        )
-
-        // Register cleanup listener
-        Option(TaskContext.get()).foreach { ctx =>
-          ctx.addTaskCompletionListener[Unit](_ => lanceIterator.close())
-        }
-
-        // Create the following projections for schema evolution:
-        // 1. Padding projection: add NULL for missing columns
-        // 2. Casting projection: handle type conversions
-        val schemaUtils = sparkAdapter.getSchemaUtils
-        val paddingProj = 
SparkSchemaTransformUtils.generateNullPaddingProjection(requestSchema, 
requiredSchema)
-        val castProj = SparkSchemaTransformUtils.generateUnsafeProjection(
-          schemaUtils.toAttributes(requiredSchema),
-          Some(SQLConf.get.sessionLocalTimeZone),
-          implicitTypeChangeInfo,
-          requiredSchema,
-          new StructType(),
-          schemaUtils
-        )
-
-        // Unify projections by applying padding and then casting for each row
-        val projection: UnsafeProjection = new UnsafeProjection {
-          def apply(row: InternalRow): UnsafeRow =
-            castProj(paddingProj(row))
-        }
-        val projectedIter = lanceIterator.asScala.map(projection.apply)
-
-        // Handle partition columns
-        if (partitionSchema.length == 0) {
-          // No partition columns - return rows directly
-          projectedIter
+        // Decide between batch mode and row mode.
+        // Fall back to row mode if type casting is needed (batch-level type 
casting deferred to follow-up).
+        val hasTypeChanges = !implicitTypeChangeInfo.isEmpty
+        if (enableVectorizedReader && !hasTypeChanges) {
+          readBatch(file, allocator, lanceReader, arrowReader, filePath,
+            requestSchema, requiredSchema, partitionSchema)
         } else {
-          // Create UnsafeProjection to convert JoinedRow to UnsafeRow
-          val fullSchema = (requiredSchema.fields ++ 
partitionSchema.fields).map(f =>
-            AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())
-          val unsafeProjection = GenerateUnsafeProjection.generate(fullSchema, 
fullSchema)
-
-          // Append partition values to each row using JoinedRow, then convert 
to UnsafeRow
-          val joinedRow = new JoinedRow()
-          projectedIter.map(row => unsafeProjection(joinedRow(row, 
file.partitionValues)))
+          readRows(file, allocator, lanceReader, arrowReader, filePath,
+            requestSchema, requiredSchema, partitionSchema, 
implicitTypeChangeInfo)
         }
 
       } catch {
         case e: Exception =>
-          if (lanceIterator != null) {
-            lanceIterator.close()  // Close iterator which handles lifecycle 
for all objects
+          allocator.close()
+          throw new IOException(s"Failed to read Lance file: $filePath", e)
+      }
+    }
+  }
+
+  /**
+   * Columnar batch reading path. Returns Iterator[ColumnarBatch] type-erased 
as Iterator[InternalRow].
+   * Used when enableVectorizedReader=true and no type casting is needed.
+   */
+  private def readBatch(file: PartitionedFile,
+                        allocator: org.apache.arrow.memory.BufferAllocator,
+                        lanceReader: LanceFileReader,
+                        arrowReader: org.apache.arrow.vector.ipc.ArrowReader,
+                        filePath: String,
+                        requestSchema: StructType,
+                        requiredSchema: StructType,
+                        partitionSchema: StructType): Iterator[InternalRow] = {
+
+    val batchIterator = new LanceBatchIterator(allocator, lanceReader, 
arrowReader, filePath)
+
+    // Build column mapping: for each column in requiredSchema, find its index 
in requestSchema (file columns)
+    // Returns -1 if the column is missing from the file (schema evolution: 
column addition)
+    val columnMapping: Array[Int] = requiredSchema.fields.map { field =>
+      requestSchema.fieldNames.indexOf(field.name)
+    }
+
+    // Create Arrow-backed null vectors for columns missing from the file.
+    // Uses LanceArrowColumnVector so that Spark's vectorTypes() contract is 
satisfied
+    // (FileSourceScanExec expects all data columns to be 
LanceArrowColumnVector).
+    val nullAllocator = if (columnMapping.contains(-1)) {
+      HoodieArrowAllocator.newChildAllocator(
+        getClass.getSimpleName + "-null-" + filePath, 
HoodieSparkLanceReader.LANCE_DATA_ALLOCATOR_SIZE)
+    } else null
+
+    val nullColumnVectors: Array[(Int, LanceArrowColumnVector, 
org.apache.arrow.vector.FieldVector)] =
+      if (nullAllocator != null) {
+        columnMapping.zipWithIndex.filter(_._1 < 0).map { case (_, idx) =>
+          val field = LanceArrowUtils.toArrowField(
+            requiredSchema(idx).name, requiredSchema(idx).dataType, 
requiredSchema(idx).nullable, "UTC")
+          val arrowVector = field.createVector(nullAllocator)
+          arrowVector.allocateNew()
+          arrowVector.setValueCount(DEFAULT_BATCH_SIZE)
+          (idx, new LanceArrowColumnVector(arrowVector), arrowVector)
+        }
+      } else {
+        Array.empty
+      }
+
+    // Pre-create partition column vectors (reused across batches, reset per 
batch)
+    val hasPartitionColumns = partitionSchema.length > 0
+    val partitionVectors: Array[WritableColumnVector] = if 
(hasPartitionColumns) {
+      partitionSchema.fields.map(f => new 
OnHeapColumnVector(DEFAULT_BATCH_SIZE, f.dataType))
+    } else {
+      Array.empty
+    }
+
+    // Populate partition vectors with constant values
+    var lastPopulatedNumRows = DEFAULT_BATCH_SIZE
+    if (hasPartitionColumns) {
+      populatePartitionVectors(partitionVectors, partitionSchema, 
file.partitionValues, DEFAULT_BATCH_SIZE)
+    }
+
+    val totalColumns = requiredSchema.length + partitionSchema.length
+
+    // Map each source batch to a batch with the correct column layout.
+    val mappedIterator = new Iterator[ColumnarBatch] with Closeable {
+      override def hasNext: Boolean = batchIterator.hasNext()
+
+      override def next(): ColumnarBatch = {
+        val sourceBatch = batchIterator.next()
+        val numRows = sourceBatch.numRows()
+
+        val vectors = new Array[ColumnVector](totalColumns)
+
+        // Data columns: reorder from source batch or substitute null Arrow 
vector
+        var i = 0
+        while (i < requiredSchema.length) {
+          if (columnMapping(i) >= 0) {
+            vectors(i) = sourceBatch.column(columnMapping(i))
           } else {
-            allocator.close()      // Close allocator directly
+            // Find the pre-created null vector for this index
+            val entry = nullColumnVectors.find(_._1 == i).get
+            // Adjust valueCount if batch size differs from allocated size
+            if (numRows != entry._3.getValueCount) {
+              entry._3.setValueCount(numRows)
+            }
+            vectors(i) = entry._2
           }
-          throw new IOException(s"Failed to read Lance file: $filePath", e)
+          i += 1
+        }
+
+        // Partition columns: constant vectors
+        if (hasPartitionColumns) {
+          if (numRows != lastPopulatedNumRows) {
+            populatePartitionVectors(partitionVectors, partitionSchema, 
file.partitionValues, numRows)
+            lastPopulatedNumRows = numRows
+          }
+          var j = 0
+          while (j < partitionSchema.length) {
+            vectors(requiredSchema.length + j) = partitionVectors(j)
+            j += 1
+          }
+        }
+
+        val result = new ColumnarBatch(vectors)
+        result.setNumRows(numRows)
+        result
+      }
+
+      override def close(): Unit = {
+        // Close null Arrow vectors and their allocator before batchIterator 
(which closes the data allocator)
+        nullColumnVectors.foreach { case (_, columnVector, arrowVector) =>
+          columnVector.close()
+          arrowVector.close()

Review Comment:
   is this arrowVector close needed? Does the above line handle this implicitly?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to