This is an automated email from the ASF dual-hosted git repository.

diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git


The following commit(s) were added to refs/heads/master by this push:
     new 7a60cb3  [feature](connector) support partial limit push down  (#257)
7a60cb3 is described below

commit 7a60cb31bd6e57c5e78ad6358903ac60cb80268c
Author: gnehil <adamlee...@gmail.com>
AuthorDate: Mon Jan 13 18:32:25 2025 +0800

    [feature](connector) support partial limit push down  (#257)
---
 .../spark/client/entity/DorisReaderPartition.java   | 21 ++++++++++++++++++++-
 .../spark/client/read/AbstractThriftReader.java     |  8 ++++++++
 .../spark/client/read/DorisFlightSqlReader.java     |  3 ++-
 .../spark/client/read/ReaderPartitionGenerator.java | 20 +++++++++++++-------
 .../apache/doris/spark/read/AbstractDorisScan.scala |  8 +++++---
 .../doris/spark/read/DorisPartitionReader.scala     |  2 +-
 .../apache/doris/spark/read/DorisScanBuilder.scala  | 14 +++++++++++---
 .../org/apache/doris/spark/read/DorisScanV2.scala   |  4 +++-
 .../apache/doris/spark/read/DorisScanBuilder.scala  | 14 +++++++++++---
 .../org/apache/doris/spark/read/DorisScanV2.scala   |  4 +++-
 .../apache/doris/spark/read/DorisScanBuilder.scala  | 14 +++++++++++---
 .../org/apache/doris/spark/read/DorisScanV2.scala   |  4 +++-
 12 files changed, 91 insertions(+), 25 deletions(-)

diff --git 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/entity/DorisReaderPartition.java
 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/entity/DorisReaderPartition.java
index 242c9d7..aa75319 100644
--- 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/entity/DorisReaderPartition.java
+++ 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/entity/DorisReaderPartition.java
@@ -32,6 +32,7 @@ public class DorisReaderPartition implements Serializable {
     private final String opaquedQueryPlan;
     private final String[] readColumns;
     private final String[] filters;
+    private final Integer limit;
     private final DorisConfig config;
 
     public DorisReaderPartition(String database, String table, Backend 
backend, Long[] tablets, String opaquedQueryPlan, String[] readColumns, 
String[] filters, DorisConfig config) {
@@ -42,6 +43,19 @@ public class DorisReaderPartition implements Serializable {
         this.opaquedQueryPlan = opaquedQueryPlan;
         this.readColumns = readColumns;
         this.filters = filters;
+        this.limit = -1;
+        this.config = config;
+    }
+
+    public DorisReaderPartition(String database, String table, Backend 
backend, Long[] tablets, String opaquedQueryPlan, String[] readColumns, 
String[] filters, Integer limit, DorisConfig config) {
+        this.database = database;
+        this.table = table;
+        this.backend = backend;
+        this.tablets = tablets;
+        this.opaquedQueryPlan = opaquedQueryPlan;
+        this.readColumns = readColumns;
+        this.filters = filters;
+        this.limit = limit;
         this.config = config;
     }
 
@@ -78,6 +92,10 @@ public class DorisReaderPartition implements Serializable {
         return filters;
     }
 
+    public Integer getLimit() {
+        return limit;
+    }
+
     @Override
     public boolean equals(Object o) {
         if (o == null || getClass() != o.getClass()) return false;
@@ -89,11 +107,12 @@ public class DorisReaderPartition implements Serializable {
                 && Objects.equals(opaquedQueryPlan, that.opaquedQueryPlan)
                 && Objects.deepEquals(readColumns, that.readColumns)
                 && Objects.deepEquals(filters, that.filters)
+                && Objects.equals(limit, that.limit)
                 && Objects.equals(config, that.config);
     }
 
     @Override
     public int hashCode() {
-        return Objects.hash(database, table, backend, 
Arrays.hashCode(tablets), opaquedQueryPlan, Arrays.hashCode(readColumns), 
Arrays.hashCode(filters), config);
+        return Objects.hash(database, table, backend, 
Arrays.hashCode(tablets), opaquedQueryPlan, Arrays.hashCode(readColumns), 
Arrays.hashCode(filters), limit, config);
     }
 }
\ No newline at end of file
diff --git 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java
 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java
index 608e30c..c533de8 100644
--- 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java
+++ 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/AbstractThriftReader.java
@@ -71,6 +71,8 @@ public abstract class AbstractThriftReader extends 
DorisReader {
 
     private final Thread asyncThread;
 
+    private int readCount = 0;
+
     protected AbstractThriftReader(DorisReaderPartition partition) throws 
Exception {
         super(partition);
         this.frontend = new DorisFrontendClient(config);
@@ -132,6 +134,9 @@ public abstract class AbstractThriftReader extends 
DorisReader {
 
     @Override
     public boolean hasNext() throws DorisException {
+        if (partition.getLimit() > 0 && readCount >= partition.getLimit()) {
+            return false;
+        }
         boolean hasNext = false;
         if (isAsync && asyncThread != null && asyncThread.isAlive()) {
             if (rowBatch == null || !rowBatch.hasNext()) {
@@ -186,6 +191,9 @@ public abstract class AbstractThriftReader extends 
DorisReader {
         if (!hasNext()) {
             throw new RuntimeException("No more elements");
         }
+        if (partition.getLimit() > 0) {
+            readCount++;
+        }
         return rowBatch.next().toArray();
     }
 
diff --git 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/DorisFlightSqlReader.java
 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/DorisFlightSqlReader.java
index 4561f40..7a2d34e 100644
--- 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/DorisFlightSqlReader.java
+++ 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/DorisFlightSqlReader.java
@@ -149,7 +149,8 @@ public class DorisFlightSqlReader extends DorisReader {
         String fullTableName = 
config.getValue(DorisOptions.DORIS_TABLE_IDENTIFIER);
         String tablets = String.format("TABLET(%s)", 
StringUtils.join(partition.getTablets(), ","));
         String predicates = partition.getFilters().length == 0 ? "" : " WHERE 
" + String.join(" AND ", partition.getFilters());
-        return String.format("SELECT %s FROM %s %s%s", columns, fullTableName, 
tablets, predicates);
+        String limit = partition.getLimit() > 0 ? " LIMIT " + 
partition.getLimit() : "";
+        return String.format("SELECT %s FROM %s %s%s%s", columns, 
fullTableName, tablets, predicates, limit);
     }
 
     protected Schema processDorisSchema(DorisReaderPartition partition, final 
Schema originSchema) throws Exception {
diff --git 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/ReaderPartitionGenerator.java
 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/ReaderPartitionGenerator.java
index 580b29e..4aa660a 100644
--- 
a/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/ReaderPartitionGenerator.java
+++ 
b/spark-doris-connector/spark-doris-connector-base/src/main/java/org/apache/doris/spark/client/read/ReaderPartitionGenerator.java
@@ -27,6 +27,8 @@ import org.apache.doris.spark.rest.models.Field;
 import org.apache.doris.spark.rest.models.QueryPlan;
 import org.apache.doris.spark.rest.models.Schema;
 import org.apache.doris.spark.util.DorisDialects;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -38,6 +40,8 @@ import java.util.stream.Collectors;
 
 public class ReaderPartitionGenerator {
 
+    private static final Logger LOG = 
LoggerFactory.getLogger(ReaderPartitionGenerator.class);
+
     /*
      * for spark 2
      */
@@ -51,14 +55,14 @@ public class ReaderPartitionGenerator {
         }
         String[] filters = config.contains(DorisOptions.DORIS_FILTER_QUERY) ?
                 config.getValue(DorisOptions.DORIS_FILTER_QUERY).split("\\.") 
: new String[0];
-        return generatePartitions(config, originReadCols, filters);
+        return generatePartitions(config, originReadCols, filters, -1);
     }
 
     /*
      * for spark 3
      */
     public static DorisReaderPartition[] generatePartitions(DorisConfig config,
-                                                            String[] fields, 
String[] filters) throws Exception {
+                                                            String[] fields, 
String[] filters, Integer limit) throws Exception {
         DorisFrontendClient frontend = new DorisFrontendClient(config);
         String fullTableName = 
config.getValue(DorisOptions.DORIS_TABLE_IDENTIFIER);
         String[] tableParts = fullTableName.split("\\.");
@@ -69,13 +73,15 @@ public class ReaderPartitionGenerator {
             originReadCols = frontend.getTableAllColumns(db, table);
         }
         String[] finalReadColumns = getFinalReadColumns(config, frontend, db, 
table, originReadCols);
-        String sql = "SELECT " + String.join(",", finalReadColumns) + " FROM 
`" + db + "`.`" + table + "`" +
-                (filters.length == 0 ? "" : " WHERE " + String.join(" AND ", 
filters));
+        String finalReadColumnString = String.join(",", finalReadColumns);
+        String finalWhereClauseString = filters.length == 0 ? "" : " WHERE " + 
String.join(" AND ", filters);
+        String sql = "SELECT " + finalReadColumnString + " FROM `" + db + 
"`.`" + table + "`" + finalWhereClauseString;
+        LOG.info("get query plan for table " + db + "." + table + ", sql: " + 
sql);
         QueryPlan queryPlan = frontend.getQueryPlan(db, table, sql);
         Map<String, List<Long>> beToTablets = mappingBeToTablets(queryPlan);
         int maxTabletSize = config.getValue(DorisOptions.DORIS_TABLET_SIZE);
         return distributeTabletsToPartitions(db, table, beToTablets, 
queryPlan.getOpaqued_query_plan(), maxTabletSize,
-                finalReadColumns, filters, config);
+                finalReadColumns, filters, config, limit);
     }
 
     @VisibleForTesting
@@ -106,7 +112,7 @@ public class ReaderPartitionGenerator {
                                                                         
Map<String, List<Long>> beToTablets,
                                                                         String 
opaquedQueryPlan, int maxTabletSize,
                                                                         
String[] readColumns, String[] predicates,
-                                                                        
DorisConfig config) {
+                                                                        
DorisConfig config, Integer limit) {
         List<DorisReaderPartition> partitions = new ArrayList<>();
         beToTablets.forEach((backendStr, tabletIds) -> {
             List<Long> distinctTablets = new ArrayList<>(new 
HashSet<>(tabletIds));
@@ -115,7 +121,7 @@ public class ReaderPartitionGenerator {
                 Long[] tablets = distinctTablets.subList(offset, 
Math.min(offset + maxTabletSize, distinctTablets.size())).toArray(new Long[0]);
                 offset += maxTabletSize;
                 partitions.add(new DorisReaderPartition(database, table, new 
Backend(backendStr), tablets,
-                        opaquedQueryPlan, readColumns, predicates, config));
+                        opaquedQueryPlan, readColumns, predicates, limit, 
config));
             }
         });
         return partitions.toArray(new DorisReaderPartition[0]);
diff --git 
a/spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/AbstractDorisScan.scala
 
b/spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/AbstractDorisScan.scala
index f1666ad..34bff24 100644
--- 
a/spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/AbstractDorisScan.scala
+++ 
b/spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/AbstractDorisScan.scala
@@ -35,7 +35,7 @@ abstract class AbstractDorisScan(config: DorisConfig, schema: 
StructType) extend
   override def toBatch: Batch = this
 
   override def planInputPartitions(): Array[InputPartition] = {
-    ReaderPartitionGenerator.generatePartitions(config, schema.names, 
compiledFilters()).map(toInputPartition)
+    ReaderPartitionGenerator.generatePartitions(config, schema.names, 
compiledFilters(), getLimit).map(toInputPartition)
   }
 
 
@@ -44,10 +44,12 @@ abstract class AbstractDorisScan(config: DorisConfig, 
schema: StructType) extend
   }
 
   private def toInputPartition(rp: DorisReaderPartition): DorisInputPartition =
-    DorisInputPartition(rp.getDatabase, rp.getTable, rp.getBackend, 
rp.getTablets.map(_.toLong), rp.getOpaquedQueryPlan, rp.getReadColumns, 
rp.getFilters)
+    DorisInputPartition(rp.getDatabase, rp.getTable, rp.getBackend, 
rp.getTablets.map(_.toLong), rp.getOpaquedQueryPlan, rp.getReadColumns, 
rp.getFilters, rp.getLimit)
 
   protected def compiledFilters(): Array[String]
 
+  protected def getLimit: Int = -1
+
 }
 
-case class DorisInputPartition(database: String, table: String, backend: 
Backend, tablets: Array[Long], opaquedQueryPlan: String, readCols: 
Array[String], predicates: Array[String]) extends InputPartition
+case class DorisInputPartition(database: String, table: String, backend: 
Backend, tablets: Array[Long], opaquedQueryPlan: String, readCols: 
Array[String], predicates: Array[String], limit: Int = -1) extends 
InputPartition
diff --git 
a/spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/DorisPartitionReader.scala
 
b/spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/DorisPartitionReader.scala
index 5b2be1a..42be1c5 100644
--- 
a/spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/DorisPartitionReader.scala
+++ 
b/spark-doris-connector/spark-doris-connector-spark-3-base/src/main/scala/org/apache/doris/spark/read/DorisPartitionReader.scala
@@ -34,7 +34,7 @@ class DorisPartitionReader(inputPartition: InputPartition, 
schema: StructType, m
   private implicit def toReaderPartition(inputPart: DorisInputPartition): 
DorisReaderPartition = {
     val tablets = inputPart.tablets.map(java.lang.Long.valueOf)
     new DorisReaderPartition(inputPart.database, inputPart.table, 
inputPart.backend, tablets,
-      inputPart.opaquedQueryPlan, inputPart.readCols, inputPart.predicates, 
config)
+      inputPart.opaquedQueryPlan, inputPart.readCols, inputPart.predicates, 
inputPart.limit, config)
   }
 
   private lazy val reader: DorisReader = {
diff --git 
a/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
 
b/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
index cc8ddd2..61a9e20 100644
--- 
a/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
+++ 
b/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
@@ -20,17 +20,20 @@ package org.apache.doris.spark.read
 import org.apache.doris.spark.config.{DorisConfig, DorisOptions}
 import org.apache.doris.spark.read.expression.V2ExpressionBuilder
 import org.apache.spark.sql.connector.expressions.filter.Predicate
-import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownV2Filters}
+import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownLimit, 
SupportsPushDownV2Filters}
 import org.apache.spark.sql.types.StructType
 
 class DorisScanBuilder(config: DorisConfig, schema: StructType) extends 
DorisScanBuilderBase(config, schema)
-  with SupportsPushDownV2Filters {
+  with SupportsPushDownV2Filters
+  with SupportsPushDownLimit {
 
   private var pushDownPredicates: Array[Predicate] = Array[Predicate]()
 
   private val expressionBuilder = new 
V2ExpressionBuilder(config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT))
 
-  override def build(): Scan = new DorisScanV2(config, schema, 
pushDownPredicates)
+  private var limitSize: Int = -1
+
+  override def build(): Scan = new DorisScanV2(config, schema, 
pushDownPredicates, limitSize)
 
   override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] 
= {
     val (pushed, unsupported) = predicates.partition(predicate => {
@@ -42,4 +45,9 @@ class DorisScanBuilder(config: DorisConfig, schema: 
StructType) extends DorisSca
 
   override def pushedPredicates(): Array[Predicate] = pushDownPredicates
 
+  override def pushLimit(i: Int): Boolean = {
+    limitSize = i
+    true
+  }
+
 }
diff --git 
a/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
 
b/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
index 634257a..503ad04 100644
--- 
a/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
+++ 
b/spark-doris-connector/spark-doris-connector-spark-3.3/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
@@ -23,10 +23,12 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.connector.expressions.filter.Predicate
 import org.apache.spark.sql.types.StructType
 
-class DorisScanV2(config: DorisConfig, schema: StructType, filters: 
Array[Predicate]) extends AbstractDorisScan(config, schema) with Logging {
+class DorisScanV2(config: DorisConfig, schema: StructType, filters: 
Array[Predicate], limit: Int) extends AbstractDorisScan(config, schema) with 
Logging {
   override protected def compiledFilters(): Array[String] = {
     val inValueLengthLimit = 
config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT)
     val v2ExpressionBuilder = new V2ExpressionBuilder(inValueLengthLimit)
     filters.map(e => 
Option[String](v2ExpressionBuilder.build(e))).filter(_.isDefined).map(_.get)
   }
+
+  override protected def getLimit: Int = limit
 }
diff --git 
a/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
 
b/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
index cc8ddd2..61a9e20 100644
--- 
a/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
+++ 
b/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
@@ -20,17 +20,20 @@ package org.apache.doris.spark.read
 import org.apache.doris.spark.config.{DorisConfig, DorisOptions}
 import org.apache.doris.spark.read.expression.V2ExpressionBuilder
 import org.apache.spark.sql.connector.expressions.filter.Predicate
-import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownV2Filters}
+import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownLimit, 
SupportsPushDownV2Filters}
 import org.apache.spark.sql.types.StructType
 
 class DorisScanBuilder(config: DorisConfig, schema: StructType) extends 
DorisScanBuilderBase(config, schema)
-  with SupportsPushDownV2Filters {
+  with SupportsPushDownV2Filters
+  with SupportsPushDownLimit {
 
   private var pushDownPredicates: Array[Predicate] = Array[Predicate]()
 
   private val expressionBuilder = new 
V2ExpressionBuilder(config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT))
 
-  override def build(): Scan = new DorisScanV2(config, schema, 
pushDownPredicates)
+  private var limitSize: Int = -1
+
+  override def build(): Scan = new DorisScanV2(config, schema, 
pushDownPredicates, limitSize)
 
   override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] 
= {
     val (pushed, unsupported) = predicates.partition(predicate => {
@@ -42,4 +45,9 @@ class DorisScanBuilder(config: DorisConfig, schema: 
StructType) extends DorisSca
 
   override def pushedPredicates(): Array[Predicate] = pushDownPredicates
 
+  override def pushLimit(i: Int): Boolean = {
+    limitSize = i
+    true
+  }
+
 }
diff --git 
a/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
 
b/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
index 634257a..503ad04 100644
--- 
a/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
+++ 
b/spark-doris-connector/spark-doris-connector-spark-3.4/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
@@ -23,10 +23,12 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.connector.expressions.filter.Predicate
 import org.apache.spark.sql.types.StructType
 
-class DorisScanV2(config: DorisConfig, schema: StructType, filters: 
Array[Predicate]) extends AbstractDorisScan(config, schema) with Logging {
+class DorisScanV2(config: DorisConfig, schema: StructType, filters: 
Array[Predicate], limit: Int) extends AbstractDorisScan(config, schema) with 
Logging {
   override protected def compiledFilters(): Array[String] = {
     val inValueLengthLimit = 
config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT)
     val v2ExpressionBuilder = new V2ExpressionBuilder(inValueLengthLimit)
     filters.map(e => 
Option[String](v2ExpressionBuilder.build(e))).filter(_.isDefined).map(_.get)
   }
+
+  override protected def getLimit: Int = limit
 }
diff --git 
a/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
 
b/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
index cc8ddd2..61a9e20 100644
--- 
a/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
+++ 
b/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanBuilder.scala
@@ -20,17 +20,20 @@ package org.apache.doris.spark.read
 import org.apache.doris.spark.config.{DorisConfig, DorisOptions}
 import org.apache.doris.spark.read.expression.V2ExpressionBuilder
 import org.apache.spark.sql.connector.expressions.filter.Predicate
-import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownV2Filters}
+import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownLimit, 
SupportsPushDownV2Filters}
 import org.apache.spark.sql.types.StructType
 
 class DorisScanBuilder(config: DorisConfig, schema: StructType) extends 
DorisScanBuilderBase(config, schema)
-  with SupportsPushDownV2Filters {
+  with SupportsPushDownV2Filters
+  with SupportsPushDownLimit {
 
   private var pushDownPredicates: Array[Predicate] = Array[Predicate]()
 
   private val expressionBuilder = new 
V2ExpressionBuilder(config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT))
 
-  override def build(): Scan = new DorisScanV2(config, schema, 
pushDownPredicates)
+  private var limitSize: Int = -1
+
+  override def build(): Scan = new DorisScanV2(config, schema, 
pushDownPredicates, limitSize)
 
   override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] 
= {
     val (pushed, unsupported) = predicates.partition(predicate => {
@@ -42,4 +45,9 @@ class DorisScanBuilder(config: DorisConfig, schema: 
StructType) extends DorisSca
 
   override def pushedPredicates(): Array[Predicate] = pushDownPredicates
 
+  override def pushLimit(i: Int): Boolean = {
+    limitSize = i
+    true
+  }
+
 }
diff --git 
a/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
 
b/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
index 634257a..503ad04 100644
--- 
a/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
+++ 
b/spark-doris-connector/spark-doris-connector-spark-3.5/src/main/scala/org/apache/doris/spark/read/DorisScanV2.scala
@@ -23,10 +23,12 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.connector.expressions.filter.Predicate
 import org.apache.spark.sql.types.StructType
 
-class DorisScanV2(config: DorisConfig, schema: StructType, filters: 
Array[Predicate]) extends AbstractDorisScan(config, schema) with Logging {
+class DorisScanV2(config: DorisConfig, schema: StructType, filters: 
Array[Predicate], limit: Int) extends AbstractDorisScan(config, schema) with 
Logging {
   override protected def compiledFilters(): Array[String] = {
     val inValueLengthLimit = 
config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT)
     val v2ExpressionBuilder = new V2ExpressionBuilder(inValueLengthLimit)
     filters.map(e => 
Option[String](v2ExpressionBuilder.build(e))).filter(_.isDefined).map(_.get)
   }
+
+  override protected def getLimit: Int = limit
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to