This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch master
in repository 
https://gitbox.apache.org/repos/asf/incubator-doris-spark-connector.git

commit 293bd8f2f270842a2014c857634038581fedc8b9
Author: 924060929 <924060...@qq.com>
AuthorDate: Thu Mar 4 17:48:59 2021 +0800

    [Spark-Doris-Connector][Bug-Fix] Resolve deserialize exception when Spark 
Doris Connector in aync deserialize mode (#5336)
    
    Resolve deserialize exception when Spark Doris Connector in aync 
deserialize mode
    Co-authored-by: lanhuajian <lanhuaj...@sankuai.com>
---
 .../apache/doris/spark/rdd/ScalaValueReader.scala  | 39 +++++++++++++++++++---
 1 file changed, 34 insertions(+), 5 deletions(-)

diff --git a/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala 
b/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala
index 1d22c42..f3334b9 100644
--- a/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala
+++ b/src/main/scala/org/apache/doris/spark/rdd/ScalaValueReader.scala
@@ -19,6 +19,7 @@ package org.apache.doris.spark.rdd
 
 import java.util.concurrent.atomic.AtomicBoolean
 import java.util.concurrent._
+import java.util.concurrent.locks.{Condition, Lock, ReentrantLock}
 
 import scala.collection.JavaConversions._
 import scala.util.Try
@@ -46,11 +47,14 @@ class ScalaValueReader(partition: PartitionDefinition, 
settings: Settings) {
   protected val logger = Logger.getLogger(classOf[ScalaValueReader])
 
   protected val client = new BackendClient(new 
Routing(partition.getBeAddress), settings)
+  protected val clientLock =
+    if (deserializeArrowToRowBatchAsync) new ReentrantLock()
+    else new NoOpLock
   protected var offset = 0
   protected var eos: AtomicBoolean = new AtomicBoolean(false)
   protected var rowBatch: RowBatch = _
   // flag indicate if support deserialize Arrow to RowBatch asynchronously
-  protected var deserializeArrowToRowBatchAsync: Boolean = Try {
+  protected lazy val deserializeArrowToRowBatchAsync: Boolean = Try {
     settings.getProperty(DORIS_DESERIALIZE_ARROW_ASYNC, 
DORIS_DESERIALIZE_ARROW_ASYNC_DEFAULT.toString).toBoolean
   } getOrElse {
     logger.warn(ErrorMessages.PARSE_BOOL_FAILED_MESSAGE, 
DORIS_DESERIALIZE_ARROW_ASYNC, 
settings.getProperty(DORIS_DESERIALIZE_ARROW_ASYNC))
@@ -123,7 +127,7 @@ class ScalaValueReader(partition: PartitionDefinition, 
settings: Settings) {
     params
   }
 
-  protected val openResult: TScanOpenResult = client.openScanner(openParams)
+  protected val openResult: TScanOpenResult = 
lockClient(_.openScanner(openParams))
   protected val contextId: String = openResult.getContext_id
   protected val schema: Schema =
     SchemaUtils.convertToSchema(openResult.getSelected_columns)
@@ -134,7 +138,7 @@ class ScalaValueReader(partition: PartitionDefinition, 
settings: Settings) {
       nextBatchParams.setContext_id(contextId)
       while (!eos.get) {
         nextBatchParams.setOffset(offset)
-        val nextResult = client.getNext(nextBatchParams)
+        val nextResult = lockClient(_.getNext(nextBatchParams))
         eos.set(nextResult.isEos)
         if (!eos.get) {
           val rowBatch = new RowBatch(nextResult, schema)
@@ -192,7 +196,7 @@ class ScalaValueReader(partition: PartitionDefinition, 
settings: Settings) {
         val nextBatchParams = new TScanNextBatchParams
         nextBatchParams.setContext_id(contextId)
         nextBatchParams.setOffset(offset)
-        val nextResult = client.getNext(nextBatchParams)
+        val nextResult = lockClient(_.getNext(nextBatchParams))
         eos.set(nextResult.isEos)
         if (!eos.get) {
           rowBatch = new RowBatch(nextResult, schema)
@@ -218,6 +222,31 @@ class ScalaValueReader(partition: PartitionDefinition, 
settings: Settings) {
   def close(): Unit = {
     val closeParams = new TScanCloseParams
     closeParams.context_id = contextId
-    client.closeScanner(closeParams)
+    lockClient(_.closeScanner(closeParams))
+  }
+
+  private def lockClient[T](action: BackendClient => T): T = {
+    clientLock.lock()
+    try {
+      action(client)
+    } finally {
+      clientLock.unlock()
+    }
+  }
+
+  private class NoOpLock extends Lock {
+    override def lock(): Unit = {}
+
+    override def lockInterruptibly(): Unit = {}
+
+    override def tryLock(): Boolean = true
+
+    override def tryLock(time: Long, unit: TimeUnit): Boolean = true
+
+    override def unlock(): Unit = {}
+
+    override def newCondition(): Condition = {
+      throw new UnsupportedOperationException("NoOpLock can't provide a 
condition")
+    }
   }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to