This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 0d7618a2ca8 [SPARK-42585][CONNECT] Streaming of local relations
0d7618a2ca8 is described below
commit 0d7618a2ca847cf9577659e50409dd5a383d66d3
Author: Max Gekk <[email protected]>
AuthorDate: Tue May 2 09:58:29 2023 +0900
[SPARK-42585][CONNECT] Streaming of local relations
### What changes were proposed in this pull request?
In the PR, I propose to transfer a local relation to the server in
streaming way when it exceeds some size which is defined by the SQL config
`spark.sql.session.localRelationCacheThreshold`. The config value is 64MB by
default. In particular:
1. The client applies the `sha256` function over the arrow form of the
local relation;
2. It checks presents of the relation at the server side by sending the
relation hash to the server;
3. If the server doesn't have the local relation, the client transfers the
local relation as an artefact with the name `cache/<sha256>`;
4. As soon as the relation has presented at the server already, or
transferred recently, the client transform the logical plan by replacing the
`LocalRelation` node by `CachedLocalRelation` with the hash.
5. On another hand, the server converts `CachedLocalRelation` back to
`LocalRelation` by retrieving the relation body from the local cache.
#### Details of the implementation
The client sends new command `ArtifactStatusesRequest` to check either the
local relation is cached at the server or not. New command comes via new RPC
endpoint `ArtifactStatus`. And the server answers by new message
`ArtifactStatusesResponse`, see **base.proto**.
The client transfers serialized (in avro) body of local relation and its
schema via the RPC endpoint `AddArtifacts`. On another hand, the server stores
the received artifact in the block manager using the id `CacheId`. The last one
has 3 parts:
- `userId` - the identifier of the user that created the local relation,
- `sessionId` - the identifier of the session which the relation belongs to,
- `hash` - a `sha-256` hash over relation body.
See **SparkConnectArtifactManager.addArtifact()**.
The current query is blocked till the local relation is cached at the
server side.
When the server receives the query, it retrieves `userId`, `sessionId` and
`hash` from `CachedLocalRelation`, and gets the local relation data from the
block manager. See **SparkConnectPlanner.transformCachedLocalRelation()**.
The occupied blocks at the block manager are removed when an user session
is invalidated in `userSessionMapping`. See
**SparkConnectService.RemoveSessionListener** and
**BlockManager.removeCache()`**.
### Why are the changes needed?
To allow creating a dataframe from a large local collection.
`spark.createDataFrame(...)` fails with the following error w/o the changes:
```java
23/04/21 20:32:20 WARN NettyServerStream: Exception processing message
org.sparkproject.connect.grpc.StatusRuntimeException: RESOURCE_EXHAUSTED:
gRPC message exceeds maximum size 134217728: 268435456
at
org.sparkproject.connect.grpc.Status.asRuntimeException(Status.java:526)
```
### Does this PR introduce _any_ user-facing change?
No. The changes extend the existing proto API.
### How was this patch tested?
By running the new tests:
```
$ build/sbt "test:testOnly *.ArtifactManagerSuite"
$ build/sbt "test:testOnly *.ClientE2ETestSuite"
$ build/sbt "test:testOnly *.ArtifactStatusesHandlerSuite"
```
Closes #40827 from MaxGekk/streaming-createDataFrame-2.
Authored-by: Max Gekk <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../scala/org/apache/spark/sql/SparkSession.scala | 19 +-
.../spark/sql/connect/client/ArtifactManager.scala | 42 +++-
.../sql/connect/client/SparkConnectClient.scala | 26 +-
.../sql/connect/client/util/ConvertToArrow.scala | 6 +-
.../org/apache/spark/sql/ClientE2ETestSuite.scala | 14 ++
.../spark/sql/connect/client/ArtifactSuite.scala | 2 +-
.../src/main/protobuf/spark/connect/base.proto | 40 ++++
.../main/protobuf/spark/connect/relations.proto | 13 +
.../artifact/SparkConnectArtifactManager.scala | 27 ++-
.../sql/connect/planner/SparkConnectPlanner.scala | 148 +++++++-----
.../service/SparkConnectAddArtifactsHandler.scala | 2 +-
.../SparkConnectArtifactStatusesHandler.scala | 54 +++++
.../sql/connect/service/SparkConnectService.scala | 29 ++-
.../connect/artifact/ArtifactManagerSuite.scala | 39 ++-
.../service/ArtifactStatusesHandlerSuite.scala | 87 +++++++
.../scala/org/apache/spark/storage/BlockId.scala | 5 +
.../org/apache/spark/storage/BlockManager.scala | 14 ++
python/pyspark/sql/connect/proto/base_pb2.py | 66 ++++-
python/pyspark/sql/connect/proto/base_pb2.pyi | 138 +++++++++++
python/pyspark/sql/connect/proto/base_pb2_grpc.py | 45 ++++
python/pyspark/sql/connect/proto/relations_pb2.py | 266 +++++++++++----------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 39 +++
.../org/apache/spark/sql/internal/SQLConf.scala | 9 +
23 files changed, 922 insertions(+), 208 deletions(-)
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 00910d5904a..461b18ec9c1 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -119,12 +119,23 @@ class SparkSession private[sql] (
private def createDataset[T](encoder: AgnosticEncoder[T], data:
Iterator[T]): Dataset[T] = {
newDataset(encoder) { builder =>
- val localRelationBuilder = builder.getLocalRelationBuilder
- .setSchema(encoder.schema.json)
if (data.nonEmpty) {
val timeZoneId = conf.get("spark.sql.session.timeZone")
- val arrowData = ConvertToArrow(encoder, data, timeZoneId, allocator)
- localRelationBuilder.setData(arrowData)
+ val (arrowData, arrowDataSize) = ConvertToArrow(encoder, data,
timeZoneId, allocator)
+ if (arrowDataSize <=
conf.get("spark.sql.session.localRelationCacheThreshold").toInt) {
+ builder.getLocalRelationBuilder
+ .setSchema(encoder.schema.json)
+ .setData(arrowData)
+ } else {
+ val hash = client.cacheLocalRelation(arrowDataSize, arrowData,
encoder.schema.json)
+ builder.getCachedLocalRelationBuilder
+ .setUserId(client.userId)
+ .setSessionId(client.sessionId)
+ .setHash(hash)
+ }
+ } else {
+ builder.getLocalRelationBuilder
+ .setSchema(encoder.schema.json)
}
}
}
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
index ef3d66c85bc..acd9f279c6d 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.connect.client
import java.io.{ByteArrayInputStream, InputStream}
import java.net.URI
import java.nio.file.{Files, Path, Paths}
+import java.util.Arrays
import java.util.concurrent.CopyOnWriteArrayList
import java.util.zip.{CheckedInputStream, CRC32}
@@ -32,6 +33,7 @@ import Artifact._
import com.google.protobuf.ByteString
import io.grpc.ManagedChannel
import io.grpc.stub.StreamObserver
+import org.apache.commons.codec.digest.DigestUtils.sha256Hex
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.AddArtifactsResponse
@@ -42,14 +44,20 @@ import org.apache.spark.util.{ThreadUtils, Utils}
* The Artifact Manager is responsible for handling and transferring artifacts
from the local
* client to the server (local/remote).
* @param userContext
+ * @param sessionId
+ * An unique identifier of the session which the artifact manager belongs to.
* @param channel
*/
-class ArtifactManager(userContext: proto.UserContext, channel: ManagedChannel)
{
+class ArtifactManager(
+ userContext: proto.UserContext,
+ sessionId: String,
+ channel: ManagedChannel) {
// Using the midpoint recommendation of 32KiB for chunk size as specified in
// https://github.com/grpc/grpc.github.io/issues/371.
private val CHUNK_SIZE: Int = 32 * 1024
private[this] val stub = proto.SparkConnectServiceGrpc.newStub(channel)
+ private[this] val bstub =
proto.SparkConnectServiceGrpc.newBlockingStub(channel)
private[this] val classFinders = new CopyOnWriteArrayList[ClassFinder]
/**
@@ -100,6 +108,31 @@ class ArtifactManager(userContext: proto.UserContext,
channel: ManagedChannel) {
*/
def addArtifacts(uris: Seq[URI]): Unit =
addArtifacts(uris.flatMap(parseArtifacts))
+ private def isCachedArtifact(hash: String): Boolean = {
+ val artifactName = CACHE_PREFIX + "/" + hash
+ val request = proto.ArtifactStatusesRequest
+ .newBuilder()
+ .setUserContext(userContext)
+ .setSessionId(sessionId)
+ .addAllNames(Arrays.asList(artifactName))
+ .build()
+ val statuses = bstub.artifactStatus(request).getStatusesMap
+ if (statuses.containsKey(artifactName)) {
+ statuses.get(artifactName).getExists
+ } else false
+ }
+
+ /**
+ * Cache the give blob at the session.
+ */
+ def cacheArtifact(blob: Array[Byte]): String = {
+ val hash = sha256Hex(blob)
+ if (!isCachedArtifact(hash)) {
+ addArtifacts(newCacheArtifact(hash, new InMemory(blob)) :: Nil)
+ }
+ hash
+ }
+
/**
* Upload all class file artifacts from the local REPL(s) to the server.
*
@@ -182,6 +215,7 @@ class ArtifactManager(userContext: proto.UserContext,
channel: ManagedChannel) {
val builder = proto.AddArtifactsRequest
.newBuilder()
.setUserContext(userContext)
+ .setSessionId(sessionId)
artifacts.foreach { artifact =>
val in = new
CheckedInputStream(artifact.storage.asInstanceOf[LocalData].stream, new CRC32)
try {
@@ -236,6 +270,7 @@ class ArtifactManager(userContext: proto.UserContext,
channel: ManagedChannel) {
val builder = proto.AddArtifactsRequest
.newBuilder()
.setUserContext(userContext)
+ .setSessionId(sessionId)
val in = new
CheckedInputStream(artifact.storage.asInstanceOf[LocalData].stream, new CRC32)
try {
@@ -289,6 +324,7 @@ class Artifact private (val path: Path, val storage:
LocalData) {
object Artifact {
val CLASS_PREFIX: Path = Paths.get("classes")
val JAR_PREFIX: Path = Paths.get("jars")
+ val CACHE_PREFIX: Path = Paths.get("cache")
def newJarArtifact(fileName: Path, storage: LocalData): Artifact = {
newArtifact(JAR_PREFIX, ".jar", fileName, storage)
@@ -298,6 +334,10 @@ object Artifact {
newArtifact(CLASS_PREFIX, ".class", fileName, storage)
}
+ def newCacheArtifact(id: String, storage: LocalData): Artifact = {
+ newArtifact(CACHE_PREFIX, "", Paths.get(id), storage)
+ }
+
private def newArtifact(
prefix: Path,
requiredSuffix: String,
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 924515166d8..1d47f3e663f 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -17,8 +17,11 @@
package org.apache.spark.sql.connect.client
+import com.google.protobuf.ByteString
import io.grpc.{CallCredentials, CallOptions, Channel, ClientCall,
ClientInterceptor, CompositeChannelCredentials, ForwardingClientCall, Grpc,
InsecureChannelCredentials, ManagedChannel, ManagedChannelBuilder, Metadata,
MethodDescriptor, Status, TlsChannelCredentials}
import java.net.URI
+import java.nio.ByteBuffer
+import java.nio.charset.StandardCharsets
import java.util.UUID
import java.util.concurrent.Executor
import scala.language.existentials
@@ -39,19 +42,21 @@ private[sql] class SparkConnectClient(
private[this] val stub =
proto.SparkConnectServiceGrpc.newBlockingStub(channel)
- private[client] val artifactManager: ArtifactManager = new
ArtifactManager(userContext, channel)
-
/**
* Placeholder method.
* @return
* User ID.
*/
- private[client] def userId: String = userContext.getUserId()
+ private[sql] def userId: String = userContext.getUserId()
// Generate a unique session ID for this client. This UUID must be unique to
allow
// concurrent Spark sessions of the same user. If the channel is closed,
creating
// a new client will create a new session ID.
- private[client] val sessionId: String = UUID.randomUUID.toString
+ private[sql] val sessionId: String = UUID.randomUUID.toString
+
+ private[client] val artifactManager: ArtifactManager = {
+ new ArtifactManager(userContext, sessionId, channel)
+ }
/**
* Dispatch the [[proto.AnalyzePlanRequest]] to the Spark Connect server.
@@ -215,6 +220,19 @@ private[sql] class SparkConnectClient(
def shutdown(): Unit = {
channel.shutdownNow()
}
+
+ /**
+ * Cache the given local relation at the server, and return its key in the
remote cache.
+ */
+ def cacheLocalRelation(size: Int, data: ByteString, schema: String): String
= {
+ val schemaBytes = schema.getBytes(StandardCharsets.UTF_8)
+ val locRelData = data.toByteArray
+ val locRel = ByteBuffer.allocate(4 + locRelData.length +
schemaBytes.length)
+ locRel.putInt(size)
+ locRel.put(locRelData)
+ locRel.put(schemaBytes)
+ artifactManager.cacheArtifact(locRel.array())
+ }
}
object SparkConnectClient {
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/util/ConvertToArrow.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/util/ConvertToArrow.scala
index d124870e162..46a9493d138 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/util/ConvertToArrow.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/connect/client/util/ConvertToArrow.scala
@@ -34,13 +34,13 @@ import org.apache.spark.sql.util.ArrowUtils
private[sql] object ConvertToArrow {
/**
- * Convert an iterator of common Scala objects into a sinlge Arrow IPC
Stream.
+ * Convert an iterator of common Scala objects into a single Arrow IPC
Stream.
*/
def apply[T](
encoder: AgnosticEncoder[T],
data: Iterator[T],
timeZoneId: String,
- bufferAllocator: BufferAllocator): ByteString = {
+ bufferAllocator: BufferAllocator): (ByteString, Int) = {
val arrowSchema = ArrowUtils.toArrowSchema(encoder.schema, timeZoneId)
val root = VectorSchemaRoot.create(arrowSchema, bufferAllocator)
val writer: ArrowWriter = ArrowWriter.create(root)
@@ -64,7 +64,7 @@ private[sql] object ConvertToArrow {
ArrowStreamWriter.writeEndOfStream(channel, IpcOption.DEFAULT)
// Done
- bytes.toByteString
+ (bytes.toByteString, bytes.size)
} finally {
root.close()
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 5b11604ebe9..abeeaf7e483 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -34,6 +34,7 @@ import
org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.catalyst.parser.ParseException
import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils,
RemoteSparkSession}
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper {
@@ -853,6 +854,19 @@ class ClientE2ETestSuite extends RemoteSparkSession with
SQLHelper {
}.getMessage
assert(message.contains("PARSE_SYNTAX_ERROR"))
}
+
+ test("SparkSession.createDataFrame - large data set") {
+ val threshold = 1024 * 1024
+ withSQLConf(SQLConf.LOCAL_RELATION_CACHE_THRESHOLD.key ->
threshold.toString) {
+ val count = 2
+ val suffix = "abcdef"
+ val str = scala.util.Random.alphanumeric.take(1024 * 1024).mkString +
suffix
+ val data = Seq.tabulate(count)(i => (i, str))
+ val df = spark.createDataFrame(data)
+ assert(df.count() === count)
+ assert(!df.filter(df("_2").endsWith(suffix)).isEmpty)
+ }
+ }
}
private[sql] case class MyType(id: Long, a: Double, b: Double)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
index 09072b8d6eb..5db40806d18 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/ArtifactSuite.scala
@@ -49,7 +49,7 @@ class ArtifactSuite extends ConnectFunSuite with
BeforeAndAfterEach {
private def createArtifactManager(): Unit = {
channel =
InProcessChannelBuilder.forName(getClass.getName).directExecutor().build()
- artifactManager = new
ArtifactManager(proto.UserContext.newBuilder().build(), channel)
+ artifactManager = new
ArtifactManager(proto.UserContext.newBuilder().build(), "", channel)
}
override def beforeEach(): Unit = {
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index 1c927adcffa..d2eb882af7f 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -542,6 +542,43 @@ message AddArtifactsResponse {
repeated ArtifactSummary artifacts = 1;
}
+// Request to get current statuses of artifacts at the server side.
+message ArtifactStatusesRequest {
+ // (Required)
+ //
+ // The session_id specifies a spark session for a user id (which is specified
+ // by user_context.user_id). The session_id is set by the client to be able
to
+ // collate streaming responses from different queries within the dedicated
session.
+ string session_id = 1;
+
+ // User context
+ UserContext user_context = 2;
+
+ // Provides optional information about the client sending the request. This
field
+ // can be used for language or version specific information and is only
intended for
+ // logging purposes and will not be interpreted by the server.
+ optional string client_type = 3;
+
+ // The name of the artifact is expected in the form of a "Relative Path"
that is made up of a
+ // sequence of directories and the final file element.
+ // Examples of "Relative Path"s: "jars/test.jar", "classes/xyz.class",
"abc.xyz", "a/b/X.jar".
+ // The server is expected to maintain the hierarchy of files as defined by
their name. (i.e
+ // The relative path of the file on the server's filesystem will be the same
as the name of
+ // the provided artifact)
+ repeated string names = 4;
+}
+
+// Response to checking artifact statuses.
+message ArtifactStatusesResponse {
+ message ArtifactStatus {
+ // Exists or not particular artifact at the server.
+ bool exists = 1;
+ }
+
+ // A map of artifact names to their statuses.
+ map<string, ArtifactStatus> statuses = 1;
+}
+
// Main interface for the SparkConnect service.
service SparkConnectService {
@@ -559,5 +596,8 @@ service SparkConnectService {
// Add artifacts to the session and returns a [[AddArtifactsResponse]]
containing metadata about
// the added artifacts.
rpc AddArtifacts(stream AddArtifactsRequest) returns (AddArtifactsResponse)
{}
+
+ // Check statuses of artifacts in the session and returns them in a
[[ArtifactStatusesResponse]]
+ rpc ArtifactStatus(ArtifactStatusesRequest) returns
(ArtifactStatusesResponse) {}
}
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 77a1d40a2ea..984b7d3166c 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -68,6 +68,7 @@ message Relation {
WithWatermark with_watermark = 33;
ApplyInPandasWithState apply_in_pandas_with_state = 34;
HtmlString html_string = 35;
+ CachedLocalRelation cached_local_relation = 36;
// NA functions
NAFill fill_na = 90;
@@ -381,6 +382,18 @@ message LocalRelation {
optional string schema = 2;
}
+// A local relation that has been cached already.
+message CachedLocalRelation {
+ // (Required) An identifier of the user which created the local relation
+ string userId = 1;
+
+ // (Required) An identifier of the Spark SQL session in which the user
created the local relation.
+ string sessionId = 2;
+
+ // (Required) A sha-256 hash of the serialized local relation.
+ string hash = 3;
+}
+
// Relation of type [[Sample]] that samples a fraction of the dataset.
message Sample {
// (Required) Input relation for a Sample.
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
index 9ed5fd945f2..2521515f850 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
@@ -22,9 +22,11 @@ import java.nio.file.{Files, Path, Paths, StandardCopyOption}
import java.util.concurrent.CopyOnWriteArrayList
import scala.collection.JavaConverters._
+import scala.reflect.ClassTag
import org.apache.spark.{SparkContext, SparkEnv}
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.connect.service.SessionHolder
+import org.apache.spark.storage.{CacheId, StorageLevel}
import org.apache.spark.util.Utils
/**
@@ -87,11 +89,28 @@ class SparkConnectArtifactManager private[connect] {
* @param serverLocalStagingPath
*/
private[connect] def addArtifact(
- session: SparkSession,
+ sessionHolder: SessionHolder,
remoteRelativePath: Path,
serverLocalStagingPath: Path): Unit = {
require(!remoteRelativePath.isAbsolute)
- if (remoteRelativePath.startsWith("classes/")) {
+ if (remoteRelativePath.startsWith("cache/")) {
+ val tmpFile = serverLocalStagingPath.toFile
+ Utils.tryWithSafeFinallyAndFailureCallbacks {
+ val blockManager = sessionHolder.session.sparkContext.env.blockManager
+ val blockId = CacheId(
+ userId = sessionHolder.userId,
+ sessionId = sessionHolder.sessionId,
+ hash = remoteRelativePath.toString.stripPrefix("cache/"))
+ val updater = blockManager.TempFileBasedBlockStoreUpdater(
+ blockId = blockId,
+ level = StorageLevel.MEMORY_AND_DISK_SER,
+ classTag = implicitly[ClassTag[Array[Byte]]],
+ tmpFile = tmpFile,
+ blockSize = tmpFile.length(),
+ tellMaster = false)
+ updater.save()
+ }(catchBlock = { tmpFile.delete() })
+ } else if (remoteRelativePath.startsWith("classes/")) {
// Move class files to common location (shared among all users)
val target =
classArtifactDir.resolve(remoteRelativePath.toString.stripPrefix("classes/"))
Files.createDirectories(target.getParent)
@@ -110,7 +129,7 @@ class SparkConnectArtifactManager private[connect] {
Files.move(serverLocalStagingPath, target)
if (remoteRelativePath.startsWith("jars")) {
// Adding Jars to the underlying spark context (visible to all users)
- session.sessionState.resourceLoader.addJar(target.toString)
+
sessionHolder.session.sessionState.resourceLoader.addJar(target.toString)
jarsList.add(target)
}
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 22229ba98b2..8bac639023c 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -17,6 +17,9 @@
package org.apache.spark.sql.connect.planner
+import java.nio.ByteBuffer
+import java.nio.charset.StandardCharsets
+
import scala.collection.JavaConverters._
import scala.collection.mutable
@@ -63,6 +66,7 @@ import org.apache.spark.sql.internal.CatalogImpl
import org.apache.spark.sql.streaming.Trigger
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.storage.CacheId
import org.apache.spark.util.Utils
final case class InvalidCommandInput(
@@ -122,6 +126,8 @@ class SparkConnectPlanner(val session: SparkSession) {
case proto.Relation.RelTypeCase.WITH_COLUMNS =>
transformWithColumns(rel.getWithColumns)
case proto.Relation.RelTypeCase.WITH_WATERMARK =>
transformWithWatermark(rel.getWithWatermark)
+ case proto.Relation.RelTypeCase.CACHED_LOCAL_RELATION =>
+ transformCachedLocalRelation(rel.getCachedLocalRelation)
case proto.Relation.RelTypeCase.HINT => transformHint(rel.getHint)
case proto.Relation.RelTypeCase.UNPIVOT =>
transformUnpivot(rel.getUnpivot)
case proto.Relation.RelTypeCase.REPARTITION_BY_EXPRESSION =>
@@ -777,6 +783,31 @@ class SparkConnectPlanner(val session: SparkSession) {
.logicalPlan
}
+ private def transformCachedLocalRelation(rel: proto.CachedLocalRelation):
LogicalPlan = {
+ val blockManager = session.sparkContext.env.blockManager
+ val blockId = CacheId(rel.getUserId, rel.getSessionId, rel.getHash)
+ val bytes = blockManager.getLocalBytes(blockId)
+ bytes
+ .map { blockData =>
+ try {
+ val blob = blockData.toByteBuffer().array()
+ val blobSize = blockData.size.toInt
+ val size = ByteBuffer.wrap(blob).getInt
+ val intSize = 4
+ val data = blob.slice(intSize, intSize + size)
+ val schema = new String(blob.slice(intSize + size, blobSize),
StandardCharsets.UTF_8)
+ transformLocalRelation(Option(schema), Option(data))
+ } finally {
+ blockManager.releaseLock(blockId)
+ }
+ }
+ .getOrElse {
+ throw InvalidPlanInput(
+ s"Not found any cached local relation with the hash: ${blockId.hash}
in " +
+ s"the session ${blockId.sessionId} for the user id
${blockId.userId}.")
+ }
+ }
+
private def transformHint(rel: proto.Hint): LogicalPlan = {
def extractValue(expr: Expression): Any = {
@@ -908,70 +939,79 @@ class SparkConnectPlanner(val session: SparkSession) {
}
}
- private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = {
- var schema: StructType = null
- if (rel.hasSchema) {
+ private def transformLocalRelation(
+ schema: Option[String],
+ data: Option[Array[Byte]]): LogicalPlan = {
+ val optStruct = schema.map { schemaStr =>
val schemaType = DataType.parseTypeWithFallback(
- rel.getSchema,
+ schemaStr,
parseDatatypeString,
fallbackParser = DataType.fromJson)
- schema = schemaType match {
+ schemaType match {
case s: StructType => s
case d => StructType(Seq(StructField("value", d)))
}
}
-
- if (rel.hasData) {
- val (rows, structType) = ArrowConverters.fromBatchWithSchemaIterator(
- Iterator(rel.getData.toByteArray),
- TaskContext.get())
- if (structType == null) {
- throw InvalidPlanInput(s"Input data for LocalRelation does not produce
a schema.")
- }
- val attributes = structType.toAttributes
- val proj = UnsafeProjection.create(attributes, attributes)
- val data = rows.map(proj)
-
- if (schema == null) {
- logical.LocalRelation(attributes, data.map(_.copy()).toSeq)
- } else {
- def normalize(dt: DataType): DataType = dt match {
- case udt: UserDefinedType[_] => normalize(udt.sqlType)
- case StructType(fields) =>
- val newFields = fields.zipWithIndex.map {
- case (StructField(_, dataType, nullable, metadata), i) =>
- StructField(s"col_$i", normalize(dataType), nullable, metadata)
- }
- StructType(newFields)
- case ArrayType(elementType, containsNull) =>
- ArrayType(normalize(elementType), containsNull)
- case MapType(keyType, valueType, valueContainsNull) =>
- MapType(normalize(keyType), normalize(valueType),
valueContainsNull)
- case _ => dt
+ data
+ .map { dataBytes =>
+ val (rows, structType) =
+ ArrowConverters.fromBatchWithSchemaIterator(Iterator(dataBytes),
TaskContext.get())
+ if (structType == null) {
+ throw InvalidPlanInput(s"Input data for LocalRelation does not
produce a schema.")
}
-
- val normalized = normalize(schema).asInstanceOf[StructType]
-
- val project = Dataset
- .ofRows(
- session,
- logicalPlan =
-
logical.LocalRelation(normalize(structType).asInstanceOf[StructType].toAttributes))
- .toDF(normalized.names: _*)
- .to(normalized)
- .logicalPlan
- .asInstanceOf[Project]
-
- val proj = UnsafeProjection.create(project.projectList,
project.child.output)
- logical.LocalRelation(schema.toAttributes,
data.map(proj).map(_.copy()).toSeq)
+ val attributes = structType.toAttributes
+ val proj = UnsafeProjection.create(attributes, attributes)
+ val data = rows.map(proj)
+ optStruct
+ .map { struct =>
+ def normalize(dt: DataType): DataType = dt match {
+ case udt: UserDefinedType[_] => normalize(udt.sqlType)
+ case StructType(fields) =>
+ val newFields = fields.zipWithIndex.map {
+ case (StructField(_, dataType, nullable, metadata), i) =>
+ StructField(s"col_$i", normalize(dataType), nullable,
metadata)
+ }
+ StructType(newFields)
+ case ArrayType(elementType, containsNull) =>
+ ArrayType(normalize(elementType), containsNull)
+ case MapType(keyType, valueType, valueContainsNull) =>
+ MapType(normalize(keyType), normalize(valueType),
valueContainsNull)
+ case _ => dt
+ }
+ val normalized = normalize(struct).asInstanceOf[StructType]
+ val project = Dataset
+ .ofRows(
+ session,
+ logicalPlan = logical.LocalRelation(
+ normalize(structType).asInstanceOf[StructType].toAttributes))
+ .toDF(normalized.names: _*)
+ .to(normalized)
+ .logicalPlan
+ .asInstanceOf[Project]
+
+ val proj = UnsafeProjection.create(project.projectList,
project.child.output)
+ logical.LocalRelation(struct.toAttributes,
data.map(proj).map(_.copy()).toSeq)
+ }
+ .getOrElse {
+ logical.LocalRelation(attributes, data.map(_.copy()).toSeq)
+ }
}
- } else {
- if (schema == null) {
- throw InvalidPlanInput(
- s"Schema for LocalRelation is required when the input data is not
provided.")
+ .getOrElse {
+ optStruct
+ .map { struct =>
+ LocalRelation(struct.toAttributes, data = Seq.empty)
+ }
+ .getOrElse {
+ throw InvalidPlanInput(
+ s"Schema for LocalRelation is required when the input data is
not provided.")
+ }
}
- LocalRelation(schema.toAttributes, data = Seq.empty)
- }
+ }
+
+ private def transformLocalRelation(rel: proto.LocalRelation): LogicalPlan = {
+ transformLocalRelation(
+ if (rel.hasSchema) Some(rel.getSchema) else None,
+ if (rel.hasData) Some(rel.getData.toByteArray) else None)
}
/** Parse as DDL, with a fallback to JSON. Throws an exception if if fails
to parse. */
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
index 7f447c9672f..4f619c52544 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
@@ -83,7 +83,7 @@ class SparkConnectAddArtifactsHandler(val responseObserver:
StreamObserver[AddAr
}
protected def addStagedArtifactToArtifactManager(artifact: StagedArtifact):
Unit = {
- artifactManager.addArtifact(holder.session, artifact.path,
artifact.stagedPath)
+ artifactManager.addArtifact(holder, artifact.path, artifact.stagedPath)
}
/**
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectArtifactStatusesHandler.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectArtifactStatusesHandler.scala
new file mode 100644
index 00000000000..d67d01bef57
--- /dev/null
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectArtifactStatusesHandler.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.service
+
+import scala.collection.JavaConverters._
+
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.connect.proto
+import org.apache.spark.internal.Logging
+import org.apache.spark.storage.CacheId
+
+class SparkConnectArtifactStatusesHandler(
+ val responseObserver: StreamObserver[proto.ArtifactStatusesResponse])
+ extends Logging {
+
+ protected def cacheExists(userId: String, sessionId: String, hash: String):
Boolean = {
+ val session = SparkConnectService
+ .getOrCreateIsolatedSession(userId, sessionId)
+ .session
+ val blockManager = session.sparkContext.env.blockManager
+ blockManager.get(CacheId(userId, sessionId, hash)).isDefined
+ }
+
+ def handle(request: proto.ArtifactStatusesRequest): Unit = {
+ val builder = proto.ArtifactStatusesResponse.newBuilder()
+ request.getNamesList().iterator().asScala.foreach { name =>
+ val status = proto.ArtifactStatusesResponse.ArtifactStatus.newBuilder()
+ val exists = if (name.startsWith("cache/")) {
+ cacheExists(
+ userId = request.getUserContext.getUserId,
+ sessionId = request.getSessionId,
+ hash = name.stripPrefix("cache/"))
+ } else false
+ builder.putStatuses(name, status.setExists(exists).build())
+ }
+ responseObserver.onNext(builder.build())
+ responseObserver.onCompleted()
+ }
+}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index a051eef5e40..e8e1b6177a8 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer
import scala.util.control.NonFatal
import com.google.common.base.Ticker
-import com.google.common.cache.CacheBuilder
+import com.google.common.cache.{CacheBuilder, RemovalListener,
RemovalNotification}
import com.google.protobuf.{Any => ProtoAny}
import com.google.rpc.{Code => RPCCode, ErrorInfo, Status => RPCStatus}
import io.grpc.{Server, Status}
@@ -229,6 +229,22 @@ class SparkConnectService(debug: Boolean)
override def addArtifacts(responseObserver:
StreamObserver[AddArtifactsResponse])
: StreamObserver[AddArtifactsRequest] = new
SparkConnectAddArtifactsHandler(
responseObserver)
+
+ /**
+ * This is the entry point for all calls of getting artifact statuses.
+ */
+ override def artifactStatus(
+ request: proto.ArtifactStatusesRequest,
+ responseObserver: StreamObserver[proto.ArtifactStatusesResponse]): Unit
= {
+ try {
+ new SparkConnectArtifactStatusesHandler(responseObserver).handle(request)
+ } catch
+ handleError(
+ "artifactStatus",
+ observer = responseObserver,
+ userId = request.getUserContext.getUserId,
+ sessionId = request.getSessionId)
+ }
}
/**
@@ -274,6 +290,15 @@ object SparkConnectService {
userSessionMapping.getIfPresent((userId, sessionId))
})
+ private class RemoveSessionListener extends RemovalListener[SessionCacheKey,
SessionHolder] {
+ override def onRemoval(
+ notification: RemovalNotification[SessionCacheKey, SessionHolder]):
Unit = {
+ val SessionHolder(userId, sessionId, session) = notification.getValue
+ val blockManager = session.sparkContext.env.blockManager
+ blockManager.removeCache(userId, sessionId)
+ }
+ }
+
// Simple builder for creating the cache of Sessions.
private def cacheBuilder(cacheSize: Int, timeoutSeconds: Int):
CacheBuilder[Object, Object] = {
var cacheBuilder = CacheBuilder.newBuilder().ticker(Ticker.systemTicker())
@@ -283,6 +308,7 @@ object SparkConnectService {
if (timeoutSeconds >= 0) {
cacheBuilder.expireAfterAccess(timeoutSeconds, TimeUnit.SECONDS)
}
+ cacheBuilder.removalListener(new RemoveSessionListener)
cacheBuilder
}
@@ -338,6 +364,7 @@ object SparkConnectService {
server.shutdownNow()
}
}
+ userSessionMapping.invalidateAll()
}
def extractErrorMessage(st: Throwable): String = {
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
index 6c661cbe1bb..ba71b1839e9 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
@@ -16,15 +16,17 @@
*/
package org.apache.spark.sql.connect.artifact
-import java.nio.file.Paths
+import java.nio.charset.StandardCharsets
+import java.nio.file.{Files, Paths}
import org.apache.commons.io.FileUtils
import org.apache.spark.SparkConf
import org.apache.spark.sql.connect.ResourceHelper
-import org.apache.spark.sql.connect.service.SparkConnectService
+import org.apache.spark.sql.connect.service.{SessionHolder,
SparkConnectService}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.storage.CacheId
import org.apache.spark.util.Utils
class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper {
@@ -37,12 +39,16 @@ class ArtifactManagerSuite extends SharedSparkSession with
ResourceHelper {
private val artifactPath = commonResourcePath.resolve("artifact-tests")
private lazy val artifactManager =
SparkConnectArtifactManager.getOrCreateArtifactManager
+ private def sessionHolder(): SessionHolder = {
+ SessionHolder("test", spark.sessionUUID, spark)
+ }
+
test("Jar artifacts are added to spark session") {
val copyDir = Utils.createTempDir().toPath
FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile)
val stagingPath = copyDir.resolve("smallJar.jar")
val remotePath = Paths.get("jars/smallJar.jar")
- artifactManager.addArtifact(spark, remotePath, stagingPath)
+ artifactManager.addArtifact(sessionHolder, remotePath, stagingPath)
val jarList = spark.sparkContext.listJars()
assert(jarList.exists(_.contains(remotePath.toString)))
@@ -54,7 +60,7 @@ class ArtifactManagerSuite extends SharedSparkSession with
ResourceHelper {
val stagingPath = copyDir.resolve("smallClassFile.class")
val remotePath = Paths.get("classes/smallClassFile.class")
assert(stagingPath.toFile.exists())
- artifactManager.addArtifact(spark, remotePath, stagingPath)
+ artifactManager.addArtifact(sessionHolder, remotePath, stagingPath)
val classFileDirectory = artifactManager.classArtifactDir
val movedClassFile =
classFileDirectory.resolve("smallClassFile.class").toFile
@@ -67,7 +73,7 @@ class ArtifactManagerSuite extends SharedSparkSession with
ResourceHelper {
val stagingPath = copyDir.resolve("Hello.class")
val remotePath = Paths.get("classes/Hello.class")
assert(stagingPath.toFile.exists())
- artifactManager.addArtifact(spark, remotePath, stagingPath)
+ artifactManager.addArtifact(sessionHolder, remotePath, stagingPath)
val classFileDirectory = artifactManager.classArtifactDir
val movedClassFile = classFileDirectory.resolve("Hello.class").toFile
@@ -90,7 +96,7 @@ class ArtifactManagerSuite extends SharedSparkSession with
ResourceHelper {
val stagingPath = copyDir.resolve("Hello.class")
val remotePath = Paths.get("classes/Hello.class")
assert(stagingPath.toFile.exists())
- artifactManager.addArtifact(spark, remotePath, stagingPath)
+ artifactManager.addArtifact(sessionHolder, remotePath, stagingPath)
val classFileDirectory = artifactManager.classArtifactDir
val movedClassFile = classFileDirectory.resolve("Hello.class").toFile
@@ -107,4 +113,25 @@ class ArtifactManagerSuite extends SharedSparkSession with
ResourceHelper {
val session = SparkConnectService.getOrCreateIsolatedSession("c1",
"session").session
session.range(10).select(udf(col("id").cast("string"))).collect()
}
+
+ test("add a cache artifact to the Block Manager") {
+ withTempPath { path =>
+ val stagingPath = path.toPath
+ Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8))
+ val remotePath = Paths.get("cache/abc")
+ val session = sessionHolder()
+ val blockManager = spark.sparkContext.env.blockManager
+ val blockId = CacheId(session.userId, session.sessionId, "abc")
+ try {
+ artifactManager.addArtifact(session, remotePath, stagingPath)
+ val bytes = blockManager.getLocalBytes(blockId)
+ assert(bytes.isDefined)
+ val readback = new String(bytes.get.toByteBuffer().array(),
StandardCharsets.UTF_8)
+ assert(readback === "test")
+ } finally {
+ blockManager.releaseLock(blockId)
+ blockManager.removeCache(session.userId, session.sessionId)
+ }
+ }
+ }
}
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala
new file mode 100644
index 00000000000..b2e7f52825b
--- /dev/null
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/ArtifactStatusesHandlerSuite.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.service
+
+import scala.collection.JavaConverters._
+import scala.concurrent.Promise
+import scala.concurrent.duration._
+
+import io.grpc.stub.StreamObserver
+import org.apache.commons.codec.digest.DigestUtils.sha256Hex
+
+import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.ArtifactStatusesResponse
+import org.apache.spark.sql.connect.ResourceHelper
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.util.ThreadUtils
+
+private class DummyStreamObserver(p: Promise[ArtifactStatusesResponse])
+ extends StreamObserver[ArtifactStatusesResponse] {
+ override def onNext(v: ArtifactStatusesResponse): Unit = p.success(v)
+ override def onError(throwable: Throwable): Unit = throw throwable
+ override def onCompleted(): Unit = {}
+}
+
+class ArtifactStatusesHandlerSuite extends SharedSparkSession with
ResourceHelper {
+ def getStatuses(names: Seq[String], exist: Set[String]):
ArtifactStatusesResponse = {
+ val promise = Promise[ArtifactStatusesResponse]
+ val handler = new SparkConnectArtifactStatusesHandler(new
DummyStreamObserver(promise)) {
+ override protected def cacheExists(
+ userId: String,
+ sessionId: String,
+ hash: String): Boolean = {
+ exist.contains(hash)
+ }
+ }
+ val context = proto.UserContext
+ .newBuilder()
+ .setUserId("user1")
+ .build()
+ val request = proto.ArtifactStatusesRequest
+ .newBuilder()
+ .setUserContext(context)
+ .setSessionId("abc")
+ .addAllNames(names.asJava)
+ .build()
+ handler.handle(request)
+ ThreadUtils.awaitResult(promise.future, 5.seconds)
+ }
+
+ private def id(name: String): String = "cache/" + sha256Hex(name)
+
+ test("non-existent artifact") {
+ val response = getStatuses(names = Seq(id("name1")), exist = Set.empty)
+ assert(response.getStatusesCount === 1)
+ assert(response.getStatusesMap.get(id("name1")).getExists === false)
+ }
+
+ test("single artifact") {
+ val response = getStatuses(names = Seq(id("name1")), exist =
Set(sha256Hex("name1")))
+ assert(response.getStatusesCount === 1)
+ assert(response.getStatusesMap.get(id("name1")).getExists)
+ }
+
+ test("multiple artifacts") {
+ val response = getStatuses(
+ names = Seq("name1", "name2", "name3").map(id),
+ exist = Set("name2", "name3").map(sha256Hex))
+ assert(response.getStatusesCount === 3)
+ assert(!response.getStatusesMap.get(id("name1")).getExists)
+ assert(response.getStatusesMap.get(id("name2")).getExists)
+ assert(response.getStatusesMap.get(id("name3")).getExists)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index b8ec93f74ab..456b4edf938 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -189,6 +189,11 @@ private[spark] case class TestBlockId(id: String) extends
BlockId {
class UnrecognizedBlockId(name: String)
extends SparkException(s"Failed to parse $name into a block ID")
+@DeveloperApi
+case class CacheId(userId: String, sessionId: String, hash: String) extends
BlockId {
+ override def name: String = s"cache_${userId}_${sessionId}_$hash"
+}
+
@DeveloperApi
object BlockId {
val RDD = "rdd_([0-9]+)_([0-9]+)".r
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 143f4b3ada4..a8f74ef179b 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -2042,6 +2042,20 @@ private[spark] class BlockManager(
blocksToRemove.size
}
+ /**
+ * Remove cache blocks that might be related to cached local relations.
+ *
+ * @return The number of blocks removed.
+ */
+ def removeCache(userId: String, sessionId: String): Int = {
+ logDebug(s"Removing cache of user id = $userId in the session $sessionId")
+ val blocksToRemove = blockInfoManager.entries.map(_._1).collect {
+ case cid: CacheId if cid.userId == userId && cid.sessionId == sessionId
=> cid
+ }
+ blocksToRemove.foreach { blockId => removeBlock(blockId) }
+ blocksToRemove.size
+ }
+
/**
* Remove a block from both memory and disk.
*/
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py
b/python/pyspark/sql/connect/proto/base_pb2.py
index 93fd7ac16d0..9b59aaecdf9 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -38,7 +38,7 @@ from pyspark.sql.connect.proto import types_pb2 as
spark_dot_connect_dot_types__
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
[...]
+
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
[...]
)
@@ -113,6 +113,14 @@ _ADDARTIFACTSRESPONSE =
DESCRIPTOR.message_types_by_name["AddArtifactsResponse"]
_ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY =
_ADDARTIFACTSRESPONSE.nested_types_by_name[
"ArtifactSummary"
]
+_ARTIFACTSTATUSESREQUEST =
DESCRIPTOR.message_types_by_name["ArtifactStatusesRequest"]
+_ARTIFACTSTATUSESRESPONSE =
DESCRIPTOR.message_types_by_name["ArtifactStatusesResponse"]
+_ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS =
_ARTIFACTSTATUSESRESPONSE.nested_types_by_name[
+ "ArtifactStatus"
+]
+_ARTIFACTSTATUSESRESPONSE_STATUSESENTRY =
_ARTIFACTSTATUSESRESPONSE.nested_types_by_name[
+ "StatusesEntry"
+]
_ANALYZEPLANREQUEST_EXPLAIN_EXPLAINMODE =
_ANALYZEPLANREQUEST_EXPLAIN.enum_types_by_name[
"ExplainMode"
]
@@ -697,6 +705,48 @@ AddArtifactsResponse =
_reflection.GeneratedProtocolMessageType(
_sym_db.RegisterMessage(AddArtifactsResponse)
_sym_db.RegisterMessage(AddArtifactsResponse.ArtifactSummary)
+ArtifactStatusesRequest = _reflection.GeneratedProtocolMessageType(
+ "ArtifactStatusesRequest",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _ARTIFACTSTATUSESREQUEST,
+ "__module__": "spark.connect.base_pb2"
+ #
@@protoc_insertion_point(class_scope:spark.connect.ArtifactStatusesRequest)
+ },
+)
+_sym_db.RegisterMessage(ArtifactStatusesRequest)
+
+ArtifactStatusesResponse = _reflection.GeneratedProtocolMessageType(
+ "ArtifactStatusesResponse",
+ (_message.Message,),
+ {
+ "ArtifactStatus": _reflection.GeneratedProtocolMessageType(
+ "ArtifactStatus",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS,
+ "__module__": "spark.connect.base_pb2"
+ #
@@protoc_insertion_point(class_scope:spark.connect.ArtifactStatusesResponse.ArtifactStatus)
+ },
+ ),
+ "StatusesEntry": _reflection.GeneratedProtocolMessageType(
+ "StatusesEntry",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY,
+ "__module__": "spark.connect.base_pb2"
+ #
@@protoc_insertion_point(class_scope:spark.connect.ArtifactStatusesResponse.StatusesEntry)
+ },
+ ),
+ "DESCRIPTOR": _ARTIFACTSTATUSESRESPONSE,
+ "__module__": "spark.connect.base_pb2"
+ #
@@protoc_insertion_point(class_scope:spark.connect.ArtifactStatusesResponse)
+ },
+)
+_sym_db.RegisterMessage(ArtifactStatusesResponse)
+_sym_db.RegisterMessage(ArtifactStatusesResponse.ArtifactStatus)
+_sym_db.RegisterMessage(ArtifactStatusesResponse.StatusesEntry)
+
_SPARKCONNECTSERVICE = DESCRIPTOR.services_by_name["SparkConnectService"]
if _descriptor._USE_C_DESCRIPTORS == False:
@@ -704,6 +754,8 @@ if _descriptor._USE_C_DESCRIPTORS == False:
DESCRIPTOR._serialized_options =
b"\n\036org.apache.spark.connect.protoP\001"
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._options =
None
_EXECUTEPLANRESPONSE_METRICS_METRICOBJECT_EXECUTIONMETRICSENTRY._serialized_options
= b"8\001"
+ _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._options = None
+ _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_options = b"8\001"
_PLAN._serialized_start = 219
_PLAN._serialized_end = 335
_USERCONTEXT._serialized_start = 337
@@ -820,6 +872,14 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_ADDARTIFACTSRESPONSE._serialized_end = 8704
_ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_start = 8623
_ADDARTIFACTSRESPONSE_ARTIFACTSUMMARY._serialized_end = 8704
- _SPARKCONNECTSERVICE._serialized_start = 8707
- _SPARKCONNECTSERVICE._serialized_end = 9072
+ _ARTIFACTSTATUSESREQUEST._serialized_start = 8707
+ _ARTIFACTSTATUSESREQUEST._serialized_end = 8902
+ _ARTIFACTSTATUSESRESPONSE._serialized_start = 8905
+ _ARTIFACTSTATUSESRESPONSE._serialized_end = 9173
+ _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_start = 9016
+ _ARTIFACTSTATUSESRESPONSE_ARTIFACTSTATUS._serialized_end = 9056
+ _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_start = 9058
+ _ARTIFACTSTATUSESRESPONSE_STATUSESENTRY._serialized_end = 9173
+ _SPARKCONNECTSERVICE._serialized_start = 9176
+ _SPARKCONNECTSERVICE._serialized_end = 9642
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index 037afb822d1..f9ebc8930c6 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -1998,3 +1998,141 @@ class
AddArtifactsResponse(google.protobuf.message.Message):
) -> None: ...
global___AddArtifactsResponse = AddArtifactsResponse
+
+class ArtifactStatusesRequest(google.protobuf.message.Message):
+ """Request to get current statuses of artifacts at the server side."""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ SESSION_ID_FIELD_NUMBER: builtins.int
+ USER_CONTEXT_FIELD_NUMBER: builtins.int
+ CLIENT_TYPE_FIELD_NUMBER: builtins.int
+ NAMES_FIELD_NUMBER: builtins.int
+ session_id: builtins.str
+ """(Required)
+
+ The session_id specifies a spark session for a user id (which is specified
+ by user_context.user_id). The session_id is set by the client to be able to
+ collate streaming responses from different queries within the dedicated
session.
+ """
+ @property
+ def user_context(self) -> global___UserContext:
+ """User context"""
+ client_type: builtins.str
+ """Provides optional information about the client sending the request.
This field
+ can be used for language or version specific information and is only
intended for
+ logging purposes and will not be interpreted by the server.
+ """
+ @property
+ def names(
+ self,
+ ) ->
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+ """The name of the artifact is expected in the form of a "Relative
Path" that is made up of a
+ sequence of directories and the final file element.
+ Examples of "Relative Path"s: "jars/test.jar", "classes/xyz.class",
"abc.xyz", "a/b/X.jar".
+ The server is expected to maintain the hierarchy of files as defined
by their name. (i.e
+ The relative path of the file on the server's filesystem will be the
same as the name of
+ the provided artifact)
+ """
+ def __init__(
+ self,
+ *,
+ session_id: builtins.str = ...,
+ user_context: global___UserContext | None = ...,
+ client_type: builtins.str | None = ...,
+ names: collections.abc.Iterable[builtins.str] | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_client_type",
+ b"_client_type",
+ "client_type",
+ b"client_type",
+ "user_context",
+ b"user_context",
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_client_type",
+ b"_client_type",
+ "client_type",
+ b"client_type",
+ "names",
+ b"names",
+ "session_id",
+ b"session_id",
+ "user_context",
+ b"user_context",
+ ],
+ ) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_client_type",
b"_client_type"]
+ ) -> typing_extensions.Literal["client_type"] | None: ...
+
+global___ArtifactStatusesRequest = ArtifactStatusesRequest
+
+class ArtifactStatusesResponse(google.protobuf.message.Message):
+ """Response to checking artifact statuses."""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ class ArtifactStatus(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ EXISTS_FIELD_NUMBER: builtins.int
+ exists: builtins.bool
+ """Exists or not particular artifact at the server."""
+ def __init__(
+ self,
+ *,
+ exists: builtins.bool = ...,
+ ) -> None: ...
+ def ClearField(
+ self, field_name: typing_extensions.Literal["exists", b"exists"]
+ ) -> None: ...
+
+ class StatusesEntry(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ KEY_FIELD_NUMBER: builtins.int
+ VALUE_FIELD_NUMBER: builtins.int
+ key: builtins.str
+ @property
+ def value(self) -> global___ArtifactStatusesResponse.ArtifactStatus:
...
+ def __init__(
+ self,
+ *,
+ key: builtins.str = ...,
+ value: global___ArtifactStatusesResponse.ArtifactStatus | None =
...,
+ ) -> None: ...
+ def HasField(
+ self, field_name: typing_extensions.Literal["value", b"value"]
+ ) -> builtins.bool: ...
+ def ClearField(
+ self, field_name: typing_extensions.Literal["key", b"key",
"value", b"value"]
+ ) -> None: ...
+
+ STATUSES_FIELD_NUMBER: builtins.int
+ @property
+ def statuses(
+ self,
+ ) -> google.protobuf.internal.containers.MessageMap[
+ builtins.str, global___ArtifactStatusesResponse.ArtifactStatus
+ ]:
+ """A map of artifact names to their statuses."""
+ def __init__(
+ self,
+ *,
+ statuses: collections.abc.Mapping[
+ builtins.str, global___ArtifactStatusesResponse.ArtifactStatus
+ ]
+ | None = ...,
+ ) -> None: ...
+ def ClearField(
+ self, field_name: typing_extensions.Literal["statuses", b"statuses"]
+ ) -> None: ...
+
+global___ArtifactStatusesResponse = ArtifactStatusesResponse
diff --git a/python/pyspark/sql/connect/proto/base_pb2_grpc.py
b/python/pyspark/sql/connect/proto/base_pb2_grpc.py
index c372cbcc487..ecbe4f9c389 100644
--- a/python/pyspark/sql/connect/proto/base_pb2_grpc.py
+++ b/python/pyspark/sql/connect/proto/base_pb2_grpc.py
@@ -50,6 +50,11 @@ class SparkConnectServiceStub(object):
request_serializer=spark_dot_connect_dot_base__pb2.AddArtifactsRequest.SerializeToString,
response_deserializer=spark_dot_connect_dot_base__pb2.AddArtifactsResponse.FromString,
)
+ self.ArtifactStatus = channel.unary_unary(
+ "/spark.connect.SparkConnectService/ArtifactStatus",
+
request_serializer=spark_dot_connect_dot_base__pb2.ArtifactStatusesRequest.SerializeToString,
+
response_deserializer=spark_dot_connect_dot_base__pb2.ArtifactStatusesResponse.FromString,
+ )
class SparkConnectServiceServicer(object):
@@ -84,6 +89,12 @@ class SparkConnectServiceServicer(object):
context.set_details("Method not implemented!")
raise NotImplementedError("Method not implemented!")
+ def ArtifactStatus(self, request, context):
+ """Check statuses of artifacts in the session and returns them in a
[[ArtifactStatusesResponse]]"""
+ context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+ context.set_details("Method not implemented!")
+ raise NotImplementedError("Method not implemented!")
+
def add_SparkConnectServiceServicer_to_server(servicer, server):
rpc_method_handlers = {
@@ -107,6 +118,11 @@ def add_SparkConnectServiceServicer_to_server(servicer,
server):
request_deserializer=spark_dot_connect_dot_base__pb2.AddArtifactsRequest.FromString,
response_serializer=spark_dot_connect_dot_base__pb2.AddArtifactsResponse.SerializeToString,
),
+ "ArtifactStatus": grpc.unary_unary_rpc_method_handler(
+ servicer.ArtifactStatus,
+
request_deserializer=spark_dot_connect_dot_base__pb2.ArtifactStatusesRequest.FromString,
+
response_serializer=spark_dot_connect_dot_base__pb2.ArtifactStatusesResponse.SerializeToString,
+ ),
}
generic_handler = grpc.method_handlers_generic_handler(
"spark.connect.SparkConnectService", rpc_method_handlers
@@ -233,3 +249,32 @@ class SparkConnectService(object):
timeout,
metadata,
)
+
+ @staticmethod
+ def ArtifactStatus(
+ request,
+ target,
+ options=(),
+ channel_credentials=None,
+ call_credentials=None,
+ insecure=False,
+ compression=None,
+ wait_for_ready=None,
+ timeout=None,
+ metadata=None,
+ ):
+ return grpc.experimental.unary_unary(
+ request,
+ target,
+ "/spark.connect.SparkConnectService/ArtifactStatus",
+
spark_dot_connect_dot_base__pb2.ArtifactStatusesRequest.SerializeToString,
+
spark_dot_connect_dot_base__pb2.ArtifactStatusesResponse.FromString,
+ options,
+ channel_credentials,
+ insecure,
+ call_credentials,
+ compression,
+ wait_for_ready,
+ timeout,
+ metadata,
+ )
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 0cddac1e2a4..a61223caf12 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as
spark_dot_connect_dot_catal
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\x99\x16\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf3\x16\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
)
@@ -63,6 +63,7 @@ _SORT = DESCRIPTOR.message_types_by_name["Sort"]
_DROP = DESCRIPTOR.message_types_by_name["Drop"]
_DEDUPLICATE = DESCRIPTOR.message_types_by_name["Deduplicate"]
_LOCALRELATION = DESCRIPTOR.message_types_by_name["LocalRelation"]
+_CACHEDLOCALRELATION = DESCRIPTOR.message_types_by_name["CachedLocalRelation"]
_SAMPLE = DESCRIPTOR.message_types_by_name["Sample"]
_RANGE = DESCRIPTOR.message_types_by_name["Range"]
_SUBQUERYALIAS = DESCRIPTOR.message_types_by_name["SubqueryAlias"]
@@ -352,6 +353,17 @@ LocalRelation = _reflection.GeneratedProtocolMessageType(
)
_sym_db.RegisterMessage(LocalRelation)
+CachedLocalRelation = _reflection.GeneratedProtocolMessageType(
+ "CachedLocalRelation",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _CACHEDLOCALRELATION,
+ "__module__": "spark.connect.relations_pb2"
+ #
@@protoc_insertion_point(class_scope:spark.connect.CachedLocalRelation)
+ },
+)
+_sym_db.RegisterMessage(CachedLocalRelation)
+
Sample = _reflection.GeneratedProtocolMessageType(
"Sample",
(_message.Message,),
@@ -758,129 +770,131 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_PARSE_OPTIONSENTRY._options = None
_PARSE_OPTIONSENTRY._serialized_options = b"8\001"
_RELATION._serialized_start = 165
- _RELATION._serialized_end = 3006
- _UNKNOWN._serialized_start = 3008
- _UNKNOWN._serialized_end = 3017
- _RELATIONCOMMON._serialized_start = 3019
- _RELATIONCOMMON._serialized_end = 3110
- _SQL._serialized_start = 3113
- _SQL._serialized_end = 3282
- _SQL_ARGSENTRY._serialized_start = 3192
- _SQL_ARGSENTRY._serialized_end = 3282
- _READ._serialized_start = 3285
- _READ._serialized_end = 3948
- _READ_NAMEDTABLE._serialized_start = 3463
- _READ_NAMEDTABLE._serialized_end = 3655
- _READ_NAMEDTABLE_OPTIONSENTRY._serialized_start = 3597
- _READ_NAMEDTABLE_OPTIONSENTRY._serialized_end = 3655
- _READ_DATASOURCE._serialized_start = 3658
- _READ_DATASOURCE._serialized_end = 3935
- _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3597
- _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3655
- _PROJECT._serialized_start = 3950
- _PROJECT._serialized_end = 4067
- _FILTER._serialized_start = 4069
- _FILTER._serialized_end = 4181
- _JOIN._serialized_start = 4184
- _JOIN._serialized_end = 4655
- _JOIN_JOINTYPE._serialized_start = 4447
- _JOIN_JOINTYPE._serialized_end = 4655
- _SETOPERATION._serialized_start = 4658
- _SETOPERATION._serialized_end = 5137
- _SETOPERATION_SETOPTYPE._serialized_start = 4974
- _SETOPERATION_SETOPTYPE._serialized_end = 5088
- _LIMIT._serialized_start = 5139
- _LIMIT._serialized_end = 5215
- _OFFSET._serialized_start = 5217
- _OFFSET._serialized_end = 5296
- _TAIL._serialized_start = 5298
- _TAIL._serialized_end = 5373
- _AGGREGATE._serialized_start = 5376
- _AGGREGATE._serialized_end = 5958
- _AGGREGATE_PIVOT._serialized_start = 5715
- _AGGREGATE_PIVOT._serialized_end = 5826
- _AGGREGATE_GROUPTYPE._serialized_start = 5829
- _AGGREGATE_GROUPTYPE._serialized_end = 5958
- _SORT._serialized_start = 5961
- _SORT._serialized_end = 6121
- _DROP._serialized_start = 6124
- _DROP._serialized_end = 6265
- _DEDUPLICATE._serialized_start = 6268
- _DEDUPLICATE._serialized_end = 6508
- _LOCALRELATION._serialized_start = 6510
- _LOCALRELATION._serialized_end = 6599
- _SAMPLE._serialized_start = 6602
- _SAMPLE._serialized_end = 6875
- _RANGE._serialized_start = 6878
- _RANGE._serialized_end = 7023
- _SUBQUERYALIAS._serialized_start = 7025
- _SUBQUERYALIAS._serialized_end = 7139
- _REPARTITION._serialized_start = 7142
- _REPARTITION._serialized_end = 7284
- _SHOWSTRING._serialized_start = 7287
- _SHOWSTRING._serialized_end = 7429
- _HTMLSTRING._serialized_start = 7431
- _HTMLSTRING._serialized_end = 7545
- _STATSUMMARY._serialized_start = 7547
- _STATSUMMARY._serialized_end = 7639
- _STATDESCRIBE._serialized_start = 7641
- _STATDESCRIBE._serialized_end = 7722
- _STATCROSSTAB._serialized_start = 7724
- _STATCROSSTAB._serialized_end = 7825
- _STATCOV._serialized_start = 7827
- _STATCOV._serialized_end = 7923
- _STATCORR._serialized_start = 7926
- _STATCORR._serialized_end = 8063
- _STATAPPROXQUANTILE._serialized_start = 8066
- _STATAPPROXQUANTILE._serialized_end = 8230
- _STATFREQITEMS._serialized_start = 8232
- _STATFREQITEMS._serialized_end = 8357
- _STATSAMPLEBY._serialized_start = 8360
- _STATSAMPLEBY._serialized_end = 8669
- _STATSAMPLEBY_FRACTION._serialized_start = 8561
- _STATSAMPLEBY_FRACTION._serialized_end = 8660
- _NAFILL._serialized_start = 8672
- _NAFILL._serialized_end = 8806
- _NADROP._serialized_start = 8809
- _NADROP._serialized_end = 8943
- _NAREPLACE._serialized_start = 8946
- _NAREPLACE._serialized_end = 9242
- _NAREPLACE_REPLACEMENT._serialized_start = 9101
- _NAREPLACE_REPLACEMENT._serialized_end = 9242
- _TODF._serialized_start = 9244
- _TODF._serialized_end = 9332
- _WITHCOLUMNSRENAMED._serialized_start = 9335
- _WITHCOLUMNSRENAMED._serialized_end = 9574
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 9507
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 9574
- _WITHCOLUMNS._serialized_start = 9576
- _WITHCOLUMNS._serialized_end = 9695
- _WITHWATERMARK._serialized_start = 9698
- _WITHWATERMARK._serialized_end = 9832
- _HINT._serialized_start = 9835
- _HINT._serialized_end = 9967
- _UNPIVOT._serialized_start = 9970
- _UNPIVOT._serialized_end = 10297
- _UNPIVOT_VALUES._serialized_start = 10227
- _UNPIVOT_VALUES._serialized_end = 10286
- _TOSCHEMA._serialized_start = 10299
- _TOSCHEMA._serialized_end = 10405
- _REPARTITIONBYEXPRESSION._serialized_start = 10408
- _REPARTITIONBYEXPRESSION._serialized_end = 10611
- _MAPPARTITIONS._serialized_start = 10614
- _MAPPARTITIONS._serialized_end = 10795
- _GROUPMAP._serialized_start = 10798
- _GROUPMAP._serialized_end = 11077
- _COGROUPMAP._serialized_start = 11080
- _COGROUPMAP._serialized_end = 11606
- _APPLYINPANDASWITHSTATE._serialized_start = 11609
- _APPLYINPANDASWITHSTATE._serialized_end = 11966
- _COLLECTMETRICS._serialized_start = 11969
- _COLLECTMETRICS._serialized_end = 12105
- _PARSE._serialized_start = 12108
- _PARSE._serialized_end = 12496
- _PARSE_OPTIONSENTRY._serialized_start = 3597
- _PARSE_OPTIONSENTRY._serialized_end = 3655
- _PARSE_PARSEFORMAT._serialized_start = 12397
- _PARSE_PARSEFORMAT._serialized_end = 12485
+ _RELATION._serialized_end = 3096
+ _UNKNOWN._serialized_start = 3098
+ _UNKNOWN._serialized_end = 3107
+ _RELATIONCOMMON._serialized_start = 3109
+ _RELATIONCOMMON._serialized_end = 3200
+ _SQL._serialized_start = 3203
+ _SQL._serialized_end = 3372
+ _SQL_ARGSENTRY._serialized_start = 3282
+ _SQL_ARGSENTRY._serialized_end = 3372
+ _READ._serialized_start = 3375
+ _READ._serialized_end = 4038
+ _READ_NAMEDTABLE._serialized_start = 3553
+ _READ_NAMEDTABLE._serialized_end = 3745
+ _READ_NAMEDTABLE_OPTIONSENTRY._serialized_start = 3687
+ _READ_NAMEDTABLE_OPTIONSENTRY._serialized_end = 3745
+ _READ_DATASOURCE._serialized_start = 3748
+ _READ_DATASOURCE._serialized_end = 4025
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3687
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3745
+ _PROJECT._serialized_start = 4040
+ _PROJECT._serialized_end = 4157
+ _FILTER._serialized_start = 4159
+ _FILTER._serialized_end = 4271
+ _JOIN._serialized_start = 4274
+ _JOIN._serialized_end = 4745
+ _JOIN_JOINTYPE._serialized_start = 4537
+ _JOIN_JOINTYPE._serialized_end = 4745
+ _SETOPERATION._serialized_start = 4748
+ _SETOPERATION._serialized_end = 5227
+ _SETOPERATION_SETOPTYPE._serialized_start = 5064
+ _SETOPERATION_SETOPTYPE._serialized_end = 5178
+ _LIMIT._serialized_start = 5229
+ _LIMIT._serialized_end = 5305
+ _OFFSET._serialized_start = 5307
+ _OFFSET._serialized_end = 5386
+ _TAIL._serialized_start = 5388
+ _TAIL._serialized_end = 5463
+ _AGGREGATE._serialized_start = 5466
+ _AGGREGATE._serialized_end = 6048
+ _AGGREGATE_PIVOT._serialized_start = 5805
+ _AGGREGATE_PIVOT._serialized_end = 5916
+ _AGGREGATE_GROUPTYPE._serialized_start = 5919
+ _AGGREGATE_GROUPTYPE._serialized_end = 6048
+ _SORT._serialized_start = 6051
+ _SORT._serialized_end = 6211
+ _DROP._serialized_start = 6214
+ _DROP._serialized_end = 6355
+ _DEDUPLICATE._serialized_start = 6358
+ _DEDUPLICATE._serialized_end = 6598
+ _LOCALRELATION._serialized_start = 6600
+ _LOCALRELATION._serialized_end = 6689
+ _CACHEDLOCALRELATION._serialized_start = 6691
+ _CACHEDLOCALRELATION._serialized_end = 6786
+ _SAMPLE._serialized_start = 6789
+ _SAMPLE._serialized_end = 7062
+ _RANGE._serialized_start = 7065
+ _RANGE._serialized_end = 7210
+ _SUBQUERYALIAS._serialized_start = 7212
+ _SUBQUERYALIAS._serialized_end = 7326
+ _REPARTITION._serialized_start = 7329
+ _REPARTITION._serialized_end = 7471
+ _SHOWSTRING._serialized_start = 7474
+ _SHOWSTRING._serialized_end = 7616
+ _HTMLSTRING._serialized_start = 7618
+ _HTMLSTRING._serialized_end = 7732
+ _STATSUMMARY._serialized_start = 7734
+ _STATSUMMARY._serialized_end = 7826
+ _STATDESCRIBE._serialized_start = 7828
+ _STATDESCRIBE._serialized_end = 7909
+ _STATCROSSTAB._serialized_start = 7911
+ _STATCROSSTAB._serialized_end = 8012
+ _STATCOV._serialized_start = 8014
+ _STATCOV._serialized_end = 8110
+ _STATCORR._serialized_start = 8113
+ _STATCORR._serialized_end = 8250
+ _STATAPPROXQUANTILE._serialized_start = 8253
+ _STATAPPROXQUANTILE._serialized_end = 8417
+ _STATFREQITEMS._serialized_start = 8419
+ _STATFREQITEMS._serialized_end = 8544
+ _STATSAMPLEBY._serialized_start = 8547
+ _STATSAMPLEBY._serialized_end = 8856
+ _STATSAMPLEBY_FRACTION._serialized_start = 8748
+ _STATSAMPLEBY_FRACTION._serialized_end = 8847
+ _NAFILL._serialized_start = 8859
+ _NAFILL._serialized_end = 8993
+ _NADROP._serialized_start = 8996
+ _NADROP._serialized_end = 9130
+ _NAREPLACE._serialized_start = 9133
+ _NAREPLACE._serialized_end = 9429
+ _NAREPLACE_REPLACEMENT._serialized_start = 9288
+ _NAREPLACE_REPLACEMENT._serialized_end = 9429
+ _TODF._serialized_start = 9431
+ _TODF._serialized_end = 9519
+ _WITHCOLUMNSRENAMED._serialized_start = 9522
+ _WITHCOLUMNSRENAMED._serialized_end = 9761
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 9694
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 9761
+ _WITHCOLUMNS._serialized_start = 9763
+ _WITHCOLUMNS._serialized_end = 9882
+ _WITHWATERMARK._serialized_start = 9885
+ _WITHWATERMARK._serialized_end = 10019
+ _HINT._serialized_start = 10022
+ _HINT._serialized_end = 10154
+ _UNPIVOT._serialized_start = 10157
+ _UNPIVOT._serialized_end = 10484
+ _UNPIVOT_VALUES._serialized_start = 10414
+ _UNPIVOT_VALUES._serialized_end = 10473
+ _TOSCHEMA._serialized_start = 10486
+ _TOSCHEMA._serialized_end = 10592
+ _REPARTITIONBYEXPRESSION._serialized_start = 10595
+ _REPARTITIONBYEXPRESSION._serialized_end = 10798
+ _MAPPARTITIONS._serialized_start = 10801
+ _MAPPARTITIONS._serialized_end = 10982
+ _GROUPMAP._serialized_start = 10985
+ _GROUPMAP._serialized_end = 11264
+ _COGROUPMAP._serialized_start = 11267
+ _COGROUPMAP._serialized_end = 11793
+ _APPLYINPANDASWITHSTATE._serialized_start = 11796
+ _APPLYINPANDASWITHSTATE._serialized_end = 12153
+ _COLLECTMETRICS._serialized_start = 12156
+ _COLLECTMETRICS._serialized_end = 12292
+ _PARSE._serialized_start = 12295
+ _PARSE._serialized_end = 12683
+ _PARSE_OPTIONSENTRY._serialized_start = 3687
+ _PARSE_OPTIONSENTRY._serialized_end = 3745
+ _PARSE_PARSEFORMAT._serialized_start = 12584
+ _PARSE_PARSEFORMAT._serialized_end = 12672
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 2b60a117b71..7898645dca5 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -97,6 +97,7 @@ class Relation(google.protobuf.message.Message):
WITH_WATERMARK_FIELD_NUMBER: builtins.int
APPLY_IN_PANDAS_WITH_STATE_FIELD_NUMBER: builtins.int
HTML_STRING_FIELD_NUMBER: builtins.int
+ CACHED_LOCAL_RELATION_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
@@ -182,6 +183,8 @@ class Relation(google.protobuf.message.Message):
@property
def html_string(self) -> global___HtmlString: ...
@property
+ def cached_local_relation(self) -> global___CachedLocalRelation: ...
+ @property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
@@ -253,6 +256,7 @@ class Relation(google.protobuf.message.Message):
with_watermark: global___WithWatermark | None = ...,
apply_in_pandas_with_state: global___ApplyInPandasWithState | None =
...,
html_string: global___HtmlString | None = ...,
+ cached_local_relation: global___CachedLocalRelation | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
@@ -277,6 +281,8 @@ class Relation(google.protobuf.message.Message):
b"apply_in_pandas_with_state",
"approx_quantile",
b"approx_quantile",
+ "cached_local_relation",
+ b"cached_local_relation",
"catalog",
b"catalog",
"co_group_map",
@@ -382,6 +388,8 @@ class Relation(google.protobuf.message.Message):
b"apply_in_pandas_with_state",
"approx_quantile",
b"approx_quantile",
+ "cached_local_relation",
+ b"cached_local_relation",
"catalog",
b"catalog",
"co_group_map",
@@ -515,6 +523,7 @@ class Relation(google.protobuf.message.Message):
"with_watermark",
"apply_in_pandas_with_state",
"html_string",
+ "cached_local_relation",
"fill_na",
"drop_na",
"replace",
@@ -1554,6 +1563,36 @@ class LocalRelation(google.protobuf.message.Message):
global___LocalRelation = LocalRelation
+class CachedLocalRelation(google.protobuf.message.Message):
+ """A local relation that has been cached already."""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ USERID_FIELD_NUMBER: builtins.int
+ SESSIONID_FIELD_NUMBER: builtins.int
+ HASH_FIELD_NUMBER: builtins.int
+ userId: builtins.str
+ """(Required) An identifier of the user which created the local relation"""
+ sessionId: builtins.str
+ """(Required) An identifier of the Spark SQL session in which the user
created the local relation."""
+ hash: builtins.str
+ """(Required) A sha-256 hash of the serialized local relation."""
+ def __init__(
+ self,
+ *,
+ userId: builtins.str = ...,
+ sessionId: builtins.str = ...,
+ hash: builtins.str = ...,
+ ) -> None: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "hash", b"hash", "sessionId", b"sessionId", "userId", b"userId"
+ ],
+ ) -> None: ...
+
+global___CachedLocalRelation = CachedLocalRelation
+
class Sample(google.protobuf.message.Message):
"""Relation of type [[Sample]] that samples a fraction of the dataset."""
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 3093f0c1378..874f95af1cb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -4212,6 +4212,15 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ val LOCAL_RELATION_CACHE_THRESHOLD =
+ buildConf("spark.sql.session.localRelationCacheThreshold")
+ .doc("The threshold for the size in bytes of local relations to be
cached at " +
+ "the driver side after serialization.")
+ .version("3.5.0")
+ .intConf
+ .checkValue(_ >= 0, "The threshold of cached local relations must not be
negative")
+ .createWithDefault(64 * 1024 * 1024)
+
/**
* Holds information about keys that have been deprecated.
*
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]