http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala index cef541f..373d3a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -21,9 +21,9 @@ import java.io.File import scala.collection.JavaConverters._ import scala.util.Try -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkConf import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.util.{Benchmark, Utils} /** @@ -34,12 +34,16 @@ import org.apache.spark.util.{Benchmark, Utils} object ParquetReadBenchmark { val conf = new SparkConf() conf.set("spark.sql.parquet.compression.codec", "snappy") - val sc = new SparkContext("local[1]", "test-sql-context", conf) - val sqlContext = new SQLContext(sc) + + val spark = SparkSession.builder + .master("local[1]") + .appName("test-sql-context") + .config(conf) + .getOrCreate() // Set default configs. Individual cases will change them if necessary. - sqlContext.conf.setConfString(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") - sqlContext.conf.setConfString(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") def withTempPath(f: File => Unit): Unit = { val path = Utils.createTempDir() @@ -48,17 +52,17 @@ object ParquetReadBenchmark { } def withTempTable(tableNames: String*)(f: => Unit): Unit = { - try f finally tableNames.foreach(sqlContext.dropTempTable) + try f finally tableNames.foreach(spark.catalog.dropTempTable) } def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) - (keys, values).zipped.foreach(sqlContext.conf.setConfString) + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => sqlContext.conf.setConfString(key, value) - case (key, None) => sqlContext.conf.unsetConf(key) + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) } } } @@ -71,18 +75,18 @@ object ParquetReadBenchmark { withTempPath { dir => withTempTable("t1", "tempTable") { - sqlContext.range(values).registerTempTable("t1") - sqlContext.sql("select cast(id as INT) as id from t1") + spark.range(values).registerTempTable("t1") + spark.sql("select cast(id as INT) as id from t1") .write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + spark.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") sqlBenchmark.addCase("SQL Parquet Vectorized") { iter => - sqlContext.sql("select sum(id) from tempTable").collect() + spark.sql("select sum(id) from tempTable").collect() } sqlBenchmark.addCase("SQL Parquet MR") { iter => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - sqlContext.sql("select sum(id) from tempTable").collect() + spark.sql("select sum(id) from tempTable").collect() } } @@ -155,20 +159,20 @@ object ParquetReadBenchmark { def intStringScanBenchmark(values: Int): Unit = { withTempPath { dir => withTempTable("t1", "tempTable") { - sqlContext.range(values).registerTempTable("t1") - sqlContext.sql("select cast(id as INT) as c1, cast(id as STRING) as c2 from t1") + spark.range(values).registerTempTable("t1") + spark.sql("select cast(id as INT) as c1, cast(id as STRING) as c2 from t1") .write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + spark.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") val benchmark = new Benchmark("Int and String Scan", values) benchmark.addCase("SQL Parquet Vectorized") { iter => - sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect + spark.sql("select sum(c1), sum(length(c2)) from tempTable").collect } benchmark.addCase("SQL Parquet MR") { iter => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect + spark.sql("select sum(c1), sum(length(c2)) from tempTable").collect } } @@ -189,20 +193,20 @@ object ParquetReadBenchmark { def stringDictionaryScanBenchmark(values: Int): Unit = { withTempPath { dir => withTempTable("t1", "tempTable") { - sqlContext.range(values).registerTempTable("t1") - sqlContext.sql("select cast((id % 200) + 10000 as STRING) as c1 from t1") + spark.range(values).registerTempTable("t1") + spark.sql("select cast((id % 200) + 10000 as STRING) as c1 from t1") .write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + spark.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") val benchmark = new Benchmark("String Dictionary", values) benchmark.addCase("SQL Parquet Vectorized") { iter => - sqlContext.sql("select sum(length(c1)) from tempTable").collect + spark.sql("select sum(length(c1)) from tempTable").collect } benchmark.addCase("SQL Parquet MR") { iter => withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { - sqlContext.sql("select sum(length(c1)) from tempTable").collect + spark.sql("select sum(length(c1)) from tempTable").collect } } @@ -221,23 +225,23 @@ object ParquetReadBenchmark { def partitionTableScanBenchmark(values: Int): Unit = { withTempPath { dir => withTempTable("t1", "tempTable") { - sqlContext.range(values).registerTempTable("t1") - sqlContext.sql("select id % 2 as p, cast(id as INT) as id from t1") + spark.range(values).registerTempTable("t1") + spark.sql("select id % 2 as p, cast(id as INT) as id from t1") .write.partitionBy("p").parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + spark.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") val benchmark = new Benchmark("Partitioned Table", values) benchmark.addCase("Read data column") { iter => - sqlContext.sql("select sum(id) from tempTable").collect + spark.sql("select sum(id) from tempTable").collect } benchmark.addCase("Read partition column") { iter => - sqlContext.sql("select sum(p) from tempTable").collect + spark.sql("select sum(p) from tempTable").collect } benchmark.addCase("Read both columns") { iter => - sqlContext.sql("select sum(p), sum(id) from tempTable").collect + spark.sql("select sum(p), sum(id) from tempTable").collect } /* @@ -256,16 +260,16 @@ object ParquetReadBenchmark { def stringWithNullsScanBenchmark(values: Int, fractionOfNulls: Double): Unit = { withTempPath { dir => withTempTable("t1", "tempTable") { - sqlContext.range(values).registerTempTable("t1") - sqlContext.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + + spark.range(values).registerTempTable("t1") + spark.sql(s"select IF(rand(1) < $fractionOfNulls, NULL, cast(id as STRING)) as c1, " + s"IF(rand(2) < $fractionOfNulls, NULL, cast(id as STRING)) as c2 from t1") .write.parquet(dir.getCanonicalPath) - sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + spark.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") val benchmark = new Benchmark("String with Nulls Scan", values) benchmark.addCase("SQL Parquet Vectorized") { iter => - sqlContext.sql("select sum(length(c2)) from tempTable where c1 is " + + spark.sql("select sum(length(c2)) from tempTable where c1 is " + "not NULL and c2 is not NULL").collect() }
http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala index 90e3d50..c43b142 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala @@ -453,11 +453,11 @@ class ParquetSchemaSuite extends ParquetSchemaTest { test("schema merging failure error message") { withTempPath { dir => val path = dir.getCanonicalPath - sqlContext.range(3).write.parquet(s"$path/p=1") - sqlContext.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") + spark.range(3).write.parquet(s"$path/p=1") + spark.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") val message = intercept[SparkException] { - sqlContext.read.option("mergeSchema", "true").parquet(path).schema + spark.read.option("mergeSchema", "true").parquet(path).schema }.getMessage assert(message.contains("Failed merging schema of file")) @@ -466,13 +466,13 @@ class ParquetSchemaSuite extends ParquetSchemaTest { // test for second merging (after read Parquet schema in parallel done) withTempPath { dir => val path = dir.getCanonicalPath - sqlContext.range(3).write.parquet(s"$path/p=1") - sqlContext.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") + spark.range(3).write.parquet(s"$path/p=1") + spark.range(3).selectExpr("CAST(id AS INT) AS id").write.parquet(s"$path/p=2") - sqlContext.sparkContext.conf.set("spark.default.parallelism", "20") + spark.sparkContext.conf.set("spark.default.parallelism", "20") val message = intercept[SparkException] { - sqlContext.read.option("mergeSchema", "true").parquet(path).schema + spark.read.option("mergeSchema", "true").parquet(path).schema }.getMessage assert(message.contains("Failed merging schema:")) http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index e8c524e..b5fc516 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -52,7 +52,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { (true :: false :: Nil).foreach { vectorized => if (!vectorized || testVectorized) { withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> vectorized.toString) { - f(sqlContext.read.parquet(path.toString)) + f(spark.read.parquet(path.toString)) } } } @@ -66,7 +66,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath) + spark.createDataFrame(data).write.parquet(file.getCanonicalPath) f(file.getCanonicalPath) } } @@ -90,14 +90,14 @@ private[sql] trait ParquetTest extends SQLTestUtils { (data: Seq[T], tableName: String, testVectorized: Boolean = true) (f: => Unit): Unit = { withParquetDataFrame(data, testVectorized) { df => - sqlContext.registerDataFrameAsTable(df, tableName) + spark.wrapped.registerDataFrameAsTable(df, tableName) withTempTable(tableName)(f) } } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( data: Seq[T], path: File): Unit = { - sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) + spark.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath) } protected def makeParquetFile[T <: Product: ClassTag: TypeTag]( @@ -173,6 +173,6 @@ private[sql] trait ParquetTest extends SQLTestUtils { protected def readResourceParquetFile(name: String): DataFrame = { val url = Thread.currentThread().getContextClassLoader.getResource(name) - sqlContext.read.parquet(url.toString) + spark.read.parquet(url.toString) } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala index 88a3d87..ff57069 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala @@ -32,7 +32,7 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar |${readParquetSchema(parquetFilePath.toString)} """.stripMargin) - checkAnswer(sqlContext.read.parquet(parquetFilePath.toString), (0 until 10).map { i => + checkAnswer(spark.read.parquet(parquetFilePath.toString), (0 until 10).map { i => val suits = Array("SPADES", "HEARTS", "DIAMONDS", "CLUBS") val nonNullablePrimitiveValues = Seq( @@ -139,7 +139,7 @@ class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with Shar logParquetSchema(path) checkAnswer( - sqlContext.read.parquet(path), + spark.read.parquet(path), Seq( Row(Seq(Seq(0, 1), Seq(2, 3))), Row(Seq(Seq(4, 5), Seq(6, 7))))) http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala index fd56265..08b7eb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/TPCDSBenchmark.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.execution.datasources.parquet -import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkConf import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.util.Benchmark @@ -36,9 +36,10 @@ object TPCDSBenchmark { conf.set("spark.driver.memory", "3g") conf.set("spark.executor.memory", "3g") conf.set("spark.sql.autoBroadcastJoinThreshold", (20 * 1024 * 1024).toString) + conf.setMaster("local[1]") + conf.setAppName("test-sql-context") - val sc = new SparkContext("local[1]", "test-sql-context", conf) - val sqlContext = new SQLContext(sc) + val spark = SparkSession.builder.config(conf).getOrCreate() // These queries a subset of the TPCDS benchmark queries and are taken from // https://github.com/databricks/spark-sql-perf/blob/master/src/main/scala/com/databricks/spark/ @@ -1186,8 +1187,8 @@ object TPCDSBenchmark { def setupTables(dataLocation: String): Map[String, Long] = { tables.map { tableName => - sqlContext.read.parquet(s"$dataLocation/$tableName").registerTempTable(tableName) - tableName -> sqlContext.table(tableName).count() + spark.read.parquet(s"$dataLocation/$tableName").registerTempTable(tableName) + tableName -> spark.table(tableName).count() }.toMap } @@ -1195,18 +1196,18 @@ object TPCDSBenchmark { require(dataLocation.nonEmpty, "please modify the value of dataLocation to point to your local TPCDS data") val tableSizes = setupTables(dataLocation) - sqlContext.conf.setConfString(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") - sqlContext.conf.setConfString(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + spark.conf.set(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key, "true") + spark.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") tpcds.filter(q => q._1 != "").foreach { case (name: String, query: String) => - val numRows = sqlContext.sql(query).queryExecution.logical.map { + val numRows = spark.sql(query).queryExecution.logical.map { case ur@UnresolvedRelation(t: TableIdentifier, _) => tableSizes.getOrElse(t.table, throw new RuntimeException(s"${t.table} not found.")) case _ => 0L }.sum val benchmark = new Benchmark("TPCDS Snappy (scale = 5)", numRows, 5) benchmark.addCase(name) { i => - sqlContext.sql(query).collect() + spark.sql(query).collect() } benchmark.run() } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 923c0b3..f61fce5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -33,20 +33,20 @@ import org.apache.spark.util.Utils class TextSuite extends QueryTest with SharedSQLContext { test("reading text file") { - verifyFrame(sqlContext.read.format("text").load(testFile)) + verifyFrame(spark.read.format("text").load(testFile)) } test("SQLContext.read.text() API") { - verifyFrame(sqlContext.read.text(testFile).toDF()) + verifyFrame(spark.read.text(testFile).toDF()) } test("SPARK-12562 verify write.text() can handle column name beyond `value`") { - val df = sqlContext.read.text(testFile).withColumnRenamed("value", "adwrasdf") + val df = spark.read.text(testFile).withColumnRenamed("value", "adwrasdf") val tempFile = Utils.createTempDir() tempFile.delete() df.write.text(tempFile.getCanonicalPath) - verifyFrame(sqlContext.read.text(tempFile.getCanonicalPath).toDF()) + verifyFrame(spark.read.text(tempFile.getCanonicalPath).toDF()) Utils.deleteRecursively(tempFile) } @@ -55,18 +55,18 @@ class TextSuite extends QueryTest with SharedSQLContext { val tempFile = Utils.createTempDir() tempFile.delete() - val df = sqlContext.range(2) + val df = spark.range(2) intercept[AnalysisException] { df.write.text(tempFile.getCanonicalPath) } intercept[AnalysisException] { - sqlContext.range(2).select(df("id"), df("id") + 1).write.text(tempFile.getCanonicalPath) + spark.range(2).select(df("id"), df("id") + 1).write.text(tempFile.getCanonicalPath) } } test("SPARK-13503 Support to specify the option for compression codec for TEXT") { - val testDf = sqlContext.read.text(testFile) + val testDf = spark.read.text(testFile) val extensionNameMap = Map("bzip2" -> ".bz2", "deflate" -> ".deflate", "gzip" -> ".gz") extensionNameMap.foreach { case (codecName, extension) => @@ -75,7 +75,7 @@ class TextSuite extends QueryTest with SharedSQLContext { testDf.write.option("compression", codecName).mode(SaveMode.Overwrite).text(tempDirPath) val compressedFiles = new File(tempDirPath).listFiles() assert(compressedFiles.exists(_.getName.endsWith(s".txt$extension"))) - verifyFrame(sqlContext.read.text(tempDirPath).toDF()) + verifyFrame(spark.read.text(tempDirPath).toDF()) } val errMsg = intercept[IllegalArgumentException] { @@ -95,14 +95,14 @@ class TextSuite extends QueryTest with SharedSQLContext { "mapreduce.map.output.compress.codec" -> classOf[GzipCodec].getName ) withTempDir { dir => - val testDf = sqlContext.read.text(testFile) + val testDf = spark.read.text(testFile) val tempDir = Utils.createTempDir() val tempDirPath = tempDir.getAbsolutePath testDf.write.option("compression", "none") .options(extraOptions).mode(SaveMode.Overwrite).text(tempDirPath) val compressedFiles = new File(tempDirPath).listFiles() assert(compressedFiles.exists(!_.getName.endsWith(".txt.gz"))) - verifyFrame(sqlContext.read.options(extraOptions).text(tempDirPath).toDF()) + verifyFrame(spark.read.options(extraOptions).text(tempDirPath).toDF()) } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 8aa0114..4fc52c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -33,7 +33,7 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { } test("debugCodegen") { - val res = codegenString(sqlContext.range(10).groupBy("id").count().queryExecution.executedPlan) + val res = codegenString(spark.range(10).groupBy("id").count().queryExecution.executedPlan) assert(res.contains("Subtree 1 / 2")) assert(res.contains("Subtree 2 / 2")) assert(res.contains("Object[]")) http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index b9df43d..730ec43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.scalatest.BeforeAndAfterAll import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} -import org.apache.spark.sql.{QueryTest, SQLContext} +import org.apache.spark.sql.{QueryTest, SparkSession} import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.functions._ @@ -34,7 +34,7 @@ import org.apache.spark.sql.functions._ * without serializing the hashed relation, which does not happen in local mode. */ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { - protected var sqlContext: SQLContext = null + protected var spark: SparkSession = null /** * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. @@ -45,26 +45,26 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { .setMaster("local-cluster[2,1,1024]") .setAppName("testing") val sc = new SparkContext(conf) - sqlContext = new SQLContext(sc) + spark = SparkSession.builder.getOrCreate() } override def afterAll(): Unit = { - sqlContext.sparkContext.stop() - sqlContext = null + spark.stop() + spark = null } /** * Test whether the specified broadcast join updates the peak execution memory accumulator. */ private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { - AccumulatorSuite.verifyPeakExecutionMemorySet(sqlContext.sparkContext, name) { - val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") - val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + AccumulatorSuite.verifyPeakExecutionMemorySet(spark.sparkContext, name) { + val df1 = spark.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df2 = spark.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") // Comparison at the end is for broadcast left semi join val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") val df3 = df1.join(broadcast(df2), joinExpression, joinType) val plan = - EnsureRequirements(sqlContext.sessionState.conf).apply(df3.queryExecution.sparkPlan) + EnsureRequirements(spark.sessionState.conf).apply(df3.queryExecution.sparkPlan) assert(plan.collect { case p: T => p }.size === 1) plan.executeCollect() } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 2a4a369..7caeb3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -32,7 +32,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { import testImplicits.newProductEncoder import testImplicits.localSeqToDatasetHolder - private lazy val myUpperCaseData = sqlContext.createDataFrame( + private lazy val myUpperCaseData = spark.createDataFrame( sparkContext.parallelize(Seq( Row(1, "A"), Row(2, "B"), @@ -43,7 +43,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { Row(null, "G") )), new StructType().add("N", IntegerType).add("L", StringType)) - private lazy val myLowerCaseData = sqlContext.createDataFrame( + private lazy val myLowerCaseData = spark.createDataFrame( sparkContext.parallelize(Seq( Row(1, "a"), Row(2, "b"), @@ -99,7 +99,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { boundCondition, leftPlan, rightPlan) - EnsureRequirements(sqlContext.sessionState.conf).apply(broadcastJoin) + EnsureRequirements(spark.sessionState.conf).apply(broadcastJoin) } def makeShuffledHashJoin( @@ -113,7 +113,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { joins.ShuffledHashJoinExec(leftKeys, rightKeys, Inner, side, None, leftPlan, rightPlan) val filteredJoin = boundCondition.map(FilterExec(_, shuffledHashJoin)).getOrElse(shuffledHashJoin) - EnsureRequirements(sqlContext.sessionState.conf).apply(filteredJoin) + EnsureRequirements(spark.sessionState.conf).apply(filteredJoin) } def makeSortMergeJoin( @@ -124,7 +124,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { rightPlan: SparkPlan) = { val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, boundCondition, leftPlan, rightPlan) - EnsureRequirements(sqlContext.sessionState.conf).apply(sortMergeJoin) + EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin) } test(s"$testName using BroadcastHashJoin (build=left)") { http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index c26cb84..001feb0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { - private lazy val left = sqlContext.createDataFrame( + private lazy val left = spark.createDataFrame( sparkContext.parallelize(Seq( Row(1, 2.0), Row(2, 100.0), @@ -42,7 +42,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { Row(null, null) )), new StructType().add("a", IntegerType).add("b", DoubleType)) - private lazy val right = sqlContext.createDataFrame( + private lazy val right = spark.createDataFrame( sparkContext.parallelize(Seq( Row(0, 0.0), Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches @@ -82,7 +82,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { val buildSide = if (joinType == LeftOuter) BuildRight else BuildLeft checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext.sessionState.conf).apply( + EnsureRequirements(spark.sessionState.conf).apply( ShuffledHashJoinExec( leftKeys, rightKeys, joinType, buildSide, boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), @@ -115,7 +115,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(sqlContext.sessionState.conf).apply( + EnsureRequirements(spark.sessionState.conf).apply( SortMergeJoinExec(leftKeys, rightKeys, joinType, boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index d41e88a..1b82769 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -71,21 +71,21 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { df: DataFrame, expectedNumOfJobs: Int, expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + val previousExecutionIds = spark.listener.executionIdToData.keySet withSQLConf("spark.sql.codegen.wholeStage" -> "false") { df.collect() } sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + val executionIds = spark.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs + val jobs = spark.listener.getExecution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change it to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= expectedNumOfJobs) if (jobs.size == expectedNumOfJobs) { // If we can track all jobs, check the metric values - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + val metricValues = spark.listener.getExecutionMetrics(executionId) val actualMetrics = SparkPlanGraph(SparkPlanInfo.fromSparkPlan( df.queryExecution.executedPlan)).allNodes.filter { node => expectedMetrics.contains(node.id) @@ -128,7 +128,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { // Assume the execution plan is // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Filter(nodeId = 1)) // TODO: update metrics in generated operators - val ds = sqlContext.range(10).filter('id < 5) + val ds = spark.range(10).filter('id < 5) testSparkPlanMetrics(ds.toDF(), 1, Map.empty) } @@ -157,7 +157,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("Sort metrics") { // Assume the execution plan is // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1)) - val ds = sqlContext.range(10).sort('id) + val ds = spark.range(10).sort('id) testSparkPlanMetrics(ds.toDF(), 2, Map.empty) } @@ -169,7 +169,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { withTempTable("testDataForJoin") { // Assume the execution plan is // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( + val df = spark.sql( "SELECT * FROM testData2 JOIN testDataForJoin ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df, 1, Map( 0L -> ("SortMergeJoin", Map( @@ -187,7 +187,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { withTempTable("testDataForJoin") { // Assume the execution plan is // ... -> SortMergeJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( + val df = spark.sql( "SELECT * FROM testData2 left JOIN testDataForJoin ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df, 1, Map( 0L -> ("SortMergeJoin", Map( @@ -195,7 +195,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { "number of output rows" -> 8L))) ) - val df2 = sqlContext.sql( + val df2 = spark.sql( "SELECT * FROM testDataForJoin right JOIN testData2 ON testData2.a = testDataForJoin.a") testSparkPlanMetrics(df2, 1, Map( 0L -> ("SortMergeJoin", Map( @@ -241,7 +241,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { withTempTable("testDataForJoin") { // Assume the execution plan is // ... -> BroadcastNestedLoopJoin(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( + val df = spark.sql( "SELECT * FROM testData2 left JOIN testDataForJoin ON " + "testData2.a * testDataForJoin.a != testData2.a + testDataForJoin.a") testSparkPlanMetrics(df, 3, Map( @@ -269,7 +269,7 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { withTempTable("testDataForJoin") { // Assume the execution plan is // ... -> CartesianProduct(nodeId = 1) -> TungstenProject(nodeId = 0) - val df = sqlContext.sql( + val df = spark.sql( "SELECT * FROM testData2 JOIN testDataForJoin") testSparkPlanMetrics(df, 1, Map( 0L -> ("CartesianProduct", Map( @@ -280,19 +280,19 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { test("save metrics") { withTempPath { file => - val previousExecutionIds = sqlContext.listener.executionIdToData.keySet + val previousExecutionIds = spark.listener.executionIdToData.keySet // Assume the execution plan is // PhysicalRDD(nodeId = 0) person.select('name).write.format("json").save(file.getAbsolutePath) sparkContext.listenerBus.waitUntilEmpty(10000) - val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) + val executionIds = spark.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) val executionId = executionIds.head - val jobs = sqlContext.listener.getExecution(executionId).get.jobs + val jobs = spark.listener.getExecution(executionId).get.jobs // Use "<=" because there is a race condition that we may miss some jobs // TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event. assert(jobs.size <= 1) - val metricValues = sqlContext.listener.getExecutionMetrics(executionId) + val metricValues = spark.listener.getExecutionMetrics(executionId) // Because "save" will create a new DataFrame internally, we cannot get the real metric id. // However, we still can check the value. assert(metricValues.values.toSeq.exists(_ === "2")) http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala index 7b413dd..a7b2cfe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala @@ -217,7 +217,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3", SQLConf.FILE_SINK_LOG_CLEANUP_DELAY.key -> "0") { withFileStreamSinkLog { sinkLog => - val fs = sinkLog.metadataPath.getFileSystem(sqlContext.sessionState.newHadoopConf()) + val fs = sinkLog.metadataPath.getFileSystem(spark.sessionState.newHadoopConf()) def listBatchFiles(): Set[String] = { fs.listStatus(sinkLog.metadataPath).map(_.getPath.getName).filter { fileName => @@ -263,7 +263,7 @@ class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { private def withFileStreamSinkLog(f: FileStreamSinkLog => Unit): Unit = { withTempDir { file => - val sinkLog = new FileStreamSinkLog(sqlContext.sparkSession, file.getCanonicalPath) + val sinkLog = new FileStreamSinkLog(spark, file.getCanonicalPath) f(sinkLog) } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 5f92c5b..ef2b479 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -59,63 +59,63 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { test("HDFSMetadataLog: basic") { withTempDir { temp => val dir = new File(temp, "dir") // use non-existent directory to test whether log make the dir - val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, dir.getAbsolutePath) + val metadataLog = new HDFSMetadataLog[String](spark, dir.getAbsolutePath) assert(metadataLog.add(0, "batch0")) assert(metadataLog.getLatest() === Some(0 -> "batch0")) assert(metadataLog.get(0) === Some("batch0")) assert(metadataLog.getLatest() === Some(0 -> "batch0")) - assert(metadataLog.get(None, 0) === Array(0 -> "batch0")) + assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) assert(metadataLog.add(1, "batch1")) assert(metadataLog.get(0) === Some("batch0")) assert(metadataLog.get(1) === Some("batch1")) assert(metadataLog.getLatest() === Some(1 -> "batch1")) - assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + assert(metadataLog.get(None, Some(1)) === Array(0 -> "batch0", 1 -> "batch1")) // Adding the same batch does nothing metadataLog.add(1, "batch1-duplicated") assert(metadataLog.get(0) === Some("batch0")) assert(metadataLog.get(1) === Some("batch1")) assert(metadataLog.getLatest() === Some(1 -> "batch1")) - assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + assert(metadataLog.get(None, Some(1)) === Array(0 -> "batch0", 1 -> "batch1")) } } testQuietly("HDFSMetadataLog: fallback from FileContext to FileSystem") { - sqlContext.conf.setConfString( + spark.conf.set( s"fs.$scheme.impl", classOf[FakeFileSystem].getName) withTempDir { temp => - val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, s"$scheme://$temp") + val metadataLog = new HDFSMetadataLog[String](spark, s"$scheme://$temp") assert(metadataLog.add(0, "batch0")) assert(metadataLog.getLatest() === Some(0 -> "batch0")) assert(metadataLog.get(0) === Some("batch0")) - assert(metadataLog.get(None, 0) === Array(0 -> "batch0")) + assert(metadataLog.get(None, Some(0)) === Array(0 -> "batch0")) - val metadataLog2 = new HDFSMetadataLog[String](sqlContext.sparkSession, s"$scheme://$temp") + val metadataLog2 = new HDFSMetadataLog[String](spark, s"$scheme://$temp") assert(metadataLog2.get(0) === Some("batch0")) assert(metadataLog2.getLatest() === Some(0 -> "batch0")) - assert(metadataLog2.get(None, 0) === Array(0 -> "batch0")) + assert(metadataLog2.get(None, Some(0)) === Array(0 -> "batch0")) } } test("HDFSMetadataLog: restart") { withTempDir { temp => - val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath) + val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) assert(metadataLog.add(0, "batch0")) assert(metadataLog.add(1, "batch1")) assert(metadataLog.get(0) === Some("batch0")) assert(metadataLog.get(1) === Some("batch1")) assert(metadataLog.getLatest() === Some(1 -> "batch1")) - assert(metadataLog.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + assert(metadataLog.get(None, Some(1)) === Array(0 -> "batch0", 1 -> "batch1")) - val metadataLog2 = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath) + val metadataLog2 = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) assert(metadataLog2.get(0) === Some("batch0")) assert(metadataLog2.get(1) === Some("batch1")) assert(metadataLog2.getLatest() === Some(1 -> "batch1")) - assert(metadataLog2.get(None, 1) === Array(0 -> "batch0", 1 -> "batch1")) + assert(metadataLog2.get(None, Some(1)) === Array(0 -> "batch0", 1 -> "batch1")) } } @@ -127,7 +127,7 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { new Thread() { override def run(): Unit = waiter { val metadataLog = - new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath) + new HDFSMetadataLog[String](spark, temp.getAbsolutePath) try { var nextBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L) nextBatchId += 1 @@ -146,9 +146,10 @@ class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { } waiter.await(timeout(10.seconds), dismissals(10)) - val metadataLog = new HDFSMetadataLog[String](sqlContext.sparkSession, temp.getAbsolutePath) + val metadataLog = new HDFSMetadataLog[String](spark, temp.getAbsolutePath) assert(metadataLog.getLatest() === Some(maxBatchId -> maxBatchId.toString)) - assert(metadataLog.get(None, maxBatchId) === (0 to maxBatchId).map(i => (i, i.toString))) + assert( + metadataLog.get(None, Some(maxBatchId)) === (0 to maxBatchId).map(i => (i, i.toString))) } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 6be94eb..4fa1754 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -27,10 +27,11 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.LocalSparkSession._ import org.apache.spark.LocalSparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.ExecutorCacheTaskLocation -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{CompletionIterator, Utils} @@ -54,19 +55,18 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("versioning and immutability") { - withSpark(new SparkContext(sparkConf)) { sc => - val sqlContext = new SQLContext(sc) + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 val rdd1 = - makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)( + makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( + spark.wrapped, path, opId, storeVersion = 0, keySchema, valueSchema)( increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores - val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) + val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( + spark.wrapped, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -79,30 +79,30 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString def makeStoreRDD( - sc: SparkContext, + spark: SparkSession, seq: Seq[String], storeVersion: Int): RDD[(String, Int)] = { - implicit val sqlContext = new SQLContext(sc) - makeRDD(sc, Seq("a")).mapPartitionsWithStateStore( + implicit val sqlContext = spark.wrapped + makeRDD(spark.sparkContext, Seq("a")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment) } // Generate RDDs and state store data - withSpark(new SparkContext(sparkConf)) { sc => + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => for (i <- 1 to 20) { - require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) + require(makeStoreRDD(spark, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) } } // With a new context, try using the earlier state store data - withSpark(new SparkContext(sparkConf)) { sc => - assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => + assert(makeStoreRDD(spark, Seq("a"), 20).collect().toSet === Set("a" -> 21)) } } test("usage with iterators - only gets and only puts") { - withSpark(new SparkContext(sparkConf)) { sc => - implicit val sqlContext = new SQLContext(sc) + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => + implicit val sqlContext = spark.wrapped val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 @@ -130,15 +130,15 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } } - val rddOfGets1 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore( - sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets) + val rddOfGets1 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( + spark.wrapped, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets) assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None)) - val rddOfPuts = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + val rddOfPuts = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfPuts) assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1)) - val rddOfGets2 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore( + val rddOfGets2 = makeRDD(spark.sparkContext, Seq("a", "b", "c")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(iteratorOfGets) assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None)) } @@ -149,8 +149,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn val opId = 0 val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - withSpark(new SparkContext(sparkConf)) { sc => - implicit val sqlContext = new SQLContext(sc) + withSparkSession(SparkSession.builder.config(sparkConf).getOrCreate()) { spark => + implicit val sqlContext = spark.wrapped val coordinatorRef = sqlContext.streams.stateStoreCoordinator coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 0), "host1", "exec1") coordinatorRef.reportActiveInstance(StateStoreId(path, opId, 1), "host2", "exec2") @@ -159,7 +159,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) - val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + val rdd = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) require(rdd.partitions.length === 2) @@ -178,16 +178,20 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("distributed test") { quietly { - withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc => - implicit val sqlContext = new SQLContext(sc) + + withSparkSession( + SparkSession.builder + .config(sparkConf.setMaster("local-cluster[2, 1, 1024]")) + .getOrCreate()) { spark => + implicit val sqlContext = spark.wrapped val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString val opId = 0 - val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + val rdd1 = makeRDD(spark.sparkContext, Seq("a", "b", "a")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores - val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore( + val rdd2 = makeRDD(spark.sparkContext, Seq("a", "c")).mapPartitionsWithStateStore( sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 67e4484..9eff42a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -98,7 +98,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } } - val listener = new SQLListener(sqlContext.sparkContext.conf) + val listener = new SQLListener(spark.sparkContext.conf) val executionId = 0 val df = createTestDataFrame val accumulatorIds = @@ -239,7 +239,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before onJobEnd(JobSucceeded)") { - val listener = new SQLListener(sqlContext.sparkContext.conf) + val listener = new SQLListener(spark.sparkContext.conf) val executionId = 0 val df = createTestDataFrame listener.onOtherEvent(SparkListenerSQLExecutionStart( @@ -269,7 +269,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") { - val listener = new SQLListener(sqlContext.sparkContext.conf) + val listener = new SQLListener(spark.sparkContext.conf) val executionId = 0 val df = createTestDataFrame listener.onOtherEvent(SparkListenerSQLExecutionStart( @@ -310,7 +310,7 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("onExecutionEnd happens before onJobEnd(JobFailed)") { - val listener = new SQLListener(sqlContext.sparkContext.conf) + val listener = new SQLListener(spark.sparkContext.conf) val executionId = 0 val df = createTestDataFrame listener.onOtherEvent(SparkListenerSQLExecutionStart( @@ -340,16 +340,16 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { } test("SPARK-11126: no memory leak when running non SQL jobs") { - val previousStageNumber = sqlContext.listener.stageIdToStageMetrics.size - sqlContext.sparkContext.parallelize(1 to 10).foreach(i => ()) - sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) + val previousStageNumber = spark.listener.stageIdToStageMetrics.size + spark.sparkContext.parallelize(1 to 10).foreach(i => ()) + spark.sparkContext.listenerBus.waitUntilEmpty(10000) // listener should ignore the non SQL stage - assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber) + assert(spark.listener.stageIdToStageMetrics.size == previousStageNumber) - sqlContext.sparkContext.parallelize(1 to 10).toDF().foreach(i => ()) - sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) + spark.sparkContext.parallelize(1 to 10).toDF().foreach(i => ()) + spark.sparkContext.listenerBus.waitUntilEmpty(10000) // listener should save the SQL stage - assert(sqlContext.listener.stageIdToStageMetrics.size == previousStageNumber + 1) + assert(spark.listener.stageIdToStageMetrics.size == previousStageNumber + 1) } test("SPARK-13055: history listener only tracks SQL metrics") { @@ -401,8 +401,8 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { val sc = new SparkContext(conf) try { SQLContext.clearSqlListener() - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ + val spark = new SQLContext(sc) + import spark.implicits._ // Run 100 successful executions and 100 failed executions. // Each execution only has one job and one stage. for (i <- 0 until 100) { @@ -418,12 +418,12 @@ class SQLListenerMemoryLeakSuite extends SparkFunSuite { } } sc.listenerBus.waitUntilEmpty(10000) - assert(sqlContext.listener.getCompletedExecutions.size <= 50) - assert(sqlContext.listener.getFailedExecutions.size <= 50) + assert(spark.listener.getCompletedExecutions.size <= 50) + assert(spark.listener.getFailedExecutions.size <= 50) // 50 for successful executions and 50 for failed executions - assert(sqlContext.listener.executionIdToData.size <= 100) - assert(sqlContext.listener.jobIdToExecutionId.size <= 100) - assert(sqlContext.listener.stageIdToStageMetrics.size <= 100) + assert(spark.listener.executionIdToData.size <= 100) + assert(spark.listener.jobIdToExecutionId.size <= 100) + assert(spark.listener.stageIdToStageMetrics.size <= 100) } finally { sc.stop() } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 73c2076..56f848b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -37,13 +37,12 @@ class CatalogSuite with BeforeAndAfterEach with SharedSQLContext { - private def sparkSession: SparkSession = sqlContext.sparkSession - private def sessionCatalog: SessionCatalog = sparkSession.sessionState.catalog + private def sessionCatalog: SessionCatalog = spark.sessionState.catalog private val utils = new CatalogTestUtils { override val tableInputFormat: String = "com.fruit.eyephone.CameraInputFormat" override val tableOutputFormat: String = "com.fruit.eyephone.CameraOutputFormat" - override def newEmptyCatalog(): ExternalCatalog = sparkSession.sharedState.externalCatalog + override def newEmptyCatalog(): ExternalCatalog = spark.sharedState.externalCatalog } private def createDatabase(name: String): Unit = { @@ -87,8 +86,8 @@ class CatalogSuite private def testListColumns(tableName: String, dbName: Option[String]): Unit = { val tableMetadata = sessionCatalog.getTableMetadata(TableIdentifier(tableName, dbName)) val columns = dbName - .map { db => sparkSession.catalog.listColumns(db, tableName) } - .getOrElse { sparkSession.catalog.listColumns(tableName) } + .map { db => spark.catalog.listColumns(db, tableName) } + .getOrElse { spark.catalog.listColumns(tableName) } assume(tableMetadata.schema.nonEmpty, "bad test") assume(tableMetadata.partitionColumnNames.nonEmpty, "bad test") assume(tableMetadata.bucketColumnNames.nonEmpty, "bad test") @@ -108,85 +107,85 @@ class CatalogSuite } test("current database") { - assert(sparkSession.catalog.currentDatabase == "default") + assert(spark.catalog.currentDatabase == "default") assert(sessionCatalog.getCurrentDatabase == "default") createDatabase("my_db") - sparkSession.catalog.setCurrentDatabase("my_db") - assert(sparkSession.catalog.currentDatabase == "my_db") + spark.catalog.setCurrentDatabase("my_db") + assert(spark.catalog.currentDatabase == "my_db") assert(sessionCatalog.getCurrentDatabase == "my_db") val e = intercept[AnalysisException] { - sparkSession.catalog.setCurrentDatabase("unknown_db") + spark.catalog.setCurrentDatabase("unknown_db") } assert(e.getMessage.contains("unknown_db")) } test("list databases") { - assert(sparkSession.catalog.listDatabases().collect().map(_.name).toSet == Set("default")) + assert(spark.catalog.listDatabases().collect().map(_.name).toSet == Set("default")) createDatabase("my_db1") createDatabase("my_db2") - assert(sparkSession.catalog.listDatabases().collect().map(_.name).toSet == + assert(spark.catalog.listDatabases().collect().map(_.name).toSet == Set("default", "my_db1", "my_db2")) dropDatabase("my_db1") - assert(sparkSession.catalog.listDatabases().collect().map(_.name).toSet == + assert(spark.catalog.listDatabases().collect().map(_.name).toSet == Set("default", "my_db2")) } test("list tables") { - assert(sparkSession.catalog.listTables().collect().isEmpty) + assert(spark.catalog.listTables().collect().isEmpty) createTable("my_table1") createTable("my_table2") createTempTable("my_temp_table") - assert(sparkSession.catalog.listTables().collect().map(_.name).toSet == + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_table1", "my_table2", "my_temp_table")) dropTable("my_table1") - assert(sparkSession.catalog.listTables().collect().map(_.name).toSet == + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_table2", "my_temp_table")) dropTable("my_temp_table") - assert(sparkSession.catalog.listTables().collect().map(_.name).toSet == Set("my_table2")) + assert(spark.catalog.listTables().collect().map(_.name).toSet == Set("my_table2")) } test("list tables with database") { - assert(sparkSession.catalog.listTables("default").collect().isEmpty) + assert(spark.catalog.listTables("default").collect().isEmpty) createDatabase("my_db1") createDatabase("my_db2") createTable("my_table1", Some("my_db1")) createTable("my_table2", Some("my_db2")) createTempTable("my_temp_table") - assert(sparkSession.catalog.listTables("default").collect().map(_.name).toSet == + assert(spark.catalog.listTables("default").collect().map(_.name).toSet == Set("my_temp_table")) - assert(sparkSession.catalog.listTables("my_db1").collect().map(_.name).toSet == + assert(spark.catalog.listTables("my_db1").collect().map(_.name).toSet == Set("my_table1", "my_temp_table")) - assert(sparkSession.catalog.listTables("my_db2").collect().map(_.name).toSet == + assert(spark.catalog.listTables("my_db2").collect().map(_.name).toSet == Set("my_table2", "my_temp_table")) dropTable("my_table1", Some("my_db1")) - assert(sparkSession.catalog.listTables("my_db1").collect().map(_.name).toSet == + assert(spark.catalog.listTables("my_db1").collect().map(_.name).toSet == Set("my_temp_table")) - assert(sparkSession.catalog.listTables("my_db2").collect().map(_.name).toSet == + assert(spark.catalog.listTables("my_db2").collect().map(_.name).toSet == Set("my_table2", "my_temp_table")) dropTable("my_temp_table") - assert(sparkSession.catalog.listTables("default").collect().map(_.name).isEmpty) - assert(sparkSession.catalog.listTables("my_db1").collect().map(_.name).isEmpty) - assert(sparkSession.catalog.listTables("my_db2").collect().map(_.name).toSet == + assert(spark.catalog.listTables("default").collect().map(_.name).isEmpty) + assert(spark.catalog.listTables("my_db1").collect().map(_.name).isEmpty) + assert(spark.catalog.listTables("my_db2").collect().map(_.name).toSet == Set("my_table2")) val e = intercept[AnalysisException] { - sparkSession.catalog.listTables("unknown_db") + spark.catalog.listTables("unknown_db") } assert(e.getMessage.contains("unknown_db")) } test("list functions") { assert(Set("+", "current_database", "window").subsetOf( - sparkSession.catalog.listFunctions().collect().map(_.name).toSet)) + spark.catalog.listFunctions().collect().map(_.name).toSet)) createFunction("my_func1") createFunction("my_func2") createTempFunction("my_temp_func") - val funcNames1 = sparkSession.catalog.listFunctions().collect().map(_.name).toSet + val funcNames1 = spark.catalog.listFunctions().collect().map(_.name).toSet assert(funcNames1.contains("my_func1")) assert(funcNames1.contains("my_func2")) assert(funcNames1.contains("my_temp_func")) dropFunction("my_func1") dropTempFunction("my_temp_func") - val funcNames2 = sparkSession.catalog.listFunctions().collect().map(_.name).toSet + val funcNames2 = spark.catalog.listFunctions().collect().map(_.name).toSet assert(!funcNames2.contains("my_func1")) assert(funcNames2.contains("my_func2")) assert(!funcNames2.contains("my_temp_func")) @@ -194,14 +193,14 @@ class CatalogSuite test("list functions with database") { assert(Set("+", "current_database", "window").subsetOf( - sparkSession.catalog.listFunctions("default").collect().map(_.name).toSet)) + spark.catalog.listFunctions("default").collect().map(_.name).toSet)) createDatabase("my_db1") createDatabase("my_db2") createFunction("my_func1", Some("my_db1")) createFunction("my_func2", Some("my_db2")) createTempFunction("my_temp_func") - val funcNames1 = sparkSession.catalog.listFunctions("my_db1").collect().map(_.name).toSet - val funcNames2 = sparkSession.catalog.listFunctions("my_db2").collect().map(_.name).toSet + val funcNames1 = spark.catalog.listFunctions("my_db1").collect().map(_.name).toSet + val funcNames2 = spark.catalog.listFunctions("my_db2").collect().map(_.name).toSet assert(funcNames1.contains("my_func1")) assert(!funcNames1.contains("my_func2")) assert(funcNames1.contains("my_temp_func")) @@ -210,14 +209,14 @@ class CatalogSuite assert(funcNames2.contains("my_temp_func")) dropFunction("my_func1", Some("my_db1")) dropTempFunction("my_temp_func") - val funcNames1b = sparkSession.catalog.listFunctions("my_db1").collect().map(_.name).toSet - val funcNames2b = sparkSession.catalog.listFunctions("my_db2").collect().map(_.name).toSet + val funcNames1b = spark.catalog.listFunctions("my_db1").collect().map(_.name).toSet + val funcNames2b = spark.catalog.listFunctions("my_db2").collect().map(_.name).toSet assert(!funcNames1b.contains("my_func1")) assert(!funcNames1b.contains("my_temp_func")) assert(funcNames2b.contains("my_func2")) assert(!funcNames2b.contains("my_temp_func")) val e = intercept[AnalysisException] { - sparkSession.catalog.listFunctions("unknown_db") + spark.catalog.listFunctions("unknown_db") } assert(e.getMessage.contains("unknown_db")) } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala index b87f482..7ead97b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala @@ -33,61 +33,61 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { test("programmatic ways of basic setting and getting") { // Set a conf first. - sqlContext.setConf(testKey, testVal) + spark.conf.set(testKey, testVal) // Clear the conf. - sqlContext.conf.clear() + spark.wrapped.conf.clear() // After clear, only overrideConfs used by unit test should be in the SQLConf. - assert(sqlContext.getAllConfs === TestSQLContext.overrideConfs) + assert(spark.conf.getAll === TestSQLContext.overrideConfs) - sqlContext.setConf(testKey, testVal) - assert(sqlContext.getConf(testKey) === testVal) - assert(sqlContext.getConf(testKey, testVal + "_") === testVal) - assert(sqlContext.getAllConfs.contains(testKey)) + spark.conf.set(testKey, testVal) + assert(spark.conf.get(testKey) === testVal) + assert(spark.conf.get(testKey, testVal + "_") === testVal) + assert(spark.conf.getAll.contains(testKey)) // Tests SQLConf as accessed from a SQLContext is mutable after // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(sqlContext.getConf(testKey) === testVal) - assert(sqlContext.getConf(testKey, testVal + "_") === testVal) - assert(sqlContext.getAllConfs.contains(testKey)) + assert(spark.conf.get(testKey) === testVal) + assert(spark.conf.get(testKey, testVal + "_") === testVal) + assert(spark.conf.getAll.contains(testKey)) - sqlContext.conf.clear() + spark.wrapped.conf.clear() } test("parse SQL set commands") { - sqlContext.conf.clear() + spark.wrapped.conf.clear() sql(s"set $testKey=$testVal") - assert(sqlContext.getConf(testKey, testVal + "_") === testVal) - assert(sqlContext.getConf(testKey, testVal + "_") === testVal) + assert(spark.conf.get(testKey, testVal + "_") === testVal) + assert(spark.conf.get(testKey, testVal + "_") === testVal) sql("set some.property=20") - assert(sqlContext.getConf("some.property", "0") === "20") + assert(spark.conf.get("some.property", "0") === "20") sql("set some.property = 40") - assert(sqlContext.getConf("some.property", "0") === "40") + assert(spark.conf.get("some.property", "0") === "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" sql(s"set $key=$vs") - assert(sqlContext.getConf(key, "0") === vs) + assert(spark.conf.get(key, "0") === vs) sql(s"set $key=") - assert(sqlContext.getConf(key, "0") === "") + assert(spark.conf.get(key, "0") === "") - sqlContext.conf.clear() + spark.wrapped.conf.clear() } test("deprecated property") { - sqlContext.conf.clear() - val original = sqlContext.conf.numShufflePartitions + spark.wrapped.conf.clear() + val original = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS) try{ sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(sqlContext.conf.numShufflePartitions === 10) + assert(spark.conf.get(SQLConf.SHUFFLE_PARTITIONS) === 10) } finally { sql(s"set ${SQLConf.SHUFFLE_PARTITIONS}=$original") } } test("invalid conf value") { - sqlContext.conf.clear() + spark.wrapped.conf.clear() val e = intercept[IllegalArgumentException] { sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10") } @@ -95,35 +95,35 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { } test("Test SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE's method") { - sqlContext.conf.clear() + spark.wrapped.conf.clear() - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "100") - assert(sqlContext.conf.targetPostShuffleInputSize === 100) + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "100") + assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 100) - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1k") - assert(sqlContext.conf.targetPostShuffleInputSize === 1024) + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1k") + assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 1024) - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1M") - assert(sqlContext.conf.targetPostShuffleInputSize === 1048576) + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1M") + assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 1048576) - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1g") - assert(sqlContext.conf.targetPostShuffleInputSize === 1073741824) + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1g") + assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === 1073741824) - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-1") - assert(sqlContext.conf.targetPostShuffleInputSize === -1) + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-1") + assert(spark.conf.get(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE) === -1) // Test overflow exception intercept[IllegalArgumentException] { // This value exceeds Long.MaxValue - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "90000000000g") + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "90000000000g") } intercept[IllegalArgumentException] { // This value less than Long.MinValue - sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g") + spark.conf.set(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g") } - sqlContext.conf.clear() + spark.wrapped.conf.clear() } test("SparkSession can access configs set in SparkConf") { http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 47a1017..44d1b9d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -337,39 +337,39 @@ class JDBCSuite extends SparkFunSuite } test("Basic API") { - assert(sqlContext.read.jdbc( + assert(spark.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3) } test("Basic API with FetchSize") { val properties = new Properties properties.setProperty("fetchSize", "2") - assert(sqlContext.read.jdbc( + assert(spark.read.jdbc( urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3) } test("Partitioning via JDBCPartitioningInfo API") { assert( - sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) + spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties) .collect().length === 3) } test("Partitioning via list-of-where-clauses API") { val parts = Array[String]("THEID < 2", "THEID >= 2") - assert(sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) + assert(spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties) .collect().length === 3) } test("Partitioning on column that might have null values.") { assert( - sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties) + spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties) .collect().length === 4) assert( - sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties) + spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties) .collect().length === 4) // partitioning on a nullable quoted column assert( - sqlContext.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties) + spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties) .collect().length === 4) } @@ -429,9 +429,9 @@ class JDBCSuite extends SparkFunSuite } test("test DATE types") { - val rows = sqlContext.read.jdbc( + val rows = spark.read.jdbc( urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - val cachedRows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val cachedRows = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(rows(1).getAs[java.sql.Date](1) === null) @@ -439,8 +439,8 @@ class JDBCSuite extends SparkFunSuite } test("test DATE types in cache") { - val rows = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() - sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val rows = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect() + spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) .cache().registerTempTable("mycached_date") val cachedRows = sql("select * from mycached_date").collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) @@ -448,7 +448,7 @@ class JDBCSuite extends SparkFunSuite } test("test types for null value") { - val rows = sqlContext.read.jdbc( + val rows = spark.read.jdbc( urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect() assert((0 to 14).forall(i => rows(0).isNullAt(i))) } @@ -495,7 +495,7 @@ class JDBCSuite extends SparkFunSuite test("Remap types via JdbcDialects") { JdbcDialects.registerDialect(testH2Dialect) - val df = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty) val rows = df.collect() assert(rows(0).get(0).isInstanceOf[String]) @@ -629,7 +629,7 @@ class JDBCSuite extends SparkFunSuite // Regression test for bug SPARK-11788 val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543"); val date = java.sql.Date.valueOf("1995-01-01") - val jdbcDf = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) + val jdbcDf = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties) val rows = jdbcDf.where($"B" > date && $"C" > timestamp).collect() assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01")) assert(rows(0).getAs[java.sql.Timestamp](2) @@ -639,7 +639,7 @@ class JDBCSuite extends SparkFunSuite test("test credentials in the properties are not in plan output") { val df = sql("SELECT * FROM parts") val explain = ExplainCommand(df.queryExecution.logical, extended = true) - sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + spark.executePlan(explain).executedPlan.executeCollect().foreach { r => assert(!List("testPass", "testUser").exists(r.toString.contains)) } // test the JdbcRelation toString output @@ -649,9 +649,9 @@ class JDBCSuite extends SparkFunSuite } test("test credentials in the connection url are not in the plan output") { - val df = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) val explain = ExplainCommand(df.queryExecution.logical, extended = true) - sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + spark.executePlan(explain).executedPlan.executeCollect().foreach { r => assert(!List("testPass", "testUser").exists(r.toString.contains)) } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala index e23ee66..48fa5f9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala @@ -88,50 +88,50 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { StructField("seq", IntegerType) :: Nil) test("Basic CREATE") { - val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties) - assert(2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) + assert(2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count) assert( - 2 === sqlContext.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) + 2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length) } test("CREATE with overwrite") { - val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3) - val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.DROPTEST", properties) - assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(3 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(3 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties) - assert(1 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).count) - assert(2 === sqlContext.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) + assert(1 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count) + assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length) } test("CREATE then INSERT to append") { - val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url, "TEST.APPENDTEST", new Properties) df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties) - assert(3 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) - assert(2 === sqlContext.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) + assert(3 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties).count) + assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length) } test("CREATE then INSERT to truncate") { - val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr1x2), schema2) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2) df.write.jdbc(url1, "TEST.TRUNCATETEST", properties) df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties) - assert(1 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) - assert(2 === sqlContext.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) + assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count) + assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length) } test("Incompatible INSERT to append") { - val df = sqlContext.createDataFrame(sparkContext.parallelize(arr2x2), schema2) - val df2 = sqlContext.createDataFrame(sparkContext.parallelize(arr2x3), schema3) + val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2) + val df2 = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3) df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties) intercept[org.apache.spark.SparkException] { @@ -141,14 +141,14 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter { test("INSERT to JDBC Datasource") { sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } test("INSERT to JDBC Datasource with overwrite") { sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE") sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE") - assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).count) - assert(2 === sqlContext.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) + assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count) + assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length) } } http://git-wip-us.apache.org/repos/asf/spark/blob/5bf74b44/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala index 9206113..754aa32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala @@ -24,7 +24,7 @@ private[sql] abstract class DataSourceTest extends QueryTest { // We want to test some edge cases. protected lazy val caseInsensitiveContext: SQLContext = { - val ctx = new SQLContext(sqlContext.sparkContext) + val ctx = new SQLContext(spark.sparkContext) ctx.setConf(SQLConf.CASE_SENSITIVE, false) ctx } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
