This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 51b06ac4356 [SPARK-42043][CONNECT] Scala Client Result with E2E Tests
51b06ac4356 is described below

commit 51b06ac43563613af048113219698fef3f40eb79
Author: Zhen Li <[email protected]>
AuthorDate: Thu Jan 19 18:00:28 2023 -0400

    [SPARK-42043][CONNECT] Scala Client Result with E2E Tests
    
    ### What changes were proposed in this pull request?
    
    Added the implementation for the Spark Connect Scala Client Result.
    Added minimal Dataset and SparkSession implementation and E2E test to 
verify the Result can be received correctly.
    
    ### Why are the changes needed?
    
    Provides a minimal Scala Client to run queries and receive results using 
`spark.sql`
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    E2E tests
    
    Closes #39541 from zhenlineo/client-result.
    
    Authored-by: Zhen Li <[email protected]>
    Signed-off-by: Herman van Hovell <[email protected]>
---
 connector/connect/client/jvm/pom.xml               |   6 +
 .../main/scala/org/apache/spark/sql/Dataset.scala  |  24 +++
 .../scala/org/apache/spark/sql/SparkSession.scala  | 103 +++++++++++
 .../sql/connect/client/SparkConnectClient.scala    |  12 +-
 .../spark/sql/connect/client/SparkResult.scala     | 170 ++++++++++++++++++
 .../spark/sql/connect/client/util/Cleaner.scala    | 113 ++++++++++++
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  |  43 +++++
 .../connect/client/SparkConnectClientSuite.scala   |   2 +-
 .../connect/client/util/RemoteSparkSession.scala   | 198 +++++++++++++++++++++
 9 files changed, 669 insertions(+), 2 deletions(-)

diff --git a/connector/connect/client/jvm/pom.xml 
b/connector/connect/client/jvm/pom.xml
index 29a00a71cf5..ba3ee80f7de 100644
--- a/connector/connect/client/jvm/pom.xml
+++ b/connector/connect/client/jvm/pom.xml
@@ -47,6 +47,12 @@
         </exclusion>
       </exclusions>
     </dependency>
+    <!--TODO: fix the dependency once the catalyst refactoring is done-->
+    <dependency>
+      <groupId>org.apache.spark</groupId>
+      <artifactId>spark-catalyst_${scala.binary.version}</artifactId>
+      <version>${project.version}</version>
+    </dependency>
     <dependency>
       <groupId>com.google.protobuf</groupId>
       <artifactId>protobuf-java</artifactId>
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
new file mode 100644
index 00000000000..f7ed764a11e
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -0,0 +1,24 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.client.SparkResult
+
+class Dataset(val session: SparkSession, private[sql] val plan: proto.Plan) {
+  def collectResult(): SparkResult = session.execute(plan)
+}
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
new file mode 100644
index 00000000000..21f4ebd75db
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import org.apache.arrow.memory.RootAllocator
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.connect.client.{SparkConnectClient, SparkResult}
+import org.apache.spark.sql.connect.client.util.Cleaner
+
+/**
+ * The entry point to programming Spark with the Dataset and DataFrame API.
+ *
+ * In environments that this has been created upfront (e.g. REPL, notebooks), 
use the builder to
+ * get an existing session:
+ *
+ * {{{
+ *   SparkSession.builder().getOrCreate()
+ * }}}
+ *
+ * The builder can also be used to create a new session:
+ *
+ * {{{
+ *   SparkSession.builder
+ *     .master("local")
+ *     .appName("Word Count")
+ *     .config("spark.some.config.option", "some-value")
+ *     .getOrCreate()
+ * }}}
+ */
+class SparkSession(private val client: SparkConnectClient, private val 
cleaner: Cleaner)
+    extends AutoCloseable {
+
+  private[this] val allocator = new RootAllocator()
+
+  /**
+   * Executes a SQL query using Spark, returning the result as a `DataFrame`. 
This API eagerly
+   * runs DDL/DML commands, but not for SELECT queries.
+   *
+   * @since 3.4.0
+   */
+  def sql(query: String): Dataset = newDataset { builder =>
+    builder.setSql(proto.SQL.newBuilder().setQuery(query))
+  }
+
+  private[sql] def newDataset(f: proto.Relation.Builder => Unit): Dataset = {
+    val builder = proto.Relation.newBuilder()
+    f(builder)
+    val plan = proto.Plan.newBuilder().setRoot(builder).build()
+    new Dataset(this, plan)
+  }
+
+  private[sql] def execute(plan: proto.Plan): SparkResult = {
+    val value = client.execute(plan)
+    val result = new SparkResult(value, allocator)
+    cleaner.register(result)
+    result
+  }
+
+  override def close(): Unit = {
+    client.shutdown()
+    allocator.close()
+  }
+}
+
+// The minimal builder needed to create a spark session.
+// TODO: implements all methods mentioned in the scaladoc of [[SparkSession]]
+object SparkSession {
+  def builder(): Builder = new Builder()
+
+  private lazy val cleaner = {
+    val cleaner = new Cleaner
+    cleaner.start()
+    cleaner
+  }
+
+  class Builder() {
+    private var _client = SparkConnectClient.builder().build()
+
+    def client(client: SparkConnectClient): Builder = {
+      _client = client
+      this
+    }
+
+    def build(): SparkSession = {
+      new SparkSession(_client, cleaner)
+    }
+  }
+}
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index cdae9f0ceea..87682fdd700 100644
--- 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -49,6 +49,15 @@ class SparkConnectClient(
   def analyze(request: proto.AnalyzePlanRequest): proto.AnalyzePlanResponse =
     stub.analyzePlan(request)
 
+  def execute(plan: proto.Plan): java.util.Iterator[proto.ExecutePlanResponse] 
= {
+    val request = proto.ExecutePlanRequest
+      .newBuilder()
+      .setPlan(plan)
+      .setUserContext(userContext)
+      .build()
+    stub.executePlan(request)
+  }
+
   /**
    * Shutdown the client's connection to the server.
    */
@@ -158,7 +167,8 @@ object SparkConnectClient {
 
     def build(): SparkConnectClient = {
       val channelBuilder = ManagedChannelBuilder.forAddress(host, 
port).usePlaintext()
-      new SparkConnectClient(userContextBuilder.build(), 
channelBuilder.build())
+      val channel: ManagedChannel = channelBuilder.build()
+      new SparkConnectClient(userContextBuilder.build(), channel)
     }
   }
 }
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
new file mode 100644
index 00000000000..317c20cad3e
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkResult.scala
@@ -0,0 +1,170 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client
+
+import java.util.Collections
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.apache.arrow.memory.BufferAllocator
+import org.apache.arrow.vector.FieldVector
+import org.apache.arrow.vector.ipc.ArrowStreamReader
+
+import org.apache.spark.connect.proto
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder.Deserializer
+import org.apache.spark.sql.connect.client.util.{AutoCloseables, Cleanable}
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, 
ColumnVector}
+
+private[sql] class SparkResult(
+    responses: java.util.Iterator[proto.ExecutePlanResponse],
+    allocator: BufferAllocator)
+    extends AutoCloseable
+    with Cleanable {
+
+  private[this] var numRecords: Int = 0
+  private[this] var structType: StructType = _
+  private[this] var encoder: ExpressionEncoder[Row] = _
+  private[this] val batches = mutable.Buffer.empty[ColumnarBatch]
+
+  private def processResponses(stopOnFirstNonEmptyResponse: Boolean): Boolean 
= {
+    while (responses.hasNext) {
+      val response = responses.next()
+      if (response.hasArrowBatch) {
+        val ipcStreamBytes = response.getArrowBatch.getData
+        val reader = new ArrowStreamReader(ipcStreamBytes.newInput(), 
allocator)
+        try {
+          val root = reader.getVectorSchemaRoot
+          if (batches.isEmpty) {
+            structType = ArrowUtils.fromArrowSchema(root.getSchema)
+            // TODO: create encoders that directly operate on arrow vectors.
+            encoder = 
RowEncoder(structType).resolveAndBind(structType.toAttributes)
+          }
+          while (reader.loadNextBatch()) {
+            val rowCount = root.getRowCount
+            assert(root.getRowCount == response.getArrowBatch.getRowCount) // 
HUH!
+            if (rowCount > 0) {
+              val vectors = root.getFieldVectors.asScala
+                .map(v => new ArrowColumnVector(transferToNewVector(v)))
+                .toArray[ColumnVector]
+              batches += new ColumnarBatch(vectors, rowCount)
+              numRecords += rowCount
+              if (stopOnFirstNonEmptyResponse) {
+                return true
+              }
+            }
+          }
+        } finally {
+          reader.close()
+        }
+      }
+    }
+    false
+  }
+
+  private def transferToNewVector(in: FieldVector): FieldVector = {
+    val pair = in.getTransferPair(allocator)
+    pair.transfer()
+    pair.getTo.asInstanceOf[FieldVector]
+  }
+
+  /**
+   * Returns the number of elements in the result.
+   */
+  def length: Int = {
+    // We need to process all responses to make sure numRecords is correct.
+    processResponses(stopOnFirstNonEmptyResponse = false)
+    numRecords
+  }
+
+  /**
+   * @return
+   *   the schema of the result.
+   */
+  def schema: StructType = {
+    processResponses(stopOnFirstNonEmptyResponse = true)
+    structType
+  }
+
+  /**
+   * Create an Array with the contents of the result.
+   */
+  def toArray: Array[Row] = {
+    val result = new Array[Row](length)
+    val rows = iterator
+    var i = 0
+    while (rows.hasNext) {
+      result(i) = rows.next()
+      assert(i < numRecords)
+      i += 1
+    }
+    result
+  }
+
+  /**
+   * Returns an iterator over the contents of the result.
+   */
+  def iterator: java.util.Iterator[Row] with AutoCloseable = {
+    new java.util.Iterator[Row] with AutoCloseable {
+      private[this] var batchIndex: Int = -1
+      private[this] var iterator: java.util.Iterator[InternalRow] = 
Collections.emptyIterator()
+      private[this] var deserializer: Deserializer[Row] = _
+      override def hasNext: Boolean = {
+        if (iterator.hasNext) {
+          return true
+        }
+        val nextBatchIndex = batchIndex + 1
+        val hasNextBatch = if (nextBatchIndex == batches.size) {
+          processResponses(stopOnFirstNonEmptyResponse = true)
+        } else {
+          true
+        }
+        if (hasNextBatch) {
+          batchIndex = nextBatchIndex
+          iterator = batches(nextBatchIndex).rowIterator()
+          if (deserializer == null) {
+            deserializer = encoder.createDeserializer()
+          }
+        }
+        hasNextBatch
+      }
+
+      override def next(): Row = {
+        if (!hasNext) {
+          throw new NoSuchElementException
+        }
+        deserializer(iterator.next())
+      }
+
+      override def close(): Unit = SparkResult.this.close()
+    }
+  }
+
+  /**
+   * Close this result, freeing any underlying resources.
+   */
+  override def close(): Unit = {
+    batches.foreach(_.close())
+  }
+
+  override def cleaner: AutoCloseable = AutoCloseables(batches.toSeq)
+}
diff --git 
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/util/Cleaner.scala
 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/util/Cleaner.scala
new file mode 100644
index 00000000000..4eecc881356
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/util/Cleaner.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client.util
+
+import java.lang.ref.{ReferenceQueue, WeakReference}
+import java.util.Collections
+import java.util.concurrent.ConcurrentHashMap
+
+import scala.collection.mutable
+import scala.util.control.NonFatal
+
+/**
+ * Helper class for cleaning up an object's resources after the object itself 
has been garbage
+ * collected.
+ *
+ * When we move to Java 9+ we should replace this class by 
[[java.lang.ref.Cleaner]].
+ */
+private[sql] class Cleaner {
+  class Ref(pin: AnyRef, val resource: AutoCloseable)
+      extends WeakReference[AnyRef](pin, referenceQueue)
+      with AutoCloseable {
+    override def close(): Unit = resource.close()
+  }
+
+  def register(pin: Cleanable): Unit = {
+    register(pin, pin.cleaner)
+  }
+
+  /**
+   * Register an objects' resources for clean-up. Note that it is absolutely 
pivotal that resource
+   * itself does not contain any reference to the object, if it does the 
object will never be
+   * garbage collected and the clean-up will never be performed.
+   *
+   * @param pin
+   *   who's resources need to be cleaned up after GC.
+   * @param resource
+   *   to clean-up.
+   */
+  def register(pin: AnyRef, resource: AutoCloseable): Unit = {
+    referenceBuffer.add(new Ref(pin, resource))
+  }
+
+  @volatile private var stopped = false
+  private val referenceBuffer = Collections.newSetFromMap[Ref](new 
ConcurrentHashMap)
+  private val referenceQueue = new ReferenceQueue[AnyRef]
+
+  private val cleanerThread = {
+    val thread = new Thread(() => cleanUp())
+    thread.setName("cleaner")
+    thread.setDaemon(true)
+    thread
+  }
+
+  def start(): Unit = {
+    require(!stopped)
+    cleanerThread.start()
+  }
+
+  def stop(): Unit = {
+    stopped = true
+    cleanerThread.interrupt()
+  }
+
+  private def cleanUp(): Unit = {
+    while (!stopped) {
+      try {
+        val ref = referenceQueue.remove().asInstanceOf[Ref]
+        referenceBuffer.remove(ref)
+        ref.close()
+      } catch {
+        case NonFatal(e) =>
+          // Perhaps log this?
+          e.printStackTrace()
+      }
+    }
+  }
+}
+
+trait Cleanable {
+  def cleaner: AutoCloseable
+}
+
+object AutoCloseables {
+  def apply(resources: Seq[AutoCloseable]): AutoCloseable = { () =>
+    val throwables = mutable.Buffer.empty[Throwable]
+    resources.foreach { resource =>
+      try {
+        resource.close()
+      } catch {
+        case NonFatal(e) => throwables += e
+      }
+    }
+    if (throwables.nonEmpty) {
+      val t = throwables.head
+      throwables.tail.foreach(t.addSuppressed)
+      throw t
+    }
+  }
+}
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
new file mode 100644
index 00000000000..ff18e36a02f
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql
+
+import org.apache.spark.sql.connect.client.util.RemoteSparkSession
+import org.apache.spark.sql.types.{StringType, StructField, StructType}
+
+class ClientE2ETestSuite extends RemoteSparkSession { // scalastyle:ignore 
funsuite
+
+  // Spark Result
+  test("test spark result schema") {
+    val df = spark.sql("select val from (values ('Hello'), ('World')) as 
t(val)")
+    val schema = df.collectResult().schema
+    assert(schema == StructType(StructField("val", StringType, false) :: Nil))
+  }
+
+  test("test spark result array") {
+    val df = spark.sql("select val from (values ('Hello'), ('World')) as 
t(val)")
+    val result = df.collectResult()
+    assert(result.length == 2)
+    val array = result.toArray
+    assert(array.length == 2)
+    assert(array(0).getString(0) == "Hello")
+    assert(array(1).getString(0) == "World")
+  }
+
+  // TODO test large result when we can create table or view
+  // test("test spark large result")
+}
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index 5dd5f0e2502..1229a91aa54 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -78,7 +78,7 @@ class SparkConnectClientSuite
   }
 
   test("Test connection") {
-    val testPort = 16000
+    val testPort = 16001
     client = SparkConnectClient.builder().port(testPort).build()
     testClientConnection(client, testPort)
   }
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
new file mode 100644
index 00000000000..f843b651ae8
--- /dev/null
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/util/RemoteSparkSession.scala
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.client.util
+
+import java.io.{BufferedOutputStream, File}
+
+import scala.io.Source
+
+import org.scalatest.BeforeAndAfterAll
+import sys.process._
+
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.connect.client.SparkConnectClient
+import org.apache.spark.sql.connect.common.config.ConnectCommon
+
+/**
+ * An util class to start a local spark connect server in a different process 
for local E2E tests.
+ * It is designed to start the server once but shared by all tests. It is 
equivalent to use the
+ * following command to start the connect server via command line:
+ *
+ * {{{
+ * bin/spark-shell \
+ * --jars `ls connector/connect/server/target/**/spark-connect*SNAPSHOT.jar | 
paste -sd ',' -` \
+ * --conf spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin
+ * }}}
+ *
+ * Set system property `SPARK_HOME` if the test is not executed from the Spark 
project top folder.
+ * Set system property `DEBUG_SC_JVM_CLIENT=true` to print the server process 
output in the
+ * console to debug server start stop problems.
+ */
+object SparkConnectServerUtils {
+  // System properties used for testing and debugging
+  private val SPARK_HOME = "SPARK_HOME"
+  private val ENV_DEBUG_SC_JVM_CLIENT = "DEBUG_SC_JVM_CLIENT"
+
+  private val sparkHome = System.getProperty(SPARK_HOME, fileSparkHome())
+  private val isDebug = System.getProperty(ENV_DEBUG_SC_JVM_CLIENT, 
"false").toBoolean
+
+  // Log server start stop debug info into console
+  // scalastyle:off println
+  private[connect] def debug(msg: String): Unit = if (isDebug) println(msg)
+  // scalastyle:on println
+  private[connect] def debug(error: Throwable): Unit = if (isDebug) 
error.printStackTrace()
+
+  // Server port
+  private[connect] val port = ConnectCommon.CONNECT_GRPC_BINDING_PORT + 
util.Random.nextInt(1000)
+
+  @volatile private var stopped = false
+
+  private lazy val sparkConnect: Process = {
+    debug("Starting the Spark Connect Server...")
+    val jar = findSparkConnectJar
+    val builder = Process(
+      new File(sparkHome, "bin/spark-shell").getCanonicalPath,
+      Seq(
+        "--jars",
+        jar,
+        "--driver-class-path",
+        jar,
+        "--conf",
+        "spark.plugins=org.apache.spark.sql.connect.SparkConnectPlugin",
+        "--conf",
+        s"spark.connect.grpc.binding.port=$port"))
+
+    val io = new ProcessIO(
+      // Hold the input channel to the spark console to keep the console open
+      in => new BufferedOutputStream(in),
+      // Only redirect output if debug to avoid channel interruption error on 
process termination.
+      out => if (isDebug) Source.fromInputStream(out).getLines.foreach(debug),
+      err => if (isDebug) Source.fromInputStream(err).getLines.foreach(debug))
+    val process = builder.run(io)
+
+    // Adding JVM shutdown hook
+    sys.addShutdownHook(kill())
+    process
+  }
+
+  def start(): Unit = {
+    assert(!stopped)
+    sparkConnect
+  }
+
+  def kill(): Int = {
+    stopped = true
+    debug("Stopping the Spark Connect Server...")
+    sparkConnect.destroy()
+    val code = sparkConnect.exitValue()
+    debug(s"Spark Connect Server is stopped with exit code: $code")
+    code
+  }
+
+  private def fileSparkHome(): String = {
+    val path = new File("./").getCanonicalPath
+    if (path.endsWith("connector/connect/client/jvm")) {
+      // the current folder is the client project folder
+      new File("../../../../").getCanonicalPath
+    } else {
+      path
+    }
+  }
+
+  private def findSparkConnectJar: String = {
+    val target = "connector/connect/server/target"
+    val parentDir = new File(sparkHome, target)
+    assert(
+      parentDir.exists(),
+      s"Fail to locate the spark connect target folder: 
'${parentDir.getCanonicalPath}'. " +
+        s"SPARK_HOME='${new File(sparkHome).getCanonicalPath}'. " +
+        "Make sure system property `SPARK_HOME` is set correctly.")
+    val jars = recursiveListFiles(parentDir).filter { f =>
+      // SBT jar
+      (f.getParent.endsWith("scala-2.12") &&
+        f.getName.startsWith("spark-connect-assembly") && 
f.getName.endsWith("SNAPSHOT.jar")) ||
+      // Maven Jar
+      (f.getParent.endsWith("target") &&
+        f.getName.startsWith("spark-connect") && 
f.getName.endsWith("SNAPSHOT.jar"))
+    }
+    // It is possible we found more than one: one built by maven, and another 
by SBT
+    assert(
+      jars.nonEmpty,
+      s"Failed to find the `spark-connect` jar inside folder: 
${parentDir.getCanonicalPath}")
+    debug("Using jar: " + jars(0).getCanonicalPath)
+    jars(0).getCanonicalPath // return the first one
+  }
+
+  def recursiveListFiles(f: File): Array[File] = {
+    val these = f.listFiles
+    these ++ these.filter(_.isDirectory).flatMap(recursiveListFiles)
+  }
+}
+
+trait RemoteSparkSession
+    extends org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite
+    with BeforeAndAfterAll {
+  import SparkConnectServerUtils._
+  var spark: SparkSession = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    SparkConnectServerUtils.start()
+    spark = 
SparkSession.builder().client(SparkConnectClient.builder().port(port).build()).build()
+
+    // Retry and wait for the server to start
+    val stop = System.currentTimeMillis() + 60000 * 1 // ~1 min
+    var sleepInternal = 1000 // 1s with * 2 backoff
+    var success = false
+    val error = new RuntimeException(s"Failed to start the test server on port 
$port.")
+
+    while (!success && System.currentTimeMillis() < stop) {
+      try {
+        // Run a simple query to verify the server is really up and ready
+        val result = spark
+          .sql("select val from (values ('Hello'), ('World')) as t(val)")
+          .collectResult()
+          .toArray
+        assert(result.length == 2)
+        success = true
+        debug("Spark Connect Server is up.")
+      } catch {
+        // ignored the error
+        case e: Throwable =>
+          error.addSuppressed(e)
+          Thread.sleep(sleepInternal)
+          sleepInternal *= 2
+      }
+    }
+
+    // Throw error if failed
+    if (!success) {
+      debug(error)
+      throw error
+    }
+  }
+
+  override def afterAll(): Unit = {
+    try {
+      spark.close()
+    } catch {
+      case e: Throwable => debug(e)
+    }
+    spark = null
+    super.afterAll()
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to