This is an automated email from the ASF dual-hosted git repository.

xxyu pushed a commit to branch kylin5
in repository https://gitbox.apache.org/repos/asf/kylin.git

commit 1b9d519bbb87268d20248dec27819c793b368c28
Author: fengguangyuan <qq272101...@gmail.com>
AuthorDate: Thu May 25 10:25:37 2023 +0800

    KYLIN-5704 Avoid exceptions on iterating filters and support pruning 
segments with cast-in expressions on dimension columns.
    
    ---------
    
    Co-authored-by: Guangyuan Feng <guangyuan.f...@kyligence.io>
---
 .../org/apache/kylin/newten/NFilePruningTest.java  | 19 +++++--
 .../sql/execution/datasource/FilePruner.scala      | 61 +++++++++++++++-------
 2 files changed, 57 insertions(+), 23 deletions(-)

diff --git 
a/src/kylin-it/src/test/java/org/apache/kylin/newten/NFilePruningTest.java 
b/src/kylin-it/src/test/java/org/apache/kylin/newten/NFilePruningTest.java
index 93c99122ae..63578ae755 100644
--- a/src/kylin-it/src/test/java/org/apache/kylin/newten/NFilePruningTest.java
+++ b/src/kylin-it/src/test/java/org/apache/kylin/newten/NFilePruningTest.java
@@ -19,6 +19,8 @@
 
 package org.apache.kylin.newten;
 
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.List;
@@ -31,6 +33,7 @@ import org.apache.kylin.common.util.Pair;
 import org.apache.kylin.common.util.RandomUtil;
 import org.apache.kylin.common.util.TempMetadataBuilder;
 import org.apache.kylin.engine.spark.NLocalWithSparkSessionTest;
+import org.apache.kylin.guava30.shaded.common.collect.Lists;
 import org.apache.kylin.job.engine.JobEngineConfig;
 import org.apache.kylin.job.impl.threadpool.NDefaultScheduler;
 import org.apache.kylin.junit.TimeZoneTestRunner;
@@ -61,8 +64,6 @@ import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.sparkproject.guava.collect.Sets;
 
-import org.apache.kylin.guava30.shaded.common.collect.Lists;
-
 import lombok.val;
 import scala.runtime.AbstractFunction1;
 
@@ -510,9 +511,19 @@ public class NFilePruningTest extends 
NLocalWithSparkSessionTest implements Adap
         String modelId = "3f152495-44de-406c-9abf-b11d4132aaed";
         overwriteSystemProp("kylin.engine.persist-flattable-enabled", "true");
         buildMultiSegAndMerge("3f152495-44de-406c-9abf-b11d4132aaed");
-        populateSSWithCSVData(getTestConfig(), getProject(), 
SparderEnv.getSparkSession());
+        val ss = SparderEnv.getSparkSession();
+        populateSSWithCSVData(getTestConfig(), getProject(), ss);
 
         val lessThanEquality = base + "where TEST_KYLIN_FACT.ORDER_ID <= 10";
+        val castIn = base + "where test_date_enc in ('2014-12-20', 
'2014-12-21')";
+
+        Instant startDate = DateFormat.stringToDate("2014-12-20").toInstant();
+        String[] dates = new 
String[ss.sqlContext().conf().optimizerInSetConversionThreshold() + 1];
+        for (int i = 0; i < dates.length; i++) {
+            dates[i] = "'" + DateFormat.formatToDateStr(startDate.plus(i, 
ChronoUnit.DAYS).toEpochMilli()) + "'";
+        }
+        val largeCastIn = base + "where test_date_enc in (" + String.join(",", 
dates) + ")";
+
         val in = base + "where TEST_KYLIN_FACT.ORDER_ID in (4998, 4999)";
         val lessThan = base + "where TEST_KYLIN_FACT.ORDER_ID < 10";
         val and = base + "where PRICE < -99 AND TEST_KYLIN_FACT.ORDER_ID = 1";
@@ -526,7 +537,9 @@ public class NFilePruningTest extends 
NLocalWithSparkSessionTest implements Adap
         expectedRanges.add(segmentRange1);
         expectedRanges.add(segmentRange2);
 
+        assertResultsAndScanFiles(modelId, largeCastIn, 1, false, 
expectedRanges);
         assertResultsAndScanFiles(modelId, lessThanEquality, 2, false, 
expectedRanges);
+        assertResultsAndScanFiles(modelId, castIn, 1, false, expectedRanges);
         assertResultsAndScanFiles(modelId, in, 1, false, expectedRanges);
         assertResultsAndScanFiles(modelId, lessThan, 1, false, expectedRanges);
         assertResultsAndScanFiles(modelId, and, 1, false, expectedRanges);
diff --git 
a/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/execution/datasource/FilePruner.scala
 
b/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/execution/datasource/FilePruner.scala
index 483e35aa50..aaf33b3e4b 100644
--- 
a/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/execution/datasource/FilePruner.scala
+++ 
b/src/spark-project/spark-common/src/main/scala/org/apache/spark/sql/execution/datasource/FilePruner.scala
@@ -35,7 +35,7 @@ import org.apache.kylin.metadata.project.NProjectManager
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.analysis.Resolver
 import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, 
EmptyRow, Expression, Literal}
-import org.apache.spark.sql.catalyst.{InternalRow, expressions}
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, 
expressions}
 import org.apache.spark.sql.execution.datasources._
 import org.apache.spark.sql.sources._
 import org.apache.spark.sql.types.StructType
@@ -236,7 +236,7 @@ class FilePruner(val session: SparkSession,
     val project = dataflow.getProject
     val projectKylinConfig = 
NProjectManager.getInstance(KylinConfig.getInstanceFromEnv).getProject(project).getConfig
 
-    var selected = prunedSegmentDirs;
+    var selected = prunedSegmentDirs
     if (projectKylinConfig.isSkipEmptySegments) {
       selected = afterPruning("pruning empty segment", null, selected) {
         (_, segDirs) => pruneEmptySegments(segDirs)
@@ -376,17 +376,17 @@ class FilePruner(val session: SparkSession,
 
   private def pruneSegments(filters: Seq[Expression],
                             segDirs: Seq[SegmentDirectory]): 
Seq[SegmentDirectory] = {
-
-    val filteredStatuses = if (filters.isEmpty) {
+    val reducedFilter = filters.map(filter => convertCastFilter(filter))
+      .flatMap(f => DataSourceStrategy.translateFilter(f, true))
+      .reduceLeftOption(And)
+    if (reducedFilter.isEmpty) {
       segDirs
     } else {
-      val reducedFilter = filters.toList.map(filter => 
convertCastFilter(filter))
-        .flatMap(f => DataSourceStrategy.translateFilter(f, 
true)).reduceLeft(And)
       segDirs.filter {
         e => {
           if (dataflow.getSegment(e.segmentID).isOffsetCube) {
             val ksRange = dataflow.getSegment(e.segmentID).getKSRange
-            SegFilters(ksRange.getStart, ksRange.getEnd, 
pattern).foldStreamingFilter(reducedFilter) match {
+            SegFilters(ksRange.getStart, ksRange.getEnd, 
pattern).foldStreamingFilter(reducedFilter.get) match {
               case Trivial(true) => true
               case Trivial(false) => false
             }
@@ -394,7 +394,7 @@ class FilePruner(val session: SparkSession,
             val tsRange = dataflow.getSegment(e.segmentID).getTSRange
             val start = 
DateFormat.getFormatTimeStamp(tsRange.getStart.toString, pattern)
             val end = DateFormat.getFormatTimeStamp(tsRange.getEnd.toString, 
pattern)
-            SegFilters(start, end, pattern).foldFilter(reducedFilter) match {
+            SegFilters(start, end, pattern).foldFilter(reducedFilter.get) 
match {
               case Trivial(true) => true
               case Trivial(false) => false
             }
@@ -402,25 +402,23 @@ class FilePruner(val session: SparkSession,
         }
       }
     }
-    filteredStatuses
   }
 
   private def pruneSegmentsDimRange(filters: Seq[Expression],
                                     segDirs: Seq[SegmentDirectory]): 
Seq[SegmentDirectory] = {
     val hitColumns = Sets.newHashSet[String]()
     val project = options.getOrElse("project", "")
-    val filteredStatuses = if (filters.isEmpty) {
+    val reducedFilters = translateToSourceFilter(filters)
+
+    val filteredStatuses = if (reducedFilters.isEmpty) {
       segDirs
     } else {
-      val reducedFilter = filters.toList.map(filter => 
convertCastFilter(filter))
-        .flatMap(f => DataSourceStrategy.translateFilter(f, 
true)).reduceLeft(And)
-
       segDirs.filter {
         e => {
           val dimRange = 
dataflow.getSegment(e.segmentID).getDimensionRangeInfoMap
           if (dimRange != null && !dimRange.isEmpty) {
             SegDimFilters(dimRange, dataflow.getIndexPlan.getEffectiveDimCols, 
dataflow.getId, project, hitColumns)
-              .foldFilter(reducedFilter) match {
+              .foldFilter(reducedFilters.get) match {
               case Trivial(true) => true
               case Trivial(false) => false
             }
@@ -434,18 +432,39 @@ class FilePruner(val session: SparkSession,
     filteredStatuses
   }
 
+  private def translateToSourceFilter(filters: Seq[Expression]): 
Option[Filter] = {
+    filters.map(filter => convertCastFilter(filter))
+      .flatMap(f => {
+        DataSourceStrategy.translateFilter(f, true) match {
+          case v @ Some(_) => v
+          case None =>
+            // special cases which are forced pushed down by Kylin
+            f match {
+              case expressions.In(e @ expressions.Cast(a: Attribute, _, _, _), 
list)
+                if list.forall(_.isInstanceOf[Literal]) =>
+                val hSet = list.map(_.eval(EmptyRow))
+                val toScala = 
CatalystTypeConverters.createToScalaConverter(e.dataType)
+                Some(In(a.name, hSet.toArray.map(toScala)))
+              case expressions.InSet(e @ expressions.Cast(a: Attribute, _, _, 
_), set) =>
+                val toScala = 
CatalystTypeConverters.createToScalaConverter(e.dataType)
+                Some(In(a.name, set.toArray.map(toScala)))
+              case _ => None
+            }
+        }
+      }).reduceLeftOption(And)
+  }
+
   private def pruneShards(filters: Seq[Expression],
                           segDirs: Seq[SegmentDirectory]): 
Seq[SegmentDirectory] = {
-    val filteredStatuses = if (layout.getShardByColumns.size() != 1) {
+    val normalizedFiltersAndExpr = filters.reduceOption(expressions.And)
+    if (layout.getShardByColumns.size() != 1 || 
normalizedFiltersAndExpr.isEmpty) {
       segDirs
     } else {
-      val normalizedFiltersAndExpr = filters.reduce(expressions.And)
-
       val pruned = segDirs.map { case SegmentDirectory(segID, partitions, 
files) =>
         val partitionNumber = 
dataflow.getSegment(segID).getLayout(layout.getId).getPartitionNum
         require(partitionNumber > 0, "Shards num with shard by col should 
greater than 0.")
 
-        val bitSet = getExpressionShards(normalizedFiltersAndExpr, 
shardByColumn.name, partitionNumber)
+        val bitSet = getExpressionShards(normalizedFiltersAndExpr.get, 
shardByColumn.name, partitionNumber)
 
         val selected = files.filter(f => {
           val partitionId = FilePruner.getPartitionId(f.getPath)
@@ -455,7 +474,6 @@ class FilePruner(val session: SparkSession,
       }
       pruned
     }
-    filteredStatuses
   }
 
   override lazy val inputFiles: Array[String] = Array.empty[String]
@@ -517,7 +535,10 @@ class FilePruner(val session: SparkSession,
     }
   }
 
-  //  translate for filter type match
+  /**
+   * Note: This is a dangerous method to extract Cast value in comparison 
expressions, without considering
+   * precision losing for numerics etc, which will produce a different 
semantic expression.
+   */
   private def convertCastFilter(filter: Expression): Expression = {
     filter match {
       case expressions.EqualTo(expressions.Cast(a: Attribute, _, _, _), 
Literal(v, t)) =>

Reply via email to