This is an automated email from the ASF dual-hosted git repository.

xxyu pushed a commit to branch kylin-on-parquet-v2
in repository https://gitbox.apache.org/repos/asf/kylin.git


The following commit(s) were added to refs/heads/kylin-on-parquet-v2 by this 
push:
     new 959edeb  KYLIN-4829 Support to use thread-level SparkSession to 
execute query
959edeb is described below

commit 959edebdb262afd63308ba32031abcf889670245
Author: Zhichao Zhang <441586...@qq.com>
AuthorDate: Mon Nov 30 17:16:29 2020 +0800

    KYLIN-4829 Support to use thread-level SparkSession to execute query
---
 kylin-spark-project/kylin-spark-common/pom.xml     |   2 +-
 .../datasource/ResetShufflePartition.scala         |  13 +-
 .../org/apache/spark/utils/SparderUtils.scala}     |  34 ++---
 kylin-spark-project/kylin-spark-engine/pom.xml     |   2 +-
 .../scala/org/apache/spark/sql/KylinSparkEnv.scala |  75 +----------
 kylin-spark-project/kylin-spark-query/pom.xml      |   5 +
 .../kylin/query/monitor/SparderContextCanary.java  |  10 +-
 .../kylin/query/pushdown/SparkSubmitter.java       |   1 +
 .../scala/org/apache/kylin/query/UdfManager.scala  |  31 ++++-
 .../kylin/query/pushdown/SparkSqlClient.scala      |   7 +-
 .../kylin/query/runtime/plans/ResultPlan.scala     |  15 ++-
 .../scala/org/apache/spark/sql/KylinSession.scala  |   2 +-
 .../org/apache/spark/sql/SparderContext.scala      |  51 ++++----
 .../apache/spark/sql/SparderContextFacade.scala}   |  36 +++---
 .../query/monitor/SparderContextCanaryTest.java    |  13 +-
 .../apache/spark/sql/SparderContextFacadeTest.java | 143 +++++++++++++++++++++
 kylin-spark-project/kylin-spark-test/pom.xml       |   6 +-
 17 files changed, 280 insertions(+), 166 deletions(-)

diff --git a/kylin-spark-project/kylin-spark-common/pom.xml 
b/kylin-spark-project/kylin-spark-common/pom.xml
index 9309647..44be54b 100644
--- a/kylin-spark-project/kylin-spark-common/pom.xml
+++ b/kylin-spark-project/kylin-spark-common/pom.xml
@@ -45,7 +45,7 @@
         <dependency>
             <groupId>org.apache.kylin</groupId>
             <artifactId>kylin-spark-metadata</artifactId>
-            <version>4.0.0-SNAPSHOT</version>
+            <version>${project.version}</version>
         </dependency>
     </dependencies>
     
diff --git 
a/kylin-spark-project/kylin-spark-common/src/main/scala/org/apache/spark/sql/execution/datasource/ResetShufflePartition.scala
 
b/kylin-spark-project/kylin-spark-common/src/main/scala/org/apache/spark/sql/execution/datasource/ResetShufflePartition.scala
index aaed1d9..b048283 100644
--- 
a/kylin-spark-project/kylin-spark-common/src/main/scala/org/apache/spark/sql/execution/datasource/ResetShufflePartition.scala
+++ 
b/kylin-spark-project/kylin-spark-common/src/main/scala/org/apache/spark/sql/execution/datasource/ResetShufflePartition.scala
@@ -17,25 +17,26 @@
  */
 package org.apache.spark.sql.execution.datasource
 
-import org.apache.kylin.common.{KylinConfig, QueryContext, QueryContextFacade}
+import org.apache.kylin.common.{KylinConfig, QueryContextFacade}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.SparkSession
+import org.apache.spark.utils.SparderUtils
 
 trait ResetShufflePartition extends Logging {
+  val PARTITION_SPLIT_BYTES: Long = 
KylinConfig.getInstanceFromEnv.getQueryPartitionSplitSizeMB * 1024 * 1024 // 
64MB
 
   def setShufflePartitions(bytes: Long, sparkSession: SparkSession): Unit = {
     QueryContextFacade.current().addAndGetSourceScanBytes(bytes)
-    val defaultParallelism = sparkSession.sparkContext.defaultParallelism
+    val defaultParallelism = 
SparderUtils.getTotalCore(sparkSession.sparkContext.getConf)
     val kylinConfig = KylinConfig.getInstanceFromEnv
     val partitionsNum = if (kylinConfig.getSparkSqlShufflePartitions != -1) {
       kylinConfig.getSparkSqlShufflePartitions
     } else {
-      Math.min(QueryContextFacade.current().getSourceScanBytes / (
-        KylinConfig.getInstanceFromEnv.getQueryPartitionSplitSizeMB * 1024 * 
1024 * 2) + 1,
+      Math.min(QueryContextFacade.current().getSourceScanBytes / 
PARTITION_SPLIT_BYTES + 1,
         defaultParallelism).toInt
     }
-    
//sparkSession.sessionState.conf.setLocalProperty("spark.sql.shuffle.partitions",
-    //  partitionsNum.toString)
+    // when hitting cube, this will override the value of 
'spark.sql.shuffle.partitions'
+    sparkSession.conf.set("spark.sql.shuffle.partitions", 
partitionsNum.toString)
     logInfo(s"Set partition to $partitionsNum, total bytes 
${QueryContextFacade.current().getSourceScanBytes}")
   }
 }
diff --git 
a/kylin-spark-project/kylin-spark-query/src/main/java/org/apache/kylin/query/pushdown/SparkSubmitter.java
 
b/kylin-spark-project/kylin-spark-common/src/main/scala/org/apache/spark/utils/SparderUtils.scala
similarity index 51%
copy from 
kylin-spark-project/kylin-spark-query/src/main/java/org/apache/kylin/query/pushdown/SparkSubmitter.java
copy to 
kylin-spark-project/kylin-spark-common/src/main/scala/org/apache/spark/utils/SparderUtils.scala
index 664e6be..fdc83cc 100644
--- 
a/kylin-spark-project/kylin-spark-query/src/main/java/org/apache/kylin/query/pushdown/SparkSubmitter.java
+++ 
b/kylin-spark-project/kylin-spark-common/src/main/scala/org/apache/spark/utils/SparderUtils.scala
@@ -16,25 +16,27 @@
  * limitations under the License.
  */
 
-package org.apache.kylin.query.pushdown;
+package org.apache.spark.utils
 
-import org.apache.kylin.common.util.Pair;
-import org.apache.kylin.engine.spark.metadata.cube.StructField;
-import org.apache.spark.sql.SparderContext;
-import org.apache.spark.sql.SparkSession;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import org.apache.spark.SparkConf
 
-import java.util.List;
-import java.util.UUID;
+object SparderUtils {
 
-public class SparkSubmitter {
-    public static final Logger logger = 
LoggerFactory.getLogger(SparkSubmitter.class);
-
-    public static PushdownResponse submitPushDownTask(String sql) {
-        SparkSession ss = SparderContext.getSparkSession();
-        Pair<List<List<String>>, List<StructField>> pair = 
SparkSqlClient.executeSql(ss, sql, UUID.randomUUID());
-        return new PushdownResponse(pair.getSecond(), pair.getFirst());
+  def getTotalCore(sparkConf: SparkConf): Int = {
+    if (sparkConf.get("spark.master").startsWith("local")) {
+      return 1
     }
+    val instances = getExecutorNum(sparkConf)
+    val cores = sparkConf.get("spark.executor.cores").toInt
+    Math.max(instances * cores, 1)
+  }
 
+  def getExecutorNum(sparkConf: SparkConf): Int = {
+    if (sparkConf.get("spark.dynamicAllocation.enabled", "false").toBoolean) {
+      val maxExecutors = sparkConf.get("spark.dynamicAllocation.maxExecutors", 
Int.MaxValue.toString).toInt
+      maxExecutors
+    } else {
+      sparkConf.get("spark.executor.instances").toInt
+    }
+  }
 }
diff --git a/kylin-spark-project/kylin-spark-engine/pom.xml 
b/kylin-spark-project/kylin-spark-engine/pom.xml
index 3632a4c..9954afe 100644
--- a/kylin-spark-project/kylin-spark-engine/pom.xml
+++ b/kylin-spark-project/kylin-spark-engine/pom.xml
@@ -40,7 +40,7 @@
         <dependency>
             <groupId>org.apache.kylin</groupId>
             <artifactId>kylin-spark-metadata</artifactId>
-            <version>4.0.0-SNAPSHOT</version>
+            <version>${project.version}</version>
         </dependency>
         <dependency>
             <groupId>org.apache.kylin</groupId>
diff --git 
a/kylin-spark-project/kylin-spark-engine/src/main/scala/org/apache/spark/sql/KylinSparkEnv.scala
 
b/kylin-spark-project/kylin-spark-engine/src/main/scala/org/apache/spark/sql/KylinSparkEnv.scala
index bb39dc5..3ef6104 100644
--- 
a/kylin-spark-project/kylin-spark-engine/src/main/scala/org/apache/spark/sql/KylinSparkEnv.scala
+++ 
b/kylin-spark-project/kylin-spark-engine/src/main/scala/org/apache/spark/sql/KylinSparkEnv.scala
@@ -25,19 +25,11 @@ object KylinSparkEnv extends Logging {
        @volatile
        private var spark: SparkSession = _
 
-       val _cuboid = new ThreadLocal[Dataset[Row]]
        val _needCompute = new ThreadLocal[JBoolean] {
                override protected def initialValue = false
        }
 
-       @volatile
-       private var initializingThread: Thread = null
-
        def getSparkSession: SparkSession = withClassLoad {
-               if (spark == null || spark.sparkContext.isStopped) {
-                       logInfo("Init spark.")
-                       initSpark()
-               }
                spark
        }
 
@@ -45,76 +37,15 @@ object KylinSparkEnv extends Logging {
                spark = sparkSession
        }
 
-       def init(): Unit = withClassLoad {
-               getSparkSession
-       }
-
        def withClassLoad[T](body: => T): T = {
-               //    val originClassLoad = 
Thread.currentThread().getContextClassLoader
+               // val originClassLoad = 
Thread.currentThread().getContextClassLoader
                // fixme aron
-               //        
Thread.currentThread().setContextClassLoader(ClassLoaderUtils.getSparkClassLoader)
+               // 
Thread.currentThread().setContextClassLoader(ClassLoaderUtils.getSparkClassLoader)
                val t = body
-               //    
Thread.currentThread().setContextClassLoader(originClassLoad)
+               // Thread.currentThread().setContextClassLoader(originClassLoad)
                t
        }
 
-       def initSpark(): Unit = withClassLoad {
-               this.synchronized {
-                       if (initializingThread == null && (spark == null || 
spark.sparkContext.isStopped)) {
-                               initializingThread = new Thread(new Runnable {
-                                       override def run(): Unit = {
-                                               try {
-                                                       val sparkSession = 
System.getProperty("spark.local") match {
-                                                               case "true" =>
-                                                                       
SparkSession.builder
-                                                                               
        .master("local")
-                                                                               
        .appName("sparder-test-sql-context")
-                                                                               
        .enableHiveSupport()
-                                                                               
        .getOrCreate()
-                                                               case _ =>
-                                                                       
SparkSession.builder
-                                                                               
        .appName("sparder-sql-context")
-                                                                               
        .master("yarn-client")
-                                                                               
        //if user defined other master in kylin.properties,
-                                                                               
        // it will get overwrite later in 
org.apache.spark.sql.KylinSession.KylinBuilder.initSparkConf
-                                                                               
        .enableHiveSupport()
-                                                                               
        .getOrCreate()
-                                                       }
-                                                       spark = sparkSession
-                                                       logInfo("Spark context 
started successfully with stack trace:")
-                                                       
logInfo(Thread.currentThread().getStackTrace.mkString("\n"))
-                                                       logInfo(
-                                                               "Class loader: 
" + Thread
-                                                                               
.currentThread()
-                                                                               
.getContextClassLoader
-                                                                               
.toString)
-                                               } catch {
-                                                       case throwable: 
Throwable =>
-                                                               logError("Error 
for initializing spark ", throwable)
-                                               } finally {
-                                                       logInfo("Setting 
initializing Spark thread to null.")
-                                                       initializingThread = 
null
-                                               }
-                                       }
-                               })
-
-                               logInfo("Initializing Spark thread starting.")
-                               initializingThread.start()
-                       }
-
-                       if (initializingThread != null) {
-                               logInfo("Initializing Spark, waiting for done.")
-                               initializingThread.join()
-                       }
-               }
-       }
-
-       def getCuboid: Dataset[Row] = _cuboid.get
-
-       def setCuboid(cuboid: Dataset[Row]): Unit = {
-               _cuboid.set(cuboid)
-       }
-
        def skipCompute(): Unit = {
                _needCompute.set(true)
        }
diff --git a/kylin-spark-project/kylin-spark-query/pom.xml 
b/kylin-spark-project/kylin-spark-query/pom.xml
index 59493eb..2e6f7e4 100644
--- a/kylin-spark-project/kylin-spark-query/pom.xml
+++ b/kylin-spark-project/kylin-spark-query/pom.xml
@@ -106,6 +106,11 @@
             <artifactId>kylin-core-storage</artifactId>
         </dependency>
 
+        <dependency>
+            <groupId>org.apache.kylin</groupId>
+            <artifactId>kylin-core-common</artifactId>
+        </dependency>
+
         <!--For update spark job info from cluster-->
         <dependency>
             <groupId>org.apache.kylin</groupId>
diff --git 
a/kylin-spark-project/kylin-spark-query/src/main/java/org/apache/kylin/query/monitor/SparderContextCanary.java
 
b/kylin-spark-project/kylin-spark-query/src/main/java/org/apache/kylin/query/monitor/SparderContextCanary.java
index d0950c1..4066574 100644
--- 
a/kylin-spark-project/kylin-spark-query/src/main/java/org/apache/kylin/query/monitor/SparderContextCanary.java
+++ 
b/kylin-spark-project/kylin-spark-query/src/main/java/org/apache/kylin/query/monitor/SparderContextCanary.java
@@ -21,7 +21,6 @@ package org.apache.kylin.query.monitor;
 import org.apache.kylin.common.KylinConfig;
 import org.apache.spark.api.java.JavaFutureAction;
 import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.KylinSparkEnv;
 import org.apache.spark.sql.SparderContext;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -87,7 +86,8 @@ public class SparderContextCanary {
                 errorAccumulated = Math.max(errorAccumulated + 1, 
THRESHOLD_TO_RESTART_SPARK);
             } else {
                 try {
-                    JavaSparkContext jsc = 
JavaSparkContext.fromSparkContext(SparderContext.getSparkSession().sparkContext());
+                    JavaSparkContext jsc =
+                            
JavaSparkContext.fromSparkContext(SparderContext.getOriginalSparkSession().sparkContext());
                     jsc.setLocalProperty("spark.scheduler.pool", "vip_tasks");
 
                     long t = System.currentTimeMillis();
@@ -118,11 +118,7 @@ public class SparderContextCanary {
                 try {
                     // Take repair action if error accumulated exceeds 
threshold
                     logger.warn("Repairing sparder context");
-                    if ("true".equals(System.getProperty("spark.local"))) {
-                        
SparderContext.setSparkSession(KylinSparkEnv.getSparkSession());
-                    } else {
-                        SparderContext.restartSpark();
-                    }
+                    SparderContext.restartSpark();
                 } catch (Throwable th) {
                     logger.error("Restart sparder context failed.", th);
                 }
diff --git 
a/kylin-spark-project/kylin-spark-query/src/main/java/org/apache/kylin/query/pushdown/SparkSubmitter.java
 
b/kylin-spark-project/kylin-spark-query/src/main/java/org/apache/kylin/query/pushdown/SparkSubmitter.java
index 664e6be..2d31822 100644
--- 
a/kylin-spark-project/kylin-spark-query/src/main/java/org/apache/kylin/query/pushdown/SparkSubmitter.java
+++ 
b/kylin-spark-project/kylin-spark-query/src/main/java/org/apache/kylin/query/pushdown/SparkSubmitter.java
@@ -34,6 +34,7 @@ public class SparkSubmitter {
     public static PushdownResponse submitPushDownTask(String sql) {
         SparkSession ss = SparderContext.getSparkSession();
         Pair<List<List<String>>, List<StructField>> pair = 
SparkSqlClient.executeSql(ss, sql, UUID.randomUUID());
+        SparderContext.closeThreadSparkSession();
         return new PushdownResponse(pair.getSecond(), pair.getFirst());
     }
 
diff --git 
a/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/kylin/query/UdfManager.scala
 
b/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/kylin/query/UdfManager.scala
index 4044e58..18c01c3 100644
--- 
a/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/kylin/query/UdfManager.scala
+++ 
b/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/kylin/query/UdfManager.scala
@@ -24,15 +24,18 @@ import java.util.concurrent.atomic.AtomicReference
 import org.apache.kylin.shaded.com.google.common.cache.{Cache, CacheBuilder, 
RemovalListener, RemovalNotification}
 import org.apache.kylin.metadata.datatype.DataType
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.{FunctionEntity, KylinFunctions, SparkSession}
+import org.apache.spark.sql.{FunctionEntity, KylinFunctions, 
SparderContextFacade, SparkSession}
 import org.apache.spark.sql.catalyst.expressions.SparderAggFun
 
 class UdfManager(sparkSession: SparkSession) extends Logging {
   private var udfCache: Cache[String, String] = _
 
-  KylinFunctions.builtin.foreach { case FunctionEntity(name, info, builder) =>
-    sparkSession.sessionState.functionRegistry.registerFunction(name, info, 
builder)
+  private def registerBuiltInFunc(): Unit = {
+    KylinFunctions.builtin.foreach { case FunctionEntity(name, info, builder) 
=>
+      sparkSession.sessionState.functionRegistry.registerFunction(name, info, 
builder)
+    }
   }
+
   udfCache = CacheBuilder.newBuilder
     .maximumSize(100)
     .expireAfterWrite(1, TimeUnit.HOURS)
@@ -74,14 +77,34 @@ object UdfManager {
   private val defaultSparkSession: AtomicReference[SparkSession] =
     new AtomicReference[SparkSession]
 
+  /**
+   * create UdfManager for original SparkSession
+   */
   def create(sparkSession: SparkSession): Unit = {
     val manager = new UdfManager(sparkSession)
+    manager.registerBuiltInFunc
     defaultManager.set(manager)
     defaultSparkSession.set(sparkSession)
   }
 
-  def register(dataType: DataType, func: String): String = {
+  /**
+   * register for original SparkSession
+   */
+  def registerForOriginal(dataType: DataType, func: String): String = {
     defaultManager.get().doRegister(dataType, func)
   }
 
+  /**
+   * create UdfManager for thread local SparkSession
+   */
+  def createWithoutBuildInFunc(sparkSession: SparkSession): UdfManager = {
+    new UdfManager(sparkSession)
+  }
+
+  /**
+   * register for thread local SparkSession
+   */
+  def register(dataType: DataType, func: String): String = {
+    SparderContextFacade.current().getSecond.doRegister(dataType, func)
+  }
 }
diff --git 
a/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/kylin/query/pushdown/SparkSqlClient.scala
 
b/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/kylin/query/pushdown/SparkSqlClient.scala
index 5ce4f8f..0d8b769 100644
--- 
a/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/kylin/query/pushdown/SparkSqlClient.scala
+++ 
b/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/kylin/query/pushdown/SparkSqlClient.scala
@@ -64,9 +64,9 @@ object SparkSqlClient {
                                val paths = 
ResourceDetectUtils.getPaths(df.queryExecution.sparkPlan)
                                val sourceTableSize = 
ResourceDetectUtils.getResourceSize(paths: _*) + "b"
                                val partitions = Math.max(1, 
JavaUtils.byteStringAsMb(sourceTableSize) / basePartitionSize).toString
-                               
//df.sparkSession.sessionState.conf.setLocalProperty("spark.sql.shuffle.partitions",
-                               //      partitions)
-                               logger.info(s"Auto set 
spark.sql.shuffle.partitions $partitions")
+                               
df.sparkSession.conf.set("spark.sql.shuffle.partitions", partitions)
+                               logger.info(s"Auto set 
spark.sql.shuffle.partitions to $partitions, the total sources " +
+                                       s"size is ${sourceTableSize}")
                        } catch {
                                case e: Throwable =>
                                        logger.error("Auto set 
spark.sql.shuffle.partitions failed.", e)
@@ -103,7 +103,6 @@ object SparkSqlClient {
                                }
                                else throw e
                } finally {
-                       
//df.sparkSession.sessionState.conf.setLocalProperty("spark.sql.shuffle.partitions",
 null)
                        HadoopUtil.setCurrentConfiguration(null)
                }
        }
diff --git 
a/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/kylin/query/runtime/plans/ResultPlan.scala
 
b/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/kylin/query/runtime/plans/ResultPlan.scala
index 65ae2b0..d840207 100644
--- 
a/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/kylin/query/runtime/plans/ResultPlan.scala
+++ 
b/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/kylin/query/runtime/plans/ResultPlan.scala
@@ -31,6 +31,7 @@ import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{DataFrame, SparderContext}
 import org.apache.spark.sql.hive.utils.QueryMetricUtils
 import org.apache.spark.sql.utils.SparkTypeUtil
+import org.apache.spark.utils.SparderUtils
 
 import scala.collection.JavaConverters._
 
@@ -41,7 +42,6 @@ object ResultType extends Enumeration {
 }
 
 object ResultPlan extends Logging {
-  val PARTITION_SPLIT_BYTES: Long = 
KylinConfig.getInstanceFromEnv.getQueryPartitionSplitSizeMB * 1024 * 1024 // 
64MB
 
   def collectEnumerable(
     df: DataFrame,
@@ -70,17 +70,18 @@ object ResultPlan extends Logging {
       kylinConfig = 
ProjectManager.getInstance(kylinConfig).getProject(projectName).getConfig
     }
     var pool = "heavy_tasks"
+    val sparderTotalCores = 
SparderUtils.getTotalCore(df.sparkSession.sparkContext.getConf)
+    // this value of partition num only effects when querying from snapshot 
tables
     val partitionsNum =
       if (kylinConfig.getSparkSqlShufflePartitions != -1) {
         kylinConfig.getSparkSqlShufflePartitions
       } else {
-        Math.min(QueryContextFacade.current().getSourceScanBytes / 
PARTITION_SPLIT_BYTES + 1,
-          SparderContext.getTotalCore).toInt
+        sparderTotalCores
       }
 
     if (QueryContextFacade.current().isHighPriorityQuery) {
       pool = "vip_tasks"
-    } else if (partitionsNum <= SparderContext.getTotalCore) {
+    } else if (partitionsNum <= sparderTotalCores) {
       pool = "lightweight_tasks"
     }
 
@@ -96,8 +97,8 @@ object ResultPlan extends Logging {
     QueryContextFacade.current().setSparkPool(pool)
     val queryId = QueryContextFacade.current().getQueryId
     sparkContext.setLocalProperty(QueryToExecutionIDCache.KYLIN_QUERY_ID_KEY, 
queryId)
-    
//df.sparkSession.sessionState.conf.setLocalProperty("spark.sql.shuffle.partitions",
-    //  partitionsNum.toString)
+    df.sparkSession.conf.set("spark.sql.shuffle.partitions", 
partitionsNum.toString)
+    logInfo(s"Set partition to $partitionsNum")
     QueryContextFacade.current().setDataset(df)
 
     sparkContext.setJobGroup(jobGroup,
@@ -153,7 +154,6 @@ object ResultPlan extends Logging {
     val r = body
     // remember clear local properties.
     df.sparkSession.sparkContext.setLocalProperty("spark.scheduler.pool", null)
-    
//df.sparkSession.sessionState.conf.setLocalProperty("spark.sql.shuffle.partitions",
 null)
     SparderContext.setDF(df)
     TableScanPlan.cacheDf.get().clear()
     HadoopUtil.setCurrentConfiguration(null)
@@ -178,6 +178,7 @@ object ResultPlan extends Logging {
           }
       }
     SparderContext.cleanQueryInfo()
+    SparderContext.closeThreadSparkSession()
     result
   }
 }
diff --git 
a/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/spark/sql/KylinSession.scala
 
b/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/spark/sql/KylinSession.scala
index 7ae0937..e3c532e 100644
--- 
a/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/spark/sql/KylinSession.scala
+++ 
b/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/spark/sql/KylinSession.scala
@@ -23,7 +23,7 @@ import java.nio.file.Paths
 
 import org.apache.hadoop.security.UserGroupInformation
 import org.apache.kylin.common.KylinConfig
-import org.apache.kylin.query.{UdfManager}
+import org.apache.kylin.query.UdfManager
 import org.apache.spark.internal.Logging
 import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
 import org.apache.spark.sql.SparkSession.Builder
diff --git 
a/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/spark/sql/SparderContext.scala
 
b/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/spark/sql/SparderContext.scala
index 638a9ac..ba4c7b7 100644
--- 
a/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/spark/sql/SparderContext.scala
+++ 
b/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/spark/sql/SparderContext.scala
@@ -50,7 +50,7 @@ object SparderContext extends Logging {
   @volatile
   var master_app_url: String = _
 
-  def getSparkSession: SparkSession = withClassLoad {
+  def getOriginalSparkSession: SparkSession = withClassLoad {
     if (spark == null || spark.sparkContext.isStopped) {
       logInfo("Init spark.")
       initSpark()
@@ -58,6 +58,16 @@ object SparderContext extends Logging {
     spark
   }
 
+  def getSparkSession: SparkSession = {
+    logInfo(s"Current thread ${Thread.currentThread().getId} create a 
SparkSession.")
+    SparderContextFacade.current().getFirst
+  }
+
+  def closeThreadSparkSession(): Unit = {
+    logInfo(s"Remove SparkSession from thread ${Thread.currentThread().getId}")
+    SparderContextFacade.remove()
+  }
+
   def setSparkSession(sparkSession: SparkSession): Unit = {
     spark = sparkSession
     UdfManager.create(sparkSession)
@@ -93,35 +103,26 @@ object SparderContext extends Logging {
     }
   }
 
+  def stopSpark(): Unit = withClassLoad {
+    this.synchronized {
+      if (spark != null && !spark.sparkContext.isStopped) {
+        Utils.tryWithSafeFinally {
+          spark.stop()
+        } {
+          SparkContext.clearActiveContext
+        }
+      }
+    }
+  }
+
   def init(): Unit = withClassLoad {
-    getSparkSession
+    getOriginalSparkSession
   }
 
   def getSparkConf(key: String): String = {
     getSparkSession.sparkContext.conf.get(key)
   }
 
-  def getTotalCore: Int = {
-    val sparkConf = getSparkSession.sparkContext.getConf
-    if (sparkConf.get("spark.master").startsWith("local")) {
-      return 1
-    }
-    val instances = getExecutorNum(sparkConf)
-    val cores = sparkConf.get("spark.executor.cores").toInt
-    Math.max(instances * cores, 1)
-  }
-
-  def getExecutorNum(sparkConf: SparkConf): Int = {
-    if (sparkConf.get("spark.dynamicAllocation.enabled", "false").toBoolean) {
-      val maxExecutors = sparkConf.get("spark.dynamicAllocation.maxExecutors", 
Int.MaxValue.toString).toInt
-      logInfo(s"Use spark.dynamicAllocation.maxExecutors:$maxExecutors as num 
instances of executors.")
-      maxExecutors
-    } else {
-      sparkConf.get("spark.executor.instances").toInt
-    }
-  }
-
-
   def initSpark(): Unit = withClassLoad {
     this.synchronized {
       if (initializingThread == null && (spark == null || 
spark.sparkContext.isStopped)) {
@@ -240,10 +241,10 @@ object SparderContext extends Logging {
    * @return The body return
    */
   def withClassLoad[T](body: => T): T = {
-    //    val originClassLoad = Thread.currentThread().getContextClassLoader
+    // val originClassLoad = Thread.currentThread().getContextClassLoader
     
Thread.currentThread().setContextClassLoader(ClassLoaderUtils.getSparkClassLoader)
     val t = body
-    //    Thread.currentThread().setContextClassLoader(originClassLoad)
+    // Thread.currentThread().setContextClassLoader(originClassLoad)
     t
   }
 
diff --git 
a/kylin-spark-project/kylin-spark-query/src/main/java/org/apache/kylin/query/pushdown/SparkSubmitter.java
 
b/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/spark/sql/SparderContextFacade.scala
similarity index 50%
copy from 
kylin-spark-project/kylin-spark-query/src/main/java/org/apache/kylin/query/pushdown/SparkSubmitter.java
copy to 
kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/spark/sql/SparderContextFacade.scala
index 664e6be..386e9fc 100644
--- 
a/kylin-spark-project/kylin-spark-query/src/main/java/org/apache/kylin/query/pushdown/SparkSubmitter.java
+++ 
b/kylin-spark-project/kylin-spark-query/src/main/scala/org/apache/spark/sql/SparderContextFacade.scala
@@ -16,25 +16,29 @@
  * limitations under the License.
  */
 
-package org.apache.kylin.query.pushdown;
+package org.apache.spark.sql
 
-import org.apache.kylin.common.util.Pair;
-import org.apache.kylin.engine.spark.metadata.cube.StructField;
-import org.apache.spark.sql.SparderContext;
-import org.apache.spark.sql.SparkSession;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
+import org.apache.spark.internal.Logging
 
-import java.util.List;
-import java.util.UUID;
+import org.apache.kylin.common.threadlocal.InternalThreadLocal
+import org.apache.kylin.common.util.Pair
+import org.apache.kylin.query.UdfManager
 
-public class SparkSubmitter {
-    public static final Logger logger = 
LoggerFactory.getLogger(SparkSubmitter.class);
+object SparderContextFacade extends Logging {
 
-    public static PushdownResponse submitPushDownTask(String sql) {
-        SparkSession ss = SparderContext.getSparkSession();
-        Pair<List<List<String>>, List<StructField>> pair = 
SparkSqlClient.executeSql(ss, sql, UUID.randomUUID());
-        return new PushdownResponse(pair.getSecond(), pair.getFirst());
+  final val CURRENT_SPARKSESSION: InternalThreadLocal[Pair[SparkSession, 
UdfManager]] =
+    new InternalThreadLocal[Pair[SparkSession, UdfManager]]()
+
+  def current(): Pair[SparkSession, UdfManager] = {
+    if (CURRENT_SPARKSESSION.get() == null) {
+      val spark = SparderContext.getOriginalSparkSession.cloneSession()
+      CURRENT_SPARKSESSION.set(new Pair[SparkSession, UdfManager](spark,
+        UdfManager.createWithoutBuildInFunc(spark)))
     }
+    CURRENT_SPARKSESSION.get()
+  }
 
-}
+  def remove(): Unit = {
+    CURRENT_SPARKSESSION.remove()
+  }
+}
\ No newline at end of file
diff --git 
a/kylin-spark-project/kylin-spark-query/src/test/java/org/apache/kylin/query/monitor/SparderContextCanaryTest.java
 
b/kylin-spark-project/kylin-spark-query/src/test/java/org/apache/kylin/query/monitor/SparderContextCanaryTest.java
index 7a22892..a27264a 100644
--- 
a/kylin-spark-project/kylin-spark-query/src/test/java/org/apache/kylin/query/monitor/SparderContextCanaryTest.java
+++ 
b/kylin-spark-project/kylin-spark-query/src/test/java/org/apache/kylin/query/monitor/SparderContextCanaryTest.java
@@ -22,7 +22,6 @@ import org.apache.kylin.common.KylinConfig;
 import org.apache.kylin.common.util.TempMetadataBuilder;
 import org.apache.kylin.engine.spark.LocalWithSparkSessionTest;
 import org.apache.kylin.job.exception.SchedulerException;
-import org.apache.spark.sql.KylinSparkEnv;
 import org.apache.spark.sql.SparderContext;
 import org.junit.After;
 import org.junit.Assert;
@@ -34,11 +33,19 @@ public class SparderContextCanaryTest extends 
LocalWithSparkSessionTest {
     @Before
     public void setup() throws SchedulerException {
         super.setup();
-        SparderContext.setSparkSession(KylinSparkEnv.getSparkSession());
+        KylinConfig conf = KylinConfig.getInstanceFromEnv();
+        // the default value of kylin.query.spark-conf.spark.master is yarn,
+        // which will read from kylin-defaults.properties
+        conf.setProperty("kylin.query.spark-conf.spark.master", "local");
+        // create a new SparkSession of Sparder
+        SparderContext.initSpark();
     }
 
     @After
     public void after() {
+        SparderContext.stopSpark();
+        KylinConfig.getInstanceFromEnv()
+                .setProperty("kylin.query.spark-conf.spark.master", "yarn");
         super.after();
     }
 
@@ -49,7 +56,7 @@ public class SparderContextCanaryTest extends 
LocalWithSparkSessionTest {
         Assert.assertTrue(SparderContext.isSparkAvailable());
 
         // stop sparder and check again, the sparder context should 
auto-restart
-        SparderContext.getSparkSession().stop();
+        SparderContext.getOriginalSparkSession().stop();
         Assert.assertFalse(SparderContext.isSparkAvailable());
 
         SparderContextCanary.monitor();
diff --git 
a/kylin-spark-project/kylin-spark-query/src/test/java/org/apache/spark/sql/SparderContextFacadeTest.java
 
b/kylin-spark-project/kylin-spark-query/src/test/java/org/apache/spark/sql/SparderContextFacadeTest.java
new file mode 100644
index 0000000..9d6b606
--- /dev/null
+++ 
b/kylin-spark-project/kylin-spark-query/src/test/java/org/apache/spark/sql/SparderContextFacadeTest.java
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql;
+
+import org.apache.kylin.common.KylinConfig;
+import org.apache.kylin.engine.spark.LocalWithSparkSessionTest;
+import org.apache.kylin.job.exception.SchedulerException;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.Callable;
+import java.util.concurrent.CompletionService;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorCompletionService;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class SparderContextFacadeTest extends LocalWithSparkSessionTest {
+
+    private static final Logger logger = 
LoggerFactory.getLogger(SparderContextFacadeTest.class);
+    private static final Integer TEST_SIZE = 16 * 1024 * 1024;
+
+    @Override
+    @Before
+    public void setup() throws SchedulerException {
+        super.setup();
+        KylinConfig conf = KylinConfig.getInstanceFromEnv();
+        // the default value of kylin.query.spark-conf.spark.master is yarn,
+        // which will read from kylin-defaults.properties
+        conf.setProperty("kylin.query.spark-conf.spark.master", "local");
+        // Init Sparder
+        SparderContext.getOriginalSparkSession();
+    }
+
+    @After
+    public void after() {
+        SparderContext.stopSpark();
+        KylinConfig.getInstanceFromEnv()
+                .setProperty("kylin.query.spark-conf.spark.master", "yarn");
+        super.after();
+    }
+
+    @Test
+    public void testThreadSparkSession() throws InterruptedException, 
ExecutionException {
+        ThreadPoolExecutor executor = new ThreadPoolExecutor(5, 5, 1,
+                TimeUnit.DAYS, new LinkedBlockingQueue<>(5));
+
+        // test the thread local SparkSession
+        CompletionService<Throwable> service = 
runThreadSparkSessionTest(executor, false);
+
+        for (int i = 1; i <= 5; i++) {
+            Assert.assertNull(service.take().get());
+        }
+
+        // test the original SparkSession, it must throw errors.
+        service = runThreadSparkSessionTest(executor, true);
+        boolean hasError = false;
+        for (int i = 1; i <= 5; i++) {
+            if (service.take().get() != null) {
+                hasError = true;
+            }
+        }
+        Assert.assertTrue(hasError);
+
+        executor.shutdown();
+    }
+
+    protected CompletionService<Throwable> 
runThreadSparkSessionTest(ThreadPoolExecutor executor,
+                                                        boolean isOriginal) {
+        List<TestCallable> tasks = new ArrayList<>();
+        for (int i = 1; i <= 5; i++) {
+            tasks.add(new TestCallable(String.valueOf(TEST_SIZE * i), 
String.valueOf(i), isOriginal));
+        }
+
+        CompletionService<Throwable> service = new 
ExecutorCompletionService<>(executor);
+        for (TestCallable task : tasks) {
+            service.submit(task);
+        }
+        return service;
+    }
+
+    class TestCallable implements Callable<Throwable> {
+
+        private String maxPartitionBytes = null;
+        private String shufflePartitions = null;
+        private boolean isOriginal = false;
+
+        TestCallable(String maxPartitionBytes, String shufflePartitions, 
boolean isOriginal) {
+            this.maxPartitionBytes = maxPartitionBytes;
+            this.shufflePartitions = shufflePartitions;
+            this.isOriginal = isOriginal;
+        }
+
+        @Override
+        public Throwable call() throws Exception {
+            try {
+                SparkSession ss = null;
+                if (!this.isOriginal) {
+                    ss = SparderContext.getSparkSession();
+                } else {
+                    ss = SparderContext.getOriginalSparkSession();
+                }
+                ss.conf().set("spark.sql.files.maxPartitionBytes", 
this.maxPartitionBytes);
+                ss.conf().set("spark.sql.shuffle.partitions", 
this.shufflePartitions);
+
+                Thread.sleep((new Random()).nextInt(2) * 1000L);
+                Assert.assertEquals(this.maxPartitionBytes,
+                        ss.conf().get("spark.sql.files.maxPartitionBytes"));
+                Assert.assertEquals(this.shufflePartitions,
+                        ss.conf().get("spark.sql.shuffle.partitions"));
+            } catch (Throwable th) {
+                logger.error("Test thread local SparkSession error: ", th);
+                return th;
+            }
+            logger.info("Test thread local SparkSession successfully: {}");
+            return null;
+        }
+    }
+}
diff --git a/kylin-spark-project/kylin-spark-test/pom.xml 
b/kylin-spark-project/kylin-spark-test/pom.xml
index a3fb20f..4221983 100644
--- a/kylin-spark-project/kylin-spark-test/pom.xml
+++ b/kylin-spark-project/kylin-spark-test/pom.xml
@@ -47,12 +47,12 @@
         <dependency>
             <groupId>org.apache.kylin</groupId>
             <artifactId>kylin-spark-engine</artifactId>
-            <version>4.0.0-SNAPSHOT</version>
+            <version>${project.version}</version>
         </dependency>
         <dependency>
             <groupId>org.apache.kylin</groupId>
             <artifactId>kylin-spark-query</artifactId>
-            <version>4.0.0-SNAPSHOT</version>
+            <version>${project.version}</version>
         </dependency>
 
         <dependency>
@@ -142,7 +142,7 @@
         <dependency>
             <groupId>org.apache.kylin</groupId>
             <artifactId>kylin-spark-query</artifactId>
-            <version>4.0.0-SNAPSHOT</version>
+            <version>${project.version}</version>
             <type>test-jar</type>
             <scope>test</scope>
         </dependency>

Reply via email to