This is an automated email from the ASF dual-hosted git repository. diwu pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git
The following commit(s) were added to refs/heads/master by this push: new 8781834 [feature] read doris via arrow flight sql (#227) 8781834 is described below commit 8781834868291c57d4d00eafc3a0fcda2b6914e2 Author: gnehil <adamlee...@gmail.com> AuthorDate: Tue Aug 27 16:24:53 2024 +0800 [feature] read doris via arrow flight sql (#227) --- spark-doris-connector/pom.xml | 130 +++++++++++++++++++-- .../doris/spark/cfg/ConfigurationOptions.java | 5 + .../org/apache/doris/spark/cfg/SparkSettings.java | 19 ++- .../doris/spark/rest/PartitionDefinition.java | 34 +++--- .../org/apache/doris/spark/rest/RestService.java | 25 ++-- .../apache/doris/spark/serialization/RowBatch.java | 81 ++++++++----- .../doris/spark/rdd/AbstractDorisRDDIterator.scala | 6 +- .../doris/spark/rdd/AbstractValueReader.scala | 36 ++++++ .../doris/spark/rdd/ScalaADBCValueReader.scala | 128 ++++++++++++++++++++ .../org/apache/doris/spark/rdd/ScalaDorisRDD.scala | 23 ++-- .../apache/doris/spark/rdd/ScalaValueReader.scala | 10 +- .../spark/sql/ScalaDorisRowADBCValueReader.scala | 50 ++++++++ .../apache/doris/spark/sql/ScalaDorisRowRDD.scala | 12 +- .../scala/org/apache/doris/spark/sql/Utils.scala | 30 +++++ .../org/apache/doris/spark/sql/TestUtils.scala | 27 ++++- 15 files changed, 510 insertions(+), 106 deletions(-) diff --git a/spark-doris-connector/pom.xml b/spark-doris-connector/pom.xml index a9fd180..bb82e6b 100644 --- a/spark-doris-connector/pom.xml +++ b/spark-doris-connector/pom.xml @@ -73,10 +73,10 @@ <scala.version>2.12.10</scala.version> <scala.major.version>2.12</scala.major.version> <libthrift.version>0.16.0</libthrift.version> - <arrow.version>13.0.0</arrow.version> + <arrow.version>15.0.2</arrow.version> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <project.scm.id>github</project.scm.id> - <netty.version>4.1.77.Final</netty.version> + <netty.version>4.1.104.Final</netty.version> <fasterxml.jackson.version>2.13.5</fasterxml.jackson.version> <thrift-service.version>1.0.1</thrift-service.version> <testcontainers.version>1.17.6</testcontainers.version> @@ -94,12 +94,6 @@ </exclusion> </exclusions> </dependency> - <dependency> - <groupId>io.netty</groupId> - <artifactId>netty-all</artifactId> - <version>${netty.version}</version> - <scope>provided</scope> - </dependency> <dependency> <groupId>org.apache.spark</groupId> @@ -248,6 +242,93 @@ <version>4.5.13</version> </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-hive_${scala.major.version}</artifactId> + <version>${spark.version}</version> + <scope>provided</scope> + </dependency> + + <dependency> + <groupId>org.apache.arrow.adbc</groupId> + <artifactId>adbc-driver-flight-sql</artifactId> + <version>0.13.0</version> + <exclusions> + <exclusion> + <groupId>org.apache.arrow</groupId> + <artifactId>flight-sql-jdbc-core</artifactId> + </exclusion> + <exclusion> + <artifactId>arrow-memory-netty</artifactId> + <groupId>org.apache.arrow</groupId> + </exclusion> + <exclusion> + <artifactId>arrow-memory-core</artifactId> + <groupId>org.apache.arrow</groupId> + </exclusion> + <exclusion> + <artifactId>arrow-format</artifactId> + <groupId>org.apache.arrow</groupId> + </exclusion> + <exclusion> + <artifactId>arrow-vector</artifactId> + <groupId>org.apache.arrow</groupId> + </exclusion> + <exclusion> + <artifactId>grpc-netty</artifactId> + <groupId>io.grpc</groupId> + </exclusion> + </exclusions> + </dependency> + + <dependency> + <groupId>io.grpc</groupId> + <artifactId>grpc-netty</artifactId> + <version>1.60.0</version> + <exclusions> + <exclusion> + <artifactId>netty-codec-http2</artifactId> + <groupId>io.netty</groupId> + </exclusion> + <exclusion> + <artifactId>netty-handler-proxy</artifactId> + <groupId>io.netty</groupId> + </exclusion> + <exclusion> + <artifactId>netty-transport-native-unix-common</artifactId> + <groupId>io.netty</groupId> + </exclusion> + </exclusions> + </dependency> + + <dependency> + <groupId>io.netty</groupId> + <artifactId>netty-codec-http2</artifactId> + <version>${netty.version}</version> + </dependency> + <dependency> + <groupId>io.netty</groupId> + <artifactId>netty-handler-proxy</artifactId> + <version>${netty.version}</version> + </dependency> + <dependency> + <groupId>io.netty</groupId> + <artifactId>netty-transport-native-unix-common</artifactId> + <version>${netty.version}</version> + </dependency> + + <dependency> + <groupId>org.apache.arrow</groupId> + <artifactId>flight-sql-jdbc-core</artifactId> + <version>${arrow.version}</version> + </dependency> + + <dependency> + <groupId>com.google.protobuf</groupId> + <artifactId>protobuf-java</artifactId> + <version>3.22.3</version> + </dependency> + </dependencies> <build> @@ -291,7 +372,7 @@ <plugin> <groupId>net.alchim31.maven</groupId> <artifactId>scala-maven-plugin</artifactId> - <version>3.2.1</version> + <version>3.4.1</version> <executions> <execution> <id>scala-compile-first</id> @@ -317,8 +398,20 @@ <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-shade-plugin</artifactId> - <version>3.2.1</version> + <version>3.4.1</version> <configuration> + <filters> + <filter> + <!-- Do not copy the signatures in the META-INF folder. + Otherwise, this might cause SecurityExceptions when using the JAR. --> + <artifact>*:*</artifact> + <excludes> + <exclude>META-INF/*.SF</exclude> + <exclude>META-INF/*.DSA</exclude> + <exclude>META-INF/*.RSA</exclude> + </excludes> + </filter> + </filters> <artifactSet> <excludes> <exclude>com.google.code.findbugs:*</exclude> @@ -355,7 +448,20 @@ <pattern>org.apache.http</pattern> <shadedPattern>org.apache.doris.shaded.org.apache.http</shadedPattern> </relocation> + <relocation> + <pattern>io.grpc</pattern> + <shadedPattern>org.apache.doris.shaded.io.grpc</shadedPattern> + </relocation> + <relocation> + <pattern>com.google</pattern> + <shadedPattern>org.apache.doris.shaded.com.google</shadedPattern> + </relocation> </relocations> + <transformers> + <transformer + implementation="org.apache.maven.plugins.shade.resource.ServicesResourceTransformer"/> + </transformers> + <!-- <minimizeJar>true</minimizeJar> --> </configuration> <executions> <execution> @@ -370,8 +476,8 @@ <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-compiler-plugin</artifactId> <configuration> - <source>8</source> - <target>8</target> + <source>1.8</source> + <target>1.8</target> </configuration> </plugin> <plugin> diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java index 8f64c74..68f4ba8 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java @@ -161,4 +161,9 @@ public interface ConfigurationOptions { "off_mode" ))); + String DORIS_READ_MODE = "doris.read.mode"; + String DORIS_READ_MODE_DEFAULT = "thrift"; + + String DORIS_ARROW_FLIGHT_SQL_PORT = "doris.arrow-flight-sql.port"; + } diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/SparkSettings.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/SparkSettings.java index 39fcd75..1448d2f 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/SparkSettings.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/SparkSettings.java @@ -17,16 +17,14 @@ package org.apache.doris.spark.cfg; -import java.util.Properties; - -import org.apache.spark.SparkConf; - import com.google.common.base.Preconditions; - +import org.apache.spark.SparkConf; import scala.Option; import scala.Serializable; import scala.Tuple2; +import java.util.Properties; + public class SparkSettings extends Settings implements Serializable { private final SparkConf cfg; @@ -36,6 +34,16 @@ public class SparkSettings extends Settings implements Serializable { this.cfg = cfg; } + public static SparkSettings fromProperties(Properties props) { + SparkConf sparkConf = new SparkConf(); + props.forEach((k, v) -> { + if (k instanceof String) { + sparkConf.set((String) k, v.toString()); + } + }); + return new SparkSettings(sparkConf); + } + public SparkSettings copy() { return new SparkSettings(cfg.clone()); } @@ -74,4 +82,5 @@ public class SparkSettings extends Settings implements Serializable { return props; } + } diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/PartitionDefinition.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/PartitionDefinition.java index 0c2aae3..baa517a 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/PartitionDefinition.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/PartitionDefinition.java @@ -17,16 +17,16 @@ package org.apache.doris.spark.rest; +import org.apache.doris.spark.cfg.PropertiesSettings; +import org.apache.doris.spark.cfg.Settings; +import org.apache.doris.spark.exception.IllegalArgumentException; + import java.io.Serializable; import java.util.Collections; import java.util.HashSet; import java.util.Objects; import java.util.Set; -import org.apache.doris.spark.cfg.PropertiesSettings; -import org.apache.doris.spark.cfg.Settings; -import org.apache.doris.spark.exception.IllegalArgumentException; - /** * Doris RDD partition info. */ @@ -124,12 +124,12 @@ public class PartitionDefinition implements Serializable, Comparable<PartitionDe return false; } PartitionDefinition that = (PartitionDefinition) o; - return Objects.equals(database, that.database) && - Objects.equals(table, that.table) && - Objects.equals(beAddress, that.beAddress) && - Objects.equals(tabletIds, that.tabletIds) && - Objects.equals(queryPlan, that.queryPlan) && - Objects.equals(serializedSettings, that.serializedSettings); + return Objects.equals(database, that.database) + && Objects.equals(table, that.table) + && Objects.equals(beAddress, that.beAddress) + && Objects.equals(tabletIds, that.tabletIds) + && Objects.equals(queryPlan, that.queryPlan) + && Objects.equals(serializedSettings, that.serializedSettings); } @Override @@ -144,12 +144,12 @@ public class PartitionDefinition implements Serializable, Comparable<PartitionDe @Override public String toString() { - return "PartitionDefinition{" + - ", database='" + database + '\'' + - ", table='" + table + '\'' + - ", beAddress='" + beAddress + '\'' + - ", tabletIds=" + tabletIds + - ", queryPlan='" + queryPlan + '\'' + - '}'; + return "PartitionDefinition{" + + ", database='" + database + '\'' + + ", table='" + table + '\'' + + ", beAddress='" + beAddress + '\'' + + ", tabletIds=" + tabletIds + + ", queryPlan='" + queryPlan + '\'' + + '}'; } } diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/RestService.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/RestService.java index 3f3516f..50432a6 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/RestService.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/RestService.java @@ -42,6 +42,7 @@ import org.apache.doris.spark.rest.models.QueryPlan; import org.apache.doris.spark.rest.models.Schema; import org.apache.doris.spark.rest.models.Tablet; import org.apache.doris.spark.sql.SchemaUtils; +import org.apache.doris.spark.sql.Utils; import org.apache.doris.spark.util.HttpUtil; import org.apache.doris.spark.util.URLs; @@ -51,7 +52,6 @@ import com.fasterxml.jackson.databind.JsonMappingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.json.JsonMapper; import com.google.common.annotations.VisibleForTesting; -import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.StringUtils; import org.apache.http.HttpHeaders; import org.apache.http.HttpStatus; @@ -64,6 +64,7 @@ import org.apache.http.entity.StringEntity; import org.apache.http.impl.client.CloseableHttpClient; import org.apache.http.util.EntityUtils; import org.slf4j.Logger; +import scala.Option; import java.io.IOException; import java.io.Serializable; @@ -227,23 +228,11 @@ public class RestService implements Serializable { String[] tableIdentifiers = parseIdentifier(cfg.getProperty(ConfigurationOptions.DORIS_TABLE_IDENTIFIER), logger); String readFields = cfg.getProperty(ConfigurationOptions.DORIS_READ_FIELD, "*"); - if (!"*".equals(readFields)) { - String[] readFieldArr = readFields.split(","); - String[] bitmapColumns = cfg.getProperty(SchemaUtils.DORIS_BITMAP_COLUMNS(), "").split(","); - String[] hllColumns = cfg.getProperty(SchemaUtils.DORIS_HLL_COLUMNS(), "").split(","); - for (int i = 0; i < readFieldArr.length; i++) { - String readFieldName = readFieldArr[i].replaceAll("`", ""); - if (ArrayUtils.contains(bitmapColumns, readFieldName) - || ArrayUtils.contains(hllColumns, readFieldName)) { - readFieldArr[i] = "'READ UNSUPPORTED' AS " + readFieldArr[i]; - } - } - readFields = StringUtils.join(readFieldArr, ","); - } - String sql = "select " + readFields + " from `" + tableIdentifiers[0] + "`.`" + tableIdentifiers[1] + "`"; - if (!StringUtils.isEmpty(cfg.getProperty(ConfigurationOptions.DORIS_FILTER_QUERY))) { - sql += " where " + cfg.getProperty(ConfigurationOptions.DORIS_FILTER_QUERY); - } + String[] bitmapColumns = cfg.getProperty(SchemaUtils.DORIS_BITMAP_COLUMNS(), "").split(","); + String[] hllColumns = cfg.getProperty(SchemaUtils.DORIS_HLL_COLUMNS(), "").split(","); + String sql = Utils.generateQueryStatement(readFields.split(","), bitmapColumns, hllColumns, + "`" + tableIdentifiers[0] + "`.`" + tableIdentifiers[1] + "`", + cfg.getProperty(ConfigurationOptions.DORIS_FILTER_QUERY, ""), Option.empty()); logger.debug("Query SQL Sending to Doris FE is: '{}'.", sql); String finalSql = sql; diff --git a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java index 319bd3c..c3e70e9 100644 --- a/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java +++ b/spark-doris-connector/src/main/java/org/apache/doris/spark/serialization/RowBatch.java @@ -45,6 +45,7 @@ import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.MapVector; import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.complex.impl.UnionMapReader; +import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.ipc.ArrowStreamReader; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.Types.MinorType; @@ -79,6 +80,10 @@ import java.util.Objects; */ public class RowBatch { private static final Logger logger = LoggerFactory.getLogger(RowBatch.class); + + private final List<Row> rowBatch = new ArrayList<>(); + private final ArrowReader arrowReader; + private final Schema schema; private static final ZoneId DEFAULT_ZONE_ID = ZoneId.systemDefault(); private static final DateTimeFormatter DATE_TIME_FORMATTER = new DateTimeFormatterBuilder() @@ -93,10 +98,7 @@ public class RowBatch { private final DateTimeFormatter dateTimeV2Formatter = DateTimeFormatter.ofPattern(DATETIMEV2_PATTERN); private final DateTimeFormatter dateFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd"); - private final List<Row> rowBatch = new ArrayList<>(); - private final ArrowStreamReader arrowStreamReader; - private final RootAllocator rootAllocator; - private final Schema schema; + private RootAllocator rootAllocator = null; // offset for iterate the rowBatch private int offsetInRowBatch = 0; private int rowCountInOneBatch = 0; @@ -104,32 +106,15 @@ public class RowBatch { private List<FieldVector> fieldVectors; public RowBatch(TScanBatchResult nextResult, Schema schema) throws DorisException { - this.schema = schema; + this.rootAllocator = new RootAllocator(Integer.MAX_VALUE); - this.arrowStreamReader = new ArrowStreamReader( - new ByteArrayInputStream(nextResult.getRows()), - rootAllocator - ); + this.arrowReader = new ArrowStreamReader(new ByteArrayInputStream(nextResult.getRows()), rootAllocator); + this.schema = schema; + try { - VectorSchemaRoot root = arrowStreamReader.getVectorSchemaRoot(); - while (arrowStreamReader.loadNextBatch()) { - fieldVectors = root.getFieldVectors(); - if (fieldVectors.size() > schema.size()) { - logger.error("Data schema size '{}' should not be bigger than arrow field size '{}'.", - schema.size(), fieldVectors.size()); - throw new DorisException("Load Doris data failed, schema size of fetch data is wrong."); - } - if (fieldVectors.isEmpty() || root.getRowCount() == 0) { - logger.debug("One batch in arrow has no data."); - continue; - } - rowCountInOneBatch = root.getRowCount(); - // init the rowBatch - for (int i = 0; i < rowCountInOneBatch; ++i) { - rowBatch.add(new Row(fieldVectors.size())); - } - convertArrowToRowBatch(); - readRowCount += root.getRowCount(); + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + while (arrowReader.loadNextBatch()) { + readBatch(root); } } catch (Exception e) { logger.error("Read Doris Data failed because: ", e); @@ -137,6 +122,42 @@ public class RowBatch { } finally { close(); } + + } + + public RowBatch(ArrowReader reader, Schema schema) throws DorisException { + + this.arrowReader = reader; + this.schema = schema; + + try { + VectorSchemaRoot root = arrowReader.getVectorSchemaRoot(); + readBatch(root); + } catch (Exception e) { + logger.error("Read Doris Data failed because: ", e); + throw new DorisException(e.getMessage()); + } + + } + + private void readBatch(VectorSchemaRoot root) throws DorisException { + fieldVectors = root.getFieldVectors(); + if (fieldVectors.size() > schema.size()) { + logger.error("Data schema size '{}' should not be bigger than arrow field size '{}'.", + schema.size(), fieldVectors.size()); + throw new DorisException("Load Doris data failed, schema size of fetch data is wrong."); + } + if (fieldVectors.isEmpty() || root.getRowCount() == 0) { + logger.debug("One batch in arrow has no data."); + return; + } + rowCountInOneBatch = root.getRowCount(); + // init the rowBatch + for (int i = 0; i < rowCountInOneBatch; ++i) { + rowBatch.add(new Row(fieldVectors.size())); + } + convertArrowToRowBatch(); + readRowCount += root.getRowCount(); } public static LocalDateTime longToLocalDateTime(long time) { @@ -505,8 +526,8 @@ public class RowBatch { public void close() { try { - if (arrowStreamReader != null) { - arrowStreamReader.close(); + if (arrowReader != null) { + arrowReader.close(); } if (rootAllocator != null) { rootAllocator.close(); diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/AbstractDorisRDDIterator.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/AbstractDorisRDDIterator.scala index 902c634..8e5f661 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/AbstractDorisRDDIterator.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/AbstractDorisRDDIterator.scala @@ -33,14 +33,14 @@ private[spark] abstract class AbstractDorisRDDIterator[T]( private var closed = false // the reader obtain data from Doris BE - private lazy val reader = { + private lazy val reader: AbstractValueReader = { initialized = true val settings = partition.settings() initReader(settings) val valueReaderName = settings.getProperty(DORIS_VALUE_READER_CLASS) - logger.debug(s"Use value reader '$valueReaderName'.") + logger.info(s"Use value reader '$valueReaderName'.") val cons = Class.forName(valueReaderName).getDeclaredConstructor(classOf[PartitionDefinition], classOf[Settings]) - cons.newInstance(partition, settings).asInstanceOf[ScalaValueReader] + cons.newInstance(partition, settings).asInstanceOf[AbstractValueReader] } context.addTaskCompletionListener(new TaskCompletionListener() { diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/AbstractValueReader.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/AbstractValueReader.scala new file mode 100644 index 0000000..3c8acf6 --- /dev/null +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/AbstractValueReader.scala @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.spark.rdd + +import org.apache.doris.spark.serialization.RowBatch + +trait AbstractValueReader { + + protected var rowBatch: RowBatch = _ + + def hasNext: Boolean + + /** + * get next value. + * @return next value + */ + def next: AnyRef + + def close(): Unit + +} diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaADBCValueReader.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaADBCValueReader.scala new file mode 100644 index 0000000..3cd3da0 --- /dev/null +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaADBCValueReader.scala @@ -0,0 +1,128 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.spark.rdd + +import org.apache.arrow.adbc.core.{AdbcConnection, AdbcDriver, AdbcStatement} +import org.apache.arrow.adbc.driver.flightsql.FlightSqlDriver +import org.apache.arrow.flight.Location +import org.apache.arrow.memory.{BufferAllocator, RootAllocator} +import org.apache.arrow.vector.ipc.ArrowReader +import org.apache.commons.lang3.exception.ExceptionUtils +import org.apache.doris.spark.cfg.{ConfigurationOptions, Settings, SparkSettings} +import org.apache.doris.spark.exception.ShouldNeverHappenException +import org.apache.doris.spark.rest.{PartitionDefinition, RestService} +import org.apache.doris.spark.serialization.RowBatch +import org.apache.doris.spark.sql.{SchemaUtils, Utils} +import org.apache.doris.spark.util.ErrorMessages.SHOULD_NOT_HAPPEN_MESSAGE +import org.apache.spark.internal.Logging + +import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.JavaConverters._ +import scala.collection.mutable + +class ScalaADBCValueReader(partition: PartitionDefinition, settings: Settings) extends AbstractValueReader with Logging { + + private[this] val eos: AtomicBoolean = new AtomicBoolean(false) + + private lazy val schema = RestService.getSchema(SparkSettings.fromProperties(settings.asProperties()), log) + + private lazy val conn: AdbcConnection = { + // val loader = ClassLoader.getSystemClassLoader + // val classesField = classOf[ClassLoader].getDeclaredField("classes") + // classesField.setAccessible(true) + // val classes = classesField.get(loader).asInstanceOf[java.util.Vector[Any]] + // classes.forEach(clazz => println(clazz.asInstanceOf[Class[_]].getName)) + // Class.forName("org.apache.doris.shaded.org.apache.arrow.memory.RootAllocator") + var allocator: BufferAllocator = null + try { + allocator = new RootAllocator() + } catch { + case e: Throwable => println(ExceptionUtils.getStackTrace(e)) + throw e; + } + val driver = new FlightSqlDriver(allocator) + val params = mutable.HashMap[String, AnyRef]().asJava + AdbcDriver.PARAM_URI.set(params, Location.forGrpcInsecure( + settings.getProperty(ConfigurationOptions.DORIS_FENODES).split(":")(0), + settings.getIntegerProperty(ConfigurationOptions.DORIS_ARROW_FLIGHT_SQL_PORT) + ).getUri.toString) + AdbcDriver.PARAM_USERNAME.set(params, settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_USER)) + AdbcDriver.PARAM_PASSWORD.set(params, settings.getProperty(ConfigurationOptions.DORIS_REQUEST_AUTH_PASSWORD)) + val database = driver.open(params) + database.connect() + } + + private lazy val stmt: AdbcStatement = conn.createStatement() + + private lazy val queryResult: AdbcStatement.QueryResult = { + val flightSql = Utils.generateQueryStatement(settings.getProperty(ConfigurationOptions.DORIS_READ_FIELD, "*").split(","), + settings.getProperty(SchemaUtils.DORIS_BITMAP_COLUMNS, "").split(","), + settings.getProperty(SchemaUtils.DORIS_HLL_COLUMNS, "").split(","), + s"`${partition.getDatabase}`.`${partition.getTable}`", + settings.getProperty(ConfigurationOptions.DORIS_FILTER_QUERY, ""), + Some(partition) + ) + log.info(s"flightSql: $flightSql") + stmt.setSqlQuery(flightSql) + stmt.executeQuery() + } + + private lazy val arrowReader: ArrowReader = queryResult.getReader + + override def hasNext: Boolean = { + if (!eos.get && (rowBatch == null || !rowBatch.hasNext)) { + eos.set(!arrowReader.loadNextBatch()) + if (!eos.get) { + rowBatch = new RowBatch(arrowReader, schema) + } + } + !eos.get + } + + /** + * get next value. + * + * @return next value + */ + override def next: AnyRef = { + if (!hasNext) { + logError(SHOULD_NOT_HAPPEN_MESSAGE) + throw new ShouldNeverHappenException + } + rowBatch.next + } + + override def close(): Unit = { + if (rowBatch != null) { + rowBatch.close() + } + if (arrowReader != null) { + arrowReader.close() + } + if (queryResult != null) { + queryResult.close() + } + if (stmt != null) { + stmt.close() + } + if (conn != null) { + conn.close() + } + } + +} diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaDorisRDD.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaDorisRDD.scala index 0ff8bbd..768611c 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaDorisRDD.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaDorisRDD.scala @@ -18,29 +18,32 @@ package org.apache.doris.spark.rdd import scala.reflect.ClassTag - import org.apache.doris.spark.cfg.ConfigurationOptions.DORIS_VALUE_READER_CLASS -import org.apache.doris.spark.cfg.Settings +import org.apache.doris.spark.cfg.{ConfigurationOptions, Settings} import org.apache.doris.spark.rest.PartitionDefinition - import org.apache.spark.{Partition, SparkContext, TaskContext} private[spark] class ScalaDorisRDD[T: ClassTag]( - sc: SparkContext, - params: Map[String, String] = Map.empty) - extends AbstractDorisRDD[T](sc, params) { + sc: SparkContext, + params: Map[String, String] = Map.empty) + extends AbstractDorisRDD[T](sc, params) { override def compute(split: Partition, context: TaskContext): ScalaDorisRDDIterator[T] = { new ScalaDorisRDDIterator(context, split.asInstanceOf[DorisPartition].dorisPartition) } } private[spark] class ScalaDorisRDDIterator[T]( - context: TaskContext, - partition: PartitionDefinition) - extends AbstractDorisRDDIterator[T](context, partition) { + context: TaskContext, + partition: PartitionDefinition) + extends AbstractDorisRDDIterator[T](context, partition) { override def initReader(settings: Settings): Unit = { - settings.setProperty(DORIS_VALUE_READER_CLASS, classOf[ScalaValueReader].getName) + settings.getProperty(ConfigurationOptions.DORIS_READ_MODE, + ConfigurationOptions.DORIS_READ_MODE_DEFAULT).toUpperCase match { + case "THRIFT" => settings.setProperty(DORIS_VALUE_READER_CLASS, classOf[ScalaValueReader].getName) + case "ARROW" => settings.setProperty(DORIS_VALUE_READER_CLASS, classOf[ScalaADBCValueReader].getName) + case mode: String => throw new IllegalArgumentException(s"Unsupported read mode: $mode") + } } override def createValue(value: Object): T = { diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala index f9124a6..16707b8 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala @@ -45,7 +45,7 @@ import scala.util.control.Breaks * @param partition Doris RDD partition * @param settings request configuration */ -class ScalaValueReader(partition: PartitionDefinition, settings: Settings) extends Logging { +class ScalaValueReader(partition: PartitionDefinition, settings: Settings) extends AbstractValueReader with Logging { private[this] lazy val client = new BackendClient(new Routing(partition.getBeAddress), settings) @@ -53,8 +53,6 @@ class ScalaValueReader(partition: PartitionDefinition, settings: Settings) exten private[this] val eos: AtomicBoolean = new AtomicBoolean(false) - protected var rowBatch: RowBatch = _ - // flag indicate if support deserialize Arrow to RowBatch asynchronously private[this] lazy val deserializeArrowToRowBatchAsync: Boolean = Try { settings.getProperty(DORIS_DESERIALIZE_ARROW_ASYNC, DORIS_DESERIALIZE_ARROW_ASYNC_DEFAULT.toString).toBoolean @@ -173,7 +171,7 @@ class ScalaValueReader(partition: PartitionDefinition, settings: Settings) exten * read data and cached in rowBatch. * @return true if hax next value */ - def hasNext: Boolean = { + override def hasNext: Boolean = { var hasNext = false if (deserializeArrowToRowBatchAsync && asyncThreadStarted) { // support deserialize Arrow to RowBatch asynchronously @@ -219,7 +217,7 @@ class ScalaValueReader(partition: PartitionDefinition, settings: Settings) exten * get next value. * @return next value */ - def next: AnyRef = { + override def next: AnyRef = { if (!hasNext) { logError(SHOULD_NOT_HAPPEN_MESSAGE) throw new ShouldNeverHappenException @@ -227,7 +225,7 @@ class ScalaValueReader(partition: PartitionDefinition, settings: Settings) exten rowBatch.next } - def close(): Unit = { + override def close(): Unit = { val closeParams = new TScanCloseParams closeParams.setContextId(contextId) lockClient(_.closeScanner(closeParams)) diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowADBCValueReader.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowADBCValueReader.scala new file mode 100644 index 0000000..a658cdc --- /dev/null +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowADBCValueReader.scala @@ -0,0 +1,50 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.spark.sql + +import org.apache.doris.spark.cfg.ConfigurationOptions.DORIS_READ_FIELD +import org.apache.doris.spark.cfg.Settings +import org.apache.doris.spark.exception.ShouldNeverHappenException +import org.apache.doris.spark.rdd.ScalaADBCValueReader +import org.apache.doris.spark.rest.PartitionDefinition +import org.apache.doris.spark.util.ErrorMessages.SHOULD_NOT_HAPPEN_MESSAGE +import org.slf4j.{Logger, LoggerFactory} + +import scala.collection.JavaConverters._ + +class ScalaDorisRowADBCValueReader(partition: PartitionDefinition, settings: Settings) + extends ScalaADBCValueReader(partition, settings) { + + private val logger: Logger = LoggerFactory.getLogger(classOf[ScalaDorisRowADBCValueReader].getName) + + val rowOrder: Seq[String] = settings.getProperty(DORIS_READ_FIELD).split(",") + + override def next: AnyRef = { + if (!hasNext) { + logger.error(SHOULD_NOT_HAPPEN_MESSAGE) + throw new ShouldNeverHappenException + } + val row: ScalaDorisRow = new ScalaDorisRow(rowOrder) + rowBatch.next.asScala.zipWithIndex.foreach{ + case (s, index) if index < row.values.size => row.values.update(index, s) + case _ => // nothing + } + row + } + +} diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowRDD.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowRDD.scala index b31a54d..6c3bd35 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowRDD.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/ScalaDorisRowRDD.scala @@ -18,10 +18,9 @@ package org.apache.doris.spark.sql import org.apache.doris.spark.cfg.ConfigurationOptions.DORIS_VALUE_READER_CLASS -import org.apache.doris.spark.cfg.Settings +import org.apache.doris.spark.cfg.{ConfigurationOptions, Settings} import org.apache.doris.spark.rdd.{AbstractDorisRDD, AbstractDorisRDDIterator, DorisPartition} import org.apache.doris.spark.rest.PartitionDefinition - import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.sql.Row import org.apache.spark.sql.types.StructType @@ -43,8 +42,13 @@ private[spark] class ScalaDorisRowRDDIterator( struct: StructType) extends AbstractDorisRDDIterator[Row](context, partition) { - override def initReader(settings: Settings) = { - settings.setProperty(DORIS_VALUE_READER_CLASS, classOf[ScalaDorisRowValueReader].getName) + override def initReader(settings: Settings): Unit = { + settings.getProperty(ConfigurationOptions.DORIS_READ_MODE, + ConfigurationOptions.DORIS_READ_MODE_DEFAULT).toUpperCase match { + case "THRIFT" => settings.setProperty (DORIS_VALUE_READER_CLASS, classOf[ScalaDorisRowValueReader].getName) + case "ARROW" => settings.setProperty (DORIS_VALUE_READER_CLASS, classOf[ScalaDorisRowADBCValueReader].getName) + case mode: String => throw new IllegalArgumentException(s"Unsupported read mode: $mode") + } } override def createValue(value: Object): Row = { diff --git a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala index 0400b04..2404584 100644 --- a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala +++ b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala @@ -20,6 +20,7 @@ package org.apache.doris.spark.sql import org.apache.commons.lang3.StringUtils import org.apache.doris.spark.cfg.ConfigurationOptions import org.apache.doris.spark.exception.DorisException +import org.apache.doris.spark.rest.PartitionDefinition import org.apache.spark.sql.jdbc.JdbcDialect import org.apache.spark.sql.sources._ import org.slf4j.Logger @@ -28,6 +29,7 @@ import java.sql.{Date, Timestamp} import java.time.{Duration, LocalDate} import java.util.concurrent.locks.LockSupport import scala.annotation.tailrec +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.util.{Failure, Success, Try} @@ -201,4 +203,32 @@ private[spark] object Utils { case Failure(exception) => Failure(exception) } } + + def generateQueryStatement(readColumns: Array[String], bitmapColumns: Array[String], hllColumns: Array[String], + tableName: String, queryFilter: String, partitionOpt: Option[PartitionDefinition] = None): String = { + + val columns = { + val finalReadColumns = readColumns.clone() + if (finalReadColumns(0) != "*" && bitmapColumns.nonEmpty && hllColumns.nonEmpty) { + for (i <- finalReadColumns.indices) { + finalReadColumns(i) + val readFieldName = finalReadColumns(i).replaceAll("`", "") + if (bitmapColumns.contains(readFieldName) || hllColumns.contains(readFieldName)) { + finalReadColumns(i) = "'READ UNSUPPORTED' AS " + finalReadColumns(i) + } + } + } + finalReadColumns.mkString(",") + } + + val tabletClause = partitionOpt match { + case Some(partition) => s"TABLET(${partition.getTabletIds.asScala.mkString(",")})" + case None => "" + } + val whereClause = if (queryFilter.isEmpty) "" else s"WHERE $queryFilter" + + s"SELECT $columns FROM $tableName $tabletClause $whereClause".trim + + } + } diff --git a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestUtils.scala b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestUtils.scala index 7e7919a..e9db609 100644 --- a/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestUtils.scala +++ b/spark-doris-connector/src/test/scala/org/apache/doris/spark/sql/TestUtils.scala @@ -17,14 +17,17 @@ package org.apache.doris.spark.sql -import org.apache.doris.spark.cfg.ConfigurationOptions +import org.apache.doris.spark.cfg.{ConfigurationOptions, PropertiesSettings, Settings} import org.apache.doris.spark.exception.DorisException +import org.apache.doris.spark.rest.PartitionDefinition import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.hamcrest.core.StringStartsWith.startsWith import org.junit._ import org.slf4j.LoggerFactory +import scala.collection.JavaConverters._ + class TestUtils extends ExpectedExceptionTest { private lazy val logger = LoggerFactory.getLogger(classOf[TestUtils]) @@ -132,4 +135,26 @@ class TestUtils extends ExpectedExceptionTest { thrown.expectMessage(startsWith(s"${ConfigurationOptions.DORIS_REQUEST_AUTH_USER} cannot use in Doris Datasource,")) Utils.params(parameters6, logger) } + + @Test + def testGenerateQueryStatement(): Unit = { + + val readColumns = Array[String]("*") + + val partition = new PartitionDefinition("db", "tbl1", new PropertiesSettings(), "127.0.0.1:8060", Set[java.lang.Long](1L).asJava, "") + Assert.assertEquals("SELECT * FROM `db`.`tbl1` TABLET(1)", + Utils.generateQueryStatement(readColumns, Array[String](), Array[String](), "`db`.`tbl1`", "", Some(partition))) + + val readColumns1 = Array[String]("`c1`","`c2`","`c3`") + + val bitmapColumns = Array[String]("c2") + val hllColumns = Array[String]("c3") + + val where = "c1 = 10" + + Assert.assertEquals("SELECT `c1`,'READ UNSUPPORTED' AS `c2`,'READ UNSUPPORTED' AS `c3` FROM `db`.`tbl1` WHERE c1 = 10", + Utils.generateQueryStatement(readColumns1, bitmapColumns, hllColumns, "`db`.`tbl1`", where)) + + } + } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org