This is an automated email from the ASF dual-hosted git repository.

diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git

The following commit(s) were added to refs/heads/master by this push:
     new 11f4976  [fix] streaming write execution plan error (#135)
11f4976 is described below

commit 11f4976be1aae4c041ea66e3e83487aa2614c947
Author: gnehil <adamlee...@gmail.com>
AuthorDate: Tue Sep 5 13:56:25 2023 +0800

    [fix] streaming write execution plan error (#135)
    
    * fix streaming write error and add json data pass through option
    * handle stream pass through, force set read_json_by_line is true when 
format is json
---
 .../doris/spark/cfg/ConfigurationOptions.java      |  6 ++
 .../apache/doris/spark/load/DorisStreamLoad.java   | 50 +++++++++++
 .../doris/spark/sql/DorisStreamLoadSink.scala      |  2 +-
 .../apache/doris/spark/writer/DorisWriter.scala    | 97 +++++++++++++++++-----
 4 files changed, 135 insertions(+), 20 deletions(-)

diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
index 2ab200d..09c0416 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
@@ -100,4 +100,10 @@ public interface ConfigurationOptions {
     String DORIS_SINK_ENABLE_2PC = "doris.sink.enable-2pc";
     boolean DORIS_SINK_ENABLE_2PC_DEFAULT = false;
 
+    /**
+     * pass through json data when sink to doris in streaming mode
+     */
+    String DORIS_SINK_STREAMING_PASSTHROUGH = 
"doris.sink.streaming.passthrough";
+    boolean DORIS_SINK_STREAMING_PASSTHROUGH_DEFAULT = false;
+
 }
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
index 4a7b1e0..ac920cd 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
@@ -96,6 +96,8 @@ public class DorisStreamLoad implements Serializable {
 
     private boolean readJsonByLine = false;
 
+    private boolean streamingPassthrough = false;
+
     public DorisStreamLoad(SparkSettings settings) {
         String[] dbTable = 
settings.getProperty(ConfigurationOptions.DORIS_TABLE_IDENTIFIER).split("\\.");
         this.db = dbTable[0];
@@ -121,6 +123,8 @@ public class DorisStreamLoad implements Serializable {
             }
         }
         LINE_DELIMITER = 
escapeString(streamLoadProp.getOrDefault("line_delimiter", "\n"));
+        this.streamingPassthrough = 
settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH,
+                ConfigurationOptions.DORIS_SINK_STREAMING_PASSTHROUGH_DEFAULT);
     }
 
     public String getLoadUrlStr() {
@@ -196,6 +200,38 @@ public class DorisStreamLoad implements Serializable {
 
     }
 
+    public List<Integer> loadStream(List<List<Object>> rows, String[] 
dfColumns, Boolean enable2PC)
+            throws StreamLoadException, JsonProcessingException {
+
+        List<String> loadData;
+
+        if (this.streamingPassthrough) {
+            handleStreamPassThrough();
+            loadData = passthrough(rows);
+        } else {
+            loadData = parseLoadData(rows, dfColumns);
+        }
+
+        List<Integer> txnIds = new ArrayList<>(loadData.size());
+
+        try {
+            for (String data : loadData) {
+                txnIds.add(load(data, enable2PC));
+            }
+        } catch (StreamLoadException e) {
+            if (enable2PC && !txnIds.isEmpty()) {
+                LOG.error("load batch failed, abort previously pre-committed 
transactions");
+                for (Integer txnId : txnIds) {
+                    abort(txnId);
+                }
+            }
+            throw e;
+        }
+
+        return txnIds;
+
+    }
+
     public int load(String value, Boolean enable2PC) throws 
StreamLoadException {
 
         String label = generateLoadLabel();
@@ -442,4 +478,18 @@ public class DorisStreamLoad implements Serializable {
         return hexData;
     }
 
+    private void handleStreamPassThrough() {
+
+        if ("json".equalsIgnoreCase(fileType)) {
+            LOG.info("handle stream pass through, force set read_json_by_line 
is true for json format");
+            streamLoadProp.put("read_json_by_line", "true");
+            streamLoadProp.remove("strip_outer_array");
+        }
+
+    }
+
+    private List<String> passthrough(List<List<Object>> values) {
+        return values.stream().map(list -> 
list.get(0).toString()).collect(Collectors.toList());
+    }
+
 }
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala
index 342e940..d1a2b74 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/DorisStreamLoadSink.scala
@@ -34,7 +34,7 @@ private[sql] class DorisStreamLoadSink(sqlContext: 
SQLContext, settings: SparkSe
     if (batchId <= latestBatchId) {
       logger.info(s"Skipping already committed batch $batchId")
     } else {
-      writer.write(data)
+      writer.writeStream(data)
       latestBatchId = batchId
     }
   }
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
index 2b918e8..e32267e 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
@@ -22,6 +22,9 @@ import 
org.apache.doris.spark.listener.DorisTransactionListener
 import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, 
DorisStreamLoad}
 import org.apache.doris.spark.sql.Utils
 import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.CollectionAccumulator
 import org.slf4j.{Logger, LoggerFactory}
 
 import java.io.IOException
@@ -76,28 +79,13 @@ class DorisWriter(settings: SparkSettings) extends 
Serializable {
      * flush data to Doris and do retry when flush error
      *
      */
-    def flush(batch: Iterable[util.List[Object]], dfColumns: Array[String]): 
Unit = {
+    def flush(batch: Seq[util.List[Object]], dfColumns: Array[String]): Unit = 
{
       Utils.retry[util.List[Integer], Exception](maxRetryTimes, 
Duration.ofMillis(batchInterValMs.toLong), logger) {
-        dorisStreamLoader.loadV2(batch.toList.asJava, dfColumns, enable2PC)
+        dorisStreamLoader.loadV2(batch.asJava, dfColumns, enable2PC)
       } match {
-        case Success(txnIds) => if (enable2PC) txnIds.asScala.foreach(txnId => 
preCommittedTxnAcc.add(txnId))
+        case Success(txnIds) => if (enable2PC) 
handleLoadSuccess(txnIds.asScala, preCommittedTxnAcc)
         case Failure(e) =>
-          if (enable2PC) {
-            // if task run failed, acc value will not be returned to driver,
-            // should abort all pre committed transactions inside the task
-            logger.info("load task failed, start aborting previously 
pre-committed transactions")
-            val abortFailedTxnIds = mutable.Buffer[Int]()
-            preCommittedTxnAcc.value.asScala.foreach(txnId => {
-              Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) {
-                dorisStreamLoader.abort(txnId)
-              } match {
-                case Success(_) =>
-                case Failure(_) => abortFailedTxnIds += txnId
-              }
-            })
-            if (abortFailedTxnIds.nonEmpty) logger.warn("not aborted txn ids: 
{}", abortFailedTxnIds.mkString(","))
-            preCommittedTxnAcc.reset()
-          }
+          if (enable2PC) handleLoadFailure(preCommittedTxnAcc)
           throw new IOException(
             s"Failed to load batch data on BE: 
${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} 
retry times.", e)
       }
@@ -105,5 +93,76 @@ class DorisWriter(settings: SparkSettings) extends 
Serializable {
 
   }
 
+  def writeStream(dataFrame: DataFrame): Unit = {
+
+    val sc = dataFrame.sqlContext.sparkContext
+    val preCommittedTxnAcc = 
sc.collectionAccumulator[Int]("preCommittedTxnAcc")
+    if (enable2PC) {
+      sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, 
dorisStreamLoader))
+    }
+
+    var resultRdd = dataFrame.queryExecution.toRdd
+    val schema = dataFrame.schema
+    val dfColumns = dataFrame.columns
+    if (Objects.nonNull(sinkTaskPartitionSize)) {
+      resultRdd = if (sinkTaskUseRepartition) 
resultRdd.repartition(sinkTaskPartitionSize) else 
resultRdd.coalesce(sinkTaskPartitionSize)
+    }
+    resultRdd
+      .foreachPartition(partition => {
+        partition
+          .grouped(batchSize)
+          .foreach(batch =>
+            flush(batch, dfColumns))
+      })
+
+    /**
+     * flush data to Doris and do retry when flush error
+     *
+     */
+    def flush(batch: Seq[InternalRow], dfColumns: Array[String]): Unit = {
+      Utils.retry[util.List[Integer], Exception](maxRetryTimes, 
Duration.ofMillis(batchInterValMs.toLong), logger) {
+        dorisStreamLoader.loadStream(convertToObjectList(batch, schema), 
dfColumns, enable2PC)
+      } match {
+        case Success(txnIds) => if (enable2PC) 
handleLoadSuccess(txnIds.asScala, preCommittedTxnAcc)
+        case Failure(e) =>
+          if (enable2PC) handleLoadFailure(preCommittedTxnAcc)
+          throw new IOException(
+            s"Failed to load batch data on BE: 
${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} 
retry times.", e)
+      }
+    }
+
+    def convertToObjectList(rows: Seq[InternalRow], schema: StructType): 
util.List[util.List[Object]] = {
+      rows.map(row => {
+        row.toSeq(schema).map(_.asInstanceOf[AnyRef]).toList.asJava
+      }).asJava
+    }
+
+  }
+
+  private def handleLoadSuccess(txnIds: mutable.Buffer[Integer], acc: 
CollectionAccumulator[Int]): Unit = {
+    txnIds.foreach(txnId => acc.add(txnId))
+  }
+
+  def handleLoadFailure(acc: CollectionAccumulator[Int]): Unit = {
+    // if task run failed, acc value will not be returned to driver,
+    // should abort all pre committed transactions inside the task
+    logger.info("load task failed, start aborting previously pre-committed 
transactions")
+    if (acc.isZero) {
+      logger.info("no pre-committed transactions, skip abort")
+      return
+    }
+    val abortFailedTxnIds = mutable.Buffer[Int]()
+    acc.value.asScala.foreach(txnId => {
+      Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) {
+        dorisStreamLoader.abort(txnId)
+      } match {
+        case Success(_) =>
+        case Failure(_) => abortFailedTxnIds += txnId
+      }
+    })
+    if (abortFailedTxnIds.nonEmpty) logger.warn("not aborted txn ids: {}", 
abortFailedTxnIds.mkString(","))
+    acc.reset()
+  }
+
 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to