ankitsultana commented on code in PR #13748: URL: https://github.com/apache/pinot/pull/13748#discussion_r1707302876
########## pinot-connectors/pinot-spark-3-connector/src/main/scala/org/apache/pinot/connector/spark/v3/datasource/PinotDataWriter.scala: ########## @@ -0,0 +1,246 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.connector.spark.v3.datasource + +import org.apache.commons.io.FileUtils +import org.apache.pinot.common.utils.TarGzCompressionUtils +import org.apache.pinot.connector.spark.common.PinotDataSourceWriteOptions +import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} +import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl +import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig +import org.apache.pinot.spi.config.table.{IndexingConfig, SegmentsValidationAndRetentionConfig, TableConfig, TableCustomConfig, TenantConfig} +import org.apache.pinot.spi.data.readers.GenericRow +import org.apache.pinot.spi.data.Schema +import org.apache.pinot.spi.ingestion.batch.spec.Constants +import org.apache.pinot.spi.utils.DataSizeUtils +import org.apache.spark.sql.catalyst +import org.apache.spark.sql.types.StructType +import org.slf4j.{Logger, LoggerFactory} + +import java.io.File +import java.nio.file.Files +import java.util.regex.Pattern + +class PinotDataWriter[InternalRow]( + partitionId: Int, + taskId: Long, + writeOptions: PinotDataSourceWriteOptions, + writeSchema: StructType, + pinotSchema: Schema) + extends DataWriter[org.apache.spark.sql.catalyst.InternalRow] with AutoCloseable { + private val logger: Logger = LoggerFactory.getLogger(classOf[PinotDataWriter[InternalRow]]) + logger.info("PinotDataWriter created with writeOptions: {}, partitionId: {}, taskId: {}", + (writeOptions, partitionId, taskId)) + + val tableName: String = writeOptions.tableName Review Comment: nit: Haven't written a lot of scala code, but why are these 3 not private? ########## pinot-connectors/pinot-spark-3-connector/src/main/scala/org/apache/pinot/connector/spark/v3/datasource/PinotDataWriter.scala: ########## @@ -0,0 +1,246 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.connector.spark.v3.datasource + +import org.apache.commons.io.FileUtils +import org.apache.pinot.common.utils.TarGzCompressionUtils +import org.apache.pinot.connector.spark.common.PinotDataSourceWriteOptions +import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} +import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl +import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig +import org.apache.pinot.spi.config.table.{IndexingConfig, SegmentsValidationAndRetentionConfig, TableConfig, TableCustomConfig, TenantConfig} +import org.apache.pinot.spi.data.readers.GenericRow +import org.apache.pinot.spi.data.Schema +import org.apache.pinot.spi.ingestion.batch.spec.Constants +import org.apache.pinot.spi.utils.DataSizeUtils +import org.apache.spark.sql.catalyst +import org.apache.spark.sql.types.StructType +import org.slf4j.{Logger, LoggerFactory} + +import java.io.File +import java.nio.file.Files +import java.util.regex.Pattern + +class PinotDataWriter[InternalRow]( + partitionId: Int, + taskId: Long, + writeOptions: PinotDataSourceWriteOptions, + writeSchema: StructType, + pinotSchema: Schema) + extends DataWriter[org.apache.spark.sql.catalyst.InternalRow] with AutoCloseable { + private val logger: Logger = LoggerFactory.getLogger(classOf[PinotDataWriter[InternalRow]]) + logger.info("PinotDataWriter created with writeOptions: {}, partitionId: {}, taskId: {}", + (writeOptions, partitionId, taskId)) + + val tableName: String = writeOptions.tableName + val savePath: String = writeOptions.savePath + val bufferedRecordReader: PinotBufferedRecordReader = new PinotBufferedRecordReader() + + override def write(record: catalyst.InternalRow): Unit = { + bufferedRecordReader.write(internalRowToGenericRow(record)) + } + + override def commit(): WriterCommitMessage = { + val segmentName = getSegmentName + val segmentDir = generateSegment(segmentName) + val segmentTarFile = tarSegmentDir(segmentName, segmentDir) + pushSegmentTarFile(segmentTarFile) + new SuccessWriterCommitMessage(segmentName) + } + + // This method is used to generate the segment name based on the format provided in the write options Review Comment: nit: doc comments https://docs.scala-lang.org/style/scaladoc.html ########## pinot-connectors/pinot-spark-common/src/main/scala/org/apache/pinot/connector/spark/common/PinotDataSourceWriteOptions.scala: ########## @@ -0,0 +1,80 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.connector.spark.common + +import java.util + +object PinotDataSourceWriteOptions { + val CONFIG_TABLE_NAME = "table" + val CONFIG_SEGMENT_NAME_FORMAT = "segmentNameFormat" + val CONFIG_PATH = "path" + val CONFIG_INVERTED_INDEX_COLUMNS = "invertedIndexColumns" + val CONFIG_NO_DICTIONARY_COLUMNS = "noDictionaryColumns" + val CONFIG_BLOOM_FILTER_COLUMNS = "bloomFilterColumns" + val CONFIG_RANGE_INDEX_COLUMNS = "rangeIndexColumns" + val CONFIG_TIME_COLUMN_NAME = "timeColumnName" + + private[pinot] def from(options: util.Map[String, String]): PinotDataSourceWriteOptions = { Review Comment: We should also support an option to disable tar-ball creation. With that, one would be able to write segments in the v3 format to whatever location they specify. Can we leave a TODO for it for now? ########## pinot-connectors/pinot-spark-3-connector/src/main/scala/org/apache/pinot/connector/spark/v3/datasource/PinotDataWriter.scala: ########## @@ -0,0 +1,246 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.connector.spark.v3.datasource + +import org.apache.commons.io.FileUtils +import org.apache.pinot.common.utils.TarGzCompressionUtils +import org.apache.pinot.connector.spark.common.PinotDataSourceWriteOptions +import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} +import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl +import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig +import org.apache.pinot.spi.config.table.{IndexingConfig, SegmentsValidationAndRetentionConfig, TableConfig, TableCustomConfig, TenantConfig} +import org.apache.pinot.spi.data.readers.GenericRow +import org.apache.pinot.spi.data.Schema +import org.apache.pinot.spi.ingestion.batch.spec.Constants +import org.apache.pinot.spi.utils.DataSizeUtils +import org.apache.spark.sql.catalyst +import org.apache.spark.sql.types.StructType +import org.slf4j.{Logger, LoggerFactory} + +import java.io.File +import java.nio.file.Files +import java.util.regex.Pattern + +class PinotDataWriter[InternalRow]( + partitionId: Int, + taskId: Long, + writeOptions: PinotDataSourceWriteOptions, + writeSchema: StructType, + pinotSchema: Schema) + extends DataWriter[org.apache.spark.sql.catalyst.InternalRow] with AutoCloseable { + private val logger: Logger = LoggerFactory.getLogger(classOf[PinotDataWriter[InternalRow]]) + logger.info("PinotDataWriter created with writeOptions: {}, partitionId: {}, taskId: {}", + (writeOptions, partitionId, taskId)) + + val tableName: String = writeOptions.tableName + val savePath: String = writeOptions.savePath + val bufferedRecordReader: PinotBufferedRecordReader = new PinotBufferedRecordReader() + + override def write(record: catalyst.InternalRow): Unit = { + bufferedRecordReader.write(internalRowToGenericRow(record)) Review Comment: For use-cases where users may want to support segment names that depend on the values, we can add some sort of an interface which looks like the following. That would allow users to have separate segments based on any given dimension. ``` interface SegmentNameGenerator { void init(partitionNum, ... other context metadata); void consume(Record record); string generate(); } ``` @cbalci : thoughts? ########## pinot-connectors/pinot-spark-3-connector/src/main/scala/org/apache/pinot/connector/spark/v3/datasource/PinotDataWriter.scala: ########## @@ -0,0 +1,246 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.connector.spark.v3.datasource + +import org.apache.commons.io.FileUtils +import org.apache.pinot.common.utils.TarGzCompressionUtils +import org.apache.pinot.connector.spark.common.PinotDataSourceWriteOptions +import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} +import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl +import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig +import org.apache.pinot.spi.config.table.{IndexingConfig, SegmentsValidationAndRetentionConfig, TableConfig, TableCustomConfig, TenantConfig} +import org.apache.pinot.spi.data.readers.GenericRow +import org.apache.pinot.spi.data.Schema +import org.apache.pinot.spi.ingestion.batch.spec.Constants +import org.apache.pinot.spi.utils.DataSizeUtils +import org.apache.spark.sql.catalyst +import org.apache.spark.sql.types.StructType +import org.slf4j.{Logger, LoggerFactory} + +import java.io.File +import java.nio.file.Files +import java.util.regex.Pattern + +class PinotDataWriter[InternalRow]( + partitionId: Int, + taskId: Long, + writeOptions: PinotDataSourceWriteOptions, + writeSchema: StructType, + pinotSchema: Schema) + extends DataWriter[org.apache.spark.sql.catalyst.InternalRow] with AutoCloseable { + private val logger: Logger = LoggerFactory.getLogger(classOf[PinotDataWriter[InternalRow]]) + logger.info("PinotDataWriter created with writeOptions: {}, partitionId: {}, taskId: {}", + (writeOptions, partitionId, taskId)) + + val tableName: String = writeOptions.tableName + val savePath: String = writeOptions.savePath + val bufferedRecordReader: PinotBufferedRecordReader = new PinotBufferedRecordReader() + + override def write(record: catalyst.InternalRow): Unit = { + bufferedRecordReader.write(internalRowToGenericRow(record)) + } + + override def commit(): WriterCommitMessage = { + val segmentName = getSegmentName + val segmentDir = generateSegment(segmentName) + val segmentTarFile = tarSegmentDir(segmentName, segmentDir) + pushSegmentTarFile(segmentTarFile) + new SuccessWriterCommitMessage(segmentName) + } + + // This method is used to generate the segment name based on the format provided in the write options + // The format can contain variables like {partitionId} + // Currently supported variables are `partitionId`, `table` + // It also supports the following, python inspired format specifier for digit formatting: + // `{partitionId:05}` + // which will zero pad partitionId up to five characters. + // + // Some examples: + // "{partitionId}_{table}" -> "12_airlineStats" + // "{partitionId:05}_{table}" -> "00012_airlineStats" + // "{table}_{partitionId}" -> "airlineStats_12" + // "{table}_20240805" -> "airlineStats_20240805" + private[pinot] def getSegmentName: String = { + val format = writeOptions.segmentNameFormat + val variables = Map( + "partitionId" -> partitionId, + "table" -> tableName, + ) + + val pattern = Pattern.compile("\\{(\\w+)(?::(\\d+))?}") + val matcher = pattern.matcher(format) + + val buffer = new StringBuffer() + while (matcher.find()) { + val variableName = matcher.group(1) + val formatSpecifier = matcher.group(2) + val value = variables(variableName) + + val formattedValue = formatSpecifier match { + case null => value.toString + case spec => String.format(s"%${spec}d", value.asInstanceOf[Number]) + } + + matcher.appendReplacement(buffer, formattedValue) + } + matcher.appendTail(buffer) + + buffer.toString + } + + private[pinot] def generateSegment(segmentName: String): File = { + val outputDir = Files.createTempDirectory(classOf[PinotDataWriter[InternalRow]].getName).toFile + val indexingConfig = getIndexConfig + val segmentGeneratorConfig = getSegmentGenerationConfig(segmentName, indexingConfig, outputDir) + + logger.info("Creating segment with indexConfig: {} and segmentGeneratorConfig config: {}", + (indexingConfig, segmentGeneratorConfig)) + + // create segment and return output directory + val driver = new SegmentIndexCreationDriverImpl() + driver.init(segmentGeneratorConfig, bufferedRecordReader) + driver.build() + outputDir + } + + private def getIndexConfig: IndexingConfig = { + val indexingConfig = new IndexingConfig + indexingConfig.setInvertedIndexColumns(java.util.Arrays.asList(writeOptions.invertedIndexColumns:_*)) + indexingConfig.setNoDictionaryColumns(java.util.Arrays.asList(writeOptions.noDictionaryColumns:_*)) + indexingConfig.setBloomFilterColumns(java.util.Arrays.asList(writeOptions.bloomFilterColumns:_*)) + indexingConfig.setRangeIndexColumns(java.util.Arrays.asList(writeOptions.rangeIndexColumns:_*)) + indexingConfig + } + + private def getSegmentGenerationConfig(segmentName: String, + indexingConfig: IndexingConfig, + outputDir: File, + ): SegmentGeneratorConfig = { + // Mostly dummy tableConfig, sufficient for segment generation purposes + val tableConfig = new TableConfig( + tableName, + "OFFLINE", + new SegmentsValidationAndRetentionConfig(), + new TenantConfig(null, null, null), + indexingConfig, + new TableCustomConfig(null), + null, null, null, null, null, null, null, + null, null, null, null, false, null, null, + null) + + val segmentGeneratorConfig = new SegmentGeneratorConfig(tableConfig, pinotSchema) + segmentGeneratorConfig.setTableName(tableName) + segmentGeneratorConfig.setSegmentName(segmentName) + segmentGeneratorConfig.setOutDir(outputDir.getAbsolutePath) + segmentGeneratorConfig + } + + private def pushSegmentTarFile(segmentTarFile: File): Unit = { + // TODO Support file systems other than local and HDFS + val fs = org.apache.hadoop.fs.FileSystem.get(new java.net.URI(savePath), new org.apache.hadoop.conf.Configuration()) + val destPath = new org.apache.hadoop.fs.Path(savePath + "/" + segmentTarFile.getName) + fs.copyFromLocalFile(new org.apache.hadoop.fs.Path(segmentTarFile.getAbsolutePath), destPath) + + logger.info("Pushed segment tar file {} to: {}", (segmentTarFile.getName, destPath)) + } + + private def internalRowToGenericRow(record: catalyst.InternalRow): GenericRow = { + val gr = new GenericRow() + + writeSchema.fields.zipWithIndex foreach { case(field, idx) => + field.dataType match { + case org.apache.spark.sql.types.StringType => + gr.putValue(field.name, record.getString(idx)) + case org.apache.spark.sql.types.IntegerType => + gr.putValue(field.name, record.getInt(idx)) + case org.apache.spark.sql.types.LongType => + gr.putValue(field.name, record.getLong(idx)) + case org.apache.spark.sql.types.FloatType => + gr.putValue(field.name, record.getFloat(idx)) + case org.apache.spark.sql.types.DoubleType => + gr.putValue(field.name, record.getDouble(idx)) + case org.apache.spark.sql.types.BooleanType => + gr.putValue(field.name, record.getBoolean(idx)) + case org.apache.spark.sql.types.ByteType => + gr.putValue(field.name, record.getByte(idx)) + case org.apache.spark.sql.types.ShortType => + gr.putValue(field.name, record.getShort(idx)) + case org.apache.spark.sql.types.ArrayType(elementType, _) => + elementType match { + case org.apache.spark.sql.types.StringType => + gr.putValue(field.name, record.getArray(idx).array.map(_.asInstanceOf[String])) + case org.apache.spark.sql.types.IntegerType => + gr.putValue(field.name, record.getArray(idx).array.map(_.asInstanceOf[Int])) + case org.apache.spark.sql.types.LongType => + gr.putValue(field.name, record.getArray(idx).array.map(_.asInstanceOf[Long])) + case org.apache.spark.sql.types.FloatType => + gr.putValue(field.name, record.getArray(idx).array.map(_.asInstanceOf[Float])) + case org.apache.spark.sql.types.DoubleType => + gr.putValue(field.name, record.getArray(idx).array.map(_.asInstanceOf[Double])) + case org.apache.spark.sql.types.BooleanType => + gr.putValue(field.name, record.getArray(idx).array.map(_.asInstanceOf[Boolean])) + case org.apache.spark.sql.types.ByteType => + gr.putValue(field.name, record.getArray(idx).array.map(_.asInstanceOf[Byte])) + case org.apache.spark.sql.types.ShortType => + gr.putValue(field.name, record.getArray(idx).array.map(_.asInstanceOf[Short])) + case _ => + throw new UnsupportedOperationException("Unsupported data type") + } + case _ => + throw new UnsupportedOperationException("Unsupported data type") Review Comment: nit: improve error message to say "Unsupported type: " + elementType ########## pinot-connectors/pinot-spark-3-connector/src/test/scala/org/apache/pinot/connector/spark/v3/datasource/PinotDataWriterTest.scala: ########## @@ -0,0 +1,173 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.connector.spark.v3.datasource + +import org.apache.pinot.connector.spark.common.PinotDataSourceWriteOptions +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.connector.write.WriterCommitMessage +import org.scalatest.matchers.should.Matchers +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.pinot.common.utils.TarGzCompressionUtils +import org.apache.pinot.spi.data.readers.GenericRow +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.scalatest.BeforeAndAfter +import org.scalatest.funsuite.AnyFunSuite + +import java.io.File +import java.net.URI +import java.nio.file.{Files, Paths} +import scala.io.Source + +class PinotDataWriterTest extends AnyFunSuite with Matchers with BeforeAndAfter { + + var tmpDir: File = _ + + before { + tmpDir = Files.createTempDirectory("pinot-spark-connector-write-test").toFile + } + + after { + if (tmpDir.exists()) { + tmpDir.listFiles().foreach(_.delete()) + tmpDir.delete() + } + } + + test("Initialize buffer and accept records") { + val writeOptions = PinotDataSourceWriteOptions( + tableName = "testTable", + savePath = "/tmp/pinot", + timeColumnName = "ts", + segmentNameFormat = "{table}_{partitionId:03}", + invertedIndexColumns = Array("name"), + noDictionaryColumns = Array("age"), + bloomFilterColumns = Array("name"), + rangeIndexColumns = Array() + ) + val writeSchema = StructType(Seq( + StructField("name", StringType, nullable = false), + StructField("age", IntegerType, nullable = false) + )) + + val pinotSchema = SparkToPinotTypeTranslator.translate(writeSchema, writeOptions.tableName) + val writer = new PinotDataWriter[InternalRow](0, 0, writeOptions, writeSchema, pinotSchema) + + val record1 = new TestInternalRow(Array[Any]("Alice", 30)) + val record2 = new TestInternalRow(Array[Any]("Bob", 25)) + + writer.write(record1) + writer.write(record2) + + val writeBuffer = writer.bufferedRecordReader + writer.bufferedRecordReader.hasNext shouldBe true + writeBuffer.next() shouldBe a[GenericRow] + writeBuffer.next() shouldBe a[GenericRow] + + writer.close() + writeBuffer.hasNext shouldBe false + } + + test("Should create segment file on commit") { + // create tmp directory with test name + tmpDir = Files.createTempDirectory("pinot-spark-connector-test").toFile + + val writeOptions = PinotDataSourceWriteOptions( + tableName = "testTable", + savePath = tmpDir.getAbsolutePath, + timeColumnName = "ts", + segmentNameFormat = "{table}_{partitionId:03}", + invertedIndexColumns = Array("name"), + noDictionaryColumns = Array("age"), + bloomFilterColumns = Array("name"), + rangeIndexColumns = Array() + ) + val writeSchema = StructType(Seq( + StructField("name", StringType, nullable = false), + StructField("age", IntegerType, nullable = false) + )) + val pinotSchema = SparkToPinotTypeTranslator.translate(writeSchema, writeOptions.tableName) + val writer = new PinotDataWriter[InternalRow](0, 0, writeOptions, writeSchema, pinotSchema) + val record1 = new TestInternalRow(Array[Any]("Alice", 30)) + writer.write(record1) + + val commitMessage: WriterCommitMessage = writer.commit() + commitMessage shouldBe a[SuccessWriterCommitMessage] + + // Verify that the segment is created and stored in the target location + val fs = FileSystem.get(new URI(writeOptions.savePath), new org.apache.hadoop.conf.Configuration()) + val segmentPath = new Path(writeOptions.savePath + "/testTable_000.tar.gz") + fs.exists(segmentPath) shouldBe true + + // Verify the contents of the segment tar file + TarGzCompressionUtils.untar( + new File(writeOptions.savePath + "/testTable_000.tar.gz"), + new File(writeOptions.savePath)) + val untarDir = Paths.get(writeOptions.savePath + "/testTable_000/v3/") + Files.exists(untarDir) shouldBe true + + val segmentFiles = Files.list(untarDir).toArray.map(_.toString) + segmentFiles should contain (untarDir + "/creation.meta") + segmentFiles should contain (untarDir + "/index_map") + segmentFiles should contain (untarDir + "/metadata.properties") + segmentFiles should contain (untarDir + "/columns.psf") + + // Verify basic metadata content + val metadataSrc = Source.fromFile(untarDir + "/metadata.properties") + val metadataContent = metadataSrc.getLines.mkString("\n") + metadataSrc.close() + metadataContent should include ("segment.name = testTable_000") + } + + test("getSegmentName should format segment name correctly with custom format") { + val testCases = Seq( Review Comment: can you also add some invalid cases here? (e.g. `""`, `"{table"`). ########## pinot-connectors/pinot-spark-3-connector/src/main/scala/org/apache/pinot/connector/spark/v3/datasource/PinotDataWriter.scala: ########## @@ -0,0 +1,246 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.connector.spark.v3.datasource + +import org.apache.commons.io.FileUtils +import org.apache.pinot.common.utils.TarGzCompressionUtils +import org.apache.pinot.connector.spark.common.PinotDataSourceWriteOptions +import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} +import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl +import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig +import org.apache.pinot.spi.config.table.{IndexingConfig, SegmentsValidationAndRetentionConfig, TableConfig, TableCustomConfig, TenantConfig} +import org.apache.pinot.spi.data.readers.GenericRow +import org.apache.pinot.spi.data.Schema +import org.apache.pinot.spi.ingestion.batch.spec.Constants +import org.apache.pinot.spi.utils.DataSizeUtils +import org.apache.spark.sql.catalyst +import org.apache.spark.sql.types.StructType +import org.slf4j.{Logger, LoggerFactory} + +import java.io.File +import java.nio.file.Files +import java.util.regex.Pattern + +class PinotDataWriter[InternalRow]( Review Comment: Let's add some doc comments here. Specifically, you can call out how this writer can be helpful. ########## pinot-connectors/pinot-spark-3-connector/src/main/scala/org/apache/pinot/connector/spark/v3/datasource/PinotDataWriter.scala: ########## @@ -0,0 +1,246 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.connector.spark.v3.datasource + +import org.apache.commons.io.FileUtils +import org.apache.pinot.common.utils.TarGzCompressionUtils +import org.apache.pinot.connector.spark.common.PinotDataSourceWriteOptions +import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} +import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl +import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig +import org.apache.pinot.spi.config.table.{IndexingConfig, SegmentsValidationAndRetentionConfig, TableConfig, TableCustomConfig, TenantConfig} +import org.apache.pinot.spi.data.readers.GenericRow +import org.apache.pinot.spi.data.Schema +import org.apache.pinot.spi.ingestion.batch.spec.Constants +import org.apache.pinot.spi.utils.DataSizeUtils +import org.apache.spark.sql.catalyst +import org.apache.spark.sql.types.StructType +import org.slf4j.{Logger, LoggerFactory} + +import java.io.File +import java.nio.file.Files +import java.util.regex.Pattern + +class PinotDataWriter[InternalRow]( + partitionId: Int, + taskId: Long, + writeOptions: PinotDataSourceWriteOptions, + writeSchema: StructType, + pinotSchema: Schema) + extends DataWriter[org.apache.spark.sql.catalyst.InternalRow] with AutoCloseable { + private val logger: Logger = LoggerFactory.getLogger(classOf[PinotDataWriter[InternalRow]]) + logger.info("PinotDataWriter created with writeOptions: {}, partitionId: {}, taskId: {}", + (writeOptions, partitionId, taskId)) + + val tableName: String = writeOptions.tableName + val savePath: String = writeOptions.savePath + val bufferedRecordReader: PinotBufferedRecordReader = new PinotBufferedRecordReader() + + override def write(record: catalyst.InternalRow): Unit = { + bufferedRecordReader.write(internalRowToGenericRow(record)) + } + + override def commit(): WriterCommitMessage = { + val segmentName = getSegmentName + val segmentDir = generateSegment(segmentName) + val segmentTarFile = tarSegmentDir(segmentName, segmentDir) + pushSegmentTarFile(segmentTarFile) + new SuccessWriterCommitMessage(segmentName) + } + + // This method is used to generate the segment name based on the format provided in the write options + // The format can contain variables like {partitionId} + // Currently supported variables are `partitionId`, `table` + // It also supports the following, python inspired format specifier for digit formatting: + // `{partitionId:05}` + // which will zero pad partitionId up to five characters. + // + // Some examples: + // "{partitionId}_{table}" -> "12_airlineStats" + // "{partitionId:05}_{table}" -> "00012_airlineStats" + // "{table}_{partitionId}" -> "airlineStats_12" + // "{table}_20240805" -> "airlineStats_20240805" + private[pinot] def getSegmentName: String = { + val format = writeOptions.segmentNameFormat + val variables = Map( + "partitionId" -> partitionId, + "table" -> tableName, + ) + + val pattern = Pattern.compile("\\{(\\w+)(?::(\\d+))?}") + val matcher = pattern.matcher(format) + + val buffer = new StringBuffer() + while (matcher.find()) { + val variableName = matcher.group(1) + val formatSpecifier = matcher.group(2) + val value = variables(variableName) + + val formattedValue = formatSpecifier match { + case null => value.toString + case spec => String.format(s"%${spec}d", value.asInstanceOf[Number]) + } + + matcher.appendReplacement(buffer, formattedValue) + } + matcher.appendTail(buffer) + + buffer.toString + } + + private[pinot] def generateSegment(segmentName: String): File = { + val outputDir = Files.createTempDirectory(classOf[PinotDataWriter[InternalRow]].getName).toFile + val indexingConfig = getIndexConfig + val segmentGeneratorConfig = getSegmentGenerationConfig(segmentName, indexingConfig, outputDir) + + logger.info("Creating segment with indexConfig: {} and segmentGeneratorConfig config: {}", + (indexingConfig, segmentGeneratorConfig)) + + // create segment and return output directory + val driver = new SegmentIndexCreationDriverImpl() + driver.init(segmentGeneratorConfig, bufferedRecordReader) + driver.build() + outputDir + } + + private def getIndexConfig: IndexingConfig = { + val indexingConfig = new IndexingConfig + indexingConfig.setInvertedIndexColumns(java.util.Arrays.asList(writeOptions.invertedIndexColumns:_*)) + indexingConfig.setNoDictionaryColumns(java.util.Arrays.asList(writeOptions.noDictionaryColumns:_*)) + indexingConfig.setBloomFilterColumns(java.util.Arrays.asList(writeOptions.bloomFilterColumns:_*)) + indexingConfig.setRangeIndexColumns(java.util.Arrays.asList(writeOptions.rangeIndexColumns:_*)) + indexingConfig + } + + private def getSegmentGenerationConfig(segmentName: String, + indexingConfig: IndexingConfig, + outputDir: File, + ): SegmentGeneratorConfig = { + // Mostly dummy tableConfig, sufficient for segment generation purposes + val tableConfig = new TableConfig( + tableName, + "OFFLINE", + new SegmentsValidationAndRetentionConfig(), + new TenantConfig(null, null, null), + indexingConfig, + new TableCustomConfig(null), + null, null, null, null, null, null, null, + null, null, null, null, false, null, null, + null) + + val segmentGeneratorConfig = new SegmentGeneratorConfig(tableConfig, pinotSchema) + segmentGeneratorConfig.setTableName(tableName) + segmentGeneratorConfig.setSegmentName(segmentName) + segmentGeneratorConfig.setOutDir(outputDir.getAbsolutePath) + segmentGeneratorConfig + } + + private def pushSegmentTarFile(segmentTarFile: File): Unit = { + // TODO Support file systems other than local and HDFS + val fs = org.apache.hadoop.fs.FileSystem.get(new java.net.URI(savePath), new org.apache.hadoop.conf.Configuration()) + val destPath = new org.apache.hadoop.fs.Path(savePath + "/" + segmentTarFile.getName) + fs.copyFromLocalFile(new org.apache.hadoop.fs.Path(segmentTarFile.getAbsolutePath), destPath) + + logger.info("Pushed segment tar file {} to: {}", (segmentTarFile.getName, destPath)) + } + + private def internalRowToGenericRow(record: catalyst.InternalRow): GenericRow = { + val gr = new GenericRow() + + writeSchema.fields.zipWithIndex foreach { case(field, idx) => + field.dataType match { + case org.apache.spark.sql.types.StringType => + gr.putValue(field.name, record.getString(idx)) + case org.apache.spark.sql.types.IntegerType => + gr.putValue(field.name, record.getInt(idx)) + case org.apache.spark.sql.types.LongType => + gr.putValue(field.name, record.getLong(idx)) + case org.apache.spark.sql.types.FloatType => + gr.putValue(field.name, record.getFloat(idx)) + case org.apache.spark.sql.types.DoubleType => + gr.putValue(field.name, record.getDouble(idx)) + case org.apache.spark.sql.types.BooleanType => + gr.putValue(field.name, record.getBoolean(idx)) + case org.apache.spark.sql.types.ByteType => + gr.putValue(field.name, record.getByte(idx)) + case org.apache.spark.sql.types.ShortType => + gr.putValue(field.name, record.getShort(idx)) + case org.apache.spark.sql.types.ArrayType(elementType, _) => + elementType match { + case org.apache.spark.sql.types.StringType => + gr.putValue(field.name, record.getArray(idx).array.map(_.asInstanceOf[String])) + case org.apache.spark.sql.types.IntegerType => + gr.putValue(field.name, record.getArray(idx).array.map(_.asInstanceOf[Int])) + case org.apache.spark.sql.types.LongType => + gr.putValue(field.name, record.getArray(idx).array.map(_.asInstanceOf[Long])) + case org.apache.spark.sql.types.FloatType => + gr.putValue(field.name, record.getArray(idx).array.map(_.asInstanceOf[Float])) + case org.apache.spark.sql.types.DoubleType => + gr.putValue(field.name, record.getArray(idx).array.map(_.asInstanceOf[Double])) + case org.apache.spark.sql.types.BooleanType => + gr.putValue(field.name, record.getArray(idx).array.map(_.asInstanceOf[Boolean])) + case org.apache.spark.sql.types.ByteType => + gr.putValue(field.name, record.getArray(idx).array.map(_.asInstanceOf[Byte])) + case org.apache.spark.sql.types.ShortType => + gr.putValue(field.name, record.getArray(idx).array.map(_.asInstanceOf[Short])) + case _ => + throw new UnsupportedOperationException("Unsupported data type") Review Comment: nit: improve error message to say "Unsupported array-type: array<%s>".format(elementType) -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org