This is an automated email from the ASF dual-hosted git repository.

liyang pushed a commit to branch kylin5
in repository https://gitbox.apache.org/repos/asf/kylin.git

commit c4aad6e78b24c36cad6ce18a2d00dfd86527be0d
Author: Yaguang Jia <jiayagu...@foxmail.com>
AuthorDate: Sat Sep 2 09:16:38 2023 +0800

    KYLIN-5808 Optimize performance when saving model/adding CC
---
 .../engine/spark/utils/ComputedColumnEvalUtil.java | 20 +++-----
 .../job/stage/build/FlatTableAndDictBase.scala     | 57 ++++++++++++++++++++
 .../spark/smarter/IndexDependencyParser.scala      | 60 +++++++++++++++-------
 3 files changed, 106 insertions(+), 31 deletions(-)

diff --git 
a/src/spark-project/engine-spark/src/main/java/org/apache/kylin/engine/spark/utils/ComputedColumnEvalUtil.java
 
b/src/spark-project/engine-spark/src/main/java/org/apache/kylin/engine/spark/utils/ComputedColumnEvalUtil.java
index 603dbd7a5a..f7e9b7936e 100644
--- 
a/src/spark-project/engine-spark/src/main/java/org/apache/kylin/engine/spark/utils/ComputedColumnEvalUtil.java
+++ 
b/src/spark-project/engine-spark/src/main/java/org/apache/kylin/engine/spark/utils/ComputedColumnEvalUtil.java
@@ -26,6 +26,7 @@ import org.apache.kylin.common.exception.QueryErrorCode;
 import org.apache.kylin.common.msg.MsgPicker;
 import org.apache.kylin.engine.spark.job.NSparkCubingUtil;
 import org.apache.kylin.engine.spark.smarter.IndexDependencyParser;
+import org.apache.kylin.guava30.shaded.common.base.Preconditions;
 import org.apache.kylin.metadata.model.BadModelException;
 import org.apache.kylin.metadata.model.ComputedColumnDesc;
 import org.apache.kylin.metadata.model.NDataModel;
@@ -33,13 +34,9 @@ import 
org.apache.kylin.metadata.model.exception.IllegalCCExpressionException;
 import org.apache.kylin.metadata.model.util.ComputedColumnUtil;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SparderEnv;
-import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.util.SparderTypeUtil;
 import org.springframework.util.CollectionUtils;
 
-import org.apache.kylin.guava30.shaded.common.base.Preconditions;
-
 import lombok.extern.slf4j.Slf4j;
 
 @Slf4j
@@ -62,7 +59,7 @@ public class ComputedColumnEvalUtil {
     public static void evalDataTypeOfCCInAuto(List<ComputedColumnDesc> 
computedColumns, NDataModel nDataModel,
             int start, int end) {
         try {
-            evalDataTypeOfCC(computedColumns, SparderEnv.getSparkSession(), 
nDataModel, start, end);
+            evalDataTypeOfCC(computedColumns, nDataModel, start, end);
         } catch (Exception e) {
             if (end - start > 1) { //numbers of CC > 1
                 evalDataTypeOfCCInAuto(computedColumns, nDataModel, start, 
start + (end - start) / 2);
@@ -80,7 +77,7 @@ public class ComputedColumnEvalUtil {
             return;
         }
         try {
-            evalDataTypeOfCC(computedColumns, SparderEnv.getSparkSession(), 
nDataModel, 0, computedColumns.size());
+            evalDataTypeOfCC(computedColumns, nDataModel, 0, 
computedColumns.size());
         } catch (Exception e) {
             evalDataTypeOfCCInManual(computedColumns, nDataModel, 0, 
computedColumns.size());
         }
@@ -90,7 +87,7 @@ public class ComputedColumnEvalUtil {
             int start, int end) {
         for (int i = start; i < end; i++) {
             try {
-                evalDataTypeOfCC(computedColumns, 
SparderEnv.getSparkSession(), nDataModel, i, i + 1);
+                evalDataTypeOfCC(computedColumns, nDataModel, i, i + 1);
             } catch (Exception e) {
                 Preconditions.checkNotNull(computedColumns.get(i));
                 throw new 
IllegalCCExpressionException(QueryErrorCode.CC_EXPRESSION_ILLEGAL,
@@ -101,15 +98,14 @@ public class ComputedColumnEvalUtil {
         }
     }
 
-    private static void evalDataTypeOfCC(List<ComputedColumnDesc> 
computedColumns, SparkSession ss,
-            NDataModel nDataModel, int start, int end) {
+    private static void evalDataTypeOfCC(List<ComputedColumnDesc> 
computedColumns, NDataModel nDataModel, int start,
+            int end) {
         IndexDependencyParser parser = new IndexDependencyParser(nDataModel);
-        Dataset<Row> originDf = parser.generateFullFlatTableDF(ss, nDataModel);
-        originDf.persist();
+        Dataset<Row> df = parser.getFullFlatTableDataFrame(nDataModel);
         String[] ccExprArray = computedColumns.subList(start, end).stream() //
                 .map(ComputedColumnDesc::getInnerExpression) //
                 
.map(NSparkCubingUtil::convertFromDotWithBackTick).toArray(String[]::new);
-        Dataset<Row> ds = originDf.selectExpr(ccExprArray);
+        Dataset<Row> ds = df.selectExpr(ccExprArray);
         for (int i = start; i < end; i++) {
             String dataType = 
SparderTypeUtil.convertSparkTypeToSqlType(ds.schema().fields()[i - 
start].dataType());
             computedColumns.get(i).setDatatype(dataType);
diff --git 
a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/job/stage/build/FlatTableAndDictBase.scala
 
b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/job/stage/build/FlatTableAndDictBase.scala
index 4634b77096..935eaefaea 100644
--- 
a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/job/stage/build/FlatTableAndDictBase.scala
+++ 
b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/job/stage/build/FlatTableAndDictBase.scala
@@ -46,6 +46,8 @@ import 
org.apache.kylin.metadata.cube.planner.CostBasePlannerUtils
 import org.apache.kylin.metadata.model._
 import org.apache.spark.sql.KapFunctions.dict_encode_v3
 import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.catalyst.plans.logical.{Join, JoinHint, 
LogicalPlan}
 import org.apache.spark.sql.functions.{col, expr}
 import org.apache.spark.sql.types.StructField
 import org.apache.spark.sql.util.SparderTypeUtil
@@ -718,6 +720,19 @@ object FlatTableAndDictBase extends LogEx {
     newDS
   }
 
+  def wrapAlias(originPlan: LogicalPlan, alias: String, needLog: Boolean): 
LogicalPlan = {
+    val newFields = originPlan.output
+      .map(f => {
+        val aliasDotColName = "`" + alias + "`" + "." + "`" + f.name + "`"
+        convertFromDot(aliasDotColName)
+      })
+    val newDS = SparkOperation.projectAsAlias(newFields, originPlan)
+    if (needLog) {
+      logInfo(s"Wrap ALIAS ${originPlan.schema.treeString} TO 
${newDS.schema.treeString}")
+    }
+    newDS
+  }
+
 
   def joinFactTableWithLookupTables(rootFactDataset: Dataset[Row],
                                     lookupTableDatasetMap: 
mutable.Map[JoinTableDesc, Dataset[Row]],
@@ -728,6 +743,15 @@ object FlatTableAndDictBase extends LogEx {
         joinTableDataset(model.getRootFactTable.getTableDesc, tuple._1, 
joinedDataset, tuple._2, needLog))
   }
 
+  def joinFactTableWithLookupTables(rootFactPlan: LogicalPlan,
+                                    lookupTableDatasetMap: 
mutable.Map[JoinTableDesc, LogicalPlan],
+                                    model: NDataModel,
+                                    needLog: Boolean): LogicalPlan = {
+    lookupTableDatasetMap.foldLeft(rootFactPlan)(
+      (joinedDataset: LogicalPlan, tuple: (JoinTableDesc, LogicalPlan)) =>
+        joinTableLogicalPlan(model.getRootFactTable.getTableDesc, tuple._1, 
joinedDataset, tuple._2, needLog))
+  }
+
   def joinTableDataset(rootFactDesc: TableDesc,
                        lookupDesc: JoinTableDesc,
                        rootFactDataset: Dataset[Row],
@@ -765,6 +789,39 @@ object FlatTableAndDictBase extends LogEx {
     afterJoin
   }
 
+  def joinTableLogicalPlan(rootFactDesc: TableDesc,
+                           lookupDesc: JoinTableDesc,
+                           rootFactPlan: LogicalPlan,
+                           lookupPlan: LogicalPlan,
+                           needLog: Boolean = true): LogicalPlan = {
+    var afterJoin = rootFactPlan
+    val join = lookupDesc.getJoin
+    if (join != null && !StringUtils.isEmpty(join.getType)) {
+      val joinType = join.getType.toUpperCase(Locale.ROOT)
+      val pk = join.getPrimaryKeyColumns
+      val fk = join.getForeignKeyColumns
+      if (pk.length != fk.length) {
+        throw new RuntimeException(
+          s"Invalid join condition of fact table: $rootFactDesc,fk: 
${fk.mkString(",")}," +
+            s" lookup table:$lookupDesc, pk: ${pk.mkString(",")}")
+      }
+      val equiConditionColPairs = fk.zip(pk).map(joinKey =>
+        col(convertFromDot(joinKey._1.getBackTickIdentity))
+          .equalTo(col(convertFromDot(joinKey._2.getBackTickIdentity))))
+
+      if (join.getNonEquiJoinCondition != null) {
+        val condition: Column = getCondition(join)
+        logInfo(s"Root table ${rootFactDesc.getIdentity}, join table 
${lookupDesc.getAlias}, non-equi condition: ${condition.toString()}")
+        afterJoin = Join(afterJoin, lookupPlan, JoinType.apply(joinType), 
Option.apply(condition.expr), JoinHint.NONE)
+      } else {
+        val condition = equiConditionColPairs.reduce(_ && _)
+        logInfo(s"Root table ${rootFactDesc.getIdentity}, join table 
${lookupDesc.getAlias}, condition: ${condition.toString()}")
+        afterJoin = Join(afterJoin, lookupPlan, JoinType.apply(joinType), 
Option.apply(condition.expr), JoinHint.NONE)
+      }
+    }
+    afterJoin
+  }
+
   def getCondition(join: JoinDesc): Column = {
     val pk = join.getPrimaryKeyColumns
     val fk = join.getForeignKeyColumns
diff --git 
a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/smarter/IndexDependencyParser.scala
 
b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/smarter/IndexDependencyParser.scala
index 44e52099d7..9da76b7996 100644
--- 
a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/smarter/IndexDependencyParser.scala
+++ 
b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/smarter/IndexDependencyParser.scala
@@ -17,9 +17,6 @@
  */
 package org.apache.kylin.engine.spark.smarter
 
-import java.util
-import java.util.Collections
-
 import org.apache.commons.collections.CollectionUtils
 import org.apache.commons.lang3.StringUtils
 import org.apache.kylin.engine.spark.job.NSparkCubingUtil
@@ -28,10 +25,15 @@ import 
org.apache.kylin.guava30.shaded.common.collect.{Lists, Maps, Sets}
 import org.apache.kylin.metadata.cube.model.LayoutEntity
 import org.apache.kylin.metadata.model._
 import org.apache.kylin.query.util.PushDownUtil
+import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, 
LogicalPlan, SubqueryAlias}
 import org.apache.spark.sql.execution.utils.SchemaProcessor
-import org.apache.spark.sql.types.StructField
-import org.apache.spark.sql.{Dataset, Row, SparderEnv, SparkSession}
+import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.types.{StructField, StructType}
 
+import java.util
+import java.util.Collections
 import scala.collection.JavaConverters._
 import scala.collection.mutable
 
@@ -40,8 +42,15 @@ class IndexDependencyParser(val model: NDataModel) {
   private val ccTableNameAliasMap = Maps.newHashMap[String, util.Set[String]]
   private val joinTableAliasMap = Maps.newHashMap[String, util.Set[String]]
   private val allTablesAlias = Sets.newHashSet[String]
+  private var fullFlatTableDF : Option[Dataset[Row]] = None
   initTableNames()
-
+  def getFullFlatTableDataFrame(model: NDataModel): Dataset[Row] = {
+    if (fullFlatTableDF.isDefined) {
+      fullFlatTableDF.get
+    } else {
+      generateFullFlatTableDF(model)
+    }
+  }
   def getRelatedTablesAlias(layouts: util.Collection[LayoutEntity]): 
util.List[String] = {
     val relatedTables = Sets.newHashSet[String]
     layouts.asScala.foreach(layout => 
relatedTables.addAll(getRelatedTablesAlias(layout)))
@@ -105,34 +114,47 @@ class IndexDependencyParser(val model: NDataModel) {
     }
   }
 
-  def generateFullFlatTableDF(ss: SparkSession, model: NDataModel): 
Dataset[Row] = {
-    val rootDF = generateDatasetOnTable(ss, model.getRootFactTable)
+  def generateFullFlatTableDF(model: NDataModel): Dataset[Row] = {
+    val rootLogicalPlan = generateLogicalPlanOnTable(model.getRootFactTable)
     // look up tables
-    val joinTableDFMap = mutable.LinkedHashMap[JoinTableDesc, Dataset[Row]]()
+    val joinTableDFMap = mutable.LinkedHashMap[JoinTableDesc, LogicalPlan]()
     model.getJoinTables.asScala.map((joinTable: JoinTableDesc) => {
-      joinTableDFMap.put(joinTable, generateDatasetOnTable(ss, 
joinTable.getTableRef))
+      joinTableDFMap.put(joinTable, 
generateLogicalPlanOnTable(joinTable.getTableRef))
     })
-    val df = FlatTableAndDictBase.joinFactTableWithLookupTables(rootDF, 
joinTableDFMap, model, needLog = false)
+    val df = 
FlatTableAndDictBase.joinFactTableWithLookupTables(rootLogicalPlan, 
joinTableDFMap, model, needLog = false)
     val filterCondition = model.getFilterCondition
     if (StringUtils.isNotEmpty(filterCondition)) {
       val massagedCondition = PushDownUtil.massageExpression(model, 
model.getProject, filterCondition, null)
-      df.where(NSparkCubingUtil.convertFromDotWithBackTick(massagedCondition))
+      val condition = NSparkCubingUtil.convertFromDot(massagedCondition)
+      SparkOperation.filter(col(condition), df)
     }
-    df
+    SparkInternalAgent.getDataFrame(SparderEnv.getSparkSession, df)
   }
 
-  private def generateDatasetOnTable(ss: SparkSession, tableRef: TableRef): 
Dataset[Row] = {
+  private def generateLogicalPlanOnTable(tableRef: TableRef): LogicalPlan = {
     val tableCols = tableRef.getColumns.asScala.map(_.getColumnDesc).toArray
     val structType = SchemaProcessor.buildSchemaWithRawTable(tableCols)
     val alias = tableRef.getAlias
-    val dataset = ss.createDataFrame(Lists.newArrayList[Row], 
structType).alias(alias)
-    FlatTableAndDictBase.wrapAlias(dataset, alias, needLog = false)
+    val fsRelation = LocalRelation(toAttributes(structType))
+    val plan = SubqueryAlias(alias, fsRelation)
+    FlatTableAndDictBase.wrapAlias(plan, alias, needLog = false)
+  }
+
+  def toAttribute(field: StructField): AttributeReference =
+    AttributeReference(field.name, field.dataType, field.nullable, 
field.metadata)()
+
+  /**
+   * Convert a [[StructType]] into a Seq of [[AttributeReference]].
+   */
+  def toAttributes(schema: StructType): Seq[AttributeReference] = {
+    schema.map(toAttribute)
   }
 
   private def initTableNames(): Unit = {
-    val ccList = model.getComputedColumnDescs
-    val originDf = generateFullFlatTableDF(SparderEnv.getSparkSession, model)
+    fullFlatTableDF = Option(generateFullFlatTableDF(model))
+    val originDf = fullFlatTableDF.get
     val colFields = originDf.schema.fields
+    val ccList = model.getComputedColumnDescs
     val ds = originDf.selectExpr(ccList.asScala.map(_.getInnerExpression)
       .map(NSparkCubingUtil.convertFromDotWithBackTick): _*)
     ccList.asScala.zip(ds.schema.fields).foreach(pair => {
@@ -156,7 +178,7 @@ class IndexDependencyParser(val model: NDataModel) {
 
   def unwrapComputeColumn(ccInnerExpression: String): java.util.Set[TblColRef] 
= {
     val result: util.Set[TblColRef] = Sets.newHashSet()
-    val originDf = generateFullFlatTableDF(SparderEnv.getSparkSession, model)
+    val originDf = getFullFlatTableDataFrame(model)
     val colFields = originDf.schema.fields
     val ccDs = 
originDf.selectExpr(NSparkCubingUtil.convertFromDotWithBackTick(ccInnerExpression))
     ccDs.schema.fields.foreach(fieldName => {

Reply via email to