http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index a9b1970..a2decad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -29,11 +29,11 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val path = Utils.createTempDir() path.delete() - val df = sqlContext.range(100).select($"id", lit(1).as("data")) + val df = spark.range(100).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - sqlContext.read.load(path.getCanonicalPath), + spark.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) @@ -43,12 +43,12 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { val path = Utils.createTempDir() path.delete() - val base = sqlContext.range(100) + val base = spark.range(100) val df = base.union(base).select($"id", lit(1).as("data")) df.write.partitionBy("id").save(path.getCanonicalPath) checkAnswer( - sqlContext.read.load(path.getCanonicalPath), + spark.read.load(path.getCanonicalPath), (0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq) Utils.deleteRecursively(path) @@ -58,7 +58,7 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { withTempPath { f => val path = f.getAbsolutePath Seq(1 -> "a").toDF("i", "j").write.partitionBy("i").parquet(path) - assert(sqlContext.read.parquet(path).schema.map(_.name) == Seq("j", "i")) + assert(spark.read.parquet(path).schema.map(_.name) == Seq("j", "i")) } } }
http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala index 3d69c8a..a743cdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/ContinuousQueryManagerSuite.scala @@ -41,13 +41,13 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with override val streamingTimeout = 20.seconds before { - assert(sqlContext.streams.active.isEmpty) - sqlContext.streams.resetTerminated() + assert(spark.streams.active.isEmpty) + spark.streams.resetTerminated() } after { - assert(sqlContext.streams.active.isEmpty) - sqlContext.streams.resetTerminated() + assert(spark.streams.active.isEmpty) + spark.streams.resetTerminated() } testQuietly("listing") { @@ -57,26 +57,26 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with withQueriesOn(ds1, ds2, ds3) { queries => require(queries.size === 3) - assert(sqlContext.streams.active.toSet === queries.toSet) + assert(spark.streams.active.toSet === queries.toSet) val (q1, q2, q3) = (queries(0), queries(1), queries(2)) - assert(sqlContext.streams.get(q1.name).eq(q1)) - assert(sqlContext.streams.get(q2.name).eq(q2)) - assert(sqlContext.streams.get(q3.name).eq(q3)) + assert(spark.streams.get(q1.name).eq(q1)) + assert(spark.streams.get(q2.name).eq(q2)) + assert(spark.streams.get(q3.name).eq(q3)) intercept[IllegalArgumentException] { - sqlContext.streams.get("non-existent-name") + spark.streams.get("non-existent-name") } q1.stop() - assert(sqlContext.streams.active.toSet === Set(q2, q3)) + assert(spark.streams.active.toSet === Set(q2, q3)) val ex1 = withClue("no error while getting non-active query") { intercept[IllegalArgumentException] { - sqlContext.streams.get(q1.name) + spark.streams.get(q1.name) } } assert(ex1.getMessage.contains(q1.name), "error does not contain name of query to be fetched") - assert(sqlContext.streams.get(q2.name).eq(q2)) + assert(spark.streams.get(q2.name).eq(q2)) m2.addData(0) // q2 should terminate with error @@ -86,11 +86,11 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with } withClue("no error while getting non-active query") { intercept[IllegalArgumentException] { - sqlContext.streams.get(q2.name).eq(q2) + spark.streams.get(q2.name).eq(q2) } } - assert(sqlContext.streams.active.toSet === Set(q3)) + assert(spark.streams.active.toSet === Set(q3)) } } @@ -98,7 +98,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with val datasets = Seq.fill(5)(makeDataset._2) withQueriesOn(datasets: _*) { queries => require(queries.size === datasets.size) - assert(sqlContext.streams.active.toSet === queries.toSet) + assert(spark.streams.active.toSet === queries.toSet) // awaitAnyTermination should be blocking testAwaitAnyTermination(ExpectBlocked) @@ -112,7 +112,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with testAwaitAnyTermination(ExpectNotBlocked) // Resetting termination should make awaitAnyTermination() blocking again - sqlContext.streams.resetTerminated() + spark.streams.resetTerminated() testAwaitAnyTermination(ExpectBlocked) // Terminate a query asynchronously with exception and see awaitAnyTermination throws @@ -125,7 +125,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with testAwaitAnyTermination(ExpectException[SparkException]) // Resetting termination should make awaitAnyTermination() blocking again - sqlContext.streams.resetTerminated() + spark.streams.resetTerminated() testAwaitAnyTermination(ExpectBlocked) // Terminate multiple queries, one with failure and see whether awaitAnyTermination throws @@ -144,7 +144,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with val datasets = Seq.fill(6)(makeDataset._2) withQueriesOn(datasets: _*) { queries => require(queries.size === datasets.size) - assert(sqlContext.streams.active.toSet === queries.toSet) + assert(spark.streams.active.toSet === queries.toSet) // awaitAnyTermination should be blocking or non-blocking depending on timeout values testAwaitAnyTermination( @@ -173,7 +173,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with ExpectNotBlocked, awaitTimeout = 4 seconds, expectedReturnedValue = true) // Resetting termination should make awaitAnyTermination() blocking again - sqlContext.streams.resetTerminated() + spark.streams.resetTerminated() testAwaitAnyTermination( ExpectBlocked, awaitTimeout = 4 seconds, @@ -196,7 +196,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with testBehaviorFor = 4 seconds) // Terminate a query asynchronously outside the timeout, awaitAnyTerm should be blocked - sqlContext.streams.resetTerminated() + spark.streams.resetTerminated() val q3 = stopRandomQueryAsync(2 seconds, withError = true) testAwaitAnyTermination( ExpectNotBlocked, @@ -214,7 +214,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with // Terminate multiple queries, one with failure and see whether awaitAnyTermination throws // the exception - sqlContext.streams.resetTerminated() + spark.streams.resetTerminated() val q4 = stopRandomQueryAsync(10 milliseconds, withError = false) testAwaitAnyTermination( @@ -238,7 +238,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with val df = ds.toDF val metadataRoot = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath - query = sqlContext + query = spark .streams .startQuery( StreamExecution.nextName, @@ -272,10 +272,10 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with def awaitTermFunc(): Unit = { if (awaitTimeout != null && awaitTimeout.toMillis > 0) { - val returnedValue = sqlContext.streams.awaitAnyTermination(awaitTimeout.toMillis) + val returnedValue = spark.streams.awaitAnyTermination(awaitTimeout.toMillis) assert(returnedValue === expectedReturnedValue, "Returned value does not match expected") } else { - sqlContext.streams.awaitAnyTermination() + spark.streams.awaitAnyTermination() } } @@ -287,7 +287,7 @@ class ContinuousQueryManagerSuite extends StreamTest with SharedSQLContext with import scala.concurrent.ExecutionContext.Implicits.global - val activeQueries = sqlContext.streams.active + val activeQueries = spark.streams.active val queryToStop = activeQueries(Random.nextInt(activeQueries.length)) Future { Thread.sleep(stopAfter.toMillis) http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala index c7b2b99..cb53b2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DataFrameReaderWriterSuite.scala @@ -54,18 +54,18 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) override def sourceSchema( - sqlContext: SQLContext, + spark: SQLContext, schema: Option[StructType], providerName: String, parameters: Map[String, String]): (String, StructType) = { LastOptions.parameters = parameters LastOptions.schema = schema - LastOptions.mockStreamSourceProvider.sourceSchema(sqlContext, schema, providerName, parameters) + LastOptions.mockStreamSourceProvider.sourceSchema(spark, schema, providerName, parameters) ("dummySource", fakeSchema) } override def createSource( - sqlContext: SQLContext, + spark: SQLContext, metadataPath: String, schema: Option[StructType], providerName: String, @@ -73,14 +73,14 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { LastOptions.parameters = parameters LastOptions.schema = schema LastOptions.mockStreamSourceProvider.createSource( - sqlContext, metadataPath, schema, providerName, parameters) + spark, metadataPath, schema, providerName, parameters) new Source { override def schema: StructType = fakeSchema override def getOffset: Option[Offset] = Some(new LongOffset(0)) override def getBatch(start: Option[Offset], end: Offset): DataFrame = { - import sqlContext.implicits._ + import spark.implicits._ Seq[Int]().toDS().toDF() } @@ -88,12 +88,12 @@ class DefaultSource extends StreamSourceProvider with StreamSinkProvider { } override def createSink( - sqlContext: SQLContext, + spark: SQLContext, parameters: Map[String, String], partitionColumns: Seq[String]): Sink = { LastOptions.parameters = parameters LastOptions.partitionColumns = partitionColumns - LastOptions.mockStreamSinkProvider.createSink(sqlContext, parameters, partitionColumns) + LastOptions.mockStreamSinkProvider.createSink(spark, parameters, partitionColumns) new Sink { override def addBatch(batchId: Long, data: DataFrame): Unit = {} } @@ -107,11 +107,11 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath after { - sqlContext.streams.active.foreach(_.stop()) + spark.streams.active.foreach(_.stop()) } test("resolve default source") { - sqlContext.read + spark.read .format("org.apache.spark.sql.streaming.test") .stream() .write @@ -122,7 +122,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("resolve full class") { - sqlContext.read + spark.read .format("org.apache.spark.sql.streaming.test.DefaultSource") .stream() .write @@ -136,7 +136,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B val map = new java.util.HashMap[String, String] map.put("opt3", "3") - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .option("opt1", "1") .options(Map("opt2" -> "2")) @@ -164,7 +164,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("partitioning") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() @@ -204,7 +204,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("stream paths") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .option("checkpointLocation", newMetadataDir) .stream("/test") @@ -223,7 +223,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("test different data types for options") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .option("intOpt", 56) .option("boolOpt", false) @@ -253,7 +253,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B /** Start a query with a specific name */ def startQueryWithName(name: String = ""): ContinuousQuery = { - sqlContext.read + spark.read .format("org.apache.spark.sql.streaming.test") .stream("/test") .write @@ -265,7 +265,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B /** Start a query without specifying a name */ def startQueryWithoutName(): ContinuousQuery = { - sqlContext.read + spark.read .format("org.apache.spark.sql.streaming.test") .stream("/test") .write @@ -276,7 +276,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B /** Get the names of active streams */ def activeStreamNames: Set[String] = { - val streams = sqlContext.streams.active + val streams = spark.streams.active val names = streams.map(_.name).toSet assert(streams.length === names.size, s"names of active queries are not unique: $names") names @@ -307,11 +307,11 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B q1.stop() val q5 = startQueryWithName("name") assert(activeStreamNames.contains("name")) - sqlContext.streams.active.foreach(_.stop()) + spark.streams.active.foreach(_.stop()) } test("trigger") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream("/test") @@ -339,11 +339,11 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B val checkpointLocation = newMetadataDir - val df1 = sqlContext.read + val df1 = spark.read .format("org.apache.spark.sql.streaming.test") .stream() - val df2 = sqlContext.read + val df2 = spark.read .format("org.apache.spark.sql.streaming.test") .stream() @@ -355,14 +355,14 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B q.stop() verify(LastOptions.mockStreamSourceProvider).createSource( - sqlContext, + spark.wrapped, checkpointLocation + "/sources/0", None, "org.apache.spark.sql.streaming.test", Map.empty) verify(LastOptions.mockStreamSourceProvider).createSource( - sqlContext, + spark.wrapped, checkpointLocation + "/sources/1", None, "org.apache.spark.sql.streaming.test", @@ -372,35 +372,35 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B private def newTextInput = Utils.createTempDir(namePrefix = "text").getCanonicalPath test("check trigger() can only be called on continuous queries") { - val df = sqlContext.read.text(newTextInput) + val df = spark.read.text(newTextInput) val w = df.write.option("checkpointLocation", newMetadataDir) val e = intercept[AnalysisException](w.trigger(ProcessingTime("10 seconds"))) assert(e.getMessage == "trigger() can only be called on continuous queries;") } test("check queryName() can only be called on continuous queries") { - val df = sqlContext.read.text(newTextInput) + val df = spark.read.text(newTextInput) val w = df.write.option("checkpointLocation", newMetadataDir) val e = intercept[AnalysisException](w.queryName("queryName")) assert(e.getMessage == "queryName() can only be called on continuous queries;") } test("check startStream() can only be called on continuous queries") { - val df = sqlContext.read.text(newTextInput) + val df = spark.read.text(newTextInput) val w = df.write.option("checkpointLocation", newMetadataDir) val e = intercept[AnalysisException](w.startStream()) assert(e.getMessage == "startStream() can only be called on continuous queries;") } test("check startStream(path) can only be called on continuous queries") { - val df = sqlContext.read.text(newTextInput) + val df = spark.read.text(newTextInput) val w = df.write.option("checkpointLocation", newMetadataDir) val e = intercept[AnalysisException](w.startStream("non_exist_path")) assert(e.getMessage == "startStream() can only be called on continuous queries;") } test("check mode(SaveMode) can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -409,7 +409,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check mode(string) can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -418,7 +418,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check bucketBy() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -427,7 +427,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check sortBy() can only be called on non-continuous queries;") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -436,7 +436,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check save(path) can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -445,7 +445,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check save() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -454,7 +454,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check insertInto() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -463,7 +463,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check saveAsTable() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -472,7 +472,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check jdbc() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -481,7 +481,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check json() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -490,7 +490,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check parquet() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -499,7 +499,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check orc() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -508,7 +508,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check text() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write @@ -517,7 +517,7 @@ class DataFrameReaderWriterSuite extends StreamTest with SharedSQLContext with B } test("check csv() can only be called on non-continuous queries") { - val df = sqlContext.read + val df = spark.read .format("org.apache.spark.sql.streaming.test") .stream() val w = df.write http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index e937fc3..6238b74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -40,11 +40,11 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { val path = Utils.createTempDir() path.delete() - val hadoopConf = sqlContext.sparkContext.hadoopConfiguration + val hadoopConf = spark.sparkContext.hadoopConfiguration val fileFormat = new parquet.DefaultSource() def writeRange(start: Int, end: Int, numPartitions: Int): Seq[String] = { - val df = sqlContext + val df = spark .range(start, end, 1, numPartitions) .select($"id", lit(100).as("data")) val writer = new FileStreamSinkWriter( @@ -56,7 +56,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { val files1 = writeRange(0, 10, 2) assert(files1.size === 2, s"unexpected number of files: $files1") checkFilesExist(path, files1, "file not written") - checkAnswer(sqlContext.read.load(path.getCanonicalPath), (0 until 10).map(Row(_, 100))) + checkAnswer(spark.read.load(path.getCanonicalPath), (0 until 10).map(Row(_, 100))) // Append and check whether new files are written correctly and old files still exist val files2 = writeRange(10, 20, 3) @@ -64,7 +64,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { assert(files2.intersect(files1).isEmpty, "old files returned") checkFilesExist(path, files2, s"New file not written") checkFilesExist(path, files1, s"Old file not found") - checkAnswer(sqlContext.read.load(path.getCanonicalPath), (0 until 20).map(Row(_, 100))) + checkAnswer(spark.read.load(path.getCanonicalPath), (0 until 20).map(Row(_, 100))) } test("FileStreamSinkWriter - partitioned data") { @@ -72,11 +72,11 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { val path = Utils.createTempDir() path.delete() - val hadoopConf = sqlContext.sparkContext.hadoopConfiguration + val hadoopConf = spark.sparkContext.hadoopConfiguration val fileFormat = new parquet.DefaultSource() def writeRange(start: Int, end: Int, numPartitions: Int): Seq[String] = { - val df = sqlContext + val df = spark .range(start, end, 1, numPartitions) .flatMap(x => Iterator(x, x, x)).toDF("id") .select($"id", lit(100).as("data1"), lit(1000).as("data2")) @@ -103,7 +103,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { checkOneFileWrittenPerKey(0 until 10, files1) val answer1 = (0 until 10).flatMap(x => Iterator(x, x, x)).map(Row(100, 1000, _)) - checkAnswer(sqlContext.read.load(path.getCanonicalPath), answer1) + checkAnswer(spark.read.load(path.getCanonicalPath), answer1) // Append and check whether new files are written correctly and old files still exist val files2 = writeRange(0, 20, 3) @@ -114,7 +114,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { checkOneFileWrittenPerKey(0 until 20, files2) val answer2 = (0 until 20).flatMap(x => Iterator(x, x, x)).map(Row(100, 1000, _)) - checkAnswer(sqlContext.read.load(path.getCanonicalPath), answer1 ++ answer2) + checkAnswer(spark.read.load(path.getCanonicalPath), answer1 ++ answer2) } test("FileStreamSink - unpartitioned writing and batch reading") { @@ -139,7 +139,7 @@ class FileStreamSinkSuite extends StreamTest with SharedSQLContext { query.processAllAvailable() } - val outputDf = sqlContext.read.parquet(outputDir).as[Int] + val outputDf = spark.read.parquet(outputDir).as[Int] checkDataset(outputDf, 1, 2, 3) } finally { http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index a62852b..4b95d65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -103,9 +103,9 @@ class FileStreamSourceTest extends StreamTest with SharedSQLContext { val reader = if (schema.isDefined) { - sqlContext.read.format(format).schema(schema.get) + spark.read.format(format).schema(schema.get) } else { - sqlContext.read.format(format) + spark.read.format(format) } reader.stream(path) } @@ -149,7 +149,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { format: Option[String], path: Option[String], schema: Option[StructType] = None): StructType = { - val reader = sqlContext.read + val reader = spark.read format.foreach(reader.format) schema.foreach(reader.schema) val df = http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala index 50703e5..4efb7cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala @@ -100,7 +100,7 @@ class FileStressSuite extends StreamTest with SharedSQLContext { } writer.start() - val input = sqlContext.read.format("text").stream(inputDir) + val input = spark.read.format("text").stream(inputDir) def startStream(): ContinuousQuery = { val output = input @@ -150,6 +150,6 @@ class FileStressSuite extends StreamTest with SharedSQLContext { streamThread.join() logError(s"Stream restarted $failures times.") - assert(sqlContext.read.parquet(outputDir).distinct().count() == numRecords) + assert(spark.read.parquet(outputDir).distinct().count() == numRecords) } } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala index 74ca397..09c35bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/MemorySinkSuite.scala @@ -44,13 +44,13 @@ class MemorySinkSuite extends StreamTest with SharedSQLContext { query.processAllAvailable() checkDataset( - sqlContext.table("memStream").as[Int], + spark.table("memStream").as[Int], 1, 2, 3) input.addData(4, 5, 6) query.processAllAvailable() checkDataset( - sqlContext.table("memStream").as[Int], + spark.table("memStream").as[Int], 1, 2, 3, 4, 5, 6) query.stop() http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index bcd3cba..6a8b280 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -94,7 +94,7 @@ class StreamSuite extends StreamTest with SharedSQLContext { .startStream(outputDir.getAbsolutePath) try { query.processAllAvailable() - val outputDf = sqlContext.read.parquet(outputDir.getAbsolutePath).as[Long] + val outputDf = spark.read.parquet(outputDir.getAbsolutePath).as[Long] checkDataset[Long](outputDf, (0L to 10L).toArray: _*) } finally { query.stop() @@ -103,7 +103,7 @@ class StreamSuite extends StreamTest with SharedSQLContext { } } - val df = sqlContext.read.format(classOf[FakeDefaultSource].getName).stream() + val df = spark.read.format(classOf[FakeDefaultSource].getName).stream() assertDF(df) assertDF(df) } @@ -162,13 +162,13 @@ class FakeDefaultSource extends StreamSourceProvider { private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) override def sourceSchema( - sqlContext: SQLContext, + spark: SQLContext, schema: Option[StructType], providerName: String, parameters: Map[String, String]): (String, StructType) = ("fakeSource", fakeSchema) override def createSource( - sqlContext: SQLContext, + spark: SQLContext, metadataPath: String, schema: Option[StructType], providerName: String, @@ -190,7 +190,7 @@ class FakeDefaultSource extends StreamSourceProvider { override def getBatch(start: Option[Offset], end: Offset): DataFrame = { val startOffset = start.map(_.asInstanceOf[LongOffset].offset).getOrElse(-1L) + 1 - sqlContext.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a") + spark.range(startOffset, end.asInstanceOf[LongOffset].offset + 1).toDF("a") } } } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 7fa6760..03369c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -20,17 +20,17 @@ package org.apache.spark.sql.test import java.nio.charset.StandardCharsets import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} +import org.apache.spark.sql.{DataFrame, SparkSession, SQLContext, SQLImplicits} /** * A collection of sample data used in SQL tests. */ private[sql] trait SQLTestData { self => - protected def sqlContext: SQLContext + protected def spark: SparkSession // Helper object to import SQL implicits without a concrete SQLContext private object internalImplicits extends SQLImplicits { - protected override def _sqlContext: SQLContext = self.sqlContext + protected override def _sqlContext: SQLContext = self.spark.wrapped } import internalImplicits._ @@ -39,21 +39,21 @@ private[sql] trait SQLTestData { self => // Note: all test data should be lazy because the SQLContext is not set up yet. protected lazy val emptyTestData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( Seq.empty[Int].map(i => TestData(i, i.toString))).toDF() df.registerTempTable("emptyTestData") df } protected lazy val testData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() df.registerTempTable("testData") df } protected lazy val testData2: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( TestData2(1, 1) :: TestData2(1, 2) :: TestData2(2, 1) :: @@ -65,7 +65,7 @@ private[sql] trait SQLTestData { self => } protected lazy val testData3: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( TestData3(1, None) :: TestData3(2, Some(2)) :: Nil).toDF() df.registerTempTable("testData3") @@ -73,14 +73,14 @@ private[sql] trait SQLTestData { self => } protected lazy val negativeData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() df.registerTempTable("negativeData") df } protected lazy val largeAndSmallInts: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( LargeAndSmallInts(2147483644, 1) :: LargeAndSmallInts(1, 2) :: LargeAndSmallInts(2147483645, 1) :: @@ -92,7 +92,7 @@ private[sql] trait SQLTestData { self => } protected lazy val decimalData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( DecimalData(1, 1) :: DecimalData(1, 2) :: DecimalData(2, 1) :: @@ -104,7 +104,7 @@ private[sql] trait SQLTestData { self => } protected lazy val binaryData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( BinaryData("12".getBytes(StandardCharsets.UTF_8), 1) :: BinaryData("22".getBytes(StandardCharsets.UTF_8), 5) :: BinaryData("122".getBytes(StandardCharsets.UTF_8), 3) :: @@ -115,7 +115,7 @@ private[sql] trait SQLTestData { self => } protected lazy val upperCaseData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( UpperCaseData(1, "A") :: UpperCaseData(2, "B") :: UpperCaseData(3, "C") :: @@ -127,7 +127,7 @@ private[sql] trait SQLTestData { self => } protected lazy val lowerCaseData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( LowerCaseData(1, "a") :: LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: @@ -137,7 +137,7 @@ private[sql] trait SQLTestData { self => } protected lazy val arrayData: RDD[ArrayData] = { - val rdd = sqlContext.sparkContext.parallelize( + val rdd = spark.sparkContext.parallelize( ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) :: ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil) rdd.toDF().registerTempTable("arrayData") @@ -145,7 +145,7 @@ private[sql] trait SQLTestData { self => } protected lazy val mapData: RDD[MapData] = { - val rdd = sqlContext.sparkContext.parallelize( + val rdd = spark.sparkContext.parallelize( MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) :: MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) :: MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) :: @@ -156,13 +156,13 @@ private[sql] trait SQLTestData { self => } protected lazy val repeatedData: RDD[StringData] = { - val rdd = sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test"))) + val rdd = spark.sparkContext.parallelize(List.fill(2)(StringData("test"))) rdd.toDF().registerTempTable("repeatedData") rdd } protected lazy val nullableRepeatedData: RDD[StringData] = { - val rdd = sqlContext.sparkContext.parallelize( + val rdd = spark.sparkContext.parallelize( List.fill(2)(StringData(null)) ++ List.fill(2)(StringData("test"))) rdd.toDF().registerTempTable("nullableRepeatedData") @@ -170,7 +170,7 @@ private[sql] trait SQLTestData { self => } protected lazy val nullInts: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( NullInts(1) :: NullInts(2) :: NullInts(3) :: @@ -180,7 +180,7 @@ private[sql] trait SQLTestData { self => } protected lazy val allNulls: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( NullInts(null) :: NullInts(null) :: NullInts(null) :: @@ -190,7 +190,7 @@ private[sql] trait SQLTestData { self => } protected lazy val nullStrings: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( NullStrings(1, "abc") :: NullStrings(2, "ABC") :: NullStrings(3, null) :: Nil).toDF() @@ -199,13 +199,13 @@ private[sql] trait SQLTestData { self => } protected lazy val tableName: DataFrame = { - val df = sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF() + val df = spark.sparkContext.parallelize(TableName("test") :: Nil).toDF() df.registerTempTable("tableName") df } protected lazy val unparsedStrings: RDD[String] = { - sqlContext.sparkContext.parallelize( + spark.sparkContext.parallelize( "1, A1, true, null" :: "2, B2, false, null" :: "3, C3, true, null" :: @@ -214,13 +214,13 @@ private[sql] trait SQLTestData { self => // An RDD with 4 elements and 8 partitions protected lazy val withEmptyParts: RDD[IntField] = { - val rdd = sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8) + val rdd = spark.sparkContext.parallelize((1 to 4).map(IntField), 8) rdd.toDF().registerTempTable("withEmptyParts") rdd } protected lazy val person: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( Person(0, "mike", 30) :: Person(1, "jim", 20) :: Nil).toDF() df.registerTempTable("person") @@ -228,7 +228,7 @@ private[sql] trait SQLTestData { self => } protected lazy val salary: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( Salary(0, 2000.0) :: Salary(1, 1000.0) :: Nil).toDF() df.registerTempTable("salary") @@ -236,7 +236,7 @@ private[sql] trait SQLTestData { self => } protected lazy val complexData: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) :: Nil).toDF() @@ -245,7 +245,7 @@ private[sql] trait SQLTestData { self => } protected lazy val courseSales: DataFrame = { - val df = sqlContext.sparkContext.parallelize( + val df = spark.sparkContext.parallelize( CourseSales("dotNET", 2012, 10000) :: CourseSales("Java", 2012, 20000) :: CourseSales("dotNET", 2012, 5000) :: @@ -259,7 +259,7 @@ private[sql] trait SQLTestData { self => * Initialize all test data such that all temp tables are properly registered. */ def loadTestData(): Unit = { - assert(sqlContext != null, "attempted to initialize test data before SQLContext.") + assert(spark != null, "attempted to initialize test data before SparkSession.") emptyTestData testData testData2 http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 6d2b95e..a49a8c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -50,23 +50,23 @@ private[sql] trait SQLTestUtils with BeforeAndAfterAll with SQLTestData { self => - protected def sparkContext = sqlContext.sparkContext + protected def sparkContext = spark.sparkContext // Whether to materialize all test data before the first test is run private var loadTestDataBeforeTests = false // Shorthand for running a query using our SQLContext - protected lazy val sql = sqlContext.sql _ + protected lazy val sql = spark.sql _ /** * A helper object for importing SQL implicits. * - * Note that the alternative of importing `sqlContext.implicits._` is not possible here. + * Note that the alternative of importing `spark.implicits._` is not possible here. * This is because we create the [[SQLContext]] immediately before the first test is run, * but the implicits import is needed in the constructor. */ protected object testImplicits extends SQLImplicits { - protected override def _sqlContext: SQLContext = self.sqlContext + protected override def _sqlContext: SQLContext = self.spark.wrapped } /** @@ -92,12 +92,12 @@ private[sql] trait SQLTestUtils */ protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { val (keys, values) = pairs.unzip - val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) - (keys, values).zipped.foreach(sqlContext.conf.setConfString) + val currentValues = keys.map(key => Try(spark.conf.get(key)).toOption) + (keys, values).zipped.foreach(spark.conf.set) try f finally { keys.zip(currentValues).foreach { - case (key, Some(value)) => sqlContext.conf.setConfString(key, value) - case (key, None) => sqlContext.conf.unsetConf(key) + case (key, Some(value)) => spark.conf.set(key, value) + case (key, None) => spark.conf.unset(key) } } } @@ -138,9 +138,9 @@ private[sql] trait SQLTestUtils // temp tables that never got created. functions.foreach { case (functionName, isTemporary) => val withTemporary = if (isTemporary) "TEMPORARY" else "" - sqlContext.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") + spark.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") assert( - !sqlContext.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), + !spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), s"Function $functionName should have been dropped. But, it still exists.") } } @@ -153,7 +153,7 @@ private[sql] trait SQLTestUtils try f finally { // If the test failed part way, we don't want to mask the failure by failing to remove // temp tables that never got created. - try tableNames.foreach(sqlContext.dropTempTable) catch { + try tableNames.foreach(spark.catalog.dropTempTable) catch { case _: NoSuchTableException => } } @@ -165,7 +165,7 @@ private[sql] trait SQLTestUtils protected def withTable(tableNames: String*)(f: => Unit): Unit = { try f finally { tableNames.foreach { name => - sqlContext.sql(s"DROP TABLE IF EXISTS $name") + spark.sql(s"DROP TABLE IF EXISTS $name") } } } @@ -176,7 +176,7 @@ private[sql] trait SQLTestUtils protected def withView(viewNames: String*)(f: => Unit): Unit = { try f finally { viewNames.foreach { name => - sqlContext.sql(s"DROP VIEW IF EXISTS $name") + spark.sql(s"DROP VIEW IF EXISTS $name") } } } @@ -191,12 +191,12 @@ private[sql] trait SQLTestUtils val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}" try { - sqlContext.sql(s"CREATE DATABASE $dbName") + spark.sql(s"CREATE DATABASE $dbName") } catch { case cause: Throwable => fail("Failed to create temporary database", cause) } - try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE") + try f(dbName) finally spark.sql(s"DROP DATABASE $dbName CASCADE") } /** @@ -204,8 +204,8 @@ private[sql] trait SQLTestUtils * `f` returns. */ protected def activateDatabase(db: String)(f: => Unit): Unit = { - sqlContext.sessionState.catalog.setCurrentDatabase(db) - try f finally sqlContext.sessionState.catalog.setCurrentDatabase("default") + spark.sessionState.catalog.setCurrentDatabase(db) + try f finally spark.sessionState.catalog.setCurrentDatabase("default") } /** @@ -221,7 +221,7 @@ private[sql] trait SQLTestUtils .execute() .map(row => Row.fromSeq(row.copy().toSeq(schema))) - sqlContext.createDataFrame(childRDD, schema) + spark.createDataFrame(childRDD, schema) } /** @@ -229,7 +229,7 @@ private[sql] trait SQLTestUtils * way to construct [[DataFrame]] directly out of local data without relying on implicits. */ protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { - Dataset.ofRows(sqlContext.sparkSession, plan) + Dataset.ofRows(spark, plan) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index 914c6a5..620bfa9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,37 +17,42 @@ package org.apache.spark.sql.test -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.SQLContext +import org.apache.spark.SparkConf +import org.apache.spark.sql.{SparkSession, SQLContext} /** - * Helper trait for SQL test suites where all tests share a single [[TestSQLContext]]. + * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. */ trait SharedSQLContext extends SQLTestUtils { protected val sparkConf = new SparkConf() /** - * The [[TestSQLContext]] to use for all tests in this suite. + * The [[TestSparkSession]] to use for all tests in this suite. * * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local * mode with the default test configurations. */ - private var _ctx: TestSQLContext = null + private var _spark: TestSparkSession = null + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + */ + protected implicit def spark: SparkSession = _spark /** * The [[TestSQLContext]] to use for all tests in this suite. */ - protected implicit def sqlContext: SQLContext = _ctx + protected implicit def sqlContext: SQLContext = _spark.wrapped /** - * Initialize the [[TestSQLContext]]. + * Initialize the [[TestSparkSession]]. */ protected override def beforeAll(): Unit = { SQLContext.clearSqlListener() - if (_ctx == null) { - _ctx = new TestSQLContext(sparkConf) + if (_spark == null) { + _spark = new TestSparkSession(sparkConf) } // Ensure we have initialized the context before calling parent code super.beforeAll() @@ -58,9 +63,9 @@ trait SharedSQLContext extends SQLTestUtils { */ protected override def afterAll(): Unit = { try { - if (_ctx != null) { - _ctx.sparkContext.stop() - _ctx = null + if (_spark != null) { + _spark.stop() + _spark = null } } finally { super.afterAll() http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 5ef80b9..785e345 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -18,44 +18,32 @@ package org.apache.spark.sql.test import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.internal.{SessionState, SQLConf} /** - * A special [[SQLContext]] prepared for testing. + * A special [[SparkSession]] prepared for testing. */ -private[sql] class TestSQLContext( - @transient override val sparkSession: SparkSession, - isRootContext: Boolean) - extends SQLContext(sparkSession, isRootContext) { self => - - def this(sc: SparkContext) { - this(new TestSparkSession(sc), true) - } - +private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self => def this(sparkConf: SparkConf) { this(new SparkContext("local[2]", "test-sql-context", sparkConf.set("spark.sql.testkey", "true"))) } def this() { - this(new SparkConf) - } - - // Needed for Java tests - def loadTestData(): Unit = { - testData.loadTestData() - } - - private object testData extends SQLTestData { - protected override def sqlContext: SQLContext = self + this { + val conf = new SparkConf() + conf.set("spark.sql.testkey", "true") + + val spark = SparkSession.builder + .master("local[2]") + .appName("test-sql-context") + .config(conf) + .getOrCreate() + spark.sparkContext + } } -} - - -private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self => - @transient protected[sql] override lazy val sessionState: SessionState = new SessionState(self) { override lazy val conf: SQLConf = { @@ -70,6 +58,14 @@ private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { } } + // Needed for Java tests + def loadTestData(): Unit = { + testData.loadTestData() + } + + private object testData extends SQLTestData { + protected override def spark: SparkSession = self + } } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala index 54acd4d..8788898 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/ContinuousQueryListenerSuite.scala @@ -36,11 +36,11 @@ class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with import testImplicits._ after { - sqlContext.streams.active.foreach(_.stop()) - assert(sqlContext.streams.active.isEmpty) + spark.streams.active.foreach(_.stop()) + assert(spark.streams.active.isEmpty) assert(addedListeners.isEmpty) // Make sure we don't leak any events to the next test - sqlContext.sparkContext.listenerBus.waitUntilEmpty(10000) + spark.sparkContext.listenerBus.waitUntilEmpty(10000) } test("single listener") { @@ -112,17 +112,17 @@ class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with val listener1 = new QueryStatusCollector val listener2 = new QueryStatusCollector - sqlContext.streams.addListener(listener1) + spark.streams.addListener(listener1) assert(isListenerActive(listener1) === true) assert(isListenerActive(listener2) === false) - sqlContext.streams.addListener(listener2) + spark.streams.addListener(listener2) assert(isListenerActive(listener1) === true) assert(isListenerActive(listener2) === true) - sqlContext.streams.removeListener(listener1) + spark.streams.removeListener(listener1) assert(isListenerActive(listener1) === false) assert(isListenerActive(listener2) === true) } finally { - addedListeners.foreach(sqlContext.streams.removeListener) + addedListeners.foreach(spark.streams.removeListener) } } @@ -146,18 +146,18 @@ class ContinuousQueryListenerSuite extends StreamTest with SharedSQLContext with private def withListenerAdded(listener: ContinuousQueryListener)(body: => Unit): Unit = { try { failAfter(1 minute) { - sqlContext.streams.addListener(listener) + spark.streams.addListener(listener) body } } finally { - sqlContext.streams.removeListener(listener) + spark.streams.removeListener(listener) } } private def addedListeners(): Array[ContinuousQueryListener] = { val listenerBusMethod = PrivateMethod[ContinuousQueryListenerBus]('listenerBus) - val listenerBus = sqlContext.streams invokePrivate listenerBusMethod() + val listenerBus = spark.streams invokePrivate listenerBusMethod() listenerBus.listeners.toArray.map(_.asInstanceOf[ContinuousQueryListener]) } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 8a0578c..3ae5ce6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -39,7 +39,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { metrics += ((funcName, qe, duration)) } } - sqlContext.listenerManager.register(listener) + spark.listenerManager.register(listener) val df = Seq(1 -> "a").toDF("i", "j") df.select("i").collect() @@ -55,7 +55,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(metrics(1)._2.analyzed.isInstanceOf[Aggregate]) assert(metrics(1)._3 > 0) - sqlContext.listenerManager.unregister(listener) + spark.listenerManager.unregister(listener) } test("execute callback functions when a DataFrame action failed") { @@ -68,7 +68,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { // Only test failed case here, so no need to implement `onSuccess` override def onSuccess(funcName: String, qe: QueryExecution, duration: Long): Unit = {} } - sqlContext.listenerManager.register(listener) + spark.listenerManager.register(listener) val errorUdf = udf[Int, Int] { _ => throw new RuntimeException("udf error") } val df = sparkContext.makeRDD(Seq(1 -> "a")).toDF("i", "j") @@ -82,7 +82,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(metrics(0)._2.analyzed.isInstanceOf[Project]) assert(metrics(0)._3.getMessage == e.getMessage) - sqlContext.listenerManager.unregister(listener) + spark.listenerManager.unregister(listener) } test("get numRows metrics by callback") { @@ -99,7 +99,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { metrics += metric.value } } - sqlContext.listenerManager.register(listener) + spark.listenerManager.register(listener) val df = Seq(1 -> "a").toDF("i", "j").groupBy("i").count() df.collect() @@ -111,7 +111,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(metrics(1) === 1) assert(metrics(2) === 2) - sqlContext.listenerManager.unregister(listener) + spark.listenerManager.unregister(listener) } // TODO: Currently some LongSQLMetric use -1 as initial value, so if the accumulator is never @@ -131,10 +131,10 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { metrics += bottomAgg.longMetric("dataSize").value } } - sqlContext.listenerManager.register(listener) + spark.listenerManager.register(listener) val sparkListener = new SaveInfoListener - sqlContext.sparkContext.addSparkListener(sparkListener) + spark.sparkContext.addSparkListener(sparkListener) val df = (1 to 100).map(i => i -> i.toString).toDF("i", "j") df.groupBy("i").count().collect() @@ -157,6 +157,6 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { assert(metrics(0) == topAggDataSize) assert(metrics(1) == bottomAggDataSize) - sqlContext.listenerManager.unregister(listener) + spark.listenerManager.unregister(listener) } } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala index 154ada3..9bf84ab 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/test/TestHiveSingleton.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.hive.test import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.SparkSession import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.SQLContext trait TestHiveSingleton extends SparkFunSuite with BeforeAndAfterAll { - protected val sqlContext: SQLContext = TestHive + protected val spark: SparkSession = TestHive.sparkSession protected val hiveContext: TestHiveContext = TestHive protected override def afterAll(): Unit = { http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala index a7782ab..72736ee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala @@ -34,13 +34,13 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { val bytes = Array[Byte](1, 2, 3, 4) Seq((bytes, "AQIDBA==")).toDF("a", "b").write.saveAsTable("t0") - sqlContext + spark .range(10) .select('id as 'key, concat(lit("val_"), 'id) as 'value) .write .saveAsTable("t1") - sqlContext.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2") + spark.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2") } override protected def afterAll(): Unit = { http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala index 34c2773..9abefa5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala @@ -33,16 +33,16 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { sql("DROP TABLE IF EXISTS parquet_t2") sql("DROP TABLE IF EXISTS t0") - sqlContext.range(10).write.saveAsTable("parquet_t0") + spark.range(10).write.saveAsTable("parquet_t0") sql("CREATE TABLE t0 AS SELECT * FROM parquet_t0") - sqlContext + spark .range(10) .select('id as 'key, concat(lit("val_"), 'id) as 'value) .write .saveAsTable("parquet_t1") - sqlContext + spark .range(10) .select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd) .write @@ -52,7 +52,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { when(id % 3 === 0, lit(null)).otherwise(array('id, 'id + 1)) } - sqlContext + spark .range(10) .select( createArray('id).as("arr"), @@ -394,7 +394,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { Seq("orc", "json", "parquet").foreach { format => val tableName = s"${format}_parquet_t0" withTable(tableName) { - sqlContext.range(10).write.format(format).saveAsTable(tableName) + spark.range(10).write.format(format).saveAsTable(tableName) checkHiveQl(s"SELECT id FROM $tableName") } } @@ -458,7 +458,7 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { } test("plans with non-SQL expressions") { - sqlContext.udf.register("foo", (_: Int) * 2) + spark.udf.register("foo", (_: Int) * 2) intercept[UnsupportedOperationException](new SQLBuilder(sql("SELECT foo(id) FROM t0")).toSQL) } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala index 27c9e99..31755f5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala @@ -64,7 +64,7 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { """.stripMargin) } - checkAnswer(sqlContext.sql(generatedSQL), Dataset.ofRows(sqlContext.sparkSession, plan)) + checkAnswer(spark.sql(generatedSQL), Dataset.ofRows(spark, plan)) } protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = { http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index 61910b8..093cd3a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -30,8 +30,8 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd override protected def beforeEach(): Unit = { super.beforeEach() - if (sqlContext.tableNames().contains("src")) { - sqlContext.dropTempTable("src") + if (spark.wrapped.tableNames().contains("src")) { + spark.catalog.dropTempTable("src") } Seq((1, "")).toDF("key", "value").registerTempTable("src") Seq((1, 1, 1)).toDF("a", "a", "b").registerTempTable("dupAttributes") @@ -39,8 +39,8 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd override protected def afterEach(): Unit = { try { - sqlContext.dropTempTable("src") - sqlContext.dropTempTable("dupAttributes") + spark.catalog.dropTempTable("src") + spark.catalog.dropTempTable("dupAttributes") } finally { super.afterEach() } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index a717a99..bfe559f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -555,7 +555,7 @@ object SparkSQLConfTest extends Logging { object SPARK_9757 extends QueryTest { import org.apache.spark.sql.functions._ - protected var sqlContext: SQLContext = _ + protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { Utils.configTestLog4j("INFO") @@ -567,7 +567,7 @@ object SPARK_9757 extends QueryTest { .set("spark.ui.enabled", "false")) val hiveContext = new TestHiveContext(sparkContext) - sqlContext = hiveContext + spark = hiveContext.sparkSession import hiveContext.implicits._ val dir = Utils.createTempDir() @@ -602,7 +602,7 @@ object SPARK_9757 extends QueryTest { object SPARK_11009 extends QueryTest { import org.apache.spark.sql.functions._ - protected var sqlContext: SQLContext = _ + protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { Utils.configTestLog4j("INFO") @@ -613,10 +613,10 @@ object SPARK_11009 extends QueryTest { .set("spark.sql.shuffle.partitions", "100")) val hiveContext = new TestHiveContext(sparkContext) - sqlContext = hiveContext + spark = hiveContext.sparkSession try { - val df = sqlContext.range(1 << 20) + val df = spark.range(1 << 20) val df2 = df.select((df("id") % 1000).alias("A"), (df("id") / 1000).alias("B")) val ws = Window.partitionBy(df2("A")).orderBy(df2("B")) val df3 = df2.select(df2("A"), df2("B"), row_number().over(ws).alias("rn")).filter("rn < 0") @@ -633,7 +633,7 @@ object SPARK_14244 extends QueryTest { import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ - protected var sqlContext: SQLContext = _ + protected var spark: SparkSession = _ def main(args: Array[String]): Unit = { Utils.configTestLog4j("INFO") @@ -644,13 +644,13 @@ object SPARK_14244 extends QueryTest { .set("spark.sql.shuffle.partitions", "100")) val hiveContext = new TestHiveContext(sparkContext) - sqlContext = hiveContext + spark = hiveContext.sparkSession import hiveContext.implicits._ try { val window = Window.orderBy('id) - val df = sqlContext.range(2).select(cume_dist().over(window).as('cdist)).orderBy('cdist) + val df = spark.range(2).select(cume_dist().over(window).as('cdist)).orderBy('cdist) checkAnswer(df, Seq(Row(0.5D), Row(1.0D))) } finally { sparkContext.stop() http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 52aba32..82d3e49 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -251,7 +251,7 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") // this will pick up the output partitioning from the table definition - sqlContext.table("source").write.insertInto("partitioned") + spark.table("source").write.insertInto("partitioned") checkAnswer(sql("SELECT * FROM partitioned"), data.collect().toSeq) } @@ -272,7 +272,7 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql( """CREATE TABLE partitioned (id bigint, data string) |PARTITIONED BY (part1 string, part2 string)""".stripMargin) - sqlContext.table("source").write.insertInto("partitioned") + spark.table("source").write.insertInto("partitioned") checkAnswer(sql("SELECT * FROM partitioned"), expected.collect().toSeq) } @@ -283,7 +283,7 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") val data = (1 to 10).map(i => (i.toLong, s"data-$i")).toDF("id", "data") - val logical = InsertIntoTable(sqlContext.table("partitioned").logicalPlan, + val logical = InsertIntoTable(spark.table("partitioned").logicalPlan, Map("part" -> None), data.logicalPlan, overwrite = false, ifNotExists = false) assert(!logical.resolved, "Should not resolve: missing partition data") } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 78c8f00..b2a80e7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -374,7 +374,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv val expectedPath = sessionState.catalog.hiveDefaultTableFilePath(TableIdentifier("ctasJsonTable")) val filesystemPath = new Path(expectedPath) - val fs = filesystemPath.getFileSystem(sqlContext.sessionState.newHadoopConf()) + val fs = filesystemPath.getFileSystem(spark.sessionState.newHadoopConf()) if (fs.exists(filesystemPath)) fs.delete(filesystemPath, true) // It is a managed table when we do not specify the location. @@ -701,7 +701,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv // Manually create a metastore data source table. CreateDataSourceTableUtils.createDataSourceTable( - sparkSession = sqlContext.sparkSession, + sparkSession = spark, tableIdent = TableIdentifier("wide_schema"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], @@ -891,18 +891,18 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv test("SPARK-8156:create table to specific database by 'use dbname' ") { val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") - sqlContext.sql("""create database if not exists testdb8156""") - sqlContext.sql("""use testdb8156""") + spark.sql("""create database if not exists testdb8156""") + spark.sql("""use testdb8156""") df.write .format("parquet") .mode(SaveMode.Overwrite) .saveAsTable("ttt3") checkAnswer( - sqlContext.sql("show TABLES in testdb8156").filter("tableName = 'ttt3'"), + spark.sql("show TABLES in testdb8156").filter("tableName = 'ttt3'"), Row("ttt3", false)) - sqlContext.sql("""use default""") - sqlContext.sql("""drop database if exists testdb8156 CASCADE""") + spark.sql("""use default""") + spark.sql("""drop database if exists testdb8156 CASCADE""") } @@ -911,7 +911,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv val schema = StructType((1 to 5).map(i => StructField(s"c_$i", StringType))) CreateDataSourceTableUtils.createDataSourceTable( - sparkSession = sqlContext.sparkSession, + sparkSession = spark, tableIdent = TableIdentifier("not_skip_hive_metadata"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], @@ -926,7 +926,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv .forall(column => CatalystSqlParser.parseDataType(column.dataType) == StringType)) CreateDataSourceTableUtils.createDataSourceTable( - sparkSession = sqlContext.sparkSession, + sparkSession = spark, tableIdent = TableIdentifier("skip_hive_metadata"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 850cb1e..6c9ce20 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { - private lazy val df = sqlContext.range(10).coalesce(1).toDF() + private lazy val df = spark.range(10).coalesce(1).toDF() private def checkTablePath(dbName: String, tableName: String): Unit = { val metastoreTable = hiveContext.sharedState.externalCatalog.getTable(dbName, tableName) @@ -36,12 +36,12 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") - assert(sqlContext.tableNames().contains("t")) - checkAnswer(sqlContext.table("t"), df) + assert(spark.wrapped.tableNames().contains("t")) + checkAnswer(spark.table("t"), df) } - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df) + assert(spark.wrapped.tableNames(db).contains("t")) + checkAnswer(spark.table(s"$db.t"), df) checkTablePath(db, "t") } @@ -50,8 +50,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle test(s"saveAsTable() to non-default database - without USE - Overwrite") { withTempDatabase { db => df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df) + assert(spark.wrapped.tableNames(db).contains("t")) + checkAnswer(spark.table(s"$db.t"), df) checkTablePath(db, "t") } @@ -64,9 +64,9 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle val path = dir.getCanonicalPath df.write.format("parquet").mode(SaveMode.Overwrite).save(path) - sqlContext.createExternalTable("t", path, "parquet") - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table("t"), df) + spark.catalog.createExternalTable("t", path, "parquet") + assert(spark.wrapped.tableNames(db).contains("t")) + checkAnswer(spark.table("t"), df) sql( s""" @@ -76,8 +76,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle | path '$path' |) """.stripMargin) - assert(sqlContext.tableNames(db).contains("t1")) - checkAnswer(sqlContext.table("t1"), df) + assert(spark.wrapped.tableNames(db).contains("t1")) + checkAnswer(spark.table("t1"), df) } } } @@ -88,10 +88,10 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempPath { dir => val path = dir.getCanonicalPath df.write.format("parquet").mode(SaveMode.Overwrite).save(path) - sqlContext.createExternalTable(s"$db.t", path, "parquet") + spark.catalog.createExternalTable(s"$db.t", path, "parquet") - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df) + assert(spark.wrapped.tableNames(db).contains("t")) + checkAnswer(spark.table(s"$db.t"), df) sql( s""" @@ -101,8 +101,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle | path '$path' |) """.stripMargin) - assert(sqlContext.tableNames(db).contains("t1")) - checkAnswer(sqlContext.table(s"$db.t1"), df) + assert(spark.wrapped.tableNames(db).contains("t1")) + checkAnswer(spark.table(s"$db.t1"), df) } } } @@ -112,12 +112,12 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") df.write.mode(SaveMode.Append).saveAsTable("t") - assert(sqlContext.tableNames().contains("t")) - checkAnswer(sqlContext.table("t"), df.union(df)) + assert(spark.wrapped.tableNames().contains("t")) + checkAnswer(spark.table("t"), df.union(df)) } - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df.union(df)) + assert(spark.wrapped.tableNames(db).contains("t")) + checkAnswer(spark.table(s"$db.t"), df.union(df)) checkTablePath(db, "t") } @@ -127,8 +127,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => df.write.mode(SaveMode.Overwrite).saveAsTable(s"$db.t") df.write.mode(SaveMode.Append).saveAsTable(s"$db.t") - assert(sqlContext.tableNames(db).contains("t")) - checkAnswer(sqlContext.table(s"$db.t"), df.union(df)) + assert(spark.wrapped.tableNames(db).contains("t")) + checkAnswer(spark.table(s"$db.t"), df.union(df)) checkTablePath(db, "t") } @@ -138,10 +138,10 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") - assert(sqlContext.tableNames().contains("t")) + assert(spark.wrapped.tableNames().contains("t")) df.write.insertInto(s"$db.t") - checkAnswer(sqlContext.table(s"$db.t"), df.union(df)) + checkAnswer(spark.table(s"$db.t"), df.union(df)) } } } @@ -150,13 +150,13 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => activateDatabase(db) { df.write.mode(SaveMode.Overwrite).saveAsTable("t") - assert(sqlContext.tableNames().contains("t")) + assert(spark.wrapped.tableNames().contains("t")) } - assert(sqlContext.tableNames(db).contains("t")) + assert(spark.wrapped.tableNames(db).contains("t")) df.write.insertInto(s"$db.t") - checkAnswer(sqlContext.table(s"$db.t"), df.union(df)) + checkAnswer(spark.table(s"$db.t"), df.union(df)) } } @@ -164,10 +164,10 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => activateDatabase(db) { sql("CREATE TABLE t (key INT)") - checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame) + checkAnswer(spark.table("t"), spark.emptyDataFrame) } - checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame) + checkAnswer(spark.table(s"$db.t"), spark.emptyDataFrame) } } @@ -175,21 +175,21 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle withTempDatabase { db => activateDatabase(db) { sql(s"CREATE TABLE t (key INT)") - assert(sqlContext.tableNames().contains("t")) - assert(!sqlContext.tableNames("default").contains("t")) + assert(spark.wrapped.tableNames().contains("t")) + assert(!spark.wrapped.tableNames("default").contains("t")) } - assert(!sqlContext.tableNames().contains("t")) - assert(sqlContext.tableNames(db).contains("t")) + assert(!spark.wrapped.tableNames().contains("t")) + assert(spark.wrapped.tableNames(db).contains("t")) activateDatabase(db) { sql(s"DROP TABLE t") - assert(!sqlContext.tableNames().contains("t")) - assert(!sqlContext.tableNames("default").contains("t")) + assert(!spark.wrapped.tableNames().contains("t")) + assert(!spark.wrapped.tableNames("default").contains("t")) } - assert(!sqlContext.tableNames().contains("t")) - assert(!sqlContext.tableNames(db).contains("t")) + assert(!spark.wrapped.tableNames().contains("t")) + assert(!spark.wrapped.tableNames(db).contains("t")) } } @@ -208,18 +208,18 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle |LOCATION '$path' """.stripMargin) - checkAnswer(sqlContext.table("t"), sqlContext.emptyDataFrame) + checkAnswer(spark.table("t"), spark.emptyDataFrame) df.write.parquet(s"$path/p=1") sql("ALTER TABLE t ADD PARTITION (p=1)") sql("REFRESH TABLE t") - checkAnswer(sqlContext.table("t"), df.withColumn("p", lit(1))) + checkAnswer(spark.table("t"), df.withColumn("p", lit(1))) df.write.parquet(s"$path/p=2") sql("ALTER TABLE t ADD PARTITION (p=2)") hiveContext.sessionState.refreshTable("t") checkAnswer( - sqlContext.table("t"), + spark.table("t"), df.withColumn("p", lit(1)).union(df.withColumn("p", lit(2)))) } } @@ -240,18 +240,18 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle |LOCATION '$path' """.stripMargin) - checkAnswer(sqlContext.table(s"$db.t"), sqlContext.emptyDataFrame) + checkAnswer(spark.table(s"$db.t"), spark.emptyDataFrame) df.write.parquet(s"$path/p=1") sql(s"ALTER TABLE $db.t ADD PARTITION (p=1)") sql(s"REFRESH TABLE $db.t") - checkAnswer(sqlContext.table(s"$db.t"), df.withColumn("p", lit(1))) + checkAnswer(spark.table(s"$db.t"), df.withColumn("p", lit(1))) df.write.parquet(s"$path/p=2") sql(s"ALTER TABLE $db.t ADD PARTITION (p=2)") hiveContext.sessionState.refreshTable(s"$db.t") checkAnswer( - sqlContext.table(s"$db.t"), + spark.table(s"$db.t"), df.withColumn("p", lit(1)).union(df.withColumn("p", lit(2)))) } } http://git-wip-us.apache.org/repos/asf/spark/blob/ed0b4070/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index af4dc1b..3f6418c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -70,12 +70,12 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHi |$ddl """.stripMargin) - sqlContext.sql(ddl) + spark.sql(ddl) - val schema = sqlContext.table("parquet_compat").schema - val rowRDD = sqlContext.sparkContext.parallelize(rows).coalesce(1) - sqlContext.createDataFrame(rowRDD, schema).registerTempTable("data") - sqlContext.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") + val schema = spark.table("parquet_compat").schema + val rowRDD = spark.sparkContext.parallelize(rows).coalesce(1) + spark.createDataFrame(rowRDD, schema).registerTempTable("data") + spark.sql("INSERT INTO TABLE parquet_compat SELECT * FROM data") } } @@ -84,7 +84,7 @@ class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHi // Unfortunately parquet-hive doesn't add `UTF8` annotation to BINARY when writing strings. // Have to assume all BINARY values are strings here. withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING.key -> "true") { - checkAnswer(sqlContext.read.parquet(path), rows) + checkAnswer(spark.read.parquet(path), rows) } } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
