wypoon commented on code in PR #9192: URL: https://github.com/apache/iceberg/pull/9192#discussion_r1430865455
########## spark/v3.5/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala: ########## @@ -40,14 +37,23 @@ import org.apache.spark.sql.types.StructType object ReplaceStaticInvoke extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = - plan.transformWithPruning (_.containsAllPatterns(BINARY_COMPARISON, FILTER)) { + plan.transformWithPruning (_.containsPattern(FILTER)) { case filter @ Filter(condition, _) => - val newCondition = condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) { + val newCondition = condition.transformWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN, INSET)) { case c @ BinaryComparison(left: StaticInvoke, right) if canReplace(left) && right.foldable => c.withNewChildren(Seq(replaceStaticInvoke(left), right)) case c @ BinaryComparison(left, right: StaticInvoke) if canReplace(right) && left.foldable => c.withNewChildren(Seq(left, replaceStaticInvoke(right))) + + case in @ In(systemFunction: StaticInvoke, values) + if canReplace(systemFunction) && values.forall(_.foldable) => + in.copy(value = replaceStaticInvoke(systemFunction)) + + case in@InSet(systemFunction: StaticInvoke, _) if canReplace(systemFunction) => Review Comment: Nit: `in @ InSet` to be consistent with the style. ########## spark/v3.5/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestSystemFunctionPushDownDQL.java: ########## @@ -213,17 +198,77 @@ private void testBucketLongFunction(boolean partitioned) { String query = String.format( "SELECT * FROM %s WHERE system.bucket(5, id) <= %s ORDER BY id", tableName, target); + checkQueryExecution(query, partitioned, lessThanOrEqual(bucket("id", 5), target)); + } + + @Test + public void testBucketLongFunctionInClauseOnUnpartitionedTable() { + createUnpartitionedTable(spark, tableName); + testBucketLongFunctionInClause(false); + } + + @Test + public void testBucketLongFunctionInClauseOnPartitionedTable() { + createPartitionedTable(spark, tableName, "bucket(5, id)"); + testBucketLongFunctionInClause(true); + } + + private void testBucketLongFunctionInClause(boolean partitioned) { + List<Integer> range = IntStream.range(0, 3).boxed().collect(Collectors.toList()); + String rangeAsSql = + range.stream().map(x -> Integer.toString(x)).collect(Collectors.joining(", ")); + String query = + String.format( + "SELECT * FROM %s WHERE system.bucket(5, id) IN (%s) ORDER BY id", + tableName, rangeAsSql); + checkQueryExecution(query, partitioned, in(bucket("id", 5), range.toArray())); + } + + private void checkQueryExecution( + String query, boolean partitioned, org.apache.iceberg.expressions.Expression expression) { Dataset<Row> df = spark.sql(query); LogicalPlan optimizedPlan = df.queryExecution().optimizedPlan(); checkExpressions(optimizedPlan, partitioned, "bucket"); - checkPushedFilters(optimizedPlan, lessThanOrEqual(bucket("id", 5), target)); + checkPushedFilters(optimizedPlan, expression); List<Object[]> actual = rowsToJava(df.collectAsList()); Assertions.assertThat(actual.size()).isEqualTo(5); } + @Test + public void testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiteralsOnPartitionedTable() { + createPartitionedTable(spark, tableName, "bucket(5, id)"); + testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiterals(); + } + + @Test + public void testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiteralsOnUnpartitionedTable() { + createUnpartitionedTable(spark, tableName); + testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiterals(); + } + + private void testBucketLongFunctionIsNotReplacedWhenArgumentsAreNotLiterals() { + List<Integer> range = IntStream.range(0, 3).boxed().collect(Collectors.toList()); + String rangeAsSql = + range.stream().map(x -> Integer.toString(x)).collect(Collectors.joining(", ")); + String query = + String.format( + "SELECT * FROM %s WHERE system.bucket(5, id) IN (system.bucket(5, id), 1) ORDER BY id", + tableName, rangeAsSql); Review Comment: `range` and `rangeAsSql` are not needed, since the `IN` clause here doesn't use `rangeAsSql`. ########## spark/v3.5/spark-extensions/src/main/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceStaticInvoke.scala: ########## @@ -40,14 +37,23 @@ import org.apache.spark.sql.types.StructType object ReplaceStaticInvoke extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = - plan.transformWithPruning (_.containsAllPatterns(BINARY_COMPARISON, FILTER)) { + plan.transformWithPruning (_.containsPattern(FILTER)) { case filter @ Filter(condition, _) => - val newCondition = condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) { + val newCondition = condition.transformWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN, INSET)) { case c @ BinaryComparison(left: StaticInvoke, right) if canReplace(left) && right.foldable => c.withNewChildren(Seq(replaceStaticInvoke(left), right)) case c @ BinaryComparison(left, right: StaticInvoke) if canReplace(right) && left.foldable => c.withNewChildren(Seq(left, replaceStaticInvoke(right))) + + case in @ In(systemFunction: StaticInvoke, values) + if canReplace(systemFunction) && values.forall(_.foldable) => Review Comment: Nit: Indent the `if` one more level for readability of this block? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@iceberg.apache.org For additional commands, e-mail: issues-h...@iceberg.apache.org