This is an automated email from the ASF dual-hosted git repository.

diwu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris-spark-connector.git


The following commit(s) were added to refs/heads/master by this push:
     new b1fbc6e  [feature] support two phase commit (#122)
b1fbc6e is described below

commit b1fbc6e5e314100214ef89f8bb11b972515bfc7c
Author: gnehil <adamlee...@gmail.com>
AuthorDate: Fri Jul 28 15:07:51 2023 +0800

    [feature] support two phase commit (#122)
---
 .../doris/spark/cfg/ConfigurationOptions.java      |   3 +
 .../java/org/apache/doris/spark/cfg/Settings.java  |  11 +
 .../apache/doris/spark/load/DorisStreamLoad.java   | 267 +++++++++++++++------
 .../doris/spark/rest/models/RespContent.java       |   4 +
 .../org/apache/doris/spark/util/ResponseUtil.java  |  33 +++
 .../spark/listener/DorisTransactionListener.scala  |  83 +++++++
 .../scala/org/apache/doris/spark/sql/Utils.scala   |   2 +-
 .../apache/doris/spark/writer/DorisWriter.scala    |  34 ++-
 8 files changed, 358 insertions(+), 79 deletions(-)

diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
index 61184f9..2ab200d 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/ConfigurationOptions.java
@@ -97,4 +97,7 @@ public interface ConfigurationOptions {
      */
     String DORIS_IGNORE_TYPE = "doris.ignore-type";
 
+    String DORIS_SINK_ENABLE_2PC = "doris.sink.enable-2pc";
+    boolean DORIS_SINK_ENABLE_2PC_DEFAULT = false;
+
 }
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java
index d2e845a..798ec8c 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/cfg/Settings.java
@@ -62,6 +62,17 @@ public abstract class Settings {
         return defaultValue;
     }
 
+    public Boolean getBooleanProperty(String name) {
+        return getBooleanProperty(name, null);
+    }
+
+    public Boolean getBooleanProperty(String name, Boolean defaultValue) {
+        if (getProperty(name) != null) {
+            return Boolean.valueOf(getProperty(name));
+        }
+        return defaultValue;
+    }
+
     public Settings merge(Properties properties) {
         if (properties == null || properties.isEmpty()) {
             return this;
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
index 5341f67..c40420d 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/load/DorisStreamLoad.java
@@ -23,8 +23,10 @@ import org.apache.doris.spark.rest.RestService;
 import org.apache.doris.spark.rest.models.BackendV2;
 import org.apache.doris.spark.rest.models.RespContent;
 import org.apache.doris.spark.util.ListUtils;
+import org.apache.doris.spark.util.ResponseUtil;
 
 import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.core.type.TypeReference;
 import com.fasterxml.jackson.databind.ObjectMapper;
 import com.google.common.cache.CacheBuilder;
 import com.google.common.cache.CacheLoader;
@@ -34,6 +36,7 @@ import org.apache.commons.lang3.StringUtils;
 import org.apache.http.HttpHeaders;
 import org.apache.http.HttpResponse;
 import org.apache.http.HttpStatus;
+import org.apache.http.client.methods.CloseableHttpResponse;
 import org.apache.http.client.methods.HttpPut;
 import org.apache.http.entity.BufferedHttpEntity;
 import org.apache.http.entity.StringEntity;
@@ -74,6 +77,9 @@ public class DorisStreamLoad implements Serializable {
 
     private final static List<String> DORIS_SUCCESS_STATUS = new 
ArrayList<>(Arrays.asList("Success", "Publish Timeout"));
     private static String loadUrlPattern = "http://%s/api/%s/%s/_stream_load?";;
+
+    private static String abortUrlPattern = 
"http://%s/api/%s/%s/_stream_load_2pc?";;
+
     private String user;
     private String passwd;
     private String loadUrlStr;
@@ -99,9 +105,7 @@ public class DorisStreamLoad implements Serializable {
         this.columns = 
settings.getProperty(ConfigurationOptions.DORIS_WRITE_FIELDS);
         this.maxFilterRatio = 
settings.getProperty(ConfigurationOptions.DORIS_MAX_FILTER_RATIO);
         this.streamLoadProp = getStreamLoadProp(settings);
-        cache = CacheBuilder.newBuilder()
-                .expireAfterWrite(cacheExpireTimeout, TimeUnit.MINUTES)
-                .build(new BackendCacheLoader(settings));
+        cache = CacheBuilder.newBuilder().expireAfterWrite(cacheExpireTimeout, 
TimeUnit.MINUTES).build(new BackendCacheLoader(settings));
         fileType = streamLoadProp.getOrDefault("format", "csv");
         if ("csv".equals(fileType)) {
             FIELD_DELIMITER = 
escapeString(streamLoadProp.getOrDefault("column_separator", "\t"));
@@ -121,13 +125,13 @@ public class DorisStreamLoad implements Serializable {
         }
         return loadUrlStr;
     }
+
     private CloseableHttpClient getHttpClient() {
-        HttpClientBuilder httpClientBuilder = HttpClientBuilder.create()
-                .disableRedirectHandling();
+        HttpClientBuilder httpClientBuilder = 
HttpClientBuilder.create().disableRedirectHandling();
         return httpClientBuilder.build();
     }
 
-    private HttpPut getHttpPut(String label, String loadUrlStr) {
+    private HttpPut getHttpPut(String label, String loadUrlStr, Boolean 
enable2PC) {
         HttpPut httpPut = new HttpPut(loadUrlStr);
         httpPut.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + authEncoded);
         httpPut.setHeader(HttpHeaders.EXPECT, "100-continue");
@@ -139,6 +143,9 @@ public class DorisStreamLoad implements Serializable {
         if (StringUtils.isNotBlank(maxFilterRatio)) {
             httpPut.setHeader("max_filter_ratio", maxFilterRatio);
         }
+        if (enable2PC) {
+            httpPut.setHeader("two_phase_commit", "true");
+        }
         if (MapUtils.isNotEmpty(streamLoadProp)) {
             streamLoadProp.forEach(httpPut::setHeader);
         }
@@ -158,100 +165,162 @@ public class DorisStreamLoad implements Serializable {
 
         @Override
         public String toString() {
-            return "status: " + status +
-                    ", resp msg: " + respMsg +
-                    ", resp content: " + respContent;
+            return "status: " + status + ", resp msg: " + respMsg + ", resp 
content: " + respContent;
         }
     }
 
-    public String listToString(List<List<Object>> rows) {
-        return rows.stream().map(row ->
-                row.stream().map(field -> field == null ? NULL_VALUE : 
field.toString())
-                        .collect(Collectors.joining(FIELD_DELIMITER))
-        ).collect(Collectors.joining(LINE_DELIMITER));
-    }
+    public List<Integer> loadV2(List<List<Object>> rows, String[] dfColumns, 
Boolean enable2PC) throws StreamLoadException, JsonProcessingException {
 
+        List<String> loadData = parseLoadData(rows, dfColumns);
+        List<Integer> txnIds = new ArrayList<>(loadData.size());
 
-    public void loadV2(List<List<Object>> rows, String[] dfColumns) throws 
StreamLoadException, JsonProcessingException {
-        if (fileType.equals("csv")) {
-            load(listToString(rows));
-        } else if(fileType.equals("json")) {
-            List<Map<Object, Object>> dataList = new ArrayList<>();
-            try {
-                for (List<Object> row : rows) {
-                    Map<Object, Object> dataMap = new HashMap<>();
-                    if (dfColumns.length == row.size()) {
-                        for (int i = 0; i < dfColumns.length; i++) {
-                            Object col = row.get(i);
-                            if (col instanceof Timestamp) {
-                                dataMap.put(dfColumns[i], col.toString());
-                                continue;
-                            }
-                            dataMap.put(dfColumns[i], col);
-                        }
-                    }
-                    dataList.add(dataMap);
-                }
-            } catch (Exception e) {
-                throw new StreamLoadException("The number of configured 
columns does not match the number of data columns.");
+        try {
+            for (String data : loadData) {
+                txnIds.add(load(data, enable2PC));
             }
-            // splits large collections to normal collection to avoid the 
"Requested array size exceeds VM limit" exception
-            List<String> serializedList = 
ListUtils.getSerializedList(dataList, readJsonByLine ? LINE_DELIMITER : null);
-            for (String serializedRows : serializedList) {
-                load(serializedRows);
+        } catch (StreamLoadException e) {
+            if (enable2PC && !txnIds.isEmpty()) {
+                LOG.error("load batch failed, abort previously pre-committed 
transactions");
+                for (Integer txnId : txnIds) {
+                    abort(txnId);
+                }
             }
-        } else {
-            throw new StreamLoadException(String.format("Unsupported file 
format in stream load: %s.", fileType));
+            throw e;
         }
+
+        return txnIds;
+
     }
 
-    public void load(String value) throws StreamLoadException {
-        LoadResponse loadResponse = loadBatch(value);
+    public int load(String value, Boolean enable2PC) throws 
StreamLoadException {
+
+        String label = generateLoadLabel();
+
+        LoadResponse loadResponse;
+        int responseHttpStatus = -1;
+        try (CloseableHttpClient httpClient = getHttpClient()) {
+            String loadUrlStr = String.format(loadUrlPattern, getBackend(), 
db, tbl);
+            LOG.debug("Stream load Request:{} ,Body:{}", loadUrlStr, value);
+            // only to record the BE node in case of an exception
+            this.loadUrlStr = loadUrlStr;
+
+            HttpPut httpPut = getHttpPut(label, loadUrlStr, enable2PC);
+            httpPut.setEntity(new StringEntity(value, StandardCharsets.UTF_8));
+            HttpResponse httpResponse = httpClient.execute(httpPut);
+            responseHttpStatus = httpResponse.getStatusLine().getStatusCode();
+            String respMsg = httpResponse.getStatusLine().getReasonPhrase();
+            String response = EntityUtils.toString(new 
BufferedHttpEntity(httpResponse.getEntity()), StandardCharsets.UTF_8);
+            loadResponse = new LoadResponse(responseHttpStatus, respMsg, 
response);
+        } catch (IOException e) {
+            e.printStackTrace();
+            String err = "http request exception,load url : " + loadUrlStr + 
",failed to execute spark stream load with label: " + label;
+            LOG.warn(err, e);
+            loadResponse = new LoadResponse(responseHttpStatus, 
e.getMessage(), err);
+        }
+
         if (loadResponse.status != HttpStatus.SC_OK) {
-            LOG.info("Streamload Response HTTP Status Error:{}", loadResponse);
-            throw new StreamLoadException("stream load error: " + 
loadResponse.respContent);
+            LOG.info("Stream load Response HTTP Status Error:{}", 
loadResponse);
+            // throw new StreamLoadException("stream load error: " + 
loadResponse.respContent);
+            throw new StreamLoadException("stream load error");
         } else {
             ObjectMapper obj = new ObjectMapper();
             try {
                 RespContent respContent = 
obj.readValue(loadResponse.respContent, RespContent.class);
                 if (!DORIS_SUCCESS_STATUS.contains(respContent.getStatus())) {
-                    LOG.error("Streamload Response RES STATUS Error:{}", 
loadResponse);
-                    throw new StreamLoadException("stream load error: " + 
loadResponse);
+                    LOG.error("Stream load Response RES STATUS Error:{}", 
loadResponse);
+                    throw new StreamLoadException("stream load error");
                 }
-                LOG.info("Streamload Response:{}", loadResponse);
+                LOG.info("Stream load Response:{}", loadResponse);
+                return respContent.getTxnId();
             } catch (IOException e) {
                 throw new StreamLoadException(e);
             }
         }
+
     }
 
-    private LoadResponse loadBatch(String value) {
-        Calendar calendar = Calendar.getInstance();
-        String label = 
String.format("spark_streamload_%s%02d%02d_%02d%02d%02d_%s",
-                calendar.get(Calendar.YEAR), calendar.get(Calendar.MONTH) + 1, 
calendar.get(Calendar.DAY_OF_MONTH),
-                calendar.get(Calendar.HOUR_OF_DAY), 
calendar.get(Calendar.MINUTE), calendar.get(Calendar.SECOND),
-                UUID.randomUUID().toString().replaceAll("-", ""));
+    public void commit(int txnId) throws StreamLoadException {
 
-        int responseHttpStatus = -1;
-        try (CloseableHttpClient httpClient = getHttpClient()) {
-            String loadUrlStr = String.format(loadUrlPattern, getBackend(), 
db, tbl);
-            LOG.debug("Streamload Request:{} ,Body:{}", loadUrlStr, value);
-            //only to record the BE node in case of an exception
-            this.loadUrlStr = loadUrlStr;
+        try (CloseableHttpClient client = getHttpClient()) {
+
+            String backend = getBackend();
+            String abortUrl = String.format(abortUrlPattern, backend, db, tbl);
+            HttpPut httpPut = new HttpPut(abortUrl);
+            httpPut.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + 
authEncoded);
+            httpPut.setHeader(HttpHeaders.EXPECT, "100-continue");
+            httpPut.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain; 
charset=UTF-8");
+            httpPut.setHeader("txn_operation", "commit");
+            httpPut.setHeader("txn_id", String.valueOf(txnId));
+
+            CloseableHttpResponse response = client.execute(httpPut);
+            int statusCode = response.getStatusLine().getStatusCode();
+            if (statusCode != 200 || response.getEntity() == null) {
+                LOG.warn("commit transaction response: " + 
response.getStatusLine().toString());
+                throw new StreamLoadException("Fail to commit transaction " + 
txnId + " with url " + abortUrl);
+            }
+
+            statusCode = response.getStatusLine().getStatusCode();
+            String reasonPhrase = response.getStatusLine().getReasonPhrase();
+            if (statusCode != 200) {
+                LOG.warn("commit failed with {}, reason {}", backend, 
reasonPhrase);
+                throw new StreamLoadException("stream load error: " + 
reasonPhrase);
+            }
+
+            ObjectMapper mapper = new ObjectMapper();
+            if (response.getEntity() != null) {
+                String loadResult = EntityUtils.toString(response.getEntity());
+                Map<String, String> res = mapper.readValue(loadResult, new 
TypeReference<HashMap<String, String>>() {
+                });
+                if (res.get("status").equals("Fail") && 
!ResponseUtil.isCommitted(res.get("msg"))) {
+                    throw new StreamLoadException("Commit failed " + 
loadResult);
+                } else {
+                    LOG.info("load result {}", loadResult);
+                }
+            }
 
-            HttpPut httpPut = getHttpPut(label, loadUrlStr);
-            httpPut.setEntity(new StringEntity(value, StandardCharsets.UTF_8));
-            HttpResponse httpResponse = httpClient.execute(httpPut);
-            responseHttpStatus = httpResponse.getStatusLine().getStatusCode();
-            String respMsg = httpResponse.getStatusLine().getReasonPhrase();
-            String response = EntityUtils.toString(new 
BufferedHttpEntity(httpResponse.getEntity()), StandardCharsets.UTF_8);
-            return new LoadResponse(responseHttpStatus, respMsg, response);
         } catch (IOException e) {
-            e.printStackTrace();
-            String err = "http request exception,load url : " + loadUrlStr + 
",failed to execute spark streamload with label: " + label;
-            LOG.warn(err, e);
-            return new LoadResponse(responseHttpStatus, e.getMessage(), err);
+            throw new StreamLoadException(e);
         }
+
+    }
+
+    public void abort(int txnId) throws StreamLoadException {
+
+        LOG.info("start abort transaction {}.", txnId);
+
+        try (CloseableHttpClient client = getHttpClient()) {
+            String abortUrl = String.format(abortUrlPattern, getBackend(), db, 
tbl);
+            HttpPut httpPut = new HttpPut(abortUrl);
+            httpPut.setHeader(HttpHeaders.AUTHORIZATION, "Basic " + 
authEncoded);
+            httpPut.setHeader(HttpHeaders.EXPECT, "100-continue");
+            httpPut.setHeader(HttpHeaders.CONTENT_TYPE, "text/plain; 
charset=UTF-8");
+            httpPut.setHeader("txn_operation", "abort");
+            httpPut.setHeader("txn_id", String.valueOf(txnId));
+
+            CloseableHttpResponse response = client.execute(httpPut);
+            int statusCode = response.getStatusLine().getStatusCode();
+            if (statusCode != 200 || response.getEntity() == null) {
+                LOG.warn("abort transaction response: " + 
response.getStatusLine().toString());
+                throw new StreamLoadException("Fail to abort transaction " + 
txnId + " with url " + abortUrl);
+            }
+
+            ObjectMapper mapper = new ObjectMapper();
+            String loadResult = EntityUtils.toString(response.getEntity());
+            Map<String, String> res = mapper.readValue(loadResult, new 
TypeReference<HashMap<String, String>>() {
+            });
+            if (!"Success".equals(res.get("status"))) {
+                if (ResponseUtil.isCommitted(res.get("msg"))) {
+                    throw new StreamLoadException("try abort committed 
transaction, " + "do you recover from old savepoint?");
+                }
+                LOG.warn("Fail to abort transaction. txnId: {}, error: {}", 
txnId, res.get("msg"));
+            }
+
+        } catch (IOException e) {
+            throw new StreamLoadException(e);
+        }
+
+        LOG.info("abort transaction {} succeed.", txnId);
+
     }
 
     public Map<String, String> getStreamLoadProp(SparkSettings sparkSettings) {
@@ -268,7 +337,7 @@ public class DorisStreamLoad implements Serializable {
 
     private String getBackend() {
         try {
-            //get backends from cache
+            // get backends from cache
             List<BackendV2.BackendRowV2> backends = cache.get("backends");
             Collections.shuffle(backends);
             BackendV2.BackendRowV2 backend = backends.get(0);
@@ -301,6 +370,54 @@ public class DorisStreamLoad implements Serializable {
 
     }
 
+    private List<String> parseLoadData(List<List<Object>> rows, String[] 
dfColumns) throws StreamLoadException, JsonProcessingException {
+
+        List<String> loadDataList;
+
+        switch (fileType.toUpperCase()) {
+
+            case "CSV":
+                loadDataList = Collections.singletonList(rows.stream().map(row 
-> row.stream().map(field -> field == null ? NULL_VALUE : 
field.toString()).collect(Collectors.joining(FIELD_DELIMITER))).collect(Collectors.joining(LINE_DELIMITER)));
+                break;
+            case "JSON":
+                List<Map<Object, Object>> dataList = new ArrayList<>();
+                try {
+                    for (List<Object> row : rows) {
+                        Map<Object, Object> dataMap = new HashMap<>();
+                        if (dfColumns.length == row.size()) {
+                            for (int i = 0; i < dfColumns.length; i++) {
+                                Object col = row.get(i);
+                                if (col instanceof Timestamp) {
+                                    dataMap.put(dfColumns[i], col.toString());
+                                    continue;
+                                }
+                                dataMap.put(dfColumns[i], col);
+                            }
+                        }
+                        dataList.add(dataMap);
+                    }
+                } catch (Exception e) {
+                    throw new StreamLoadException("The number of configured 
columns does not match the number of data columns.");
+                }
+                // splits large collections to normal collection to avoid the 
"Requested array size exceeds VM limit" exception
+                loadDataList = ListUtils.getSerializedList(dataList, 
readJsonByLine ? LINE_DELIMITER : null);
+                break;
+            default:
+                throw new StreamLoadException(String.format("Unsupported file 
format in stream load: %s.", fileType));
+
+        }
+
+        return loadDataList;
+
+    }
+
+    private String generateLoadLabel() {
+
+        Calendar calendar = Calendar.getInstance();
+        return String.format("spark_streamload_%s%02d%02d_%02d%02d%02d_%s", 
calendar.get(Calendar.YEAR), calendar.get(Calendar.MONTH) + 1, 
calendar.get(Calendar.DAY_OF_MONTH), calendar.get(Calendar.HOUR_OF_DAY), 
calendar.get(Calendar.MINUTE), calendar.get(Calendar.SECOND), 
UUID.randomUUID().toString().replaceAll("-", ""));
+
+    }
+
     private String escapeString(String hexData) {
         if (hexData.startsWith("\\x") || hexData.startsWith("\\X")) {
             try {
@@ -314,7 +431,7 @@ public class DorisStreamLoad implements Serializable {
                 }
                 return stringBuilder.toString();
             } catch (Exception e) {
-                throw new RuntimeException("escape column_separator or 
line_delimiter error.{}" , e);
+                throw new RuntimeException("escape column_separator or 
line_delimiter error.{}", e);
             }
         }
         return hexData;
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/models/RespContent.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/models/RespContent.java
index f7fa6ff..7829cc2 100644
--- 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/models/RespContent.java
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/rest/models/RespContent.java
@@ -75,6 +75,10 @@ public class RespContent {
     @JsonProperty(value = "ErrorURL")
     private String ErrorURL;
 
+    public int getTxnId() {
+        return TxnId;
+    }
+
     public String getStatus() {
         return Status;
     }
diff --git 
a/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ResponseUtil.java
 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ResponseUtil.java
new file mode 100644
index 0000000..1b6a66b
--- /dev/null
+++ 
b/spark-doris-connector/src/main/java/org/apache/doris/spark/util/ResponseUtil.java
@@ -0,0 +1,33 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.util;
+
+import java.util.regex.Pattern;
+
+public class ResponseUtil {
+    public static final Pattern LABEL_EXIST_PATTERN =
+            Pattern.compile("errCode = 2, detailMessage = Label \\[(.*)\\] " +
+                    "has already been used, relate to txn \\[(\\d+)\\]");
+    public static final Pattern COMMITTED_PATTERN =
+            Pattern.compile("errCode = 2, detailMessage = transaction 
\\[(\\d+)\\] " +
+                    "is already \\b(COMMITTED|committed|VISIBLE|visible)\\b, 
not pre-committed.");
+
+    public static boolean isCommitted(String msg) {
+       return COMMITTED_PATTERN.matcher(msg).matches();
+    }
+}
\ No newline at end of file
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala
new file mode 100644
index 0000000..a36e634
--- /dev/null
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/listener/DorisTransactionListener.scala
@@ -0,0 +1,83 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.spark.listener
+
+import org.apache.doris.spark.load.DorisStreamLoad
+import org.apache.doris.spark.sql.Utils
+import org.apache.spark.scheduler._
+import org.apache.spark.util.CollectionAccumulator
+import org.slf4j.{Logger, LoggerFactory}
+
+import java.time.Duration
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+import scala.util.{Failure, Success}
+
+class DorisTransactionListener(preCommittedTxnAcc: CollectionAccumulator[Int], 
dorisStreamLoad: DorisStreamLoad)
+  extends SparkListener {
+
+  val logger: Logger = 
LoggerFactory.getLogger(classOf[DorisTransactionListener])
+
+  override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
+    val txnIds: mutable.Buffer[Int] = preCommittedTxnAcc.value.asScala
+    val failedTxnIds = mutable.Buffer[Int]()
+    jobEnd.jobResult match {
+      // if job succeed, commit all transactions
+      case JobSucceeded =>
+        if (txnIds.isEmpty) {
+          logger.warn("job run succeed, but there is no pre-committed txn ids")
+          return
+        }
+        logger.info("job run succeed, start committing transactions")
+        txnIds.foreach(txnId =>
+          Utils.retry(3, Duration.ofSeconds(1), logger) {
+            dorisStreamLoad.commit(txnId)
+          } match {
+            case Success(_) =>
+            case Failure(_) => failedTxnIds += txnId
+          }
+        )
+
+        if (failedTxnIds.nonEmpty) {
+          logger.error("uncommitted txn ids: {}", failedTxnIds.mkString(","))
+        } else {
+          logger.info("commit transaction success")
+        }
+      // if job failed, abort all pre committed transactions
+      case _ =>
+        if (txnIds.isEmpty) {
+          logger.warn("job run failed, but there is no pre-committed txn ids")
+          return
+        }
+        logger.info("job run failed, start aborting transactions")
+        txnIds.foreach(txnId =>
+          Utils.retry(3, Duration.ofSeconds(1), logger) {
+            dorisStreamLoad.abort(txnId)
+          } match {
+            case Success(_) =>
+            case Failure(_) => failedTxnIds += txnId
+          })
+        if (failedTxnIds.nonEmpty) {
+          logger.error("not aborted txn ids: {}", failedTxnIds.mkString(","))
+        } else {
+          logger.info("abort transaction success")
+        }
+    }
+  }
+
+}
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala
index 2f3a5bb..2b9c3c1 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/sql/Utils.scala
@@ -174,7 +174,7 @@ private[spark] object Utils {
         Success(result)
       case Failure(exception: T) if retryTimes > 0 =>
         logger.warn(s"Execution failed caused by: ", exception)
-        logger.warn(s"$retryTimes times retry remaining, the next will be in 
${interval.toMillis}ms")
+        logger.warn(s"$retryTimes times retry remaining, the next attempt will 
be in ${interval.toMillis} ms")
         LockSupport.parkNanos(interval.toNanos)
         retry(retryTimes - 1, interval, logger)(f)
       case Failure(exception) => Failure(exception)
diff --git 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
index 3839ff7..2b918e8 100644
--- 
a/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
+++ 
b/spark-doris-connector/src/main/scala/org/apache/doris/spark/writer/DorisWriter.scala
@@ -18,6 +18,7 @@
 package org.apache.doris.spark.writer
 
 import org.apache.doris.spark.cfg.{ConfigurationOptions, SparkSettings}
+import org.apache.doris.spark.listener.DorisTransactionListener
 import org.apache.doris.spark.load.{CachedDorisStreamLoadClient, 
DorisStreamLoad}
 import org.apache.doris.spark.sql.Utils
 import org.apache.spark.sql.DataFrame
@@ -28,6 +29,7 @@ import java.time.Duration
 import java.util
 import java.util.Objects
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 import scala.util.{Failure, Success}
 
 class DorisWriter(settings: SparkSettings) extends Serializable {
@@ -44,9 +46,19 @@ class DorisWriter(settings: SparkSettings) extends 
Serializable {
   private val batchInterValMs: Integer = 
settings.getIntegerProperty(ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS,
     ConfigurationOptions.DORIS_SINK_BATCH_INTERVAL_MS_DEFAULT)
 
+  private val enable2PC: Boolean = 
settings.getBooleanProperty(ConfigurationOptions.DORIS_SINK_ENABLE_2PC,
+    ConfigurationOptions.DORIS_SINK_ENABLE_2PC_DEFAULT);
+
   private val dorisStreamLoader: DorisStreamLoad = 
CachedDorisStreamLoadClient.getOrCreate(settings)
 
   def write(dataFrame: DataFrame): Unit = {
+
+    val sc = dataFrame.sqlContext.sparkContext
+    val preCommittedTxnAcc = 
sc.collectionAccumulator[Int]("preCommittedTxnAcc")
+    if (enable2PC) {
+      sc.addSparkListener(new DorisTransactionListener(preCommittedTxnAcc, 
dorisStreamLoader))
+    }
+
     var resultRdd = dataFrame.rdd
     val dfColumns = dataFrame.columns
     if (Objects.nonNull(sinkTaskPartitionSize)) {
@@ -65,11 +77,27 @@ class DorisWriter(settings: SparkSettings) extends 
Serializable {
      *
      */
     def flush(batch: Iterable[util.List[Object]], dfColumns: Array[String]): 
Unit = {
-      Utils.retry[Unit, Exception](maxRetryTimes, 
Duration.ofMillis(batchInterValMs.toLong), logger) {
-        dorisStreamLoader.loadV2(batch.toList.asJava, dfColumns)
+      Utils.retry[util.List[Integer], Exception](maxRetryTimes, 
Duration.ofMillis(batchInterValMs.toLong), logger) {
+        dorisStreamLoader.loadV2(batch.toList.asJava, dfColumns, enable2PC)
       } match {
-        case Success(_) =>
+        case Success(txnIds) => if (enable2PC) txnIds.asScala.foreach(txnId => 
preCommittedTxnAcc.add(txnId))
         case Failure(e) =>
+          if (enable2PC) {
+            // if task run failed, acc value will not be returned to driver,
+            // should abort all pre committed transactions inside the task
+            logger.info("load task failed, start aborting previously 
pre-committed transactions")
+            val abortFailedTxnIds = mutable.Buffer[Int]()
+            preCommittedTxnAcc.value.asScala.foreach(txnId => {
+              Utils.retry[Unit, Exception](3, Duration.ofSeconds(1), logger) {
+                dorisStreamLoader.abort(txnId)
+              } match {
+                case Success(_) =>
+                case Failure(_) => abortFailedTxnIds += txnId
+              }
+            })
+            if (abortFailedTxnIds.nonEmpty) logger.warn("not aborted txn ids: 
{}", abortFailedTxnIds.mkString(","))
+            preCommittedTxnAcc.reset()
+          }
           throw new IOException(
             s"Failed to load batch data on BE: 
${dorisStreamLoader.getLoadUrlStr} node and exceeded the max ${maxRetryTimes} 
retry times.", e)
       }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to