This is an automated email from the ASF dual-hosted git repository. xxyu pushed a commit to branch kylin5 in repository https://gitbox.apache.org/repos/asf/kylin.git
commit 12b23b44e2372db8c95b0391c2e99d3b30adbdbf Author: Mingming Ge <7mmi...@gmail.com> AuthorDate: Wed Mar 29 19:17:56 2023 +0800 KYLIN-5603 fix V3 dict upgrade --- .../spark/builder/v3dict/DictionaryBuilder.scala | 146 +++++++++++++++------ .../v3dict/GlobalDictionaryBuilderHelper.scala | 8 +- .../v3dict/GlobalDictionaryPlaceHolder.scala | 3 +- .../v3dict/PreCountDistinctTransformer.scala | 13 +- .../job/stage/build/FlatTableAndDictBase.scala | 3 +- .../builder/v3dict/GlobalDictionarySuite.scala | 31 +++-- .../v3dict/GlobalDictionaryUpdateSuite.scala | 17 +-- .../scala/org/apache/spark/sql/KapFunctions.scala | 4 +- .../sql/catalyst/expressions/KapExpresssions.scala | 2 +- 9 files changed, 149 insertions(+), 78 deletions(-) diff --git a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/v3dict/DictionaryBuilder.scala b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/v3dict/DictionaryBuilder.scala index 872e030d17..3b88c8af43 100644 --- a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/v3dict/DictionaryBuilder.scala +++ b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/v3dict/DictionaryBuilder.scala @@ -21,6 +21,7 @@ import io.delta.tables.DeltaTable import org.apache.hadoop.fs.Path import org.apache.kylin.common.KylinConfig import org.apache.kylin.common.util.HadoopUtil +import org.apache.kylin.engine.spark.builder.v3dict.DictBuildMode.{V2UPGRADE, V3APPEND, V3INIT, V3UPGRADE} import org.apache.kylin.engine.spark.job.NSparkCubingUtil import org.apache.kylin.metadata.model.TblColRef import org.apache.spark.dict.{NBucketDictionary, NGlobalDictionaryV2} @@ -35,6 +36,7 @@ import org.apache.spark.sql.{Dataset, Row, SaveMode, SparkSession} import util.retry.blocking.RetryStrategy.RetryStrategyProducer import util.retry.blocking.{Failure, Retry, RetryStrategy, Success} +import java.nio.file.Paths import scala.collection.mutable.ListBuffer import scala.concurrent.duration.DurationInt @@ -43,18 +45,17 @@ object DictionaryBuilder extends Logging { implicit val retryStrategy: RetryStrategyProducer = RetryStrategy.fixedBackOff(retryDuration = 10.seconds, maxAttempts = 5) - private val config = KylinConfig.getInstanceFromEnv - def buildGlobalDict( - project: String, - spark: SparkSession, - plan: LogicalPlan): LogicalPlan = transformCountDistinct(spark, plan) transform { + project: String, + spark: SparkSession, + plan: LogicalPlan): LogicalPlan = transformCountDistinct(spark, plan) transform { - case GlobalDictionaryPlaceHolder(expr: String, child: LogicalPlan) => + case GlobalDictionaryPlaceHolder(expr: String, child: LogicalPlan, dbName: String) => spark.sparkContext.setJobDescription(s"Build v3 dict $expr") - val tableName = expr.split(NSparkCubingUtil.SEPARATOR).apply(0) - val columnName = expr.split(NSparkCubingUtil.SEPARATOR).apply(1) - val context = new DictionaryContext(project, tableName, columnName, expr) + val catalog = expr.split(NSparkCubingUtil.SEPARATOR) + val tableName = catalog.apply(0) + val columnName = catalog.apply(1) + val context = new DictionaryContext(project, dbName, tableName, columnName, expr) // concurrent commit may cause delta ConcurrentAppendException. // so need retry commit incremental dict to delta table. @@ -78,9 +79,9 @@ object DictionaryBuilder extends Logging { * Use Left anti join to process raw data and dictionary tables. */ private def transformerDictPlan( - spark: SparkSession, - context: DictionaryContext, - plan: LogicalPlan): LogicalPlan = { + spark: SparkSession, + context: DictionaryContext, + plan: LogicalPlan): LogicalPlan = { val dictPath = getDictionaryPath(context) val dictTable: DeltaTable = DeltaTable.forPath(dictPath) @@ -105,32 +106,47 @@ object DictionaryBuilder extends Logging { } } + private def chooseDictBuildMode(context: DictionaryContext): DictBuildMode.Value = { + val config = KylinConfig.getInstanceFromEnv + if (isExistsV3Dict(context)) { + V3APPEND + } else if (isExistsOriginalV3Dict(context)) { + V3UPGRADE + } else if (config.isConvertV3DictEnable && isExistsV2Dict(context)) { + V2UPGRADE + } else V3INIT + } + /** * Build an incremental dictionary */ private def incrementBuildDict( - spark: SparkSession, - plan: LogicalPlan, - context: DictionaryContext): Unit = { - val config = KylinConfig.getInstanceFromEnv - val dictPath = getDictionaryPath(context) - if(DeltaTable.isDeltaTable(spark, dictPath)) { - mergeIncrementDict(spark, context, plan) - } else if (config.isConvertV3DictEnable - && isExistsV2Dict(context) - && !isExistsV3Dict(context)) { - val existsV2DictDF = fetchExistsV2Dict(spark, context) - appendDictDF(existsV2DictDF, context) - mergeIncrementDict(spark, context, plan) - } else { - val incrementDictDF = getDataFrame(spark, plan) - appendDictDF(incrementDictDF, context) + spark: SparkSession, + plan: LogicalPlan, + context: DictionaryContext): Unit = { + val dictMode = chooseDictBuildMode(context) + logInfo(s"V3 Dict build mode is $dictMode") + dictMode match { + case V3INIT => + val dictDF = getDataFrame(spark, plan) + initAndSaveDictDF(dictDF, context) + case V3APPEND => + mergeIncrementDict(spark, context, plan) + // To be delete + case V3UPGRADE => + val v3OrigDict = upgradeFromOriginalV3(spark, context) + initAndSaveDictDF(v3OrigDict, context) + mergeIncrementDict(spark, context, plan) + case V2UPGRADE => + val v2Dict = upgradeFromV2(spark, context) + initAndSaveDictDF(v2Dict, context) + mergeIncrementDict(spark, context, plan) } } - private def appendDictDF(dictDF: Dataset[Row], context: DictionaryContext): Unit = { + private def initAndSaveDictDF(dictDF: Dataset[Row], context: DictionaryContext): Unit = { val dictPath = getDictionaryPath(context) - logInfo(s"Append dict values into path $dictPath.") + logInfo(s"Save dict values into path $dictPath.") dictDF.write.mode(SaveMode.Overwrite).format("delta").save(dictPath) } @@ -149,10 +165,17 @@ object DictionaryBuilder extends Logging { } private def isExistsV2Dict(context: DictionaryContext): Boolean = { + val config = KylinConfig.getInstanceFromEnv val globalDict = new NGlobalDictionaryV2(context.project, - context.tableName, context.columnName, config.getHdfsWorkingDirectory) + context.dbName + "." + context.tableName, context.columnName, config.getHdfsWorkingDirectory) val dictV2Meta = globalDict.getMetaInfo - dictV2Meta != null + if (dictV2Meta != null) { + logInfo(s"Exists V2 dict ${globalDict.getResourceDir}") + true + } else { + logInfo(s"Not exists V2 dict ${globalDict.getResourceDir}") + false + } } private def isExistsV3Dict(context: DictionaryContext): Boolean = { @@ -160,10 +183,28 @@ object DictionaryBuilder extends Logging { HadoopUtil.getWorkingFileSystem.exists(new Path(dictPath)) } - private def fetchExistsV2Dict(spark: SparkSession, context: DictionaryContext): Dataset[Row] = { + private def isExistsOriginalV3Dict(context: DictionaryContext): Boolean = { + val dictPath = getOriginalDictionaryPath(context) + HadoopUtil.getWorkingFileSystem.exists(new Path(dictPath)) + } + + private def fetchExistsOriginalV3Dict(context: DictionaryContext): Dataset[Row] = { + val originalV3DictPath = getOriginalDictionaryPath(context) + val v3dictTable = DeltaTable.forPath(originalV3DictPath) + v3dictTable.toDF + } + + private def transformCountDistinct(session: SparkSession, plan: LogicalPlan): LogicalPlan = { + val transformer = new PreCountDistinctTransformer(session) + transformer.apply(plan) + } + + private def upgradeFromV2(spark: SparkSession, context: DictionaryContext): Dataset[Row] = { + val config = KylinConfig.getInstanceFromEnv val globalDict = new NGlobalDictionaryV2(context.project, - context.tableName, context.columnName, config.getHdfsWorkingDirectory) + context.dbName + "." + context.tableName, context.columnName, config.getHdfsWorkingDirectory) val dictV2Meta = globalDict.getMetaInfo + logInfo(s"Exists V2 dict ${globalDict.getResourceDir} num ${dictV2Meta.getDictCount}") val broadcastDict = spark.sparkContext.broadcast(globalDict) val dictSchema = new StructType(Array(StructField("dict_key", StringType), StructField("dict_value", LongType))) @@ -186,12 +227,15 @@ object DictionaryBuilder extends Logging { } } - private def transformCountDistinct(session: SparkSession, plan: LogicalPlan): LogicalPlan = { - val transformer = new PreCountDistinctTransformer(session) - transformer.apply(plan) + private def upgradeFromOriginalV3(spark: SparkSession, context: DictionaryContext): Dataset[Row] = { + if (isExistsOriginalV3Dict(context)) { + fetchExistsOriginalV3Dict(context) + } else { + spark.emptyDataFrame + } } - def getDictionaryPath(context: DictionaryContext): String = { + private def getOriginalDictionaryPath(context: DictionaryContext): String = { val config = KylinConfig.getInstanceFromEnv val workingDir = config.getHdfsWorkingDirectory() val dictDir = new Path(context.project, new Path(HadoopUtil.GLOBAL_DICT_V3_STORAGE_ROOT, @@ -199,9 +243,31 @@ object DictionaryBuilder extends Logging { workingDir + dictDir } + def getDictionaryPath(context: DictionaryContext): String = { + val config = KylinConfig.getInstanceFromEnv + val workingDir = config.getHdfsWorkingDirectory() + val dictDir = Paths.get(context.project, + HadoopUtil.GLOBAL_DICT_V3_STORAGE_ROOT, + context.dbName, + context.tableName, + context.columnName) + workingDir + dictDir + } + def wrapCol(ref: TblColRef): String = { - NSparkCubingUtil.convertFromDot(ref.getColumnDesc.getBackTickIdentity) + NSparkCubingUtil.convertFromDot(ref.getBackTickIdentity) } } -class DictionaryContext(val project: String, val tableName: String, val columnName: String, val expr: String) \ No newline at end of file +class DictionaryContext( + val project: String, + val dbName: String, + val tableName: String, + val columnName: String, + val expr: String) + +object DictBuildMode extends Enumeration { + + val V3UPGRADE, V2UPGRADE, V3APPEND, V3INIT = Value + +} \ No newline at end of file diff --git a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/v3dict/GlobalDictionaryBuilderHelper.scala b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/v3dict/GlobalDictionaryBuilderHelper.scala index 523b40d8df..3641f26dc4 100644 --- a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/v3dict/GlobalDictionaryBuilderHelper.scala +++ b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/v3dict/GlobalDictionaryBuilderHelper.scala @@ -44,14 +44,14 @@ object GlobalDictionaryBuilderHelper { .map(_ => Row(RandomStringUtils.randomAlphabetic(length)))), schema) } - def genDataWithWrapEncodeCol(colName: String, df: Dataset[Row]): Dataset[Row] = { - val dictCol = dict_encode_v3(col(colName)) + def genDataWithWrapEncodeCol(dbName: String, colName: String, df: Dataset[Row]): Dataset[Row] = { + val dictCol = dict_encode_v3(col(colName), dbName) df.select(df.schema.map(ty => col(ty.name)) ++ Seq(dictCol): _*) } - def genDataWithWrapEncodeCol(spark: SparkSession, colName: String, count: Int, length: Int): Dataset[Row] = { + def genDataWithWrapEncodeCol(spark: SparkSession, dbName: String, colName: String, count: Int, length: Int): Dataset[Row] = { val df = genRandomData(spark, colName, count, length) - val dictCol = dict_encode_v3(col(colName)) + val dictCol = dict_encode_v3(col(colName), dbName) df.select(df.schema.map(ty => col(ty.name)) ++ Seq(dictCol): _*) } } diff --git a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/v3dict/GlobalDictionaryPlaceHolder.scala b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/v3dict/GlobalDictionaryPlaceHolder.scala index 67034d48ed..2ce163fb0e 100644 --- a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/v3dict/GlobalDictionaryPlaceHolder.scala +++ b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/v3dict/GlobalDictionaryPlaceHolder.scala @@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} case class GlobalDictionaryPlaceHolder( exprName: String, - child: LogicalPlan) extends UnaryNode { + child: LogicalPlan, + dbName: String) extends UnaryNode { override def output: Seq[Attribute] = child.output diff --git a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/v3dict/PreCountDistinctTransformer.scala b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/v3dict/PreCountDistinctTransformer.scala index 90a7bd1786..f4218df160 100644 --- a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/v3dict/PreCountDistinctTransformer.scala +++ b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/v3dict/PreCountDistinctTransformer.scala @@ -33,28 +33,29 @@ class PreCountDistinctTransformer(spark: SparkSession) extends Rule[LogicalPlan] case project@ Project(_, child) => val relatedFields = scala.collection.mutable.Queue[CountDistExprInfo]() project.transformExpressions { - case DictEncodeV3(child) => + case DictEncodeV3(child, dbName) => val deAttr = AttributeReference("dict_encoded_" + child.prettyName, LongType, nullable = false)(NamedExpression.newExprId, Seq.empty[String]) - relatedFields += CountDistExprInfo(child, deAttr) + relatedFields += CountDistExprInfo(child, deAttr, dbName) createColumn(deAttr).expr }.withNewChildren { val dictionaries = relatedFields.map { - case CountDistExprInfo(childExpr, encodedAttr) => + case CountDistExprInfo(childExpr, encodedAttr, dbName) => val windowSpec = Window.orderBy(createColumn(childExpr)) val exprName = childExpr match { case ne: NamedExpression => ne.name case _ => childExpr.prettyName } + logInfo(s"Count distinct expr name $exprName") val dictPlan = GlobalDictionaryPlaceHolder(exprName, getLogicalPlan( getDataFrame(spark, child).groupBy(createColumn(childExpr)).agg( createColumn(childExpr)).select( createColumn(childExpr).cast(StringType) as "dict_key", - row_number().over(windowSpec).cast(LongType) as "dict_value"))) + row_number().over(windowSpec).cast(LongType) as "dict_value")), dbName) val key = dictPlan.output.head val value = dictPlan.output(1) val valueAlias = Alias(value, encodedAttr.name)(encodedAttr.exprId) - (Project(Seq(key, valueAlias), dictPlan), (childExpr, encodedAttr)) + (Project(Seq(key, valueAlias), dictPlan), (childExpr, encodedAttr), dbName) } val result = dictionaries.foldLeft(child) { @@ -75,4 +76,4 @@ class PreCountDistinctTransformer(spark: SparkSession) extends Rule[LogicalPlan] } } -case class CountDistExprInfo(childExpr: Expression, encodedAttr: AttributeReference) +case class CountDistExprInfo(childExpr: Expression, encodedAttr: AttributeReference, dbName: String) diff --git a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/job/stage/build/FlatTableAndDictBase.scala b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/job/stage/build/FlatTableAndDictBase.scala index 3117741c9e..5678b2f5f6 100644 --- a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/job/stage/build/FlatTableAndDictBase.scala +++ b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/job/stage/build/FlatTableAndDictBase.scala @@ -577,7 +577,8 @@ abstract class FlatTableAndDictBase(private val jobContext: SegmentJob, val matchedCols = selectColumnsInTable(table, dictCols) val cols = matchedCols.map { dictColumn => val wrapDictCol = DictionaryBuilder.wrapCol(dictColumn) - dict_encode_v3(col(wrapDictCol)).alias(wrapDictCol + "_KE_ENCODE") + val dbName = dictColumn.getTableRef.getTableDesc.getDatabase + dict_encode_v3(col(wrapDictCol), dbName).alias(wrapDictCol + "_KE_ENCODE") }.toSeq val dictPlan = table .select(table.schema.map(ty => col(ty.name)) ++ cols: _*) diff --git a/src/spark-project/engine-spark/src/test/scala/org/apache/kylin/engine/spark/builder/v3dict/GlobalDictionarySuite.scala b/src/spark-project/engine-spark/src/test/scala/org/apache/kylin/engine/spark/builder/v3dict/GlobalDictionarySuite.scala index 76092b927b..e93b933df9 100644 --- a/src/spark-project/engine-spark/src/test/scala/org/apache/kylin/engine/spark/builder/v3dict/GlobalDictionarySuite.scala +++ b/src/spark-project/engine-spark/src/test/scala/org/apache/kylin/engine/spark/builder/v3dict/GlobalDictionarySuite.scala @@ -34,16 +34,17 @@ class GlobalDictionarySuite extends SparderBaseFunSuite with LocalMetadata with test("KE-35145 Test Continuously Build Dictionary") { val project = "p1" - val tableName = "t1" + val dbName = "db1" + val tableName = "t2" val colName = "c1" val encodeColName: String = tableName + NSparkCubingUtil.SEPARATOR + colName - val context = new DictionaryContext(project, tableName, colName, null) + val context = new DictionaryContext(project, dbName, tableName, colName, null) DeltaTable.createIfNotExists() .tableName("original_c1") .addColumn(encodeColName, StringType).execute() for (_ <- 0 until 10) { val originalDF = genRandomData(spark, encodeColName, 100, 1) - val df = genDataWithWrapEncodeCol(encodeColName, originalDF) + val df = genDataWithWrapEncodeCol(dbName, encodeColName, originalDF) DeltaTable.forName("original_c1") .merge(originalDF, "1 != 1") .whenNotMatched() @@ -56,7 +57,7 @@ class GlobalDictionarySuite extends SparderBaseFunSuite with LocalMetadata with val dictDF = DeltaTable.forPath(dictPath).toDF.agg(count(col("dict_key"))) val originalDF = spark.sql( """ - |SELECT count(DISTINCT t1_0_DOT_0_c1) + |SELECT count(DISTINCT t2_0_DOT_0_c1) | FROM default.original_c1 """.stripMargin) checkAnswer(originalDF, dictDF) @@ -64,10 +65,11 @@ class GlobalDictionarySuite extends SparderBaseFunSuite with LocalMetadata with test("KE-35145 Test Concurrent Build Dictionary") { val project = "p1" + val dbName = "db1" val tableName = "t1" val colName = "c2" val encodeColName: String = tableName + NSparkCubingUtil.SEPARATOR + colName - val context = new DictionaryContext(project, tableName, colName, null) + val context = new DictionaryContext(project, dbName, tableName, colName, null) val pool = Executors.newFixedThreadPool(10) implicit val ec: ExecutionContextExecutorService = ExecutionContext.fromExecutorService(pool) @@ -78,7 +80,7 @@ class GlobalDictionarySuite extends SparderBaseFunSuite with LocalMetadata with val buildDictTask = new Runnable { override def run(): Unit = { val originalDF = genRandomData(spark, encodeColName, 100, 1) - val dictDF = genDataWithWrapEncodeCol(encodeColName, originalDF) + val dictDF = genDataWithWrapEncodeCol(dbName, encodeColName, originalDF) DeltaTable.forName("original_c2") .merge(originalDF, "1 != 1") .whenNotMatched() @@ -104,12 +106,13 @@ class GlobalDictionarySuite extends SparderBaseFunSuite with LocalMetadata with test("KE-35145 Test the v3 dictionary with random data") { val project = "p1" + val dbName = "db1" val tableName = "t1" val colName = "c3" val encodeColName: String = tableName + NSparkCubingUtil.SEPARATOR + colName - val context = new DictionaryContext(project, tableName, colName, null) + val context = new DictionaryContext(project, dbName, tableName, colName, null) val df = genRandomData(spark, encodeColName, 1000, 2) - val dictDF = genDataWithWrapEncodeCol(encodeColName, df) + val dictDF = genDataWithWrapEncodeCol(dbName, encodeColName, df) DictionaryBuilder.buildGlobalDict(project, spark, dictDF.queryExecution.analyzed) val originalDF = df.agg(countDistinct(encodeColName)) @@ -120,10 +123,11 @@ class GlobalDictionarySuite extends SparderBaseFunSuite with LocalMetadata with test("KE-35145 With null dict value") { val project = "p1" + val dbName = "db1" val tableName = "t1" val colName = "c4" val encodeColName: String = tableName + NSparkCubingUtil.SEPARATOR + colName - val context = new DictionaryContext(project, tableName, colName, null) + val context = new DictionaryContext(project, dbName, tableName, colName, null) var schema = new StructType schema = schema.add(encodeColName, StringType) val data = Seq( @@ -141,7 +145,7 @@ class GlobalDictionarySuite extends SparderBaseFunSuite with LocalMetadata with Row("a"), Row("b")) - val dictCol = Seq(dict_encode_v3(col(encodeColName)).alias(colName + "_KE_ENCODE")) + val dictCol = Seq(dict_encode_v3(col(encodeColName), dbName).alias(colName + "_KE_ENCODE")) val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) val dictDfPlan = df @@ -157,22 +161,23 @@ class GlobalDictionarySuite extends SparderBaseFunSuite with LocalMetadata with test("KE-35145 Build dict with null value") { val project = "p1" + val dbName = "db1" val tableName = "t1" val colName = "c5" val encodeColName: String = tableName + NSparkCubingUtil.SEPARATOR + colName - val context = new DictionaryContext(project, tableName, colName, null) + val context = new DictionaryContext(project, dbName, tableName, colName, null) var schema = new StructType schema = schema.add(encodeColName, StringType) val data = Seq.empty[Row] - val dictCol = Seq(dict_encode_v3(col(encodeColName)).alias(encodeColName + "_KE_ENCODE")) + val dictCol = Seq(dict_encode_v3(col(encodeColName), dbName).alias(encodeColName + "_KE_ENCODE")) val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) val dictDfPlan = df .select(df.schema.map(ty => col(ty.name)) ++ dictCol: _*) .queryExecution .analyzed - DictionaryBuilder.buildGlobalDict("p2", spark, dictDfPlan) + DictionaryBuilder.buildGlobalDict(project, spark, dictDfPlan) val originalDF = df.agg(countDistinct(encodeColName)) val dictPath: String = DictionaryBuilder.getDictionaryPath(context) val dictResultDF = DeltaTable.forPath(dictPath).toDF.agg(count(col("dict_key"))) diff --git a/src/spark-project/engine-spark/src/test/scala/org/apache/kylin/engine/spark/builder/v3dict/GlobalDictionaryUpdateSuite.scala b/src/spark-project/engine-spark/src/test/scala/org/apache/kylin/engine/spark/builder/v3dict/GlobalDictionaryUpdateSuite.scala index b76dafbade..71b49a0924 100644 --- a/src/spark-project/engine-spark/src/test/scala/org/apache/kylin/engine/spark/builder/v3dict/GlobalDictionaryUpdateSuite.scala +++ b/src/spark-project/engine-spark/src/test/scala/org/apache/kylin/engine/spark/builder/v3dict/GlobalDictionaryUpdateSuite.scala @@ -26,7 +26,7 @@ import org.apache.kylin.engine.spark.builder.{DFDictionaryBuilder, DictionaryBui import org.apache.kylin.engine.spark.job.NSparkCubingUtil import org.apache.kylin.metadata.cube.cuboid.NSpanningTreeFactory import org.apache.kylin.metadata.cube.model.{NDataSegment, NDataflow, NDataflowManager} -import org.apache.kylin.metadata.model.{TableDesc, TblColRef} +import org.apache.kylin.metadata.model.TblColRef import org.apache.spark.dict.{NGlobalDictMetaInfo, NGlobalDictionaryV2} import org.apache.spark.sql.common.{LocalMetadata, SharedSparkSession, SparderBaseFunSuite} import org.apache.spark.sql.functions.{col, count, countDistinct} @@ -70,10 +70,11 @@ class GlobalDictionaryUpdateSuite extends SparderBaseFunSuite with LocalMetadata private def buildV3Dict(dictCol: TblColRef): Unit = { val tableName = StringUtils.split(dictCol.getTable, ".").apply(1) + val dbName = dictCol.getTableRef.getTableDesc.getDatabase val encodeColName: String = tableName + NSparkCubingUtil.SEPARATOR + dictCol.getName - val context = new DictionaryContext(DEFAULT_PROJECT, tableName, dictCol.getName, null) + val context = new DictionaryContext(DEFAULT_PROJECT, dbName, tableName, dictCol.getName, null) val df = genRandomData(spark, encodeColName, 100, 10) - val dictDF = genDataWithWrapEncodeCol(encodeColName, df) + val dictDF = genDataWithWrapEncodeCol(dbName, encodeColName, df) DictionaryBuilder.buildGlobalDict(DEFAULT_PROJECT, spark, dictDF.queryExecution.analyzed) val originalDF = df.agg(countDistinct(encodeColName)) @@ -85,15 +86,11 @@ class GlobalDictionaryUpdateSuite extends SparderBaseFunSuite with LocalMetadata def prepareV2Dict(seg: NDataSegment, randomDataSet: Dataset[Row], dictColSet: util.Set[TblColRef]): NGlobalDictMetaInfo = { val dictionaryBuilder = new DFDictionaryBuilder(randomDataSet, seg, randomDataSet.sparkSession, dictColSet) val colName = dictColSet.iterator().next() - val dictCol = TblColRef.mockup(TableDesc.mockup(colName.getTableRef.getTableDesc.getName), - 1, - colName.getName, - "string") val bucketPartitionSize = DictionaryBuilderHelper.calculateBucketSize(seg, colName, randomDataSet) - dictionaryBuilder.build(dictCol, bucketPartitionSize, randomDataSet) + dictionaryBuilder.build(colName, bucketPartitionSize, randomDataSet) val dict = new NGlobalDictionaryV2(seg.getProject, - dictCol.getTable, - dictCol.getName, + colName.getTable, + colName.getName, seg.getConfig.getHdfsWorkingDirectory) dict.getMetaInfo } diff --git a/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/KapFunctions.scala b/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/KapFunctions.scala index 7d6d9e4cb2..352dfcb977 100644 --- a/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/KapFunctions.scala +++ b/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/KapFunctions.scala @@ -251,8 +251,8 @@ object KapFunctions { Column(DictEncode(column.expr, dictParams.expr, bucketSize.expr)) } - def dict_encode_v3(column: Column): Column = { - Column(DictEncodeV3(column.expr)) + def dict_encode_v3(column: Column, colName: String): Column = { + Column(DictEncodeV3(column.expr, colName)) } val builtin: Seq[FunctionEntity] = Seq( diff --git a/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/catalyst/expressions/KapExpresssions.scala b/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/catalyst/expressions/KapExpresssions.scala index 8df821ed8a..32d3d7d4d7 100644 --- a/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/catalyst/expressions/KapExpresssions.scala +++ b/src/spark-project/sparder/src/main/scala/org/apache/spark/sql/catalyst/expressions/KapExpresssions.scala @@ -348,7 +348,7 @@ case class Truncate(_left: Expression, _right: Expression) extends BinaryExpress } } -case class DictEncodeV3(child: Expression) extends UnaryExpression { +case class DictEncodeV3(child: Expression, col: String) extends UnaryExpression { override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = defineCodeGen(ctx, ev, c => c) override def dataType: DataType = StringType