This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d31edcb5aac9 [SPARK-52346][SQL][FOLLOW-UP] Fix MaterializeTablesSuite 
in Maven
d31edcb5aac9 is described below

commit d31edcb5aac958f29852e5fc5dee0d24c5dad97b
Author: Sandy Ryza <[email protected]>
AuthorDate: Sun Jun 8 20:13:21 2025 -0700

    [SPARK-52346][SQL][FOLLOW-UP] Fix MaterializeTablesSuite in Maven
    
    ### What changes were proposed in this pull request?
    
    Fixes the issue identified here: 
https://github.com/apache/spark/pull/51050#discussion_r2133923514
    
    ### Why are the changes needed?
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    ```
    build/mvn clean install -DskipTests -pl sql/pipelines -am
    build/mvn test -pl sql/pipelines
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Closes #51112 from sryza/fix-mvn.
    
    Authored-by: Sandy Ryza <[email protected]>
    Signed-off-by: Gengliang Wang <[email protected]>
---
 .../pipelines/graph/MaterializeTablesSuite.scala   | 52 +++++++++++++++--
 .../graph/TriggeredGraphExecutionSuite.scala       | 43 ++++++++++++--
 .../spark/sql/pipelines/utils/PipelineTest.scala   | 65 +++-------------------
 3 files changed, 95 insertions(+), 65 deletions(-)

diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
index 90d752aff05c..2587f503222e 100644
--- 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala
@@ -33,10 +33,10 @@ import org.apache.spark.util.Utils.exceptionString
  * tables are written with the appropriate schemas.
  */
 class MaterializeTablesSuite extends BaseCoreExecutionTest {
-
-  import originalSpark.implicits._
-
   test("basic") {
+    val session = spark
+    import session.implicits._
+
     materializeGraph(
       new TestGraphRegistrationContext(spark) {
         registerFlow(
@@ -128,6 +128,9 @@ class MaterializeTablesSuite extends BaseCoreExecutionTest {
   }
 
   test("multiple") {
+    val session = spark
+    import session.implicits._
+
     materializeGraph(
       new TestGraphRegistrationContext(spark) {
         registerFlow(
@@ -162,6 +165,9 @@ class MaterializeTablesSuite extends BaseCoreExecutionTest {
   }
 
   test("temporary views don't get materialized") {
+    val session = spark
+    import session.implicits._
+
     materializeGraph(
       new TestGraphRegistrationContext(spark) {
         registerFlow(
@@ -235,6 +241,9 @@ class MaterializeTablesSuite extends BaseCoreExecutionTest {
   }
 
   test("schema matches existing table schema") {
+    val session = spark
+    import session.implicits._
+
     sql(s"CREATE TABLE ${TestGraphRegistrationContext.DEFAULT_DATABASE}.t2(x 
INT)")
     val catalog = 
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
     val identifier = 
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t2")
@@ -260,6 +269,9 @@ class MaterializeTablesSuite extends BaseCoreExecutionTest {
   }
 
   test("invalid schema merge") {
+    val session = spark
+    import session.implicits._
+
     val streamInts = MemoryStream[Int]
     streamInts.addData(1, 2)
 
@@ -286,6 +298,9 @@ class MaterializeTablesSuite extends BaseCoreExecutionTest {
   }
 
   test("table materialized with specified schema, even if different from 
inferred") {
+    val session = spark
+    import session.implicits._
+
     sql(s"CREATE TABLE ${TestGraphRegistrationContext.DEFAULT_DATABASE}.t4(x 
INT)")
     val catalog = 
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
     val identifier = 
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t4")
@@ -321,6 +336,9 @@ class MaterializeTablesSuite extends BaseCoreExecutionTest {
   }
 
   test("specified schema incompatible with existing table") {
+    val session = spark
+    import session.implicits._
+
     sql(s"CREATE TABLE ${TestGraphRegistrationContext.DEFAULT_DATABASE}.t6(x 
BOOLEAN)")
     val catalog = 
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
     val identifier = 
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t6")
@@ -363,6 +381,9 @@ class MaterializeTablesSuite extends BaseCoreExecutionTest {
   }
 
   test("partition columns with user schema") {
+    val session = spark
+    import session.implicits._
+
     materializeGraph(
       new TestGraphRegistrationContext(spark) {
         registerTable(
@@ -389,6 +410,9 @@ class MaterializeTablesSuite extends BaseCoreExecutionTest {
   }
 
   test("specifying partition column with existing partitioned table") {
+    val session = spark
+    import session.implicits._
+
     sql(
       s"CREATE TABLE ${TestGraphRegistrationContext.DEFAULT_DATABASE}.t7(x 
BOOLEAN, y INT) " +
       s"PARTITIONED BY (x)"
@@ -451,11 +475,13 @@ class MaterializeTablesSuite extends 
BaseCoreExecutionTest {
   }
 
   test("specifying partition column different from existing partitioned 
table") {
+    val session = spark
+    import session.implicits._
+
     sql(
       s"CREATE TABLE ${TestGraphRegistrationContext.DEFAULT_DATABASE}.t8(x 
BOOLEAN, y INT) " +
       s"PARTITIONED BY (x)"
     )
-    Seq((true, 1), (false, 1)).toDF("x", 
"y").write.mode("append").saveAsTable("t8")
 
     val catalog = 
spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog]
     val identifier = 
Identifier.of(Array(TestGraphRegistrationContext.DEFAULT_DATABASE), "t8")
@@ -513,6 +539,9 @@ class MaterializeTablesSuite extends BaseCoreExecutionTest {
   }
 
   test("Invalid table properties error during table materialization") {
+    val session = spark
+    import session.implicits._
+
     // Invalid pipelines property
     val graph1 =
       new TestGraphRegistrationContext(spark) {
@@ -550,6 +579,9 @@ class MaterializeTablesSuite extends BaseCoreExecutionTest {
     test(
       s"Complete tables should not evolve schema - isFullRefresh = 
$isFullRefresh"
     ) {
+      val session = spark
+      import session.implicits._
+
       val rawGraph =
         new TestGraphRegistrationContext(spark) {
           registerView("a", query = dfFlowFunc(Seq((1, 2), (2, 3)).toDF("x", 
"y")))
@@ -602,6 +634,9 @@ class MaterializeTablesSuite extends BaseCoreExecutionTest {
     test(
       s"Streaming tables should evolve schema only if not full refresh = 
$isFullRefresh"
     ) {
+      val session = spark
+      import session.implicits._
+
       val streamInts = MemoryStream[Int]
       streamInts.addData(1 until 5: _*)
 
@@ -665,6 +700,9 @@ class MaterializeTablesSuite extends BaseCoreExecutionTest {
   test(
     "materialize only selected tables"
   ) {
+    val session = spark
+    import session.implicits._
+
     val graph = new TestGraphRegistrationContext(spark) {
       registerTable("a", query = Option(dfFlowFunc(Seq((1, 2), (2, 
3)).toDF("x", "y"))))
       registerTable("b", query = Option(sqlFlowFunc(spark, "SELECT x FROM a")))
@@ -707,6 +745,9 @@ class MaterializeTablesSuite extends BaseCoreExecutionTest {
   }
 
   test("tables with arrays and maps") {
+    val session = spark
+    import session.implicits._
+
     val rawGraph =
       new TestGraphRegistrationContext(spark) {
         registerTable("a", query = Option(sqlFlowFunc(spark, "select map(1, 
struct('a', 'b')) m")))
@@ -800,6 +841,9 @@ class MaterializeTablesSuite extends BaseCoreExecutionTest {
   }
 
   test("materializing no tables doesn't throw") {
+    val session = spark
+    import session.implicits._
+
     val graph1 =
       new DataflowGraph(flows = Seq.empty, tables = Seq.empty, views = 
Seq.empty)
     val graph2 = new TestGraphRegistrationContext(spark) {
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala
index 54759c41ace5..7db0c9875dd1 100644
--- 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala
@@ -17,8 +17,6 @@
 
 package org.apache.spark.sql.pipelines.graph
 
-import org.scalatest.concurrent.Eventually.eventually
-import org.scalatest.concurrent.Futures.timeout
 import org.scalatest.time.{Seconds, Span}
 
 import org.apache.spark.sql.{functions, Row}
@@ -34,10 +32,11 @@ import org.apache.spark.sql.types.{IntegerType, StringType, 
StructType}
 
 class TriggeredGraphExecutionSuite extends ExecutionTest {
 
-  import originalSpark.implicits._
-
   /** Returns a Dataset of Longs from the table with the given identifier. */
   private def getTable(identifier: TableIdentifier): Dataset[Long] = {
+    val session = spark
+    import session.implicits._
+
     spark.read.table(identifier.toString).as[Long]
   }
 
@@ -52,6 +51,9 @@ class TriggeredGraphExecutionSuite extends ExecutionTest {
   }
 
   test("basic graph resolution and execution") {
+    val session = spark
+    import session.implicits._
+
     val pipelineDef = new TestGraphRegistrationContext(spark) {
       registerTable("a", query = Option(dfFlowFunc(Seq(1, 2).toDF("x"))))
       registerTable("b", query = Option(readFlowFunc("a")))
@@ -95,6 +97,9 @@ class TriggeredGraphExecutionSuite extends ExecutionTest {
   }
 
   test("graph materialization with streams") {
+    val session = spark
+    import session.implicits._
+
     val pipelineDef = new TestGraphRegistrationContext(spark) {
       registerTable("a", query = Option(dfFlowFunc(Seq(1, 2).toDF("x"))))
       registerTable("b", query = Option(readFlowFunc("a")))
@@ -172,6 +177,9 @@ class TriggeredGraphExecutionSuite extends ExecutionTest {
   }
 
   test("three hop pipeline") {
+    val session = spark
+    import session.implicits._
+
     // Construct pipeline
     val pipelineDef = new TestGraphRegistrationContext(spark) {
       private val ints = MemoryStream[Int]
@@ -245,6 +253,9 @@ class TriggeredGraphExecutionSuite extends ExecutionTest {
   }
 
   test("all events are emitted even if there is no data") {
+    val session = spark
+    import session.implicits._
+
     // Construct pipeline
     val pipelineDef = new TestGraphRegistrationContext(spark) {
       private val ints = MemoryStream[Int]
@@ -285,6 +296,9 @@ class TriggeredGraphExecutionSuite extends ExecutionTest {
   }
 
   test("stream failure causes its downstream to be skipped") {
+    val session = spark
+    import session.implicits._
+
     spark.sql("CREATE TABLE src USING PARQUET AS SELECT * FROM RANGE(10)")
 
     // A UDF which fails immediately
@@ -463,6 +477,9 @@ class TriggeredGraphExecutionSuite extends ExecutionTest {
   }
 
   test("user-specified schema is applied to table") {
+    val session = spark
+    import session.implicits._
+
     val specifiedSchema = new StructType().add("x", "int", nullable = true)
     val pipelineDef = new TestGraphRegistrationContext(spark) {
       registerTable(
@@ -521,6 +538,9 @@ class TriggeredGraphExecutionSuite extends ExecutionTest {
   }
 
   test("stopping a pipeline mid-execution") {
+    val session = spark
+    import session.implicits._
+
     // A UDF which adds a delay
     val delayUDF = functions.udf((_: String) => {
       Thread.sleep(5 * 1000)
@@ -589,6 +609,9 @@ class TriggeredGraphExecutionSuite extends ExecutionTest {
   }
 
   test("two hop pipeline with partitioned graph") {
+    val session = spark
+    import session.implicits._
+
     val pipelineDef = new TestGraphRegistrationContext(spark) {
       registerTable("integer_input", query = Option(dfFlowFunc(Seq(1, 2, 3, 
4).toDF("value"))))
       registerTable(
@@ -643,6 +666,9 @@ class TriggeredGraphExecutionSuite extends ExecutionTest {
   }
 
   test("multiple hop pipeline with merge from multiple sources") {
+    val session = spark
+    import session.implicits._
+
     val pipelineDef = new TestGraphRegistrationContext(spark) {
       registerTable("integer_input", query = Option(dfFlowFunc(Seq(1, 2, 3, 
4).toDF("nums"))))
       registerTable(
@@ -719,6 +745,9 @@ class TriggeredGraphExecutionSuite extends ExecutionTest {
   }
 
   test("multiple hop pipeline with split and merge from single source") {
+    val session = spark
+    import session.implicits._
+
     val pipelineDef = new TestGraphRegistrationContext(spark) {
       registerTable(
         "input_table",
@@ -797,6 +826,9 @@ class TriggeredGraphExecutionSuite extends ExecutionTest {
   }
 
   test("test default flow retry is 2 and event WARN/ERROR levels accordingly") 
{
+    val session = spark
+    import session.implicits._
+
     val fail = functions.udf((x: Int) => {
       throw new RuntimeException("Intentionally failing UDF.")
       x
@@ -843,6 +875,9 @@ class TriggeredGraphExecutionSuite extends ExecutionTest {
   }
 
   test("partial graph updates") {
+    val session = spark
+    import session.implicits._
+
     val ints: MemoryStream[Int] = MemoryStream[Int]
     val pipelineDef = new TestGraphRegistrationContext(spark) {
       ints.addData(1, 2, 3)
diff --git 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
index a54c09d9e251..48f2f26fb450 100644
--- 
a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
+++ 
b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/utils/PipelineTest.scala
@@ -31,15 +31,16 @@ import org.scalatest.matchers.should.Matchers
 import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.{Column, QueryTest, Row, TypedColumn}
-import org.apache.spark.sql.SparkSession.{clearActiveSession, 
clearDefaultSession, setActiveSession}
 import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession, 
SQLContext}
+import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession}
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.pipelines.graph.{DataflowGraph, 
PipelineUpdateContextImpl, SqlGraphRegistrationContext}
 import org.apache.spark.sql.pipelines.utils.PipelineTest.{cleanupMetastore, 
createTempDir}
+import org.apache.spark.sql.test.SharedSparkSession
 
 abstract class PipelineTest
     extends SparkFunSuite
+    with SharedSparkSession
     with BeforeAndAfterAll
     with BeforeAndAfterEach
     with Matchers
@@ -49,11 +50,6 @@ abstract class PipelineTest
 
   final protected val storageRoot = createTempDir()
 
-  var spark: SparkSession = createAndInitializeSpark()
-  val originalSpark: SparkSession = spark.cloneSession()
-
-
-  implicit def sqlContext: SQLContext = spark.sqlContext
   def sql(text: String): DataFrame = spark.sql(text)
 
   protected def startPipelineAndWaitForCompletion(unresolvedDataflowGraph: 
DataflowGraph): Unit = {
@@ -64,11 +60,10 @@ abstract class PipelineTest
   }
 
   /**
-   * Spark confs for [[originalSpark]]. Spark confs set here will be the 
default spark confs for
-   * all spark sessions created in tests.
+   * Spark confs set here will be the default spark confs for all spark 
sessions created in tests.
    */
-  protected def sparkConf: SparkConf = {
-    new SparkConf()
+  override def sparkConf: SparkConf = {
+    super.sparkConf
       .set("spark.sql.shuffle.partitions", "2")
       .set("spark.sql.session.timeZone", "UTC")
   }
@@ -131,38 +126,14 @@ abstract class PipelineTest
     )
   }
 
-  /**
-   * This exists temporarily for compatibility with tests that become invalid 
when multiple
-   * executors are available.
-   */
-  protected def master = "local[*]"
-
-  /** Creates and returns a initialized spark session. */
-  def createAndInitializeSpark(): SparkSession = {
-    val newSparkSession = SparkSession
-      .builder()
-      .config(sparkConf)
-      .master(master)
-      .getOrCreate()
-    newSparkSession
-  }
-
-  /** Set up the spark session before each test. */
-  protected def initializeSparkBeforeEachTest(): Unit = {
-    clearActiveSession()
-    spark = createAndInitializeSpark()
-    setActiveSession(spark)
-  }
-
   override def beforeEach(): Unit = {
     super.beforeEach()
-    initializeSparkBeforeEachTest()
     cleanupMetastore(spark)
     (catalogInPipelineSpec, databaseInPipelineSpec) match {
       case (Some(catalog), Some(schema)) =>
-        sql(s"CREATE DATABASE IF NOT EXISTS `$catalog`.`$schema`")
+        spark.sql(s"CREATE DATABASE IF NOT EXISTS `$catalog`.`$schema`")
       case _ =>
-        databaseInPipelineSpec.foreach(s => sql(s"CREATE DATABASE IF NOT 
EXISTS `$s`"))
+        databaseInPipelineSpec.foreach(s => spark.sql(s"CREATE DATABASE IF NOT 
EXISTS `$s`"))
     }
   }
 
@@ -171,26 +142,6 @@ abstract class PipelineTest
     super.afterEach()
   }
 
-  override def afterAll(): Unit = {
-    try {
-      super.afterAll()
-    } finally {
-      try {
-        if (spark != null) {
-          try {
-            spark.sessionState.catalog.reset()
-          } finally {
-            spark.stop()
-            spark = null
-          }
-        }
-      } finally {
-        clearActiveSession()
-        clearDefaultSession()
-      }
-    }
-  }
-
   override protected def gridTest[A](testNamePrefix: String, testTags: 
Tag*)(params: Seq[A])(
       testFun: A => Unit): Unit = {
     namedGridTest(testNamePrefix, testTags: _*)(params.map(a => a.toString -> 
a).toMap)(testFun)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to