This is an automated email from the ASF dual-hosted git repository. liyang pushed a commit to branch kylin5 in repository https://gitbox.apache.org/repos/asf/kylin.git
commit 738cf4af01fca030751eaff0b1ea3a59340d83eb Author: Pengfei Zhan <dethr...@gmail.com> AuthorDate: Wed Jul 5 22:12:00 2023 +0800 KYLIN-5769 Support split function --- .../apache/kylin/metadata/datatype/DataType.java | 2 +- .../rest/controller/NBuildAndQueryMetricsTest.java | 15 ++- .../kylin/query/engine/QueryRoutingEngine.java | 80 +++++++++------ .../org/apache/kylin/query/udf/SparkStringUDF.java | 7 +- .../engine/spark/builder/CreateFlatTable.scala | 26 ++--- .../job/stage/build/FlatTableAndDictBase.scala | 82 +++++++++------- .../spark/smarter/IndexDependencyParser.scala | 7 +- .../spark/utils/ComputedColumnEvalUtilTest.java | 2 + .../kylin/query/runtime/ExpressionConverter.scala | 2 +- .../kylin/query/runtime/plan/TableScanPlan.scala | 15 +-- .../org/apache/spark/sql/SparderTypeUtilTest.scala | 29 +++++- .../org/apache/spark/sql/SparderTypeUtil.scala | 107 ++++++++++++--------- .../apache/spark/sql/common/SparderQueryTest.scala | 75 +++++++-------- 13 files changed, 254 insertions(+), 195 deletions(-) diff --git a/src/core-metadata/src/main/java/org/apache/kylin/metadata/datatype/DataType.java b/src/core-metadata/src/main/java/org/apache/kylin/metadata/datatype/DataType.java index 81e5881ef5..20057ad919 100644 --- a/src/core-metadata/src/main/java/org/apache/kylin/metadata/datatype/DataType.java +++ b/src/core-metadata/src/main/java/org/apache/kylin/metadata/datatype/DataType.java @@ -153,7 +153,7 @@ public class DataType implements Serializable { LEGACY_TYPE_MAP.put("hllc16", "hllc(16)"); } - private static final ConcurrentMap<DataType, DataType> CACHE = new ConcurrentHashMap<DataType, DataType>(); + private static final ConcurrentMap<DataType, DataType> CACHE = new ConcurrentHashMap<>(); public static final DataType ANY = DataType.getType(ANY_STR); diff --git a/src/kylin-server-it/src/test/java/org/apache/kylin/rest/controller/NBuildAndQueryMetricsTest.java b/src/kylin-server-it/src/test/java/org/apache/kylin/rest/controller/NBuildAndQueryMetricsTest.java index 7ff9b40cf9..00f8522c4a 100644 --- a/src/kylin-server-it/src/test/java/org/apache/kylin/rest/controller/NBuildAndQueryMetricsTest.java +++ b/src/kylin-server-it/src/test/java/org/apache/kylin/rest/controller/NBuildAndQueryMetricsTest.java @@ -32,6 +32,9 @@ import org.apache.kylin.common.util.TempMetadataBuilder; import org.apache.kylin.engine.spark.ExecutableUtils; import org.apache.kylin.engine.spark.job.NSparkCubingJob; import org.apache.kylin.engine.spark.merger.AfterBuildResourceMerger; +import org.apache.kylin.guava30.shaded.common.base.Preconditions; +import org.apache.kylin.guava30.shaded.common.collect.Lists; +import org.apache.kylin.guava30.shaded.common.collect.Sets; import org.apache.kylin.job.engine.JobEngineConfig; import org.apache.kylin.job.execution.ExecutableState; import org.apache.kylin.job.execution.NExecutableManager; @@ -77,10 +80,6 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.security.authentication.TestingAuthenticationToken; import org.springframework.security.core.context.SecurityContextHolder; -import org.apache.kylin.guava30.shaded.common.base.Preconditions; -import org.apache.kylin.guava30.shaded.common.collect.Lists; -import org.apache.kylin.guava30.shaded.common.collect.Sets; - import lombok.val; import lombok.var; import lombok.extern.slf4j.Slf4j; @@ -181,7 +180,7 @@ public class NBuildAndQueryMetricsTest extends AbstractMVCIntegrationTestCase { projectManager.updateProject(projectInstance, projectInstanceUpdate.getName(), projectInstanceUpdate.getDescription(), projectInstanceUpdate.getOverrideKylinProps()); - Preconditions.checkArgument(projectInstance != null); + Preconditions.checkNotNull(projectInstance); for (String table : projectInstance.getTables()) { if (!"DEFAULT.TEST_KYLIN_FACT".equals(table) && !"DEFAULT.TEST_ACCOUNT".equals(table)) { @@ -254,6 +253,12 @@ public class NBuildAndQueryMetricsTest extends AbstractMVCIntegrationTestCase { assertMetric(sql, 30); } + @Test + public void testSplitFunction() throws Exception { + String sql = "select split(account_id, '-')[0] from test_account limit 30"; + assertMetric(sql, 30); + } + @Test public void testMetricsScanForTableIndex() throws Exception { String sql = "select count(distinct case when trans_id > 100 then order_id else 0 end)," diff --git a/src/query/src/main/java/org/apache/kylin/query/engine/QueryRoutingEngine.java b/src/query/src/main/java/org/apache/kylin/query/engine/QueryRoutingEngine.java index 0aa304edf6..081edc0d0a 100644 --- a/src/query/src/main/java/org/apache/kylin/query/engine/QueryRoutingEngine.java +++ b/src/query/src/main/java/org/apache/kylin/query/engine/QueryRoutingEngine.java @@ -40,6 +40,7 @@ import org.apache.kylin.common.KylinConfig; import org.apache.kylin.common.QueryContext; import org.apache.kylin.common.QueryTrace; import org.apache.kylin.common.debug.BackdoorToggles; +import org.apache.kylin.common.exception.CalciteNotSupportException; import org.apache.kylin.common.exception.KylinException; import org.apache.kylin.common.exception.NewQueryRefuseException; import org.apache.kylin.common.exception.TargetSegmentNotFoundException; @@ -47,6 +48,8 @@ import org.apache.kylin.common.persistence.transaction.TransactionException; import org.apache.kylin.common.persistence.transaction.UnitOfWork; import org.apache.kylin.common.persistence.transaction.UnitOfWorkParams; import org.apache.kylin.common.util.DBUtils; +import org.apache.kylin.guava30.shaded.common.annotations.VisibleForTesting; +import org.apache.kylin.guava30.shaded.common.collect.Lists; import org.apache.kylin.metadata.project.NProjectLoader; import org.apache.kylin.metadata.project.NProjectManager; import org.apache.kylin.metadata.query.NativeQueryRealization; @@ -68,9 +71,6 @@ import org.apache.spark.SparkException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.kylin.guava30.shaded.common.annotations.VisibleForTesting; -import org.apache.kylin.guava30.shaded.common.collect.Lists; - import lombok.val; public class QueryRoutingEngine { @@ -133,39 +133,63 @@ public class QueryRoutingEngine { setParam(queryExec, i, queryParams.getParams()[i]); } } + + // execute query and get result from kylin layouts return execute(correctedSql, queryExec); }, queryParams.getProject()); } catch (TransactionException e) { - Throwable cause = e.getCause(); - if (cause instanceof SQLException && cause.getCause() instanceof KylinException) { - throw (SQLException) cause; - } - if (shouldPushdown(cause, queryParams)) { - return pushDownQuery((SQLException) cause, queryParams); - } else { - throw e; - } + return handleTransactionException(queryParams, e); } catch (SQLException e) { - if (e.getCause() instanceof KylinException) { - if (checkIfRetryQuery(e.getCause())) { - NProjectLoader.removeCache(); - return queryWithSqlMassage(queryParams); + return handleSqlException(queryParams, e); + } catch (AssertionError e) { + return handleAssertionError(queryParams, e); + } finally { + QueryResultMasks.remove(); + } + } + + private QueryResult handleTransactionException(QueryParams queryParams, TransactionException e) + throws SQLException { + Throwable cause = e.getCause(); + if (cause instanceof SQLException && cause.getCause() instanceof KylinException) { + throw (SQLException) cause; + } + if (!(cause instanceof SQLException)) { + throw e; + } + if (shouldPushdown(cause, queryParams)) { + return pushDownQuery((SQLException) cause, queryParams); + } else { + throw e; + } + } + + private QueryResult handleSqlException(QueryParams queryParams, SQLException e) throws Exception { + if (e.getCause() instanceof KylinException) { + if (checkIfRetryQuery(e.getCause())) { + NProjectLoader.removeCache(); + return queryWithSqlMassage(queryParams); + } else { + if (e.getCause() instanceof NewQueryRefuseException && shouldPushdown(e, queryParams)) { + return pushDownQuery(e, queryParams); } else { - if (e.getCause() instanceof NewQueryRefuseException && shouldPushdown(e, queryParams)) { - return pushDownQuery(e, queryParams); - } else { - throw e; - } + throw e; } } - if (shouldPushdown(e, queryParams)) { - return pushDownQuery(e, queryParams); - } else { - throw e; - } - } finally { - QueryResultMasks.remove(); } + if (shouldPushdown(e, queryParams)) { + return pushDownQuery(e, queryParams); + } + throw e; + } + + private QueryResult handleAssertionError(QueryParams queryParams, AssertionError e) throws SQLException { + // for example: split('abc', 'b') will jump into this AssertionError + if (e.getMessage().equals("OTHER")) { + SQLException ex = new SQLException(e.getMessage(), new CalciteNotSupportException()); + return pushDownQuery(ex, queryParams); + } + throw e; } public boolean checkIfRetryQuery(Throwable cause) { diff --git a/src/query/src/main/java/org/apache/kylin/query/udf/SparkStringUDF.java b/src/query/src/main/java/org/apache/kylin/query/udf/SparkStringUDF.java index ab5d6fab4a..5ce28e9851 100644 --- a/src/query/src/main/java/org/apache/kylin/query/udf/SparkStringUDF.java +++ b/src/query/src/main/java/org/apache/kylin/query/udf/SparkStringUDF.java @@ -90,11 +90,16 @@ public class SparkStringUDF implements NotConstant { throw new CalciteNotSupportException(); } - public String[] SPLIT(@Parameter(name = "str1") String exp1, @Parameter(name = "str2") String exp2) + public String[] SPLIT(@Parameter(name = "str") Object str, @Parameter(name = "regex") Object regex) throws CalciteNotSupportException { throw new CalciteNotSupportException(); } + public String[] SPLIT(@Parameter(name = "str") Object str, @Parameter(name = "regex") Object regex, + @Parameter(name = "limit") Object limit) throws CalciteNotSupportException { + throw new CalciteNotSupportException(); + } + public String SUBSTRING_INDEX(@Parameter(name = "str1") String exp1, @Parameter(name = "str2") String exp2, @Parameter(name = "num2") Integer exp3) throws CalciteNotSupportException { throw new CalciteNotSupportException(); diff --git a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/CreateFlatTable.scala b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/CreateFlatTable.scala index ddc005e71e..b659b87969 100644 --- a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/CreateFlatTable.scala +++ b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/builder/CreateFlatTable.scala @@ -23,9 +23,11 @@ import java.util.Locale import org.apache.commons.lang3.StringUtils import org.apache.kylin.engine.spark.builder.DFBuilderHelper._ import org.apache.kylin.engine.spark.job.NSparkCubingUtil._ +import org.apache.kylin.engine.spark.job.stage.build.FlatTableAndDictBase import org.apache.kylin.engine.spark.job.{FlatTableHelper, TableMetaManager} import org.apache.kylin.engine.spark.utils.SparkDataSource._ import org.apache.kylin.engine.spark.utils.{LogEx, LogUtils} +import org.apache.kylin.guava30.shaded.common.collect.Sets import org.apache.kylin.metadata.cube.cuboid.NSpanningTree import org.apache.kylin.metadata.cube.model.{NCubeJoinedFlatTableDesc, NDataSegment} import org.apache.kylin.metadata.model._ @@ -37,8 +39,6 @@ import scala.collection.mutable import scala.collection.parallel.ForkJoinTaskSupport import scala.concurrent.forkjoin.ForkJoinPool -import org.apache.kylin.guava30.shaded.common.collect.Sets - @Deprecated class CreateFlatTable(val flatTable: IJoinedFlatTableDesc, @@ -232,24 +232,14 @@ object CreateFlatTable extends LogEx { s"Invalid join condition of fact table: $rootFactDesc,fk: ${fk.mkString(",")}," + s" lookup table:$lookupDesc, pk: ${pk.mkString(",")}") } - val equiConditionColPairs = fk.zip(pk).map(joinKey => - col(convertFromDot(joinKey._1.getBackTickIdentity)) - .equalTo(col(convertFromDot(joinKey._2.getBackTickIdentity)))) logInfo(s"Lookup table schema ${lookupDataset.schema.treeString}") - if (join.getNonEquiJoinCondition != null) { - var condition = NonEquiJoinConditionBuilder.convert(join.getNonEquiJoinCondition) - if (!equiConditionColPairs.isEmpty) { - condition = condition && equiConditionColPairs.reduce(_ && _) - } - logInfo(s"Root table ${rootFactDesc.getIdentity}, join table ${lookupDesc.getAlias}, non-equi condition: ${condition.toString()}") - afterJoin = afterJoin.join(lookupDataset, condition, joinType) - } else { - val condition = equiConditionColPairs.reduce(_ && _) - logInfo(s"Root table ${rootFactDesc.getIdentity}, join table ${lookupDesc.getAlias}, condition: ${condition.toString()}") - afterJoin = afterJoin.join(lookupDataset, condition, joinType) - - } + val condition = FlatTableAndDictBase.getCondition(join) + val nonEquiv = if (join.getNonEquiJoinCondition == null) "" else "non-equi " + logInfo(s"Root table ${rootFactDesc.getIdentity}," + + s" join table ${lookupDesc.getAlias}," + + s" ${nonEquiv}condition: ${condition.toString()}") + afterJoin = afterJoin.join(lookupDataset, condition, joinType) } afterJoin } diff --git a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/job/stage/build/FlatTableAndDictBase.scala b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/job/stage/build/FlatTableAndDictBase.scala index 16ea4763d3..9cd5487904 100644 --- a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/job/stage/build/FlatTableAndDictBase.scala +++ b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/job/stage/build/FlatTableAndDictBase.scala @@ -18,6 +18,10 @@ package org.apache.kylin.engine.spark.job.stage.build +import java.math.BigInteger +import java.util.concurrent.{CountDownLatch, TimeUnit} +import java.util.{Locale, Objects, Timer, TimerTask} + import org.apache.commons.lang3.StringUtils import org.apache.hadoop.fs.Path import org.apache.kylin.common.util.HadoopUtil @@ -32,8 +36,8 @@ import org.apache.kylin.engine.spark.job.{FiltersUtil, SegmentJob, TableMetaMana import org.apache.kylin.engine.spark.model.SegmentFlatTableDesc import org.apache.kylin.engine.spark.model.planner.{CuboIdToLayoutUtils, FlatTableToCostUtils} import org.apache.kylin.engine.spark.smarter.IndexDependencyParser -import org.apache.kylin.engine.spark.utils.{LogEx, SparkConfHelper} import org.apache.kylin.engine.spark.utils.SparkDataSource._ +import org.apache.kylin.engine.spark.utils.{LogEx, SparkConfHelper} import org.apache.kylin.guava30.shaded.common.collect.Sets import org.apache.kylin.metadata.cube.cuboid.AdaptiveSpanningTree import org.apache.kylin.metadata.cube.cuboid.AdaptiveSpanningTree.AdaptiveTreeBuilder @@ -47,9 +51,6 @@ import org.apache.spark.sql.types.StructField import org.apache.spark.sql.util.SparderTypeUtil import org.apache.spark.utils.ProxyThreadUtils -import java.math.BigInteger -import java.util.concurrent.{CountDownLatch, TimeUnit} -import java.util.{Locale, Objects, Timer, TimerTask} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.parallel.ForkJoinTaskSupport @@ -139,7 +140,7 @@ abstract class FlatTableAndDictBase(private val jobContext: SegmentJob, if (inferFiltersEnabled) { FiltersUtil.initFilters(tableDesc, lookupTableDSMap) } - val jointDS = joinFactTableWithLookupTables(fastFactTableDS, lookupTableDSMap, dataModel, sparkSession) + val jointDS = joinFactTableWithLookupTables(fastFactTableDS, lookupTableDSMap, dataModel) concatCCs(jointDS, factTableCCs) } else { concatCCs(fastFactTableDS, factTableCCs) @@ -216,7 +217,7 @@ abstract class FlatTableAndDictBase(private val jobContext: SegmentJob, FiltersUtil.initFilters(tableDesc, lookupTables) } - val jointTable = joinFactTableWithLookupTables(factTable, lookupTables, dataModel, sparkSession) + val jointTable = joinFactTableWithLookupTables(factTable, lookupTables, dataModel) buildDictIfNeed(concatCCs(jointTable, factTableCCs), selectColumnsNotInTables(factTable, lookupTables.values.toSeq, dictCols), selectColumnsNotInTables(factTable, lookupTables.values.toSeq, encodeCols)) @@ -704,11 +705,16 @@ object FlatTableAndDictBase extends LogEx { newDS.select(selectedColumns: _*) } - def wrapAlias(originDS: Dataset[Row], alias: String): Dataset[Row] = { - val newFields = originDS.schema.fields.map(f => - convertFromDot("`" + alias + "`" + "." + "`" + f.name + "`")).toSeq + def wrapAlias(originDS: Dataset[Row], alias: String, needLog: Boolean = true): Dataset[Row] = { + val newFields = originDS.schema.fields + .map(f => { + val aliasDotColName = "`" + alias + "`" + "." + "`" + f.name + "`" + convertFromDot(aliasDotColName) + }).toSeq val newDS = originDS.toDF(newFields: _*) - logInfo(s"Wrap ALIAS ${originDS.schema.treeString} TO ${newDS.schema.treeString}") + if (needLog) { + logInfo(s"Wrap ALIAS ${originDS.schema.treeString} TO ${newDS.schema.treeString}") + } newDS } @@ -716,17 +722,17 @@ object FlatTableAndDictBase extends LogEx { def joinFactTableWithLookupTables(rootFactDataset: Dataset[Row], lookupTableDatasetMap: mutable.Map[JoinTableDesc, Dataset[Row]], model: NDataModel, - ss: SparkSession): Dataset[Row] = { + needLog: Boolean = true): Dataset[Row] = { lookupTableDatasetMap.foldLeft(rootFactDataset)( (joinedDataset: Dataset[Row], tuple: (JoinTableDesc, Dataset[Row])) => - joinTableDataset(model.getRootFactTable.getTableDesc, tuple._1, joinedDataset, tuple._2, ss)) + joinTableDataset(model.getRootFactTable.getTableDesc, tuple._1, joinedDataset, tuple._2, needLog)) } def joinTableDataset(rootFactDesc: TableDesc, lookupDesc: JoinTableDesc, rootFactDataset: Dataset[Row], lookupDataset: Dataset[Row], - ss: SparkSession): Dataset[Row] = { + needLog: Boolean = true): Dataset[Row] = { var afterJoin = rootFactDataset val join = lookupDesc.getJoin if (join != null && !StringUtils.isEmpty(join.getType)) { @@ -738,34 +744,42 @@ object FlatTableAndDictBase extends LogEx { s"Invalid join condition of fact table: $rootFactDesc,fk: ${fk.mkString(",")}," + s" lookup table:$lookupDesc, pk: ${pk.mkString(",")}") } - val equiConditionColPairs = fk.zip(pk).map(joinKey => - col(convertFromDot(joinKey._1.getBackTickIdentity)) - .equalTo(col(convertFromDot(joinKey._2.getBackTickIdentity)))) - logInfo(s"Lookup table schema ${lookupDataset.schema.treeString}") - - if (join.getNonEquiJoinCondition != null) { - val condition: Column = getCondition(join, equiConditionColPairs) - logInfo(s"Root table ${rootFactDesc.getIdentity}, join table ${lookupDesc.getAlias}, non-equi condition: ${condition.toString()}") - afterJoin = afterJoin.join(lookupDataset, condition, joinType) + if (needLog) { + logInfo(s"Lookup table schema ${lookupDataset.schema.treeString}") + } + + val condition = getCondition(join) + if (needLog) { + val nonEquiv = if (join.getNonEquiJoinCondition == null) "" else "non-equi " + logInfo(s"Root table ${rootFactDesc.getIdentity}," + + s" join table ${lookupDesc.getAlias}," + + s" ${nonEquiv} condition: ${condition.toString()}") + } + + if (join.getNonEquiJoinCondition == null && inferFiltersEnabled) { + afterJoin = afterJoin.join(FiltersUtil.inferFilters(pk, lookupDataset), condition, joinType) } else { - val condition = equiConditionColPairs.reduce(_ && _) - logInfo(s"Root table ${rootFactDesc.getIdentity}, join table ${lookupDesc.getAlias}, condition: ${condition.toString()}") - if (inferFiltersEnabled) { - afterJoin = afterJoin.join(FiltersUtil.inferFilters(pk, lookupDataset), condition, joinType) - } else { - afterJoin = afterJoin.join(lookupDataset, condition, joinType) - } + afterJoin = afterJoin.join(lookupDataset, condition, joinType) } } afterJoin } - def getCondition(join: JoinDesc, equiConditionColPairs: Array[Column]): Column = { - var condition = NonEquiJoinConditionBuilder.convert(join.getNonEquiJoinCondition) - if (!equiConditionColPairs.isEmpty) { - condition = condition && equiConditionColPairs.reduce(_ && _) + def getCondition(join: JoinDesc): Column = { + val pk = join.getPrimaryKeyColumns + val fk = join.getForeignKeyColumns + + val equalPairs = fk.zip(pk).map(joinKey => { + val fkIdentity = convertFromDot(joinKey._1.getBackTickIdentity) + val pkIdentity = convertFromDot(joinKey._2.getBackTickIdentity) + col(fkIdentity).equalTo(col(pkIdentity)) + }).reduce(_ && _) + + if (join.getNonEquiJoinCondition == null) { + equalPairs + } else { + NonEquiJoinConditionBuilder.convert(join.getNonEquiJoinCondition) && equalPairs } - condition } def changeSchemeToColumnId(ds: Dataset[Row], tableDesc: SegmentFlatTableDesc): Dataset[Row] = { diff --git a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/smarter/IndexDependencyParser.scala b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/smarter/IndexDependencyParser.scala index 757f16c5c9..44e52099d7 100644 --- a/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/smarter/IndexDependencyParser.scala +++ b/src/spark-project/engine-spark/src/main/scala/org/apache/kylin/engine/spark/smarter/IndexDependencyParser.scala @@ -24,6 +24,7 @@ import org.apache.commons.collections.CollectionUtils import org.apache.commons.lang3.StringUtils import org.apache.kylin.engine.spark.job.NSparkCubingUtil import org.apache.kylin.engine.spark.job.stage.build.FlatTableAndDictBase +import org.apache.kylin.guava30.shaded.common.collect.{Lists, Maps, Sets} import org.apache.kylin.metadata.cube.model.LayoutEntity import org.apache.kylin.metadata.model._ import org.apache.kylin.query.util.PushDownUtil @@ -34,8 +35,6 @@ import org.apache.spark.sql.{Dataset, Row, SparderEnv, SparkSession} import scala.collection.JavaConverters._ import scala.collection.mutable -import org.apache.kylin.guava30.shaded.common.collect.{Lists, Maps, Sets} - class IndexDependencyParser(val model: NDataModel) { private val ccTableNameAliasMap = Maps.newHashMap[String, util.Set[String]] @@ -113,7 +112,7 @@ class IndexDependencyParser(val model: NDataModel) { model.getJoinTables.asScala.map((joinTable: JoinTableDesc) => { joinTableDFMap.put(joinTable, generateDatasetOnTable(ss, joinTable.getTableRef)) }) - val df = FlatTableAndDictBase.joinFactTableWithLookupTables(rootDF, joinTableDFMap, model, ss) + val df = FlatTableAndDictBase.joinFactTableWithLookupTables(rootDF, joinTableDFMap, model, needLog = false) val filterCondition = model.getFilterCondition if (StringUtils.isNotEmpty(filterCondition)) { val massagedCondition = PushDownUtil.massageExpression(model, model.getProject, filterCondition, null) @@ -127,7 +126,7 @@ class IndexDependencyParser(val model: NDataModel) { val structType = SchemaProcessor.buildSchemaWithRawTable(tableCols) val alias = tableRef.getAlias val dataset = ss.createDataFrame(Lists.newArrayList[Row], structType).alias(alias) - FlatTableAndDictBase.wrapAlias(dataset, alias) + FlatTableAndDictBase.wrapAlias(dataset, alias, needLog = false) } private def initTableNames(): Unit = { diff --git a/src/spark-project/engine-spark/src/test/java/org/apache/kylin/engine/spark/utils/ComputedColumnEvalUtilTest.java b/src/spark-project/engine-spark/src/test/java/org/apache/kylin/engine/spark/utils/ComputedColumnEvalUtilTest.java index 786d1f88c3..4937fbdbe9 100644 --- a/src/spark-project/engine-spark/src/test/java/org/apache/kylin/engine/spark/utils/ComputedColumnEvalUtilTest.java +++ b/src/spark-project/engine-spark/src/test/java/org/apache/kylin/engine/spark/utils/ComputedColumnEvalUtilTest.java @@ -365,6 +365,8 @@ public class ComputedColumnEvalUtilTest extends NLocalWithSparkSessionTest { exprTypes.put("TEST_MEASURE.FLAG", "BOOLEAN"); exprTypes.put("NOT TEST_MEASURE.FLAG", "BOOLEAN"); + exprTypes.put("SPLIT(TEST_MEASURE.NAME1, '[ABC]')", "array<string>"); + exprTypes.put("SPLIT(TEST_MEASURE.NAME1, '[ABC]', 2)", "array<string>"); AtomicInteger ccId = new AtomicInteger(0); List<ComputedColumnDesc> newCCs = exprTypes.keySet().stream().map(expr -> { diff --git a/src/spark-project/sparder/src/main/scala/org/apache/kylin/query/runtime/ExpressionConverter.scala b/src/spark-project/sparder/src/main/scala/org/apache/kylin/query/runtime/ExpressionConverter.scala index 1e3d392e75..d08a37d450 100644 --- a/src/spark-project/sparder/src/main/scala/org/apache/kylin/query/runtime/ExpressionConverter.scala +++ b/src/spark-project/sparder/src/main/scala/org/apache/kylin/query/runtime/ExpressionConverter.scala @@ -57,7 +57,7 @@ object ExpressionConverter { "input_file_name", "monotonically_increasing_id", "now", "spark_partition_id", "uuid" ) - private val varArgsFunc = mutable.HashSet("months_between", "locate", "rtrim", "from_unixtime", "to_date", "to_timestamp") + private val varArgsFunc = mutable.HashSet("months_between", "locate", "rtrim", "from_unixtime", "to_date", "to_timestamp", "split") private val bitmapUDF = mutable.HashSet("intersect_count_by_col", "subtract_bitmap_value", "subtract_bitmap_uuid"); diff --git a/src/spark-project/sparder/src/main/scala/org/apache/kylin/query/runtime/plan/TableScanPlan.scala b/src/spark-project/sparder/src/main/scala/org/apache/kylin/query/runtime/plan/TableScanPlan.scala index 95ade3ef1d..5c0b0511b7 100644 --- a/src/spark-project/sparder/src/main/scala/org/apache/kylin/query/runtime/plan/TableScanPlan.scala +++ b/src/spark-project/sparder/src/main/scala/org/apache/kylin/query/runtime/plan/TableScanPlan.scala @@ -28,7 +28,7 @@ import org.apache.kylin.metadata.cube.cuboid.NLayoutCandidate import org.apache.kylin.metadata.cube.gridtable.NLayoutToGridTableMapping import org.apache.kylin.metadata.cube.model.{LayoutEntity, NDataSegment, NDataflow} import org.apache.kylin.metadata.model.{DeriveInfo, FunctionDesc, NTableMetadataManager, ParameterDesc, TblColRef} -import org.apache.kylin.metadata.realization.{HybridRealization, IRealization} +import org.apache.kylin.metadata.realization.HybridRealization import org.apache.kylin.metadata.tuple.TupleInfo import org.apache.kylin.query.implicits.sessionToQueryContext import org.apache.kylin.query.relnode.{KapRel, OLAPContext} @@ -436,13 +436,10 @@ object TableScanPlan extends LogEx { def createLookupTable(rel: KapRel): DataFrame = { val start = System.currentTimeMillis() - val session = SparderEnv.getSparkSession val olapContext = rel.getContext - var instance: IRealization = null - if (olapContext.realization.isInstanceOf[NDataflow]) { - instance = olapContext.realization.asInstanceOf[NDataflow] - } else { - instance = olapContext.realization.asInstanceOf[HybridRealization] + val instance = olapContext.realization match { + case dataflow: NDataflow => dataflow + case _ => olapContext.realization.asInstanceOf[HybridRealization] } val tableMetadataManager = NTableMetadataManager.getInstance(instance.getConfig, instance.getProject) @@ -450,9 +447,7 @@ object TableScanPlan extends LogEx { val snapshotResPath = tableMetadataManager.getTableDesc(lookupTableName).getLastSnapshotPath val config = instance.getConfig val dataFrameTableName = instance.getProject + "@" + lookupTableName - val lookupDf = SparderLookupManager.getOrCreate(dataFrameTableName, - snapshotResPath, - config) + val lookupDf = SparderLookupManager.getOrCreate(dataFrameTableName, snapshotResPath, config) val olapTable = olapContext.firstTableScan.getOlapTable val alisTableName = olapContext.firstTableScan.getBackupAlias diff --git a/src/spark-project/sparder/src/test/scala/org/apache/spark/sql/SparderTypeUtilTest.scala b/src/spark-project/sparder/src/test/scala/org/apache/spark/sql/SparderTypeUtilTest.scala index 5c41e25965..13ecb18661 100644 --- a/src/spark-project/sparder/src/test/scala/org/apache/spark/sql/SparderTypeUtilTest.scala +++ b/src/spark-project/sparder/src/test/scala/org/apache/spark/sql/SparderTypeUtilTest.scala @@ -17,18 +17,26 @@ */ package org.apache.spark.sql +import java.io.File import java.sql.Types import org.apache.calcite.rel.`type`.RelDataTypeSystem import org.apache.calcite.sql.`type`.SqlTypeFactoryImpl +import org.apache.commons.io.FileUtils +import org.apache.kylin.common.KylinConfig +import org.apache.kylin.guava30.shaded.common.io.Files import org.apache.kylin.metadata.datatype.DataType import org.apache.kylin.query.schema.OLAPTable import org.apache.spark.sql.common.SparderBaseFunSuite import org.apache.spark.sql.types.{DataTypes, StructField} import org.apache.spark.sql.util.SparderTypeUtil +import scala.collection.immutable + class SparderTypeUtilTest extends SparderBaseFunSuite { - val dataTypes = List(DataType.getType("decimal(19,4)"), + + val dataTypes: immutable.Seq[DataType] = List( + DataType.getType("decimal(19,4)"), DataType.getType("char(50)"), DataType.getType("varchar(1000)"), DataType.getType("date"), @@ -40,10 +48,10 @@ class SparderTypeUtilTest extends SparderBaseFunSuite { DataType.getType("float"), DataType.getType("double"), DataType.getType("decimal(38,19)"), - DataType.getType("numeric(5,4)") + DataType.getType("numeric(5,4)"), + DataType.getType("ARRAY<STRING>") ) - test("Test decimal") { val dt = DataType.getType("decimal(19,4)") val dataTp = DataTypes.createDecimalType(19, 4) @@ -51,7 +59,7 @@ class SparderTypeUtilTest extends SparderBaseFunSuite { assert(dataTp.sameType(dataType)) val sparkTp = SparderTypeUtil.toSparkType(dt) assert(dataTp.sameType(sparkTp)) - val sparkTpSum = SparderTypeUtil.toSparkType(dt, true) + val sparkTpSum = SparderTypeUtil.toSparkType(dt, isSum = true) assert(DataTypes.createDecimalType(29, 4).sameType(sparkTpSum)) } @@ -72,6 +80,7 @@ class SparderTypeUtilTest extends SparderBaseFunSuite { } test("Test convertSqlTypeToSparkType") { + kylinConfig val typeFactory = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT) dataTypes.map(dt => { val relDataType = OLAPTable.createSqlType(typeFactory, dt, true) @@ -79,7 +88,16 @@ class SparderTypeUtilTest extends SparderBaseFunSuite { }) } + private def kylinConfig = { + val tmpHome = Files.createTempDir + System.setProperty("KYLIN_HOME", tmpHome.getAbsolutePath) + FileUtils.touch(new File(tmpHome.getAbsolutePath + "/kylin.properties")) + KylinConfig.setKylinConfigForLocalTest(tmpHome.getCanonicalPath) + KylinConfig.getInstanceFromEnv + } + test("test convertSparkFieldToJavaField") { + kylinConfig val typeFactory = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT) dataTypes.map(dt => { val relDataType = OLAPTable.createSqlType(typeFactory, dt, true) @@ -93,6 +111,9 @@ class SparderTypeUtilTest extends SparderBaseFunSuite { assert(Types.DECIMAL == structField.getDataType) assert(relDataType.getPrecision == structField.getPrecision) assert(relDataType.getScale == structField.getScale) + } else if (dt.getName.startsWith("array")) { + assert(structField.getDataType == Types.OTHER) + assert(relDataType.getSqlTypeName.getJdbcOrdinal == Types.ARRAY) } else { assert(relDataType.getSqlTypeName.getJdbcOrdinal == structField.getDataType) } diff --git a/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/SparderTypeUtil.scala b/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/SparderTypeUtil.scala index eeda408931..83499cda60 100644 --- a/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/SparderTypeUtil.scala +++ b/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/SparderTypeUtil.scala @@ -24,7 +24,7 @@ import org.apache.calcite.rex.RexLiteral import org.apache.calcite.sql.`type`.SqlTypeName import org.apache.calcite.util.NlsString import org.apache.kylin.common.KylinConfig -import org.apache.kylin.common.util.DateFormat +import org.apache.kylin.common.util.{DateFormat, StringHelper} import org.apache.kylin.metadata.datatype.DataType import org.apache.spark.internal.Logging import org.apache.spark.sql.Column @@ -43,7 +43,9 @@ import java.util.{GregorianCalendar, Locale, TimeZone} import scala.collection.{immutable, mutable} object SparderTypeUtil extends Logging { - val DATETIME_FAMILY = List("time", "date", "timestamp", "datetime") + + private val DATETIME_FAMILY: immutable.Seq[String] = List("time", "date", "timestamp", "datetime") + private val COMMA: String = "___COMMA___" def isDateTimeFamilyType(dataType: String): Boolean = { DATETIME_FAMILY.contains(dataType.toLowerCase()) @@ -53,10 +55,6 @@ object SparderTypeUtil extends Logging { "date".equalsIgnoreCase(dataType) } - def isDateTime(sqlTypeName: SqlTypeName): Boolean = { - SqlTypeName.DATETIME_TYPES.contains(sqlTypeName) - } - // scalastyle:off def kylinTypeToSparkResultType(dataTp: DataType): org.apache.spark.sql.types.DataType = { dataTp.getName match { @@ -81,6 +79,7 @@ object SparderTypeUtil extends Logging { case "bitmap" => LongType case "dim_dc" => LongType case "boolean" => BooleanType + case "array<string>" => ArrayType(StringType) case _ => throw new IllegalArgumentException } } @@ -116,6 +115,7 @@ object SparderTypeUtil extends Logging { case tp if tp.startsWith("extendedcolumn") => BinaryType case tp if tp.startsWith("percentile") => BinaryType case tp if tp.startsWith("raw") => BinaryType + case "array<string>" => ArrayType(StringType) case "any" => StringType case _ => throw new IllegalArgumentException(dataTp.toString) } @@ -142,6 +142,7 @@ object SparderTypeUtil extends Logging { case tp if tp.startsWith("varchar") => s.toString case tp if tp.startsWith("char") => s.toString case "boolean" => java.lang.Boolean.parseBoolean(s.toString) + case "array<string>" => s.toString.split(COMMA) case noSupport => throw new IllegalArgumentException(s"No supported data type: $noSupport") } } @@ -161,6 +162,8 @@ object SparderTypeUtil extends Logging { case SqlTypeName.DATE => DateType case SqlTypeName.TIMESTAMP => TimestampType case SqlTypeName.BOOLEAN => BooleanType + case SqlTypeName.ARRAY if dt.getComponentType.getSqlTypeName == SqlTypeName.VARCHAR => ArrayType(StringType) + case SqlTypeName.OTHER if dt.getComponentType == null => ArrayType(StringType) // handle null case SqlTypeName.ANY => StringType case _ => throw new IllegalArgumentException(s"unsupported SqlTypeName $dt") @@ -179,14 +182,14 @@ object SparderTypeUtil extends Logging { case DateType => SqlTypeName.DATE.getName case TimestampType => SqlTypeName.TIMESTAMP.getName case BooleanType => SqlTypeName.BOOLEAN.getName - case decimalType: DecimalType => - SqlTypeName.DECIMAL.getName + "(" + decimalType.precision + "," + decimalType.scale + ")" + case decimalType: DecimalType => decimalType.sql + case arrayType: ArrayType => arrayType.simpleString case _ => throw new IllegalArgumentException(s"unsupported SqlTypeName $dt") } } - def getValueFromNlsString(s: NlsString): String = { + private def getValueFromNlsString(s: NlsString): String = { if (!KylinConfig.getInstanceFromEnv.isQueryEscapedLiteral) { val ret = new StringBuilder ret.append("'") @@ -199,11 +202,11 @@ object SparderTypeUtil extends Logging { } } - def getValueFromRexLit(literal: RexLiteral) = { + def getValueFromRexLit(literal: RexLiteral): Any = { val ret = literal.getValue match { case s: NlsString => getValueFromNlsString(s) - case g: GregorianCalendar => + case _: GregorianCalendar => if (literal.getTypeName.getName.equals("DATE")) { new Date(DateTimeUtils.stringToTimestamp(UTF8String.fromString(literal.toString), ZoneId.systemDefault()).get / 1000) } else { @@ -250,7 +253,17 @@ object SparderTypeUtil extends Logging { } def convertToStringWithCalciteType(rawValue: Any, relType: RelDataType, wrapped: Boolean = false): String = { - val formatStringValue = (value: String) => if (wrapped) "\"" + value + "\"" else value + val formatStringValue = (value: String) => if (wrapped) StringHelper.doubleQuote(value) else value + val formatArray = (value: String) => { + if (value.startsWith("WrappedArray")) { + val s = value.stripPrefix("WrappedArray") + s.substring(1, s.length - 1) + .split(",").toStream.map(_.trim) + .mkString(COMMA) + } else { + StringHelper.doubleQuote(value) + } + } (rawValue, relType.getSqlTypeName) match { case (null, _) => null @@ -270,6 +283,7 @@ object SparderTypeUtil extends Logging { case (value: java.sql.Timestamp, SqlTypeName.CHAR | SqlTypeName.VARCHAR) => formatStringValue(DateFormat.castTimestampToString(value.getTime)) case (value: java.sql.Date, SqlTypeName.CHAR | SqlTypeName.VARCHAR) => formatStringValue(DateFormat.formatToDateStr(value.getTime)) case (value, SqlTypeName.CHAR | SqlTypeName.VARCHAR) => formatStringValue(value.toString) + case (value: String, SqlTypeName.ARRAY) => formatArray(value) // cast type to align with relType case (value: Any, SqlTypeName.DECIMAL) => new java.math.BigDecimal(value.toString) @@ -297,13 +311,14 @@ object SparderTypeUtil extends Logging { case (dt: java.sql.Date, _) => DateFormat.formatToDateStr(dt.getTime) case (str: java.lang.String, _) => formatStringValue(str) case (value: mutable.WrappedArray.ofRef[AnyRef], _) => - value.array.map(v => convertToStringWithCalciteType(v, relType, true)).mkString("[", ",", "]") + value.array.map(v => convertToStringWithCalciteType(v, relType, wrapped = true)).mkString("[", ",", "]") case (value: mutable.WrappedArray[Any], _) => - value.array.map(v => convertToStringWithCalciteType(v, relType, true)).mkString("[", ",", "]") + value.array.map(v => convertToStringWithCalciteType(v, relType, wrapped = true)).mkString("[", ",", "]") case (value: immutable.Map[Any, Any], _) => - value - .map(v => convertToStringWithCalciteType(v._1, relType, true) + ":" + convertToStringWithCalciteType(v._2, relType, true)) - .mkString("{", ",", "}") + value.map(v => + convertToStringWithCalciteType(v._1, relType, wrapped = true) + ":" + + convertToStringWithCalciteType(v._2, relType, wrapped = true) + ).mkString("{", ",", "}") case (value: Array[Byte], _) => new String(value) case (other, _) => other.toString } @@ -336,14 +351,7 @@ object SparderTypeUtil extends Logging { } else { try { val a: Any = sqlTypeName match { - case SqlTypeName.DECIMAL => - if (s.isInstanceOf[java.lang.Double] || s - .isInstanceOf[java.lang.Float] || s.toString.contains(".")) { - new java.math.BigDecimal(s.toString) - .setScale(rowType.getScale, BigDecimal.ROUND_HALF_EVEN) - } else { - new java.math.BigDecimal(s.toString) - } + case SqlTypeName.DECIMAL => transferDecimal(s, rowType) case SqlTypeName.CHAR => s.toString case SqlTypeName.VARCHAR => s.toString case SqlTypeName.INTEGER => s.toString.toInt @@ -352,7 +360,7 @@ object SparderTypeUtil extends Logging { case SqlTypeName.BIGINT => s.toString.toLong case SqlTypeName.FLOAT => java.lang.Double.parseDouble(s.toString) case SqlTypeName.DOUBLE => java.lang.Double.parseDouble(s.toString) - case SqlTypeName.DATE => { + case SqlTypeName.DATE => // time over here is with timezone. val string = s.toString if (string.contains("-")) { @@ -372,9 +380,8 @@ object SparderTypeUtil extends Logging { DateFormat.stringToMillis(string) / 1000 } } - } - case SqlTypeName.TIMESTAMP | SqlTypeName.TIME => { - var ts = s.asInstanceOf[Timestamp].toString + case SqlTypeName.TIMESTAMP | SqlTypeName.TIME => + val ts = s.asInstanceOf[Timestamp].toString if (toCalcite) { // current ts is local timezone ,org.apache.calcite.avatica.util.AbstractCursor.TimeFromNumberAccessor need to utc DateTimeUtils.stringToTimestamp(UTF8String.fromString(ts), TimeZone.getTimeZone("UTC").toZoneId).get / 1000 @@ -382,20 +389,29 @@ object SparderTypeUtil extends Logging { // ms to s s.asInstanceOf[Timestamp].getTime / 1000 } - } case SqlTypeName.BOOLEAN => s; case _ => s.toString } a } catch { - case th: Throwable => - logWarning(s"""convertStringToValue failed: {"v": "${s}", "cls": "${s.getClass}", "type": "$sqlTypeName"}""") + case _: Throwable => + logWarning(s"""convertStringToValue failed: {"v": "$s", "cls": "${s.getClass}", "type": "$sqlTypeName"}""") // fixme aron never come to here, for coverage ignore. safetyConvertStringToValue(s, rowType, toCalcite) } } } + private def transferDecimal(s: Any, rowType: RelDataType) = { + if (s.isInstanceOf[JDouble] || s + .isInstanceOf[JFloat] || s.toString.contains(".")) { + new BigDecimal(s.toString) + .setScale(rowType.getScale, BigDecimal.ROUND_HALF_EVEN) + } else { + new BigDecimal(s.toString) + } + } + def kylinRawTableSQLTypeToSparkType(dataTp: DataType): org.apache.spark.sql.types.DataType = { dataTp.getName match { case "decimal" | "numeric" => DecimalType(dataTp.getPrecision, dataTp.getScale) @@ -417,21 +433,15 @@ object SparderTypeUtil extends Logging { case "bitmap" => LongType case "dim_dc" => LongType case "boolean" => BooleanType + case "array<string>" => ArrayType(StringType) case noSupport => throw new IllegalArgumentException(s"No supported data type: $noSupport") } } - def safetyConvertStringToValue(s: Any, rowType: RelDataType, toCalcite: Boolean): Any = { + private def safetyConvertStringToValue(s: Any, rowType: RelDataType, toCalcite: Boolean): Any = { try { rowType.getSqlTypeName match { - case SqlTypeName.DECIMAL => - if (s.isInstanceOf[java.lang.Double] || s - .isInstanceOf[java.lang.Float] || s.toString.contains(".")) { - new java.math.BigDecimal(s.toString) - .setScale(rowType.getScale, BigDecimal.ROUND_HALF_EVEN) - } else { - new java.math.BigDecimal(s.toString) - } + case SqlTypeName.DECIMAL => transferDecimal(s, rowType) case SqlTypeName.CHAR => s.toString case SqlTypeName.VARCHAR => s.toString case SqlTypeName.INTEGER => s.toString.toDouble.toInt @@ -440,7 +450,7 @@ object SparderTypeUtil extends Logging { case SqlTypeName.BIGINT => s.toString.toDouble.toLong case SqlTypeName.FLOAT => java.lang.Float.parseFloat(s.toString) case SqlTypeName.DOUBLE => java.lang.Double.parseDouble(s.toString) - case SqlTypeName.DATE => { + case SqlTypeName.DATE => // time over here is with timezone. val string = s.toString if (string.contains("-")) { @@ -459,16 +469,14 @@ object SparderTypeUtil extends Logging { DateFormat.stringToMillis(string) } } - } - case SqlTypeName.TIMESTAMP | SqlTypeName.TIME => { - var ts = s.asInstanceOf[Timestamp].getTime + case SqlTypeName.TIMESTAMP | SqlTypeName.TIME => + val ts = s.asInstanceOf[Timestamp].getTime if (toCalcite) { ts } else { // ms to s ts / 1000 } - } case SqlTypeName.BOOLEAN => s; case _ => s.toString } @@ -484,7 +492,7 @@ object SparderTypeUtil extends Logging { calciteTimestamp / 1000 } - def toCalciteTimestamp(sparkTimestamp: Long): Long = { + private def toCalciteTimestamp(sparkTimestamp: Long): Long = { sparkTimestamp * 1000 } @@ -551,6 +559,11 @@ object SparderTypeUtil extends Logging { case StringType => builder.setDataType(Types.VARCHAR) builder.setDataTypeName("VARCHAR") + // This is more better, but calcite RelNode digest seems not support. + // + // case arrayType: ArrayType => + // builder.setDataType(Types.ARRAY) + // builder.setDataTypeName(arrayType.sql) case _ => builder.setDataType(Types.OTHER) builder.setDataTypeName(field.dataType.sql) diff --git a/src/spark-project/spark-common/src/test/java/org/apache/spark/sql/common/SparderQueryTest.scala b/src/spark-project/spark-common/src/test/java/org/apache/spark/sql/common/SparderQueryTest.scala index 363b915f99..df90161512 100644 --- a/src/spark-project/spark-common/src/test/java/org/apache/spark/sql/common/SparderQueryTest.scala +++ b/src/spark-project/spark-common/src/test/java/org/apache/spark/sql/common/SparderQueryTest.scala @@ -17,25 +17,20 @@ */ package org.apache.spark.sql.common -import org.apache.kylin.common.{KapConfig, QueryContext} - import java.sql.Types -import java.util.{List, TimeZone} +import java.util.TimeZone + import org.apache.kylin.metadata.query.StructField import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.functions._ -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.SparderTypeUtil -import org.apache.spark.sql.{DataFrame, Row, SparkSession} - -import java.util +import org.apache.spark.sql.{DataFrame, Row} object SparderQueryTest extends Logging { def same(sparkDF: DataFrame, - kylinAnswer: DataFrame, - checkOrder: Boolean = false): Boolean = { + kylinAnswer: DataFrame, + checkOrder: Boolean = false): Boolean = { checkAnswerBySeq(castDataType(sparkDF, kylinAnswer), kylinAnswer.collect(), checkOrder) match { case Some(errorMessage) => logInfo(errorMessage) @@ -54,15 +49,15 @@ object SparderQueryTest extends Logging { } /** - * Runs the plan and makes sure the answer matches the expected result. - * If there was exception during the execution or the contents of the DataFrame does not - * match the expected result, an error message will be returned. Otherwise, a [[None]] will - * be returned. - * - * @param sparkDF the [[org.apache.spark.sql.DataFrame]] to be executed - * @param kylinAnswer the expected result in a [[Seq]] of [[org.apache.spark.sql.Row]]s. - * @param checkToRDD whether to verify deserialization to an RDD. This runs the query twice. - */ + * Runs the plan and makes sure the answer matches the expected result. + * If there was exception during the execution or the contents of the DataFrame does not + * match the expected result, an error message will be returned. Otherwise, a [[None]] will + * be returned. + * + * @param sparkDF the [[org.apache.spark.sql.DataFrame]] to be executed + * @param kylinAnswer the expected result in a [[Seq]] of [[org.apache.spark.sql.Row]]s. + * @param checkToRDD whether to verify deserialization to an RDD. This runs the query twice. + */ def checkAnswerBySeq(sparkDF: DataFrame, kylinAnswer: Seq[Row], checkOrder: Boolean = false, @@ -90,7 +85,7 @@ object SparderQueryTest extends Logging { |Timezone: ${TimeZone.getDefault} |Timezone Env: ${sys.env.getOrElse("TZ", "")} | - |${sparkDF.queryExecution} + |${sparkDF.queryExecution} |== Results == |$results """.stripMargin @@ -145,15 +140,15 @@ object SparderQueryTest extends Logging { /** - * Runs the plan and makes sure the answer is within absTol of the expected result. - * - * @param actualAnswer the actual result in a [[Row]]. - * @param expectedAnswer the expected result in a[[Row]]. - * @param absTol the absolute tolerance between actual and expected answers. - */ + * Runs the plan and makes sure the answer is within absTol of the expected result. + * + * @param actualAnswer the actual result in a [[Row]]. + * @param expectedAnswer the expected result in a[[Row]]. + * @param absTol the absolute tolerance between actual and expected answers. + */ protected def checkAggregatesWithTol(actualAnswer: Row, expectedAnswer: Row, - absTol: Double) = { + absTol: Double): Unit = { require(actualAnswer.length == expectedAnswer.length, s"actual answer length ${actualAnswer.length} != " + s"expected answer length ${expectedAnswer.length}") @@ -170,17 +165,6 @@ object SparderQueryTest extends Logging { } } - def checkAsyncResultData(dataFrame: DataFrame, - sparkSession: SparkSession): Unit = { - val path = KapConfig.getInstanceFromEnv.getAsyncResultBaseDir(null) + "/" + - QueryContext.current.getQueryId - val rows = sparkSession.read - .format("org.apache.spark.sql.execution.datasources.csv.CSVFileFormat") - .load(path) - val maybeString = checkAnswer(dataFrame, rows) - - } - private def isCharType(dataType: Integer): Boolean = { if (dataType == Types.CHAR || dataType == Types.VARCHAR) { return true @@ -190,7 +174,7 @@ object SparderQueryTest extends Logging { private def isIntType(structField: StructField): Boolean = { val intList = scala.collection.immutable.List(Types.TINYINT, Types.SMALLINT, Types.INTEGER, Types.BIGINT) - val dataType = structField.getDataType(); + val dataType = structField.getDataType if (intList.contains(dataType)) { return true } @@ -199,19 +183,23 @@ object SparderQueryTest extends Logging { private def isBinaryType(structField: StructField): Boolean = { val intList = scala.collection.immutable.List(Types.BINARY, Types.VARBINARY, Types.LONGVARBINARY) - val dataType = structField.getDataType(); + val dataType = structField.getDataType; if (intList.contains(dataType)) { return true } false } + private def isArrayType(field: StructField): Boolean = { + field.getDataTypeName == "ARRAY<STRING>" || field.getDataTypeName == "ARRAY" + } + def isSameDataType(cubeStructField: StructField, sparkStructField: StructField): Boolean = { - if (cubeStructField.getDataType() == sparkStructField.getDataType()) { + if (cubeStructField.getDataType == sparkStructField.getDataType) { return true } // calcite dataTypeName = "ANY" - if (cubeStructField.getDataType() == 2000) { + if (cubeStructField.getDataType == 2000) { return true } if (isCharType(cubeStructField.getDataType()) && isCharType(sparkStructField.getDataType())) { @@ -223,6 +211,9 @@ object SparderQueryTest extends Logging { if (isBinaryType(cubeStructField) && isBinaryType(sparkStructField)) { return true } + if (isArrayType(cubeStructField) && isArrayType(sparkStructField)) { + return true + } false }