This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch master
in repository 
https://gitbox.apache.org/repos/asf/incubator-doris-spark-connector.git

commit e73f7964089f5681a1644989ddedfb0bbcdd3b50
Author: huzk <1040080...@qq.com>
AuthorDate: Mon Aug 16 22:40:43 2021 +0800

    [Feature] Support spark connector sink data to Doris (#6256)
    
    support spark conector write dataframe to doris
---
 .../org/apache/doris/spark/DorisStreamLoad.java    | 179 +++++++++++++++++++++
 .../spark/exception/StreamLoadException.java}      |  30 ++--
 .../org/apache/doris/spark/rest/RestService.java   |  75 +++++++--
 .../apache/doris/spark/rest/models/Backend.java}   |  27 ++--
 .../apache/doris/spark/rest/models/BackendRow.java |  64 ++++++++
 .../doris/spark/rest/models/RespContent.java       |  96 +++++++++++
 .../doris/spark/sql/DorisSourceProvider.scala      | 119 +++++++++++++-
 .../apache/doris/spark/sql/DorisWriterOption.scala |  41 +++++
 ...eProvider.scala => DorisWriterOptionKeys.scala} |  18 +--
 .../doris/spark/sql/DataframeSinkDoris.scala}      |  32 +++-
 10 files changed, 631 insertions(+), 50 deletions(-)

diff --git a/src/main/java/org/apache/doris/spark/DorisStreamLoad.java 
b/src/main/java/org/apache/doris/spark/DorisStreamLoad.java
new file mode 100644
index 0000000..0de3746
--- /dev/null
+++ b/src/main/java/org/apache/doris/spark/DorisStreamLoad.java
@@ -0,0 +1,179 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+package org.apache.doris.spark;
+
+import com.fasterxml.jackson.databind.ObjectMapper;
+import org.apache.doris.spark.exception.StreamLoadException;
+import org.apache.doris.spark.rest.models.RespContent;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.BufferedOutputStream;
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.io.Serializable;
+import java.net.HttpURLConnection;
+import java.net.URL;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Base64;
+import java.util.Calendar;
+import java.util.List;
+import java.util.UUID;
+
+/**
+ * DorisStreamLoad
+ **/
+public class DorisStreamLoad implements Serializable{
+
+    private static final Logger LOG = 
LoggerFactory.getLogger(DorisStreamLoad.class);
+
+    private final static List<String> DORIS_SUCCESS_STATUS = new 
ArrayList<>(Arrays.asList("Success", "Publish Timeout"));
+    private static String loadUrlPattern = "http://%s/api/%s/%s/_stream_load?";;
+    private String user;
+    private String passwd;
+    private String loadUrlStr;
+    private String hostPort;
+    private String db;
+    private String tbl;
+    private String authEncoding;
+
+    public DorisStreamLoad(String hostPort, String db, String tbl, String 
user, String passwd) {
+        this.hostPort = hostPort;
+        this.db = db;
+        this.tbl = tbl;
+        this.user = user;
+        this.passwd = passwd;
+        this.loadUrlStr = String.format(loadUrlPattern, hostPort, db, tbl);
+        this.authEncoding = 
Base64.getEncoder().encodeToString(String.format("%s:%s", user, 
passwd).getBytes(StandardCharsets.UTF_8));
+    }
+
+    public String getLoadUrlStr() {
+        return loadUrlStr;
+    }
+
+    public String getHostPort() {
+        return hostPort;
+    }
+
+    public void setHostPort(String hostPort) {
+        this.hostPort = hostPort;
+        this.loadUrlStr = String.format(loadUrlPattern, hostPort, this.db, 
this.tbl);
+    }
+
+
+    private HttpURLConnection getConnection(String urlStr, String label) 
throws IOException {
+        URL url = new URL(urlStr);
+        HttpURLConnection conn = (HttpURLConnection) url.openConnection();
+        conn.setInstanceFollowRedirects(false);
+        conn.setRequestMethod("PUT");
+        String authEncoding = 
Base64.getEncoder().encodeToString(String.format("%s:%s", user, 
passwd).getBytes(StandardCharsets.UTF_8));
+        conn.setRequestProperty("Authorization", "Basic " + authEncoding);
+        conn.addRequestProperty("Expect", "100-continue");
+        conn.addRequestProperty("Content-Type", "text/plain; charset=UTF-8");
+        conn.addRequestProperty("label", label);
+        conn.setDoOutput(true);
+        conn.setDoInput(true);
+        return conn;
+    }
+
+    public static class LoadResponse {
+        public int status;
+        public String respMsg;
+        public String respContent;
+
+        public LoadResponse(int status, String respMsg, String respContent) {
+            this.status = status;
+            this.respMsg = respMsg;
+            this.respContent = respContent;
+        }
+        @Override
+        public String toString() {
+            StringBuilder sb = new StringBuilder();
+            sb.append("status: ").append(status);
+            sb.append(", resp msg: ").append(respMsg);
+            sb.append(", resp content: ").append(respContent);
+            return sb.toString();
+        }
+    }
+
+    public void load(String value) throws StreamLoadException {
+        LoadResponse loadResponse = loadBatch(value);
+        LOG.info("Streamload Response:{}",loadResponse);
+        if(loadResponse.status != 200){
+            throw new StreamLoadException("stream load error: " + 
loadResponse.respContent);
+        }else{
+            ObjectMapper obj = new ObjectMapper();
+            try {
+                RespContent respContent = 
obj.readValue(loadResponse.respContent, RespContent.class);
+                if(!DORIS_SUCCESS_STATUS.contains(respContent.getStatus())){
+                    throw new StreamLoadException("stream load error: " + 
respContent.getMessage());
+                }
+            } catch (IOException e) {
+                throw new StreamLoadException(e);
+            }
+        }
+    }
+
+    private LoadResponse loadBatch(String value) {
+        Calendar calendar = Calendar.getInstance();
+        String label = String.format("audit_%s%02d%02d_%02d%02d%02d_%s",
+                calendar.get(Calendar.YEAR), calendar.get(Calendar.MONTH) + 1, 
calendar.get(Calendar.DAY_OF_MONTH),
+                calendar.get(Calendar.HOUR_OF_DAY), 
calendar.get(Calendar.MINUTE), calendar.get(Calendar.SECOND),
+                UUID.randomUUID().toString().replaceAll("-", ""));
+
+        HttpURLConnection feConn = null;
+        HttpURLConnection beConn = null;
+        try {
+            // build request and send to new be location
+            beConn = getConnection(loadUrlStr, label);
+            // send data to be
+            BufferedOutputStream bos = new 
BufferedOutputStream(beConn.getOutputStream());
+            bos.write(value.getBytes());
+            bos.close();
+
+            // get respond
+            int status = beConn.getResponseCode();
+            String respMsg = beConn.getResponseMessage();
+            InputStream stream = (InputStream) beConn.getContent();
+            BufferedReader br = new BufferedReader(new 
InputStreamReader(stream));
+            StringBuilder response = new StringBuilder();
+            String line;
+            while ((line = br.readLine()) != null) {
+                response.append(line);
+            }
+//            log.info("AuditLoader plugin load with label: {}, response code: 
{}, msg: {}, content: {}",label, status, respMsg, response.toString());
+            return new LoadResponse(status, respMsg, response.toString());
+
+        } catch (Exception e) {
+            e.printStackTrace();
+            String err = "failed to load audit via AuditLoader plugin with 
label: " + label;
+            LOG.warn(err, e);
+            return new LoadResponse(-1, e.getMessage(), err);
+        } finally {
+            if (feConn != null) {
+                feConn.disconnect();
+            }
+            if (beConn != null) {
+                beConn.disconnect();
+            }
+        }
+    }
+}
diff --git 
a/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala 
b/src/main/java/org/apache/doris/spark/exception/StreamLoadException.java
similarity index 54%
copy from src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
copy to src/main/java/org/apache/doris/spark/exception/StreamLoadException.java
index d2df4d0..ec9f77f 100644
--- a/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
+++ b/src/main/java/org/apache/doris/spark/exception/StreamLoadException.java
@@ -15,16 +15,24 @@
 // specific language governing permissions and limitations
 // under the License.
 
-package org.apache.doris.spark.sql
+package org.apache.doris.spark.exception;
 
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, 
RelationProvider}
-
-private[sql] class DorisSourceProvider extends DataSourceRegister with 
RelationProvider with Logging {
-  override def shortName(): String = "doris"
-
-  override def createRelation(sqlContext: SQLContext, parameters: Map[String, 
String]): BaseRelation = {
-    new DorisRelation(sqlContext, Utils.params(parameters, log))
-  }
+public class StreamLoadException extends Exception {
+    public StreamLoadException() {
+        super();
+    }
+    public StreamLoadException(String message) {
+        super(message);
+    }
+    public StreamLoadException(String message, Throwable cause) {
+        super(message, cause);
+    }
+    public StreamLoadException(Throwable cause) {
+        super(cause);
+    }
+    protected StreamLoadException(String message, Throwable cause,
+                                  boolean enableSuppression,
+                                  boolean writableStackTrace) {
+        super(message, cause, enableSuppression, writableStackTrace);
+    }
 }
diff --git a/src/main/java/org/apache/doris/spark/rest/RestService.java 
b/src/main/java/org/apache/doris/spark/rest/RestService.java
index ec9cfec..10126e8 100644
--- a/src/main/java/org/apache/doris/spark/rest/RestService.java
+++ b/src/main/java/org/apache/doris/spark/rest/RestService.java
@@ -49,33 +49,29 @@ import java.util.List;
 import java.util.ArrayList;
 import java.util.Set;
 import java.util.HashSet;
+import java.util.stream.Collectors;
 
 import org.apache.commons.io.IOUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.doris.spark.cfg.ConfigurationOptions;
 import org.apache.doris.spark.cfg.Settings;
+import org.apache.doris.spark.cfg.SparkSettings;
 import org.apache.doris.spark.exception.ConnectedFailedException;
 import org.apache.doris.spark.exception.DorisException;
 import org.apache.doris.spark.exception.IllegalArgumentException;
 import org.apache.doris.spark.exception.ShouldNeverHappenException;
+import org.apache.doris.spark.rest.models.Backend;
+import org.apache.doris.spark.rest.models.BackendRow;
 import org.apache.doris.spark.rest.models.QueryPlan;
 import org.apache.doris.spark.rest.models.Schema;
 import org.apache.doris.spark.rest.models.Tablet;
+import org.apache.doris.spark.sql.DorisWriterOption;
 import org.apache.http.HttpStatus;
-import org.apache.http.auth.AuthScope;
-import org.apache.http.auth.UsernamePasswordCredentials;
-import org.apache.http.client.CredentialsProvider;
 import org.apache.http.client.config.RequestConfig;
-import org.apache.http.client.methods.CloseableHttpResponse;
 import org.apache.http.client.methods.HttpGet;
 import org.apache.http.client.methods.HttpPost;
 import org.apache.http.client.methods.HttpRequestBase;
-import org.apache.http.client.protocol.HttpClientContext;
 import org.apache.http.entity.StringEntity;
-import org.apache.http.impl.client.BasicCredentialsProvider;
-import org.apache.http.impl.client.CloseableHttpClient;
-import org.apache.http.impl.client.HttpClients;
-import org.apache.http.util.EntityUtils;
 import org.codehaus.jackson.JsonParseException;
 import org.codehaus.jackson.map.JsonMappingException;
 import org.codehaus.jackson.map.ObjectMapper;
@@ -91,6 +87,8 @@ public class RestService implements Serializable {
     private static final String API_PREFIX = "/api";
     private static final String SCHEMA = "_schema";
     private static final String QUERY_PLAN = "_query_plan";
+    private static final String BACKENDS = "/rest/v1/system?path=//backends";
+
 
     /**
      * send request to Doris FE and get response json string.
@@ -477,6 +475,65 @@ public class RestService implements Serializable {
     }
 
     /**
+     * choice a Doris BE node to request.
+     * @param options configuration of request
+     * @param logger slf4j logger
+     * @return the chosen one Doris BE node
+     * @throws IllegalArgumentException BE nodes is illegal
+     */
+    @VisibleForTesting
+    public static String randomBackend(SparkSettings sparkSettings , 
DorisWriterOption options , Logger logger) throws DorisException, IOException {
+        // set user auth
+        sparkSettings.setProperty(DORIS_REQUEST_AUTH_USER,options.user());
+        
sparkSettings.setProperty(DORIS_REQUEST_AUTH_PASSWORD,options.password());
+        String feNodes = options.feHostPort();
+        String feNode = randomEndpoint(feNodes, logger);
+        String beUrl =   String.format("http://%s"; + BACKENDS,feNode);
+        HttpGet httpGet = new HttpGet(beUrl);
+        String response = send(sparkSettings,httpGet, logger);
+        logger.info("Backend Info:{}",response);
+        List<BackendRow> backends = parseBackend(response, logger);
+        logger.trace("Parse beNodes '{}'.", backends);
+        if (backends == null || backends.isEmpty()) {
+            logger.error(ILLEGAL_ARGUMENT_MESSAGE, "beNodes", backends);
+            throw new IllegalArgumentException("beNodes", 
String.valueOf(backends));
+        }
+        Collections.shuffle(backends);
+        BackendRow backend = backends.get(0);
+        return backend.getIP() + ":" + backend.getHttpPort();
+    }
+
+
+
+    static List<BackendRow> parseBackend(String response, Logger logger) 
throws DorisException, IOException {
+        com.fasterxml.jackson.databind.ObjectMapper mapper = new 
com.fasterxml.jackson.databind.ObjectMapper();
+        Backend backend;
+        try {
+            backend = mapper.readValue(response, Backend.class);
+        } catch (com.fasterxml.jackson.core.JsonParseException e) {
+            String errMsg = "Doris BE's response is not a json. res: " + 
response;
+            logger.error(errMsg, e);
+            throw new DorisException(errMsg, e);
+        } catch (com.fasterxml.jackson.databind.JsonMappingException e) {
+            String errMsg = "Doris BE's response cannot map to schema. res: " 
+ response;
+            logger.error(errMsg, e);
+            throw new DorisException(errMsg, e);
+        } catch (IOException e) {
+            String errMsg = "Parse Doris BE's response to json failed. res: " 
+ response;
+            logger.error(errMsg, e);
+            throw new DorisException(errMsg, e);
+        }
+
+        if (backend == null) {
+            logger.error(SHOULD_NOT_HAPPEN_MESSAGE);
+            throw new ShouldNeverHappenException();
+        }
+        List<BackendRow> backendRows = backend.getRows().stream().filter(v -> 
v.getAlive()).collect(Collectors.toList());
+        logger.debug("Parsing schema result is '{}'.", backendRows);
+        return backendRows;
+    }
+
+    /**
      * translate BE tablets map to Doris RDD partition.
      * @param cfg configuration of request
      * @param be2Tablets BE to tablets {@link Map}
diff --git 
a/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala 
b/src/main/java/org/apache/doris/spark/rest/models/Backend.java
similarity index 60%
copy from src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
copy to src/main/java/org/apache/doris/spark/rest/models/Backend.java
index d2df4d0..122e71c 100644
--- a/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
+++ b/src/main/java/org/apache/doris/spark/rest/models/Backend.java
@@ -14,17 +14,26 @@
 // KIND, either express or implied.  See the License for the
 // specific language governing permissions and limitations
 // under the License.
+package org.apache.doris.spark.rest.models;
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonProperty;
 
-package org.apache.doris.spark.sql
+import java.util.List;
 
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, 
RelationProvider}
+/**
+ * Be response model
+ **/
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class Backend {
 
-private[sql] class DorisSourceProvider extends DataSourceRegister with 
RelationProvider with Logging {
-  override def shortName(): String = "doris"
+    @JsonProperty(value = "rows")
+    private List<BackendRow> rows;
 
-  override def createRelation(sqlContext: SQLContext, parameters: Map[String, 
String]): BaseRelation = {
-    new DorisRelation(sqlContext, Utils.params(parameters, log))
-  }
+    public List<BackendRow> getRows() {
+        return rows;
+    }
+
+    public void setRows(List<BackendRow> rows) {
+        this.rows = rows;
+    }
 }
diff --git a/src/main/java/org/apache/doris/spark/rest/models/BackendRow.java 
b/src/main/java/org/apache/doris/spark/rest/models/BackendRow.java
new file mode 100644
index 0000000..0e2b385
--- /dev/null
+++ b/src/main/java/org/apache/doris/spark/rest/models/BackendRow.java
@@ -0,0 +1,64 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+package org.apache.doris.spark.rest.models;
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonProperty;
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class BackendRow {
+
+    @JsonProperty(value = "HttpPort")
+    private String HttpPort;
+
+    @JsonProperty(value = "IP")
+    private String IP;
+
+    @JsonProperty(value = "Alive")
+    private Boolean Alive;
+
+    public String getHttpPort() {
+        return HttpPort;
+    }
+
+    public void setHttpPort(String httpPort) {
+        HttpPort = httpPort;
+    }
+
+    public String getIP() {
+        return IP;
+    }
+
+    public void setIP(String IP) {
+        this.IP = IP;
+    }
+
+    public Boolean getAlive() {
+        return Alive;
+    }
+
+    public void setAlive(Boolean alive) {
+        Alive = alive;
+    }
+
+    @Override
+    public String toString() {
+        return "BackendRow{" +
+                "HttpPort='" + HttpPort + '\'' +
+                ", IP='" + IP + '\'' +
+                ", Alive=" + Alive +
+                '}';
+    }
+}
\ No newline at end of file
diff --git a/src/main/java/org/apache/doris/spark/rest/models/RespContent.java 
b/src/main/java/org/apache/doris/spark/rest/models/RespContent.java
new file mode 100644
index 0000000..f7fa6ff
--- /dev/null
+++ b/src/main/java/org/apache/doris/spark/rest/models/RespContent.java
@@ -0,0 +1,96 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+package org.apache.doris.spark.rest.models;
+
+import com.fasterxml.jackson.annotation.JsonIgnoreProperties;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+@JsonIgnoreProperties(ignoreUnknown = true)
+public class RespContent {
+
+    @JsonProperty(value = "TxnId")
+    private int TxnId;
+
+    @JsonProperty(value = "Label")
+    private String Label;
+
+    @JsonProperty(value = "Status")
+    private String Status;
+
+    @JsonProperty(value = "ExistingJobStatus")
+    private String ExistingJobStatus;
+
+    @JsonProperty(value = "Message")
+    private String Message;
+
+    @JsonProperty(value = "NumberTotalRows")
+    private long NumberTotalRows;
+
+    @JsonProperty(value = "NumberLoadedRows")
+    private long NumberLoadedRows;
+
+    @JsonProperty(value = "NumberFilteredRows")
+    private int NumberFilteredRows;
+
+    @JsonProperty(value = "NumberUnselectedRows")
+    private int NumberUnselectedRows;
+
+    @JsonProperty(value = "LoadBytes")
+    private long LoadBytes;
+
+    @JsonProperty(value = "LoadTimeMs")
+    private int LoadTimeMs;
+
+    @JsonProperty(value = "BeginTxnTimeMs")
+    private int BeginTxnTimeMs;
+
+    @JsonProperty(value = "StreamLoadPutTimeMs")
+    private int StreamLoadPutTimeMs;
+
+    @JsonProperty(value = "ReadDataTimeMs")
+    private int ReadDataTimeMs;
+
+    @JsonProperty(value = "WriteDataTimeMs")
+    private int WriteDataTimeMs;
+
+    @JsonProperty(value = "CommitAndPublishTimeMs")
+    private int CommitAndPublishTimeMs;
+
+    @JsonProperty(value = "ErrorURL")
+    private String ErrorURL;
+
+    public String getStatus() {
+        return Status;
+    }
+
+    public String getMessage() {
+        return Message;
+    }
+
+    @Override
+    public String toString() {
+        ObjectMapper mapper = new ObjectMapper();
+        try {
+            return mapper.writeValueAsString(this);
+        } catch (JsonProcessingException e) {
+            return "";
+        }
+
+    }
+}
diff --git 
a/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala 
b/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
index d2df4d0..774a29c 100644
--- a/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
+++ b/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
@@ -17,14 +17,127 @@
 
 package org.apache.doris.spark.sql
 
+import java.io.IOException
+import java.util.StringJoiner
+
+import org.apache.commons.collections.CollectionUtils
+import org.apache.doris.spark.DorisStreamLoad
+import org.apache.doris.spark.cfg.SparkSettings
+import org.apache.doris.spark.exception.DorisException
+import org.apache.doris.spark.rest.RestService
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, 
RelationProvider}
+import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
+import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, 
DataSourceRegister, Filter, RelationProvider}
+import org.apache.spark.sql.types.StructType
+import org.json4s.jackson.Json
+
+import scala.collection.mutable.ListBuffer
+import scala.util.Random
+import scala.util.control.Breaks
 
-private[sql] class DorisSourceProvider extends DataSourceRegister with 
RelationProvider with Logging {
+private[sql] class DorisSourceProvider extends DataSourceRegister with 
RelationProvider with CreatableRelationProvider with Logging {
   override def shortName(): String = "doris"
 
   override def createRelation(sqlContext: SQLContext, parameters: Map[String, 
String]): BaseRelation = {
     new DorisRelation(sqlContext, Utils.params(parameters, log))
   }
+
+
+  /**
+   * df.save
+   */
+  override def createRelation(sqlContext: SQLContext,
+                              mode: SaveMode, parameters: Map[String, String],
+                              data: DataFrame): BaseRelation = {
+
+    val dorisWriterOption = DorisWriterOption(parameters)
+    val sparkSettings = new SparkSettings(sqlContext.sparkContext.getConf)
+    // choose available be node
+    val choosedBeHost = RestService.randomBackend(sparkSettings, 
dorisWriterOption, log)
+    // init stream loader
+    val dorisStreamLoader = new DorisStreamLoad(choosedBeHost, 
dorisWriterOption.dbName, dorisWriterOption.tbName, dorisWriterOption.user, 
dorisWriterOption.password)
+    val fieldDelimiter: String = "\t"
+    val lineDelimiter: String = "\n"
+    val NULL_VALUE: String = "\\N"
+
+    val maxRowCount = dorisWriterOption.maxRowCount
+    val maxRetryTimes = dorisWriterOption.maxRetryTimes
+
+    data.foreachPartition(partition => {
+
+      val buffer = ListBuffer[String]()
+      partition.foreach(row => {
+        val value = new StringJoiner(fieldDelimiter)
+        // create one row string
+        for (i <- 0 until row.size) {
+          val field = row.get(i)
+          if (field == null) {
+            value.add(NULL_VALUE)
+          } else {
+            value.add(field.toString)
+          }
+        }
+        // add one row string to buffer
+        buffer += value.toString
+        if (buffer.size > maxRowCount) {
+          flush
+        }
+      })
+      // flush buffer
+      if (buffer.nonEmpty) {
+        flush
+      }
+
+      /**
+       * flush data to Doris and do retry when flush error
+       *
+       */
+      def flush = {
+        val loop = new Breaks
+        loop.breakable {
+
+          for (i <- 1 to maxRetryTimes) {
+            try {
+              dorisStreamLoader.load(buffer.mkString(lineDelimiter))
+              buffer.clear()
+              loop.break()
+            }
+            catch {
+              case e: Exception =>
+                try {
+                  Thread.sleep(1000 * i)
+                  dorisStreamLoader.load(buffer.mkString(lineDelimiter))
+                  buffer.clear()
+                } catch {
+                  case ex: InterruptedException =>
+                    Thread.currentThread.interrupt()
+                    throw new IOException("unable to flush; interrupted while 
doing another attempt", e)
+                }
+            }
+          }
+        }
+
+      }
+
+    })
+    new BaseRelation {
+      override def sqlContext: SQLContext = unsupportedException
+
+      override def schema: StructType = unsupportedException
+
+      override def needConversion: Boolean = unsupportedException
+
+      override def sizeInBytes: Long = unsupportedException
+
+      override def unhandledFilters(filters: Array[Filter]): Array[Filter] = 
unsupportedException
+
+      private def unsupportedException =
+        throw new UnsupportedOperationException("BaseRelation from doris write 
operation is not usable.")
+    }
+  }
+
+
+
+
+
 }
diff --git a/src/main/scala/org/apache/doris/spark/sql/DorisWriterOption.scala 
b/src/main/scala/org/apache/doris/spark/sql/DorisWriterOption.scala
new file mode 100644
index 0000000..69238c7
--- /dev/null
+++ b/src/main/scala/org/apache/doris/spark/sql/DorisWriterOption.scala
@@ -0,0 +1,41 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+package org.apache.doris.spark.sql
+
+import org.apache.doris.spark.exception.DorisException
+
+class DorisWriterOption(val feHostPort: String ,val dbName: String,val tbName: 
String,
+                        val  user: String ,val password: String,
+                        val maxRowCount: Long,val maxRetryTimes:Int)
+
+object DorisWriterOption{
+ def apply(parameters: Map[String, String]): DorisWriterOption={
+  val feHostPort: String = 
parameters.getOrElse(DorisWriterOptionKeys.feHostPort, throw new 
DorisException("feHostPort is empty"))
+
+  val dbName: String = parameters.getOrElse(DorisWriterOptionKeys.dbName, 
throw new DorisException("dbName is empty"))
+
+  val tbName: String = parameters.getOrElse(DorisWriterOptionKeys.tbName, 
throw new DorisException("tbName is empty"))
+
+  val user: String = parameters.getOrElse(DorisWriterOptionKeys.user, throw 
new DorisException("user is empty"))
+
+  val password: String = parameters.getOrElse(DorisWriterOptionKeys.password, 
throw new DorisException("password is empty"))
+
+  val maxRowCount: Long = 
parameters.getOrElse(DorisWriterOptionKeys.maxRowCount, "1024").toLong
+  val maxRetryTimes: Int = 
parameters.getOrElse(DorisWriterOptionKeys.maxRetryTimes, "3").toInt
+  new DorisWriterOption(feHostPort, dbName, tbName, user, password, 
maxRowCount, maxRetryTimes)
+ }
+}
\ No newline at end of file
diff --git 
a/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala 
b/src/main/scala/org/apache/doris/spark/sql/DorisWriterOptionKeys.scala
similarity index 63%
copy from src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
copy to src/main/scala/org/apache/doris/spark/sql/DorisWriterOptionKeys.scala
index d2df4d0..9cadd9f 100644
--- a/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
+++ b/src/main/scala/org/apache/doris/spark/sql/DorisWriterOptionKeys.scala
@@ -14,17 +14,15 @@
 // KIND, either express or implied.  See the License for the
 // specific language governing permissions and limitations
 // under the License.
-
 package org.apache.doris.spark.sql
 
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, 
RelationProvider}
-
-private[sql] class DorisSourceProvider extends DataSourceRegister with 
RelationProvider with Logging {
-  override def shortName(): String = "doris"
+object DorisWriterOptionKeys {
+  val feHostPort="feHostPort"
+  val dbName="dbName"
+  val tbName="tbName"
+  val user="user"
+  val password="password"
+  val maxRowCount="maxRowCount"
+  val maxRetryTimes="maxRetryTimes"
 
-  override def createRelation(sqlContext: SQLContext, parameters: Map[String, 
String]): BaseRelation = {
-    new DorisRelation(sqlContext, Utils.params(parameters, log))
-  }
 }
diff --git 
a/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala 
b/src/test/scala/org/apache/doris/spark/sql/DataframeSinkDoris.scala
similarity index 55%
copy from src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
copy to src/test/scala/org/apache/doris/spark/sql/DataframeSinkDoris.scala
index d2df4d0..b0df051 100644
--- a/src/main/scala/org/apache/doris/spark/sql/DorisSourceProvider.scala
+++ b/src/test/scala/org/apache/doris/spark/sql/DataframeSinkDoris.scala
@@ -14,17 +14,33 @@
 // KIND, either express or implied.  See the License for the
 // specific language governing permissions and limitations
 // under the License.
-
 package org.apache.doris.spark.sql
 
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, 
RelationProvider}
+import org.apache.spark.sql.SparkSession
+
+object DataframeSinkDoris {
+  def main(args: Array[String]): Unit = {
+    val spark = SparkSession.builder().master("local").getOrCreate()
+
+    import spark.implicits._
+
+    val mockDataDF = List(
+      (3, "440403001005", "21.cn"),
+      (1, "4404030013005", "22.cn"),
+      (33, null, "23.cn")
+    ).toDF("id", "mi_code", "mi_name")
+    mockDataDF.show(5)
+
+    mockDataDF.write.format("doris")
+      .option("feHostPort", "10.211.55.9:8030")
+      .option("dbName", "example_db")
+      .option("tbName", "test_insert_into")
+      .option("maxRowCount", "1000")
+      .option("user", "root")
+      .option("password", "")
+      .save()
 
-private[sql] class DorisSourceProvider extends DataSourceRegister with 
RelationProvider with Logging {
-  override def shortName(): String = "doris"
 
-  override def createRelation(sqlContext: SQLContext, parameters: Map[String, 
String]): BaseRelation = {
-    new DorisRelation(sqlContext, Utils.params(parameters, log))
   }
+
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to