This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 7d4c1b8  [SPARK-25121][SQL] Supports multi-part table names for 
broadcast hint resolution
7d4c1b8 is described below

commit 7d4c1b894ef32170b421f51c42cf30198c35c21b
Author: Takeshi Yamamuro <[email protected]>
AuthorDate: Thu Mar 19 20:11:04 2020 -0700

    [SPARK-25121][SQL] Supports multi-part table names for broadcast hint 
resolution
    
    ### What changes were proposed in this pull request?
    
    This pr fixed code to respect a database name for broadcast table hint 
resolution.
    Currently, spark ignores a database name in multi-part names;
    ```
    scala> sql("CREATE DATABASE testDb")
    scala> spark.range(10).write.saveAsTable("testDb.t")
    
    // without this patch
    scala> spark.range(10).join(spark.table("testDb.t"), 
"id").hint("broadcast", "testDb.t").explain
    == Physical Plan ==
    *(2) Project [id#24L]
    +- *(2) BroadcastHashJoin [id#24L], [id#26L], Inner, BuildLeft
       :- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, 
false]))
       :  +- *(1) Range (0, 10, step=1, splits=4)
       +- *(2) Project [id#26L]
          +- *(2) Filter isnotnull(id#26L)
             +- *(2) FileScan parquet testdb.t[id#26L] Batched: true, Format: 
Parquet, Location: 
InMemoryFileIndex[file:/Users/maropu/Repositories/spark/spark-2.3.1-bin-hadoop2.7/spark-warehouse...,
 PartitionFilters: [], PushedFilters: [IsNotNull(id)], ReadSchema: 
struct<id:bigint>
    
    // with this patch
    scala> spark.range(10).join(spark.table("testDb.t"), 
"id").hint("broadcast", "testDb.t").explain
    == Physical Plan ==
    *(2) Project [id#3L]
    +- *(2) BroadcastHashJoin [id#3L], [id#5L], Inner, BuildRight
       :- *(2) Range (0, 10, step=1, splits=4)
       +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, 
true]))
          +- *(1) Project [id#5L]
             +- *(1) Filter isnotnull(id#5L)
                +- *(1) FileScan parquet testdb.t[id#5L] Batched: true, Format: 
Parquet, Location: 
InMemoryFileIndex[file:/Users/maropu/Repositories/spark/spark-master/spark-warehouse/testdb.db/t],
 PartitionFilters: [], PushedFilters: [IsNotNull(id)], ReadSchema: 
struct<id:bigint>
    ```
    
    This PR comes from https://github.com/apache/spark/pull/22198
    
    ### Why are the changes needed?
    
    For better usability.
    
    ### Does this PR introduce any user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Added unit tests.
    
    Closes #27935 from maropu/SPARK-25121-2.
    
    Authored-by: Takeshi Yamamuro <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
    (cherry picked from commit ca499e94091ae62a6ee76ea779d7b2b4cf2dbc5c)
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 .../sql/catalyst/analysis/HintErrorLogger.scala    |  7 +-
 .../spark/sql/catalyst/analysis/ResolveHints.scala | 71 +++++++++++-----
 .../spark/sql/catalyst/expressions/package.scala   |  2 +-
 .../spark/sql/catalyst/plans/logical/hints.scala   |  3 +-
 .../spark/sql/catalyst/analysis/AnalysisTest.scala |  2 +
 .../sql/catalyst/analysis/ResolveHintsSuite.scala  | 48 +++++++++++
 .../sql/catalyst/analysis/TestRelations.scala      |  2 +
 .../org/apache/spark/sql/DataFrameJoinSuite.scala  | 98 +++++++++++++++++++++-
 .../spark/sql/execution/GlobalTempViewSuite.scala  | 24 +++++-
 9 files changed, 230 insertions(+), 27 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HintErrorLogger.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HintErrorLogger.scala
index c6e0c74..71c6d40 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HintErrorLogger.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HintErrorLogger.scala
@@ -24,15 +24,16 @@ import 
org.apache.spark.sql.catalyst.plans.logical.{HintErrorHandler, HintInfo}
  * The hint error handler that logs warnings for each hint error.
  */
 object HintErrorLogger extends HintErrorHandler with Logging {
+  import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
 
   override def hintNotRecognized(name: String, parameters: Seq[Any]): Unit = {
     logWarning(s"Unrecognized hint: ${hintToPrettyString(name, parameters)}")
   }
 
   override def hintRelationsNotFound(
-      name: String, parameters: Seq[Any], invalidRelations: Set[String]): Unit 
= {
-    invalidRelations.foreach { n =>
-      logWarning(s"Count not find relation '$n' specified in hint " +
+      name: String, parameters: Seq[Any], invalidRelations: Set[Seq[String]]): 
Unit = {
+    invalidRelations.foreach { ident =>
+      logWarning(s"Count not find relation '${ident.quoted}' specified in hint 
" +
         s"'${hintToPrettyString(name, parameters)}'.")
     }
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
index 5b77d67..81de086 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
@@ -64,31 +64,59 @@ object ResolveHints {
             
_.toUpperCase(Locale.ROOT)).contains(hintName.toUpperCase(Locale.ROOT))))
     }
 
+    // This method checks if given multi-part identifiers are matched with 
each other.
+    // The [[ResolveJoinStrategyHints]] rule is applied before the resolution 
batch
+    // in the analyzer and we cannot semantically compare them at this stage.
+    // Therefore, we follow a simple rule; they match if an identifier in a 
hint
+    // is a tail of an identifier in a relation. This process is independent 
of a session
+    // catalog (`currentDb` in [[SessionCatalog]]) and it just compares them 
literally.
+    //
+    // For example,
+    //  * in a query `SELECT /*+ BROADCAST(t) */ * FROM db1.t JOIN t`,
+    //    the broadcast hint will match both tables, `db1.t` and `t`,
+    //    even when the current db is `db2`.
+    //  * in a query `SELECT /*+ BROADCAST(default.t) */ * FROM default.t JOIN 
t`,
+    //    the broadcast hint will match the left-side table only, `default.t`.
+    private def matchedIdentifier(identInHint: Seq[String], identInQuery: 
Seq[String]): Boolean = {
+      if (identInHint.length <= identInQuery.length) {
+        identInHint.zip(identInQuery.takeRight(identInHint.length))
+          .forall { case (i1, i2) => resolver(i1, i2) }
+      } else {
+        false
+      }
+    }
+
+    private def extractIdentifier(r: SubqueryAlias): Seq[String] = {
+      r.identifier.qualifier :+ r.identifier.name
+    }
+
     private def applyJoinStrategyHint(
         plan: LogicalPlan,
-        relations: mutable.HashSet[String],
+        relationsInHint: Set[Seq[String]],
+        relationsInHintWithMatch: mutable.HashSet[Seq[String]],
         hintName: String): LogicalPlan = {
       // Whether to continue recursing down the tree
       var recurse = true
 
+      def matchedIdentifierInHint(identInQuery: Seq[String]): Boolean = {
+        relationsInHint.find(matchedIdentifier(_, identInQuery))
+          .map(relationsInHintWithMatch.add).nonEmpty
+      }
+
       val newNode = CurrentOrigin.withOrigin(plan.origin) {
         plan match {
           case ResolvedHint(u @ UnresolvedRelation(ident), hint)
-              if relations.exists(resolver(_, ident.last)) =>
-            relations.remove(ident.last)
+              if matchedIdentifierInHint(ident) =>
             ResolvedHint(u, createHintInfo(hintName).merge(hint, 
hintErrorHandler))
 
           case ResolvedHint(r: SubqueryAlias, hint)
-              if relations.exists(resolver(_, r.alias)) =>
-            relations.remove(r.alias)
+              if matchedIdentifierInHint(extractIdentifier(r)) =>
             ResolvedHint(r, createHintInfo(hintName).merge(hint, 
hintErrorHandler))
 
-          case u @ UnresolvedRelation(ident) if relations.exists(resolver(_, 
ident.last)) =>
-            relations.remove(ident.last)
+          case UnresolvedRelation(ident) if matchedIdentifierInHint(ident) =>
             ResolvedHint(plan, createHintInfo(hintName))
 
-          case r: SubqueryAlias if relations.exists(resolver(_, r.alias)) =>
-            relations.remove(r.alias)
+          case r: SubqueryAlias if 
matchedIdentifierInHint(extractIdentifier(r)) =>
             ResolvedHint(plan, createHintInfo(hintName))
 
           case _: ResolvedHint | _: View | _: With | _: SubqueryAlias =>
@@ -107,7 +135,9 @@ object ResolveHints {
       }
 
       if ((plan fastEquals newNode) && recurse) {
-        newNode.mapChildren(child => applyJoinStrategyHint(child, relations, 
hintName))
+        newNode.mapChildren { child =>
+          applyJoinStrategyHint(child, relationsInHint, 
relationsInHintWithMatch, hintName)
+        }
       } else {
         newNode
       }
@@ -120,17 +150,19 @@ object ResolveHints {
           ResolvedHint(h.child, createHintInfo(h.name))
         } else {
           // Otherwise, find within the subtree query plans to apply the hint.
-          val relationNames = h.parameters.map {
-            case tableName: String => tableName
-            case tableId: UnresolvedAttribute => tableId.name
+          val relationNamesInHint = h.parameters.map {
+            case tableName: String => 
UnresolvedAttribute.parseAttributeName(tableName)
+            case tableId: UnresolvedAttribute => tableId.nameParts
             case unsupported => throw new AnalysisException("Join strategy 
hint parameter " +
               s"should be an identifier or string but was $unsupported 
(${unsupported.getClass}")
-          }
-          val relationNameSet = new mutable.HashSet[String]
-          relationNames.foreach(relationNameSet.add)
-
-          val applied = applyJoinStrategyHint(h.child, relationNameSet, h.name)
-          hintErrorHandler.hintRelationsNotFound(h.name, h.parameters, 
relationNameSet.toSet)
+          }.toSet
+          val relationsInHintWithMatch = new mutable.HashSet[Seq[String]]
+          val applied = applyJoinStrategyHint(
+            h.child, relationNamesInHint, relationsInHintWithMatch, h.name)
+
+          // Filters unmatched relation identifiers in the hint
+          val unmatchedIdents = relationNamesInHint -- relationsInHintWithMatch
+          hintErrorHandler.hintRelationsNotFound(h.name, h.parameters, 
unmatchedIdents)
           applied
         }
     }
@@ -246,5 +278,4 @@ object ResolveHints {
         h.child
     }
   }
-
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index 1b59056..8bf1f19 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -196,7 +196,7 @@ package object expressions  {
       // For example, consider an example where "cat" is the catalog name, 
"db1" is the database
       // name, "a" is the table name and "b" is the column name and "c" is the 
struct field name.
       // If the name parts is cat.db1.a.b.c, then Attribute will match
-      // Attribute(b, qualifier("cat", "db1, "a")) and List("c") will be the 
second element
+      // Attribute(b, qualifier("cat", "db1", "a")) and List("c") will be the 
second element
       var matches: (Seq[Attribute], Seq[String]) = nameParts match {
         case catalogPart +: dbPart +: tblPart +: name +: nestedFields =>
           val key = (catalogPart.toLowerCase(Locale.ROOT), 
dbPart.toLowerCase(Locale.ROOT),
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
index f26e566..a325b61 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
@@ -186,7 +186,8 @@ trait HintErrorHandler {
    * @param parameters the hint parameters
    * @param invalidRelations the set of relation names that cannot be 
associated
    */
-  def hintRelationsNotFound(name: String, parameters: Seq[Any], 
invalidRelations: Set[String]): Unit
+  def hintRelationsNotFound(
+    name: String, parameters: Seq[Any], invalidRelations: Set[Seq[String]]): 
Unit
 
   /**
    * Callback for a join hint specified on a relation that is not part of a 
join.
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
index 3f8d409..4473c20 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
@@ -45,6 +45,8 @@ trait AnalysisTest extends PlanTest {
     catalog.createTempView("TaBlE", TestRelations.testRelation, 
overrideIfExists = true)
     catalog.createTempView("TaBlE2", TestRelations.testRelation2, 
overrideIfExists = true)
     catalog.createTempView("TaBlE3", TestRelations.testRelation3, 
overrideIfExists = true)
+    catalog.createGlobalTempView("TaBlE4", TestRelations.testRelation4, 
overrideIfExists = true)
+    catalog.createGlobalTempView("TaBlE5", TestRelations.testRelation5, 
overrideIfExists = true)
     new Analyzer(catalog, conf) {
       override val extendedResolutionRules = EliminateSubqueryAliases +: 
extendedAnalysisRules
     }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
index 5e66c03..ca7d284 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
@@ -241,4 +241,52 @@ class ResolveHintsSuite extends AnalysisTest {
       Project(testRelation.output, testRelation),
       caseSensitive = false)
   }
+
+  test("Supports multi-part table names for broadcast hint resolution") {
+    // local temp table (single-part identifier case)
+    checkAnalysis(
+      UnresolvedHint("MAPJOIN", Seq("table", "table2"),
+        table("TaBlE").join(table("TaBlE2"))),
+      Join(
+        ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
+        ResolvedHint(testRelation2, HintInfo(strategy = Some(BROADCAST))),
+        Inner,
+        None,
+        JoinHint.NONE),
+      caseSensitive = false)
+
+    checkAnalysis(
+      UnresolvedHint("MAPJOIN", Seq("TaBlE", "table2"),
+        table("TaBlE").join(table("TaBlE2"))),
+      Join(
+        ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
+        testRelation2,
+        Inner,
+        None,
+        JoinHint.NONE),
+      caseSensitive = true)
+
+    // global temp table (multi-part identifier case)
+    checkAnalysis(
+      UnresolvedHint("MAPJOIN", Seq("GlOBal_TeMP.table4", "table5"),
+        table("global_temp", "table4").join(table("global_temp", "table5"))),
+      Join(
+        ResolvedHint(testRelation4, HintInfo(strategy = Some(BROADCAST))),
+        ResolvedHint(testRelation5, HintInfo(strategy = Some(BROADCAST))),
+        Inner,
+        None,
+        JoinHint.NONE),
+      caseSensitive = false)
+
+    checkAnalysis(
+      UnresolvedHint("MAPJOIN", Seq("global_temp.TaBlE4", "table5"),
+        table("global_temp", "TaBlE4").join(table("global_temp", "TaBlE5"))),
+      Join(
+        ResolvedHint(testRelation4, HintInfo(strategy = Some(BROADCAST))),
+        testRelation5,
+        Inner,
+        None,
+        JoinHint.NONE),
+      caseSensitive = true)
+  }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
index e12e272..33b6029 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
@@ -44,6 +44,8 @@ object TestRelations {
     AttributeReference("g", StringType)(),
     AttributeReference("h", MapType(IntegerType, IntegerType))())
 
+  val testRelation5 = LocalRelation(AttributeReference("i", StringType)())
+
   val nestedRelation = LocalRelation(
     AttributeReference("top", StructType(
       StructField("duplicateField", StringType) ::
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index c7545bc..6b772e5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -17,10 +17,14 @@
 
 package org.apache.spark.sql
 
+import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike, LeftOuter, 
RightOuter}
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan, 
Project}
+import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, Filter, 
HintInfo, Join, JoinHint, LogicalPlan, Project}
+import org.apache.spark.sql.connector.catalog.CatalogManager
+import org.apache.spark.sql.execution.FileSourceScanExec
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
 import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
@@ -322,4 +326,96 @@ class DataFrameJoinSuite extends QueryTest
       }
     }
   }
+
+  test("Supports multi-part names for broadcast hint resolution") {
+    val (table1Name, table2Name) = ("t1", "t2")
+
+    withTempDatabase { dbName =>
+      withTable(table1Name, table2Name) {
+        withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+          spark.range(50).write.saveAsTable(s"$dbName.$table1Name")
+          spark.range(100).write.saveAsTable(s"$dbName.$table2Name")
+
+          def checkIfHintApplied(df: DataFrame): Unit = {
+            val sparkPlan = df.queryExecution.executedPlan
+            val broadcastHashJoins = sparkPlan.collect { case p: 
BroadcastHashJoinExec => p }
+            assert(broadcastHashJoins.size == 1)
+            val broadcastExchanges = broadcastHashJoins.head.collect {
+              case p: BroadcastExchangeExec => p
+            }
+            assert(broadcastExchanges.size == 1)
+            val tables = broadcastExchanges.head.collect {
+              case FileSourceScanExec(_, _, _, _, _, _, Some(tableIdent)) => 
tableIdent
+            }
+            assert(tables.size == 1)
+            assert(tables.head === TableIdentifier(table1Name, Some(dbName)))
+          }
+
+          def checkIfHintNotApplied(df: DataFrame): Unit = {
+            val sparkPlan = df.queryExecution.executedPlan
+            val broadcastHashJoins = sparkPlan.collect { case p: 
BroadcastHashJoinExec => p }
+            assert(broadcastHashJoins.isEmpty)
+          }
+
+          def sqlTemplate(tableName: String, hintTableName: String): DataFrame 
= {
+            sql(s"SELECT /*+ BROADCASTJOIN($hintTableName) */ * " +
+              s"FROM $tableName, $dbName.$table2Name " +
+              s"WHERE $tableName.id = $table2Name.id")
+          }
+
+          def dfTemplate(tableName: String, hintTableName: String): DataFrame 
= {
+            spark.table(tableName).join(spark.table(s"$dbName.$table2Name"), 
"id")
+              .hint("broadcast", hintTableName)
+          }
+
+          sql(s"USE $dbName")
+
+          checkIfHintApplied(sqlTemplate(table1Name, table1Name))
+          checkIfHintApplied(sqlTemplate(s"$dbName.$table1Name", 
s"$dbName.$table1Name"))
+          checkIfHintApplied(sqlTemplate(s"$dbName.$table1Name", table1Name))
+          checkIfHintNotApplied(sqlTemplate(table1Name, 
s"$dbName.$table1Name"))
+
+          checkIfHintApplied(dfTemplate(table1Name, table1Name))
+          checkIfHintApplied(dfTemplate(s"$dbName.$table1Name", 
s"$dbName.$table1Name"))
+          checkIfHintApplied(dfTemplate(s"$dbName.$table1Name", table1Name))
+          checkIfHintApplied(dfTemplate(table1Name, s"$dbName.$table1Name"))
+          checkIfHintApplied(dfTemplate(table1Name,
+            s"${CatalogManager.SESSION_CATALOG_NAME}.$dbName.$table1Name"))
+
+          withView("tv") {
+            sql(s"CREATE VIEW tv AS SELECT * FROM $dbName.$table1Name")
+            checkIfHintApplied(sqlTemplate("tv", "tv"))
+            checkIfHintNotApplied(sqlTemplate("tv", s"$dbName.tv"))
+
+            checkIfHintApplied(dfTemplate("tv", "tv"))
+            checkIfHintApplied(dfTemplate("tv", s"$dbName.tv"))
+          }
+        }
+      }
+    }
+  }
+
+  test("The same table name exists in two databases for broadcast hint 
resolution") {
+    val (db1Name, db2Name) = ("db1", "db2")
+
+    withDatabase(db1Name, db2Name) {
+      withTable("t") {
+        withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+          sql(s"CREATE DATABASE $db1Name")
+          sql(s"CREATE DATABASE $db2Name")
+          spark.range(1).write.saveAsTable(s"$db1Name.t")
+          spark.range(1).write.saveAsTable(s"$db2Name.t")
+
+          // Checks if a broadcast hint applied in both sides
+          val statement = s"SELECT /*+ BROADCASTJOIN(t) */ * FROM $db1Name.t, 
$db2Name.t " +
+            s"WHERE $db1Name.t.id = $db2Name.t.id"
+          sql(statement).queryExecution.optimizedPlan match {
+            case Join(_, _, _, _, JoinHint(Some(HintInfo(Some(BROADCAST))),
+              Some(HintInfo(Some(BROADCAST))))) =>
+            case _ => fail("broadcast hint not found in both tables")
+          }
+        }
+      }
+    }
+  }
 }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala
index 7fbfa73..28e82aa 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql.execution
 import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
 import org.apache.spark.sql.catalog.Table
 import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, Join, 
JoinHint}
+import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.SharedSparkSession
 import org.apache.spark.sql.types.StructType
 
@@ -170,4 +171,25 @@ class GlobalTempViewSuite extends QueryTest with 
SharedSparkSession {
         isTemporary = true).toString)
     }
   }
+
+  test("broadcast hint on global temp view") {
+    withGlobalTempView("v1") {
+      spark.range(10).createGlobalTempView("v1")
+      withTempView("v2") {
+        spark.range(10).createTempView("v2")
+
+        withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+          Seq(
+            "SELECT /*+ MAPJOIN(v1) */ * FROM global_temp.v1, v2 WHERE v1.id = 
v2.id",
+            "SELECT /*+ MAPJOIN(global_temp.v1) */ * FROM global_temp.v1, v2 
WHERE v1.id = v2.id"
+          ).foreach { statement =>
+            sql(statement).queryExecution.optimizedPlan match {
+              case Join(_, _, _, _, JoinHint(Some(HintInfo(Some(BROADCAST))), 
None)) =>
+              case _ => fail("broadcast hint not found in a left-side table")
+            }
+          }
+        }
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to