This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 172d68e3be1b [SPARK-55647][SQL] Improve `ConstantPropagation` for
collated `AttributeReference`s
172d68e3be1b is described below
commit 172d68e3be1b95fe47ed7b75f1de697006d2b304
Author: ilicmarkodb <[email protected]>
AuthorDate: Tue Mar 3 23:32:41 2026 +0800
[SPARK-55647][SQL] Improve `ConstantPropagation` for collated
`AttributeReference`s
### What changes were proposed in this pull request?
The previous change (https://github.com/apache/spark/pull/54435) completely
blocked `ConstantPropagation` for non-binary-stable types, leading to potential
performance implications. In this PR, I propose improving the rule to replace
collated `AttributeReferences` when it is safe.
### Why are the changes needed?
Perf improvement.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
New tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #54515 from ilicmarkodb/improve_ConstantPropagation.
Authored-by: ilicmarkodb <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
.../spark/sql/catalyst/optimizer/expressions.scala | 49 +++--
.../spark/sql/collation/CollationSuite.scala | 230 ++++++++++++++++++++-
2 files changed, 260 insertions(+), 19 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index e406c51e7f8a..53a5e0f7eccf 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -218,28 +218,51 @@ object ConstantPropagation extends Rule[LogicalPlan] {
// substituted into `1 + 1 = 1` if 'c' isn't nullable. If 'c' is nullable
then the enclosing
// NOT prevents us to do the substitution as NOT flips the context
(`nullIsFalse`) of what a
// null result of the enclosed expression means.
- //
- // Also, we shouldn't replace attributes with non-binary-stable data types,
since this can lead
- // to incorrect results. For example:
- // `CREATE TABLE t (c STRING COLLATE UTF8_LCASE);`
- // `INSERT INTO t VALUES ('HELLO'), ('hello');`
- // `SELECT * FROM t WHERE c = 'hello' AND c = 'HELLO' COLLATE UNICODE;`
- // If we replace `c` with `'hello'`, we get `'hello' = 'HELLO' COLLATE
UNICODE` for the right
- // condition, which is false, while the original `c = 'HELLO' COLLATE
UNICODE` is true for
- // 'HELLO' and false for 'hello'.
private def safeToReplace(ar: AttributeReference, nullIsFalse: Boolean) =
- (!ar.nullable || nullIsFalse) && isBinaryStable(ar.dataType)
+ !ar.nullable || nullIsFalse
private def replaceConstants(
condition: Expression,
equalityPredicates: AttributeMap[(Literal, BinaryComparison)]):
Expression = {
val constantsMap = AttributeMap(equalityPredicates.map { case (attr, (lit,
_)) => attr -> lit })
val predicates = equalityPredicates.values.map(_._2).toSet
- condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) {
- case b: BinaryComparison if !predicates.contains(b) => b transform {
- case a: AttributeReference => constantsMap.getOrElse(a, a)
+ def replaceInComparison(b: BinaryComparison): Expression = {
+ lazy val collationSafeReplacement = isSameCollationAttrRefComparison(b)
+ b transform {
+ case a: AttributeReference
+ if isBinaryStable(a.dataType) || collationSafeReplacement =>
+ constantsMap.getOrElse(a, a)
}
}
+ condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) {
+ case b: BinaryComparison if !predicates.contains(b) =>
replaceInComparison(b)
+ }
+ }
+
+ /**
+ * Binary-stable `AttributeReference`s can always be replaced safely.
Non-binary-stable
+ * `AttributeReference`s (i.e., those with a non-`UTF8_BINARY` `StringType`)
are only replaced
+ * when both sides of the comparison are `AttributeReference`s (or
`CollationKey`-wrapped
+ * `AttributeReference`s) with the same `StringType`, preventing
substitution inside
+ * expressions that change the effective collation (e.g., a `Cast`). For
example, given a
+ * column `c STRING COLLATE UTF8_LCASE`:
+ *
+ * `c = 'hello' AND c = 'HELLO' COLLATE UNICODE`
+ *
+ * `c` is added to `constantsMap`. In the right-hand comparison, `c` is
wrapped with a
+ * `Cast` to `UNICODE`, so we don't have an `AttributeReference` vs.
`AttributeReference`
+ * comparison and `c` is not replaced inside the `Cast`, preserving
correctness.
+ */
+ private def isSameCollationAttrRefComparison(b: BinaryComparison): Boolean =
{
+ (b.left, b.right) match {
+ case (AttributeReference(_, st1: StringType, _, _),
+ AttributeReference(_, st2: StringType, _, _)) =>
+ st1 == st2
+ case (CollationKey(AttributeReference(_, st1: StringType, _, _)),
+ CollationKey(AttributeReference(_, st2: StringType, _, _))) =>
+ st1 == st2
+ case _ => false
+ }
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala
index 734c1166c403..b446b29f7c68 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala
@@ -2249,16 +2249,234 @@ class CollationSuite extends DatasourceV2SQLBase with
AdaptiveSparkPlanHelper {
}
}
- test("ConstantPropagation does not replace attributes with non-binary-stable
collation") {
- val tableName = "t1"
- withTable(tableName) {
- sql(s"CREATE TABLE $tableName (c STRING COLLATE UTF8_LCASE)")
- sql(s"INSERT INTO $tableName VALUES ('hello'), ('HELLO')")
+ test("ConstantPropagation: does not replace attributes with
non-binary-stable collation") {
+ withTable("t1") {
+ sql("CREATE TABLE t1 (c STRING COLLATE UTF8_LCASE)")
+ sql("INSERT INTO t1 VALUES ('hello'), ('HELLO')")
+
+ checkAnswer(
+ sql("SELECT * FROM t1 WHERE 'hello' = c AND c = 'HELLO' COLLATE
UNICODE"),
+ Row("HELLO")
+ )
+ }
+ }
+
+ test("ConstantPropagation: does not replace non-binary-stable attributes
with EqualNullSafe") {
+ withTable("t1") {
+ sql("CREATE TABLE t1 (c STRING COLLATE UTF8_LCASE)")
+ sql("INSERT INTO t1 VALUES ('hello'), ('HELLO'), (NULL)")
+
+ checkAnswer(
+ sql("SELECT * FROM t1 WHERE 'hello' <=> c AND c <=> 'HELLO' COLLATE
UNICODE"),
+ Row("HELLO")
+ )
+ }
+ }
+
+ test("ConstantPropagation: replaces binary-stable attributes with
contradicting predicates") {
+ withTable("t1") {
+ sql("CREATE TABLE t1 (c STRING)")
+ sql("INSERT INTO t1 VALUES ('hello'), ('world')")
+
+ checkAnswer(
+ sql("SELECT * FROM t1 WHERE c = 'hello' AND c = 'world'"),
+ Seq.empty
+ )
+ }
+ }
+
+ test("ConstantPropagation: replaces binary-stable attributes across
collation cast") {
+ withTable("t1") {
+ sql("CREATE TABLE t1 (c STRING)")
+ sql("INSERT INTO t1 VALUES ('hello'), ('HELLO')")
+
+ checkAnswer(
+ sql("SELECT * FROM t1 WHERE c = 'hello' AND c COLLATE UTF8_LCASE =
'HELLO'"),
+ Row("hello")
+ )
+ }
+ }
+
+ test("ConstantPropagation: does not replace non-binary-stable " +
+ "attributes with explicit CAST collation") {
+ withTable("t1") {
+ sql("CREATE TABLE t1 (c STRING COLLATE UTF8_LCASE)")
+ sql("INSERT INTO t1 VALUES ('hello'), ('HELLO')")
+
+ checkAnswer(
+ sql("""SELECT * FROM t1 WHERE c = 'hello'
+ |AND CAST(c AS STRING COLLATE UNICODE) = 'HELLO'""".stripMargin),
+ Row("HELLO")
+ )
+ }
+ }
+
+ test("ConstantPropagation: replaces non-binary-stable attributes in
same-collation comparison") {
+ withTable("t1") {
+ sql("CREATE TABLE t1 (col1 STRING COLLATE UTF8_LCASE, col2 STRING
COLLATE UTF8_LCASE)")
+ sql("INSERT INTO t1 VALUES ('hello', 'hello'), ('HELLO', 'hello'),
('hello', 'world')")
+
+ checkAnswer(
+ sql("SELECT * FROM t1 WHERE col1 = 'hello' AND col1 = col2"),
+ Seq(Row("hello", "hello"), Row("HELLO", "hello"))
+ )
+ }
+ }
+
+ test("ConstantPropagation: does not replace non-binary-stable attribute " +
+ "in different-collation column comparison") {
+ withTable("t1") {
+ sql("CREATE TABLE t1 (col1 STRING COLLATE UTF8_LCASE, col2 STRING
COLLATE UTF8_LCASE)")
+ sql("INSERT INTO t1 VALUES ('hello', 'hello'), ('HELLO', 'HELLO')")
checkAnswer(
- sql(s"SELECT * FROM $tableName WHERE c = 'hello' AND c = 'HELLO'
COLLATE UNICODE"),
+ sql("SELECT * FROM t1 WHERE col1 = 'hello' AND col1 COLLATE UNICODE =
col2"),
+ Seq(Row("HELLO", "HELLO"), Row("hello", "hello"))
+ )
+ }
+ }
+
+ test("ConstantPropagation: attribute is not propagated from inside NOT") {
+ withTable("t1") {
+ sql("CREATE TABLE t1 (c STRING COLLATE UTF8_LCASE)")
+ sql("INSERT INTO t1 VALUES ('hello'), ('HELLO'), ('world')")
+
+ checkAnswer(
+ sql("SELECT * FROM t1 WHERE NOT(c = 'world') AND c = 'HELLO' COLLATE
UNICODE"),
Row("HELLO")
)
}
}
+
+ test("ConstantPropagation: non-binary-stable attribute is not replaced
inside NOT") {
+ withTable("t1") {
+ sql("CREATE TABLE t1 (c STRING COLLATE UTF8_LCASE)")
+ sql("INSERT INTO t1 VALUES ('hello'), ('HELLO')")
+
+ checkAnswer(
+ sql("SELECT * FROM t1 WHERE 'HELLO' = c AND NOT(c = 'HELLO' COLLATE
UNICODE)"),
+ Row("hello")
+ )
+ }
+ }
+
+ test("ConstantPropagation: predicates do not propagate across OR branches") {
+ withTable("t1") {
+ sql("CREATE TABLE t1 (c STRING COLLATE UTF8_LCASE)")
+ sql("INSERT INTO t1 VALUES ('hello'), ('HELLO'), ('world')")
+
+ checkAnswer(
+ sql("""SELECT * FROM t1 WHERE (c = 'hello' AND c = 'HELLO' COLLATE
UNICODE)
+ |OR c = 'world'""".stripMargin),
+ Seq(Row("HELLO"), Row("world"))
+ )
+ }
+ }
+
+ test("ConstantPropagation: non-binary-stable join matches
case-insensitively") {
+ withTable("t1", "t2") {
+ sql("CREATE TABLE t1 (a STRING COLLATE UTF8_LCASE)")
+ sql("CREATE TABLE t2 (b STRING COLLATE UTF8_LCASE)")
+ sql("INSERT INTO t1 VALUES ('hello'), ('HELLO'), ('world')")
+ sql("INSERT INTO t2 VALUES ('hello')")
+
+ checkAnswer(
+ sql("SELECT t1.a FROM t1 JOIN t2 ON t1.a = t2.b WHERE t1.a = 'hello'"),
+ Seq(Row("hello"), Row("HELLO"))
+ )
+ }
+ }
+
+ test("ConstantPropagation: does not replace non-binary-stable attribute " +
+ "in cross-collation join condition") {
+ withTable("t1", "t2") {
+ sql("CREATE TABLE t1 (a STRING COLLATE UTF8_LCASE)")
+ sql("CREATE TABLE t2 (b STRING COLLATE UTF8_LCASE)")
+ sql("INSERT INTO t1 VALUES ('hello'), ('HELLO'), ('world')")
+ sql("INSERT INTO t2 VALUES ('hello')")
+
+ checkAnswer(
+ sql("""SELECT t1.a FROM t1 JOIN t2 ON t1.a = t2.b COLLATE UNICODE
+ |WHERE t1.a = 'hello'""".stripMargin),
+ Row("hello")
+ )
+ }
+ }
+
+ test("ConstantPropagation: does not replace non-binary-stable attribute " +
+ "in cross-collation join filter") {
+ withTable("t1", "t2") {
+ sql("CREATE TABLE t1 (a STRING COLLATE UTF8_LCASE)")
+ sql("CREATE TABLE t2 (b STRING COLLATE UTF8_LCASE)")
+ sql("INSERT INTO t1 VALUES ('hello'), ('HELLO'), ('world')")
+ sql("INSERT INTO t2 VALUES ('hello')")
+
+ checkAnswer(
+ sql("""SELECT t1.a FROM t1 JOIN t2 ON t1.a = t2.b
+ |WHERE t1.a = 'hello' AND t1.a = 'HELLO' COLLATE
UNICODE""".stripMargin),
+ Row("HELLO")
+ )
+ }
+ }
+
+ test("ConstantPropagation: binary-stable join correctly replaces
attributes") {
+ withTable("t1", "t2") {
+ sql("CREATE TABLE t1 (a STRING)")
+ sql("CREATE TABLE t2 (b STRING)")
+ sql("INSERT INTO t1 VALUES ('hello'), ('HELLO')")
+ sql("INSERT INTO t2 VALUES ('hello')")
+
+ checkAnswer(
+ sql("SELECT t1.a FROM t1 JOIN t2 ON t1.a = t2.b WHERE t1.a = 'hello'"),
+ Row("hello")
+ )
+ }
+ }
+
+ test("ConstantPropagation: binary-stable join replaces attributes " +
+ "across collation cast in join condition") {
+ withTable("t1", "t2") {
+ sql("CREATE TABLE t1 (a STRING)")
+ sql("CREATE TABLE t2 (b STRING)")
+ sql("INSERT INTO t1 VALUES ('hello'), ('HELLO')")
+ sql("INSERT INTO t2 VALUES ('hello')")
+
+ checkAnswer(
+ sql("""SELECT t1.a FROM t1 JOIN t2 ON t1.a = t2.b COLLATE UTF8_LCASE
+ |WHERE t1.a = 'hello'""".stripMargin),
+ Row("hello")
+ )
+ }
+ }
+
+ test("ConstantPropagation: does not replace non-binary-stable attribute " +
+ "in null-safe join filter") {
+ withTable("t1", "t2") {
+ sql("CREATE TABLE t1 (a STRING COLLATE UTF8_LCASE)")
+ sql("CREATE TABLE t2 (b STRING COLLATE UTF8_LCASE)")
+ sql("INSERT INTO t1 VALUES ('hello'), ('HELLO'), (NULL)")
+ sql("INSERT INTO t2 VALUES ('hello'), (NULL)")
+
+ checkAnswer(
+ sql("""SELECT t1.a FROM t1 JOIN t2 ON t1.a <=> t2.b
+ |WHERE t1.a = 'hello' AND t1.a = 'HELLO' COLLATE
UNICODE""".stripMargin),
+ Row("HELLO")
+ )
+ }
+ }
+
+ test("ConstantPropagation: non-binary-stable null-safe join condition " +
+ "matches case-insensitively") {
+ withTable("t1", "t2") {
+ sql("CREATE TABLE t1 (a STRING COLLATE UTF8_LCASE)")
+ sql("CREATE TABLE t2 (b STRING COLLATE UTF8_LCASE)")
+ sql("INSERT INTO t1 VALUES ('hello'), ('HELLO'), ('world')")
+ sql("INSERT INTO t2 VALUES ('hello')")
+
+ checkAnswer(
+ sql("SELECT t1.a FROM t1 JOIN t2 ON t1.a <=> t2.b WHERE t1.a =
'hello'"),
+ Seq(Row("hello"), Row("HELLO"))
+ )
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]