This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push:
new 7d4c1b8 [SPARK-25121][SQL] Supports multi-part table names for
broadcast hint resolution
7d4c1b8 is described below
commit 7d4c1b894ef32170b421f51c42cf30198c35c21b
Author: Takeshi Yamamuro <[email protected]>
AuthorDate: Thu Mar 19 20:11:04 2020 -0700
[SPARK-25121][SQL] Supports multi-part table names for broadcast hint
resolution
### What changes were proposed in this pull request?
This pr fixed code to respect a database name for broadcast table hint
resolution.
Currently, spark ignores a database name in multi-part names;
```
scala> sql("CREATE DATABASE testDb")
scala> spark.range(10).write.saveAsTable("testDb.t")
// without this patch
scala> spark.range(10).join(spark.table("testDb.t"),
"id").hint("broadcast", "testDb.t").explain
== Physical Plan ==
*(2) Project [id#24L]
+- *(2) BroadcastHashJoin [id#24L], [id#26L], Inner, BuildLeft
:- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint,
false]))
: +- *(1) Range (0, 10, step=1, splits=4)
+- *(2) Project [id#26L]
+- *(2) Filter isnotnull(id#26L)
+- *(2) FileScan parquet testdb.t[id#26L] Batched: true, Format:
Parquet, Location:
InMemoryFileIndex[file:/Users/maropu/Repositories/spark/spark-2.3.1-bin-hadoop2.7/spark-warehouse...,
PartitionFilters: [], PushedFilters: [IsNotNull(id)], ReadSchema:
struct<id:bigint>
// with this patch
scala> spark.range(10).join(spark.table("testDb.t"),
"id").hint("broadcast", "testDb.t").explain
== Physical Plan ==
*(2) Project [id#3L]
+- *(2) BroadcastHashJoin [id#3L], [id#5L], Inner, BuildRight
:- *(2) Range (0, 10, step=1, splits=4)
+- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint,
true]))
+- *(1) Project [id#5L]
+- *(1) Filter isnotnull(id#5L)
+- *(1) FileScan parquet testdb.t[id#5L] Batched: true, Format:
Parquet, Location:
InMemoryFileIndex[file:/Users/maropu/Repositories/spark/spark-master/spark-warehouse/testdb.db/t],
PartitionFilters: [], PushedFilters: [IsNotNull(id)], ReadSchema:
struct<id:bigint>
```
This PR comes from https://github.com/apache/spark/pull/22198
### Why are the changes needed?
For better usability.
### Does this PR introduce any user-facing change?
No.
### How was this patch tested?
Added unit tests.
Closes #27935 from maropu/SPARK-25121-2.
Authored-by: Takeshi Yamamuro <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
(cherry picked from commit ca499e94091ae62a6ee76ea779d7b2b4cf2dbc5c)
Signed-off-by: Dongjoon Hyun <[email protected]>
---
.../sql/catalyst/analysis/HintErrorLogger.scala | 7 +-
.../spark/sql/catalyst/analysis/ResolveHints.scala | 71 +++++++++++-----
.../spark/sql/catalyst/expressions/package.scala | 2 +-
.../spark/sql/catalyst/plans/logical/hints.scala | 3 +-
.../spark/sql/catalyst/analysis/AnalysisTest.scala | 2 +
.../sql/catalyst/analysis/ResolveHintsSuite.scala | 48 +++++++++++
.../sql/catalyst/analysis/TestRelations.scala | 2 +
.../org/apache/spark/sql/DataFrameJoinSuite.scala | 98 +++++++++++++++++++++-
.../spark/sql/execution/GlobalTempViewSuite.scala | 24 +++++-
9 files changed, 230 insertions(+), 27 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HintErrorLogger.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HintErrorLogger.scala
index c6e0c74..71c6d40 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HintErrorLogger.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HintErrorLogger.scala
@@ -24,15 +24,16 @@ import
org.apache.spark.sql.catalyst.plans.logical.{HintErrorHandler, HintInfo}
* The hint error handler that logs warnings for each hint error.
*/
object HintErrorLogger extends HintErrorHandler with Logging {
+ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._
override def hintNotRecognized(name: String, parameters: Seq[Any]): Unit = {
logWarning(s"Unrecognized hint: ${hintToPrettyString(name, parameters)}")
}
override def hintRelationsNotFound(
- name: String, parameters: Seq[Any], invalidRelations: Set[String]): Unit
= {
- invalidRelations.foreach { n =>
- logWarning(s"Count not find relation '$n' specified in hint " +
+ name: String, parameters: Seq[Any], invalidRelations: Set[Seq[String]]):
Unit = {
+ invalidRelations.foreach { ident =>
+ logWarning(s"Count not find relation '${ident.quoted}' specified in hint
" +
s"'${hintToPrettyString(name, parameters)}'.")
}
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
index 5b77d67..81de086 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala
@@ -64,31 +64,59 @@ object ResolveHints {
_.toUpperCase(Locale.ROOT)).contains(hintName.toUpperCase(Locale.ROOT))))
}
+ // This method checks if given multi-part identifiers are matched with
each other.
+ // The [[ResolveJoinStrategyHints]] rule is applied before the resolution
batch
+ // in the analyzer and we cannot semantically compare them at this stage.
+ // Therefore, we follow a simple rule; they match if an identifier in a
hint
+ // is a tail of an identifier in a relation. This process is independent
of a session
+ // catalog (`currentDb` in [[SessionCatalog]]) and it just compares them
literally.
+ //
+ // For example,
+ // * in a query `SELECT /*+ BROADCAST(t) */ * FROM db1.t JOIN t`,
+ // the broadcast hint will match both tables, `db1.t` and `t`,
+ // even when the current db is `db2`.
+ // * in a query `SELECT /*+ BROADCAST(default.t) */ * FROM default.t JOIN
t`,
+ // the broadcast hint will match the left-side table only, `default.t`.
+ private def matchedIdentifier(identInHint: Seq[String], identInQuery:
Seq[String]): Boolean = {
+ if (identInHint.length <= identInQuery.length) {
+ identInHint.zip(identInQuery.takeRight(identInHint.length))
+ .forall { case (i1, i2) => resolver(i1, i2) }
+ } else {
+ false
+ }
+ }
+
+ private def extractIdentifier(r: SubqueryAlias): Seq[String] = {
+ r.identifier.qualifier :+ r.identifier.name
+ }
+
private def applyJoinStrategyHint(
plan: LogicalPlan,
- relations: mutable.HashSet[String],
+ relationsInHint: Set[Seq[String]],
+ relationsInHintWithMatch: mutable.HashSet[Seq[String]],
hintName: String): LogicalPlan = {
// Whether to continue recursing down the tree
var recurse = true
+ def matchedIdentifierInHint(identInQuery: Seq[String]): Boolean = {
+ relationsInHint.find(matchedIdentifier(_, identInQuery))
+ .map(relationsInHintWithMatch.add).nonEmpty
+ }
+
val newNode = CurrentOrigin.withOrigin(plan.origin) {
plan match {
case ResolvedHint(u @ UnresolvedRelation(ident), hint)
- if relations.exists(resolver(_, ident.last)) =>
- relations.remove(ident.last)
+ if matchedIdentifierInHint(ident) =>
ResolvedHint(u, createHintInfo(hintName).merge(hint,
hintErrorHandler))
case ResolvedHint(r: SubqueryAlias, hint)
- if relations.exists(resolver(_, r.alias)) =>
- relations.remove(r.alias)
+ if matchedIdentifierInHint(extractIdentifier(r)) =>
ResolvedHint(r, createHintInfo(hintName).merge(hint,
hintErrorHandler))
- case u @ UnresolvedRelation(ident) if relations.exists(resolver(_,
ident.last)) =>
- relations.remove(ident.last)
+ case UnresolvedRelation(ident) if matchedIdentifierInHint(ident) =>
ResolvedHint(plan, createHintInfo(hintName))
- case r: SubqueryAlias if relations.exists(resolver(_, r.alias)) =>
- relations.remove(r.alias)
+ case r: SubqueryAlias if
matchedIdentifierInHint(extractIdentifier(r)) =>
ResolvedHint(plan, createHintInfo(hintName))
case _: ResolvedHint | _: View | _: With | _: SubqueryAlias =>
@@ -107,7 +135,9 @@ object ResolveHints {
}
if ((plan fastEquals newNode) && recurse) {
- newNode.mapChildren(child => applyJoinStrategyHint(child, relations,
hintName))
+ newNode.mapChildren { child =>
+ applyJoinStrategyHint(child, relationsInHint,
relationsInHintWithMatch, hintName)
+ }
} else {
newNode
}
@@ -120,17 +150,19 @@ object ResolveHints {
ResolvedHint(h.child, createHintInfo(h.name))
} else {
// Otherwise, find within the subtree query plans to apply the hint.
- val relationNames = h.parameters.map {
- case tableName: String => tableName
- case tableId: UnresolvedAttribute => tableId.name
+ val relationNamesInHint = h.parameters.map {
+ case tableName: String =>
UnresolvedAttribute.parseAttributeName(tableName)
+ case tableId: UnresolvedAttribute => tableId.nameParts
case unsupported => throw new AnalysisException("Join strategy
hint parameter " +
s"should be an identifier or string but was $unsupported
(${unsupported.getClass}")
- }
- val relationNameSet = new mutable.HashSet[String]
- relationNames.foreach(relationNameSet.add)
-
- val applied = applyJoinStrategyHint(h.child, relationNameSet, h.name)
- hintErrorHandler.hintRelationsNotFound(h.name, h.parameters,
relationNameSet.toSet)
+ }.toSet
+ val relationsInHintWithMatch = new mutable.HashSet[Seq[String]]
+ val applied = applyJoinStrategyHint(
+ h.child, relationNamesInHint, relationsInHintWithMatch, h.name)
+
+ // Filters unmatched relation identifiers in the hint
+ val unmatchedIdents = relationNamesInHint -- relationsInHintWithMatch
+ hintErrorHandler.hintRelationsNotFound(h.name, h.parameters,
unmatchedIdents)
applied
}
}
@@ -246,5 +278,4 @@ object ResolveHints {
h.child
}
}
-
}
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index 1b59056..8bf1f19 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -196,7 +196,7 @@ package object expressions {
// For example, consider an example where "cat" is the catalog name,
"db1" is the database
// name, "a" is the table name and "b" is the column name and "c" is the
struct field name.
// If the name parts is cat.db1.a.b.c, then Attribute will match
- // Attribute(b, qualifier("cat", "db1, "a")) and List("c") will be the
second element
+ // Attribute(b, qualifier("cat", "db1", "a")) and List("c") will be the
second element
var matches: (Seq[Attribute], Seq[String]) = nameParts match {
case catalogPart +: dbPart +: tblPart +: name +: nestedFields =>
val key = (catalogPart.toLowerCase(Locale.ROOT),
dbPart.toLowerCase(Locale.ROOT),
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
index f26e566..a325b61 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/hints.scala
@@ -186,7 +186,8 @@ trait HintErrorHandler {
* @param parameters the hint parameters
* @param invalidRelations the set of relation names that cannot be
associated
*/
- def hintRelationsNotFound(name: String, parameters: Seq[Any],
invalidRelations: Set[String]): Unit
+ def hintRelationsNotFound(
+ name: String, parameters: Seq[Any], invalidRelations: Set[Seq[String]]):
Unit
/**
* Callback for a join hint specified on a relation that is not part of a
join.
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
index 3f8d409..4473c20 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala
@@ -45,6 +45,8 @@ trait AnalysisTest extends PlanTest {
catalog.createTempView("TaBlE", TestRelations.testRelation,
overrideIfExists = true)
catalog.createTempView("TaBlE2", TestRelations.testRelation2,
overrideIfExists = true)
catalog.createTempView("TaBlE3", TestRelations.testRelation3,
overrideIfExists = true)
+ catalog.createGlobalTempView("TaBlE4", TestRelations.testRelation4,
overrideIfExists = true)
+ catalog.createGlobalTempView("TaBlE5", TestRelations.testRelation5,
overrideIfExists = true)
new Analyzer(catalog, conf) {
override val extendedResolutionRules = EliminateSubqueryAliases +:
extendedAnalysisRules
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
index 5e66c03..ca7d284 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveHintsSuite.scala
@@ -241,4 +241,52 @@ class ResolveHintsSuite extends AnalysisTest {
Project(testRelation.output, testRelation),
caseSensitive = false)
}
+
+ test("Supports multi-part table names for broadcast hint resolution") {
+ // local temp table (single-part identifier case)
+ checkAnalysis(
+ UnresolvedHint("MAPJOIN", Seq("table", "table2"),
+ table("TaBlE").join(table("TaBlE2"))),
+ Join(
+ ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
+ ResolvedHint(testRelation2, HintInfo(strategy = Some(BROADCAST))),
+ Inner,
+ None,
+ JoinHint.NONE),
+ caseSensitive = false)
+
+ checkAnalysis(
+ UnresolvedHint("MAPJOIN", Seq("TaBlE", "table2"),
+ table("TaBlE").join(table("TaBlE2"))),
+ Join(
+ ResolvedHint(testRelation, HintInfo(strategy = Some(BROADCAST))),
+ testRelation2,
+ Inner,
+ None,
+ JoinHint.NONE),
+ caseSensitive = true)
+
+ // global temp table (multi-part identifier case)
+ checkAnalysis(
+ UnresolvedHint("MAPJOIN", Seq("GlOBal_TeMP.table4", "table5"),
+ table("global_temp", "table4").join(table("global_temp", "table5"))),
+ Join(
+ ResolvedHint(testRelation4, HintInfo(strategy = Some(BROADCAST))),
+ ResolvedHint(testRelation5, HintInfo(strategy = Some(BROADCAST))),
+ Inner,
+ None,
+ JoinHint.NONE),
+ caseSensitive = false)
+
+ checkAnalysis(
+ UnresolvedHint("MAPJOIN", Seq("global_temp.TaBlE4", "table5"),
+ table("global_temp", "TaBlE4").join(table("global_temp", "TaBlE5"))),
+ Join(
+ ResolvedHint(testRelation4, HintInfo(strategy = Some(BROADCAST))),
+ testRelation5,
+ Inner,
+ None,
+ JoinHint.NONE),
+ caseSensitive = true)
+ }
}
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
index e12e272..33b6029 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TestRelations.scala
@@ -44,6 +44,8 @@ object TestRelations {
AttributeReference("g", StringType)(),
AttributeReference("h", MapType(IntegerType, IntegerType))())
+ val testRelation5 = LocalRelation(AttributeReference("i", StringType)())
+
val nestedRelation = LocalRelation(
AttributeReference("top", StructType(
StructField("duplicateField", StringType) ::
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index c7545bc..6b772e5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -17,10 +17,14 @@
package org.apache.spark.sql
+import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike, LeftOuter,
RightOuter}
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, Join, LogicalPlan,
Project}
+import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, Filter,
HintInfo, Join, JoinHint, LogicalPlan, Project}
+import org.apache.spark.sql.connector.catalog.CatalogManager
+import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
@@ -322,4 +326,96 @@ class DataFrameJoinSuite extends QueryTest
}
}
}
+
+ test("Supports multi-part names for broadcast hint resolution") {
+ val (table1Name, table2Name) = ("t1", "t2")
+
+ withTempDatabase { dbName =>
+ withTable(table1Name, table2Name) {
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ spark.range(50).write.saveAsTable(s"$dbName.$table1Name")
+ spark.range(100).write.saveAsTable(s"$dbName.$table2Name")
+
+ def checkIfHintApplied(df: DataFrame): Unit = {
+ val sparkPlan = df.queryExecution.executedPlan
+ val broadcastHashJoins = sparkPlan.collect { case p:
BroadcastHashJoinExec => p }
+ assert(broadcastHashJoins.size == 1)
+ val broadcastExchanges = broadcastHashJoins.head.collect {
+ case p: BroadcastExchangeExec => p
+ }
+ assert(broadcastExchanges.size == 1)
+ val tables = broadcastExchanges.head.collect {
+ case FileSourceScanExec(_, _, _, _, _, _, Some(tableIdent)) =>
tableIdent
+ }
+ assert(tables.size == 1)
+ assert(tables.head === TableIdentifier(table1Name, Some(dbName)))
+ }
+
+ def checkIfHintNotApplied(df: DataFrame): Unit = {
+ val sparkPlan = df.queryExecution.executedPlan
+ val broadcastHashJoins = sparkPlan.collect { case p:
BroadcastHashJoinExec => p }
+ assert(broadcastHashJoins.isEmpty)
+ }
+
+ def sqlTemplate(tableName: String, hintTableName: String): DataFrame
= {
+ sql(s"SELECT /*+ BROADCASTJOIN($hintTableName) */ * " +
+ s"FROM $tableName, $dbName.$table2Name " +
+ s"WHERE $tableName.id = $table2Name.id")
+ }
+
+ def dfTemplate(tableName: String, hintTableName: String): DataFrame
= {
+ spark.table(tableName).join(spark.table(s"$dbName.$table2Name"),
"id")
+ .hint("broadcast", hintTableName)
+ }
+
+ sql(s"USE $dbName")
+
+ checkIfHintApplied(sqlTemplate(table1Name, table1Name))
+ checkIfHintApplied(sqlTemplate(s"$dbName.$table1Name",
s"$dbName.$table1Name"))
+ checkIfHintApplied(sqlTemplate(s"$dbName.$table1Name", table1Name))
+ checkIfHintNotApplied(sqlTemplate(table1Name,
s"$dbName.$table1Name"))
+
+ checkIfHintApplied(dfTemplate(table1Name, table1Name))
+ checkIfHintApplied(dfTemplate(s"$dbName.$table1Name",
s"$dbName.$table1Name"))
+ checkIfHintApplied(dfTemplate(s"$dbName.$table1Name", table1Name))
+ checkIfHintApplied(dfTemplate(table1Name, s"$dbName.$table1Name"))
+ checkIfHintApplied(dfTemplate(table1Name,
+ s"${CatalogManager.SESSION_CATALOG_NAME}.$dbName.$table1Name"))
+
+ withView("tv") {
+ sql(s"CREATE VIEW tv AS SELECT * FROM $dbName.$table1Name")
+ checkIfHintApplied(sqlTemplate("tv", "tv"))
+ checkIfHintNotApplied(sqlTemplate("tv", s"$dbName.tv"))
+
+ checkIfHintApplied(dfTemplate("tv", "tv"))
+ checkIfHintApplied(dfTemplate("tv", s"$dbName.tv"))
+ }
+ }
+ }
+ }
+ }
+
+ test("The same table name exists in two databases for broadcast hint
resolution") {
+ val (db1Name, db2Name) = ("db1", "db2")
+
+ withDatabase(db1Name, db2Name) {
+ withTable("t") {
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ sql(s"CREATE DATABASE $db1Name")
+ sql(s"CREATE DATABASE $db2Name")
+ spark.range(1).write.saveAsTable(s"$db1Name.t")
+ spark.range(1).write.saveAsTable(s"$db2Name.t")
+
+ // Checks if a broadcast hint applied in both sides
+ val statement = s"SELECT /*+ BROADCASTJOIN(t) */ * FROM $db1Name.t,
$db2Name.t " +
+ s"WHERE $db1Name.t.id = $db2Name.t.id"
+ sql(statement).queryExecution.optimizedPlan match {
+ case Join(_, _, _, _, JoinHint(Some(HintInfo(Some(BROADCAST))),
+ Some(HintInfo(Some(BROADCAST))))) =>
+ case _ => fail("broadcast hint not found in both tables")
+ }
+ }
+ }
+ }
+ }
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala
index 7fbfa73..28e82aa 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/GlobalTempViewSuite.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalog.Table
import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, HintInfo, Join,
JoinHint}
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType
@@ -170,4 +171,25 @@ class GlobalTempViewSuite extends QueryTest with
SharedSparkSession {
isTemporary = true).toString)
}
}
+
+ test("broadcast hint on global temp view") {
+ withGlobalTempView("v1") {
+ spark.range(10).createGlobalTempView("v1")
+ withTempView("v2") {
+ spark.range(10).createTempView("v2")
+
+ withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") {
+ Seq(
+ "SELECT /*+ MAPJOIN(v1) */ * FROM global_temp.v1, v2 WHERE v1.id =
v2.id",
+ "SELECT /*+ MAPJOIN(global_temp.v1) */ * FROM global_temp.v1, v2
WHERE v1.id = v2.id"
+ ).foreach { statement =>
+ sql(statement).queryExecution.optimizedPlan match {
+ case Join(_, _, _, _, JoinHint(Some(HintInfo(Some(BROADCAST))),
None)) =>
+ case _ => fail("broadcast hint not found in a left-side table")
+ }
+ }
+ }
+ }
+ }
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]