Repository: spark
Updated Branches:
  refs/heads/branch-2.0 e11046457 -> 6347ff512


[SPARK-15647][SQL] Fix Boundary Cases in OptimizeCodegen Rule

#### What changes were proposed in this pull request?

The following condition in the Optimizer rule `OptimizeCodegen` is not right.
```Scala
branches.size < conf.maxCaseBranchesForCodegen
```

- The number of branches in case when clause should be `branches.size + 
elseBranch.size`.
- `maxCaseBranchesForCodegen` is the maximum boundary for enabling codegen. 
Thus, we should use `<=` instead of `<`.

This PR is to fix this boundary case and also add missing test cases for 
verifying the conf `MAX_CASES_BRANCHES`.

#### How was this patch tested?
Added test cases in `SQLConfSuite`

Author: gatorsmile <[email protected]>

Closes #13392 from gatorsmile/maxCaseWhen.

(cherry picked from commit d67c82e4b647dacd0b789d72c9eaf4dc7d329dbd)
Signed-off-by: Wenchen Fan <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6347ff51
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6347ff51
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6347ff51

Branch: refs/heads/branch-2.0
Commit: 6347ff512d1e11106e44609a59be25a296aef731
Parents: e110464
Author: gatorsmile <[email protected]>
Authored: Tue May 31 10:08:00 2016 -0700
Committer: Wenchen Fan <[email protected]>
Committed: Tue May 31 10:08:10 2016 -0700

----------------------------------------------------------------------
 .../sql/catalyst/optimizer/Optimizer.scala      |  8 ++++--
 .../spark/sql/internal/SQLConfSuite.scala       | 29 ++++++++++++++++++++
 2 files changed, 35 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6347ff51/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 688c77d..93762ad 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -937,8 +937,12 @@ object SimplifyConditionals extends Rule[LogicalPlan] with 
PredicateHelper {
  */
 case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
-    case e @ CaseWhen(branches, _) if branches.size < 
conf.maxCaseBranchesForCodegen =>
-      e.toCodegen()
+    case e: CaseWhen if canCodegen(e) => e.toCodegen()
+  }
+
+  private def canCodegen(e: CaseWhen): Boolean = {
+    val numBranches = e.branches.size + e.elseValue.size
+    numBranches <= conf.maxCaseBranchesForCodegen
   }
 }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6347ff51/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
index 3d4fc75..2cd3f47 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.sql.internal
 
 import org.apache.spark.sql.{QueryTest, Row, SparkSession, SQLContext}
+import org.apache.spark.sql.execution.WholeStageCodegenExec
 import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext}
 
 class SQLConfSuite extends QueryTest with SharedSQLContext {
@@ -219,4 +220,32 @@ class SQLConfSuite extends QueryTest with SharedSQLContext 
{
     }
   }
 
+  test("MAX_CASES_BRANCHES") {
+    withTable("tab1") {
+      spark.range(10).write.saveAsTable("tab1")
+      val sql_one_branch_caseWhen = "SELECT CASE WHEN id = 1 THEN 1 END FROM 
tab1"
+      val sql_two_branch_caseWhen = "SELECT CASE WHEN id = 1 THEN 1 ELSE 0 END 
FROM tab1"
+
+      withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "0") {
+        assert(!sql(sql_one_branch_caseWhen)
+          .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
+        assert(!sql(sql_two_branch_caseWhen)
+          .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
+      }
+
+      withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "1") {
+        assert(sql(sql_one_branch_caseWhen)
+          .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
+        assert(!sql(sql_two_branch_caseWhen)
+          .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
+      }
+
+      withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "2") {
+        assert(sql(sql_one_branch_caseWhen)
+          .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
+        assert(sql(sql_two_branch_caseWhen)
+          .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to