This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new db81309dd80a [SPARK-53917][CONNECT] Support large local relations
db81309dd80a is described below
commit db81309dd80a77f02a25c08522c3f06fd1c2926b
Author: Alex Khakhlyuk <[email protected]>
AuthorDate: Wed Oct 22 09:49:18 2025 -0400
[SPARK-53917][CONNECT] Support large local relations
### What changes were proposed in this pull request?
https://issues.apache.org/jira/browse/SPARK-53917
#### Problem description
LocalRelation is a Catalyst logical operator used to represent a dataset of
rows inline as part of the LogicalPlan. LocalRelations represent dataframes
created directly from Python and Scala objects, e.g., Python and Scala lists,
pandas dataframes, csv files loaded in memory, etc.
In Spark Connect, local relations are transferred over gRPC using
LocalRelation (for relations under 64MB) and CachedLocalRelation (larger
relations over 64MB) messages.
CachedLocalRelations currently have a hard size limit of 2GB, which means
that spark users can’t execute queries with local client data, pandas
dataframes, csv files of over 2GB.
#### Design
In Spark Connect, the client needs to serialize the local relation before
transferring it to the server. It serializes data via an Arrow IPC stream as a
single record batch and schema as a json string. It then embeds data and schema
as LocalRelation{schema,data} proto message.
Small local relations (under 64MB) are sent directly as part of the
ExecutePlanRequest.
<img width="1398" height="550" alt="image"
src="https://github.com/user-attachments/assets/c176f4cd-1a8f-4d72-8217-5a3bc221ace9"
/>
Larger local relations are first sent to the server via addArtifact and
stored in memory or on disk via BlockManager. Then an ExecutePlanRequest is
sent containing CachedLocalRelation{hash}, where hash is the artifact hash. The
server retrieves the cached LocalRelation from the BlockManager via the hash,
deserializes it, adds it to the LogicalPlan and then executes it.
<img width="1401" height="518" alt="image"
src="https://github.com/user-attachments/assets/51352194-5439-4559-9d43-fc19dfe81437"
/>
The server reads the data from the BlockManager as a stream and tries to
create proto.LocalRelation via
```
proto.Relation
.newBuilder()
.getLocalRelation
.getParserForType
.parseFrom(blockData.toInputStream())
```
This fails, because java protobuf library has a 2GB limit on deserializing
protobuf messages from a string.
```
org.sparkproject.connect.com.google.protobuf.InvalidProtocolBufferException)
CodedInputStream encountered an embedded string or message which claimed to
have negative size.
```
<img width="1396" height="503" alt="image"
src="https://github.com/user-attachments/assets/60da9441-f4cc-45d5-b028-57573a0175c2"
/>
To fix this, I propose avoiding the protobuf layer during the serialization
on the client and deserialization on the server. Instead of caching the full
protobuf LocalRelation message, we cache the data and schema as separate
artifacts, send two hashes {data_hash, schema_hash} to the server, load them
both from BlockManager directly and create a LocalRelation on the server based
on the unpacked data and schema.
<img width="1397" height="515" alt="image"
src="https://github.com/user-attachments/assets/e44558de-df64-43b0-8813-d03de6689810"
/>
After creating a prototype with the new proto message, I discovered that
there are additional limits for CachedLocalRelations. Both the Scala Client and
the Server store the data in a single Java Array[Byte], which has a 2GB size
limit in Java. To avoid this limit, I propose transferring data in chunks. The
Python and Scala clients will split data into multiple Arrow batches and upload
them separately to the server. Each batch will be uploaded and stored a
separate artifact. The Serve [...]
<img width="1395" height="569" alt="image"
src="https://github.com/user-attachments/assets/16fac7b2-d247-42a6-9ac3-decb48df023d"
/>
The final proto message looks like this:
```
message ChunkedCachedLocalRelation {
repeated string dataHashes = 1;
optional string schemaHash = 2;
}
```
### Why are the changes needed?
LocalRelations currently have a hard size limit of 2GB, which means that
spark users can’t execute queries with local client data, pandas dataframes,
csv files of over 2GB.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New python and scala tests for large local relations.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #52613 from khakhlyuk/largelocalrelations.
Authored-by: Alex Khakhlyuk <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
python/pyspark/sql/connect/client/artifact.py | 55 ++++
python/pyspark/sql/connect/client/core.py | 6 +
python/pyspark/sql/connect/plan.py | 91 +++++-
python/pyspark/sql/connect/proto/relations_pb2.py | 338 +++++++++++----------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 58 +++-
python/pyspark/sql/connect/session.py | 51 +++-
python/pyspark/sql/tests/arrow/test_arrow.py | 23 ++
.../sql/tests/connect/arrow/test_parity_arrow.py | 6 +
.../org/apache/spark/sql/internal/SqlApiConf.scala | 7 +-
.../spark/sql/internal/SqlApiConfHelper.scala | 3 +
.../org/apache/spark/sql/internal/SQLConf.scala | 46 ++-
.../spark/sql/connect/SparkSessionE2ESuite.scala | 26 ++
.../main/protobuf/spark/connect/relations.proto | 16 +
.../apache/spark/sql/connect/SparkSession.scala | 39 ++-
.../spark/sql/connect/client/ArtifactManager.scala | 64 ++++
.../sql/connect/client/SparkConnectClient.scala | 26 +-
.../sql/connect/planner/InvalidInputErrors.scala | 20 ++
.../sql/connect/planner/SparkConnectPlanner.scala | 136 +++++++--
18 files changed, 785 insertions(+), 226 deletions(-)
diff --git a/python/pyspark/sql/connect/client/artifact.py
b/python/pyspark/sql/connect/client/artifact.py
index ac33233a00ff..72a6ffa8bf68 100644
--- a/python/pyspark/sql/connect/client/artifact.py
+++ b/python/pyspark/sql/connect/client/artifact.py
@@ -427,6 +427,30 @@ class ArtifactManager:
status = resp.statuses.get(artifactName)
return status.exists if status is not None else False
+ def get_cached_artifacts(self, hashes: list[str]) -> set[str]:
+ """
+ Batch check which artifacts are already cached on the server.
+ Returns a set of hashes that are already cached.
+ """
+ if not hashes:
+ return set()
+
+ artifact_names = [f"{CACHE_PREFIX}/{hash}" for hash in hashes]
+ request = proto.ArtifactStatusesRequest(
+ user_context=self._user_context, session_id=self._session_id,
names=artifact_names
+ )
+ resp: proto.ArtifactStatusesResponse = self._stub.ArtifactStatus(
+ request, metadata=self._metadata
+ )
+
+ cached = set()
+ for hash in hashes:
+ artifact_name = f"{CACHE_PREFIX}/{hash}"
+ status = resp.statuses.get(artifact_name)
+ if status is not None and status.exists:
+ cached.add(hash)
+ return cached
+
def cache_artifact(self, blob: bytes) -> str:
"""
Cache the give blob at the session.
@@ -442,3 +466,34 @@ class ArtifactManager:
# TODO(SPARK-42658): Handle responses containing CRC failures.
return hash
+
+ def cache_artifacts(self, blobs: list[bytes]) -> list[str]:
+ """
+ Cache the given blobs at the session.
+
+ This method batches artifact status checks and uploads to minimize RPC
overhead.
+ """
+ # Compute hashes for all blobs upfront
+ hashes = [hashlib.sha256(blob).hexdigest() for blob in blobs]
+ unique_hashes = list(set(hashes))
+
+ # Batch check which artifacts are already cached
+ cached_hashes = self.get_cached_artifacts(unique_hashes)
+
+ # Collect unique artifacts that need to be uploaded
+ seen_hashes = set()
+ artifacts_to_add = []
+ for blob, hash in zip(blobs, hashes):
+ if hash not in cached_hashes and hash not in seen_hashes:
+ artifacts_to_add.append(new_cache_artifact(hash,
InMemory(blob)))
+ seen_hashes.add(hash)
+
+ # Batch upload all missing artifacts in a single RPC call
+ if artifacts_to_add:
+ requests = self._add_artifacts(artifacts_to_add)
+ response: proto.AddArtifactsResponse =
self._retrieve_responses(requests)
+ summaries: List[proto.AddArtifactsResponse.ArtifactSummary] = []
+ for summary in response.artifacts:
+ summaries.append(summary)
+ # TODO(SPARK-42658): Handle responses containing CRC failures.
+ return hashes
diff --git a/python/pyspark/sql/connect/client/core.py
b/python/pyspark/sql/connect/client/core.py
index d0d191dbd7fd..414781d67cd4 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -2003,6 +2003,12 @@ class SparkConnectClient(object):
return self._artifact_manager.cache_artifact(blob)
raise SparkConnectException("Invalid state during retry exception
handling.")
+ def cache_artifacts(self, blobs: list[bytes]) -> list[str]:
+ for attempt in self._retrying():
+ with attempt:
+ return self._artifact_manager.cache_artifacts(blobs)
+ raise SparkConnectException("Invalid state during retry exception
handling.")
+
def _verify_response_integrity(
self,
response: Union[
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index c5b6f5430d6d..82a6326c7dc5 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -429,16 +429,78 @@ class LocalRelation(LogicalPlan):
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
if self._table is not None:
- sink = pa.BufferOutputStream()
- with pa.ipc.new_stream(sink, self._table.schema) as writer:
- for b in self._table.to_batches():
- writer.write_batch(b)
- plan.local_relation.data = sink.getvalue().to_pybytes()
+ plan.local_relation.data = self._serialize_table()
if self._schema is not None:
plan.local_relation.schema = self._schema
return plan
+ def _serialize_table(self) -> bytes:
+ assert self._table is not None
+ sink = pa.BufferOutputStream()
+ with pa.ipc.new_stream(sink, self._table.schema) as writer:
+ batches = self._table.to_batches()
+ for b in batches:
+ writer.write_batch(b)
+ return sink.getvalue().to_pybytes()
+
+ def _serialize_table_chunks(
+ self,
+ max_chunk_size_rows: int,
+ max_chunk_size_bytes: int,
+ ) -> list[bytes]:
+ """
+ Serialize the table into multiple chunks, each up to
max_chunk_size_bytes bytes
+ and max_chunk_size_rows rows.
+ Each chunk is a valid Arrow IPC stream.
+
+ This method processes the table in fixed-size batches (1024 rows) for
+ efficiency, matching the Scala implementation's batchSizeCheckInterval.
+ """
+ assert self._table is not None
+ chunks = []
+ schema = self._table.schema
+
+ # Calculate schema serialization size once
+ schema_buffer = pa.BufferOutputStream()
+ with pa.ipc.new_stream(schema_buffer, schema):
+ pass # Just write schema
+ schema_size = len(schema_buffer.getvalue())
+
+ current_batches: list[pa.RecordBatch] = []
+ current_size = schema_size
+
+ for batch in self._table.to_batches(max_chunksize=min(1024,
max_chunk_size_rows)):
+ batch_size = sum(arr.nbytes for arr in batch.columns)
+
+ # If this batch would exceed limit and we have data, flush current
chunk
+ if current_size > schema_size and current_size + batch_size >
max_chunk_size_bytes:
+ combined = pa.Table.from_batches(current_batches,
schema=schema)
+ sink = pa.BufferOutputStream()
+ with pa.ipc.new_stream(sink, schema) as writer:
+ writer.write_table(combined)
+ chunks.append(sink.getvalue().to_pybytes())
+ current_batches = []
+ current_size = schema_size
+
+ current_batches.append(batch)
+ current_size += batch_size
+
+ # Flush remaining batches
+ if current_batches:
+ combined = pa.Table.from_batches(current_batches, schema=schema)
+ sink = pa.BufferOutputStream()
+ with pa.ipc.new_stream(sink, schema) as writer:
+ writer.write_table(combined)
+ chunks.append(sink.getvalue().to_pybytes())
+
+ return chunks
+
+ def _serialize_schema(self) -> bytes:
+ # the server uses UTF-8 for decoding the schema
+ assert self._schema is not None
+ return self._schema.encode("utf-8")
+
def serialize(self, session: "SparkConnectClient") -> bytes:
p = self.plan(session)
return bytes(p.local_relation.SerializeToString())
@@ -454,29 +516,34 @@ class LocalRelation(LogicalPlan):
"""
-class CachedLocalRelation(LogicalPlan):
+class ChunkedCachedLocalRelation(LogicalPlan):
"""Creates a CachedLocalRelation plan object based on a hash of a
LocalRelation."""
- def __init__(self, hash: str) -> None:
+ def __init__(self, data_hashes: list[str], schema_hash: Optional[str]) ->
None:
super().__init__(None)
- self._hash = hash
+ self._data_hashes = data_hashes
+ self._schema_hash = schema_hash
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = self._create_proto_relation()
- clr = plan.cached_local_relation
+ clr = plan.chunked_cached_local_relation
- clr.hash = self._hash
+ # Add hex string hashes directly to protobuf
+ for data_hash in self._data_hashes:
+ clr.dataHashes.append(data_hash)
+ if self._schema_hash is not None:
+ clr.schemaHash = self._schema_hash
return plan
def print(self, indent: int = 0) -> str:
- return f"{' ' * indent}<CachedLocalRelation>\n"
+ return f"{' ' * indent}<ChunkedCachedLocalRelation>\n"
def _repr_html_(self) -> str:
return """
<ul>
- <li><b>CachedLocalRelation</b></li>
+ <li><b>ChunkedCachedLocalRelation</b></li>
</ul>
"""
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py
b/python/pyspark/sql/connect/proto/relations_pb2.py
index be114f61e7d5..e7f319554c5e 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -43,7 +43,7 @@ from pyspark.sql.connect.proto import ml_common_pb2 as
spark_dot_connect_dot_ml_
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/ml_common.proto"\x9c\x1d\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x [...]
+
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto\x1a\x1aspark/connect/common.proto\x1a\x1dspark/connect/ml_common.proto"\x8c\x1e\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x [...]
)
_globals = globals()
@@ -79,171 +79,173 @@ if not _descriptor._USE_C_DESCRIPTORS:
_globals["_PARSE_OPTIONSENTRY"]._loaded_options = None
_globals["_PARSE_OPTIONSENTRY"]._serialized_options = b"8\001"
_globals["_RELATION"]._serialized_start = 224
- _globals["_RELATION"]._serialized_end = 3964
- _globals["_MLRELATION"]._serialized_start = 3967
- _globals["_MLRELATION"]._serialized_end = 4451
- _globals["_MLRELATION_TRANSFORM"]._serialized_start = 4179
- _globals["_MLRELATION_TRANSFORM"]._serialized_end = 4414
- _globals["_FETCH"]._serialized_start = 4454
- _globals["_FETCH"]._serialized_end = 4785
- _globals["_FETCH_METHOD"]._serialized_start = 4570
- _globals["_FETCH_METHOD"]._serialized_end = 4785
- _globals["_FETCH_METHOD_ARGS"]._serialized_start = 4658
- _globals["_FETCH_METHOD_ARGS"]._serialized_end = 4785
- _globals["_UNKNOWN"]._serialized_start = 4787
- _globals["_UNKNOWN"]._serialized_end = 4796
- _globals["_RELATIONCOMMON"]._serialized_start = 4799
- _globals["_RELATIONCOMMON"]._serialized_end = 4941
- _globals["_SQL"]._serialized_start = 4944
- _globals["_SQL"]._serialized_end = 5422
- _globals["_SQL_ARGSENTRY"]._serialized_start = 5238
- _globals["_SQL_ARGSENTRY"]._serialized_end = 5328
- _globals["_SQL_NAMEDARGUMENTSENTRY"]._serialized_start = 5330
- _globals["_SQL_NAMEDARGUMENTSENTRY"]._serialized_end = 5422
- _globals["_WITHRELATIONS"]._serialized_start = 5424
- _globals["_WITHRELATIONS"]._serialized_end = 5541
- _globals["_READ"]._serialized_start = 5544
- _globals["_READ"]._serialized_end = 6207
- _globals["_READ_NAMEDTABLE"]._serialized_start = 5722
- _globals["_READ_NAMEDTABLE"]._serialized_end = 5914
- _globals["_READ_NAMEDTABLE_OPTIONSENTRY"]._serialized_start = 5856
- _globals["_READ_NAMEDTABLE_OPTIONSENTRY"]._serialized_end = 5914
- _globals["_READ_DATASOURCE"]._serialized_start = 5917
- _globals["_READ_DATASOURCE"]._serialized_end = 6194
- _globals["_READ_DATASOURCE_OPTIONSENTRY"]._serialized_start = 5856
- _globals["_READ_DATASOURCE_OPTIONSENTRY"]._serialized_end = 5914
- _globals["_PROJECT"]._serialized_start = 6209
- _globals["_PROJECT"]._serialized_end = 6326
- _globals["_FILTER"]._serialized_start = 6328
- _globals["_FILTER"]._serialized_end = 6440
- _globals["_JOIN"]._serialized_start = 6443
- _globals["_JOIN"]._serialized_end = 7104
- _globals["_JOIN_JOINDATATYPE"]._serialized_start = 6782
- _globals["_JOIN_JOINDATATYPE"]._serialized_end = 6874
- _globals["_JOIN_JOINTYPE"]._serialized_start = 6877
- _globals["_JOIN_JOINTYPE"]._serialized_end = 7085
- _globals["_SETOPERATION"]._serialized_start = 7107
- _globals["_SETOPERATION"]._serialized_end = 7586
- _globals["_SETOPERATION_SETOPTYPE"]._serialized_start = 7423
- _globals["_SETOPERATION_SETOPTYPE"]._serialized_end = 7537
- _globals["_LIMIT"]._serialized_start = 7588
- _globals["_LIMIT"]._serialized_end = 7664
- _globals["_OFFSET"]._serialized_start = 7666
- _globals["_OFFSET"]._serialized_end = 7745
- _globals["_TAIL"]._serialized_start = 7747
- _globals["_TAIL"]._serialized_end = 7822
- _globals["_AGGREGATE"]._serialized_start = 7825
- _globals["_AGGREGATE"]._serialized_end = 8591
- _globals["_AGGREGATE_PIVOT"]._serialized_start = 8240
- _globals["_AGGREGATE_PIVOT"]._serialized_end = 8351
- _globals["_AGGREGATE_GROUPINGSETS"]._serialized_start = 8353
- _globals["_AGGREGATE_GROUPINGSETS"]._serialized_end = 8429
- _globals["_AGGREGATE_GROUPTYPE"]._serialized_start = 8432
- _globals["_AGGREGATE_GROUPTYPE"]._serialized_end = 8591
- _globals["_SORT"]._serialized_start = 8594
- _globals["_SORT"]._serialized_end = 8754
- _globals["_DROP"]._serialized_start = 8757
- _globals["_DROP"]._serialized_end = 8898
- _globals["_DEDUPLICATE"]._serialized_start = 8901
- _globals["_DEDUPLICATE"]._serialized_end = 9141
- _globals["_LOCALRELATION"]._serialized_start = 9143
- _globals["_LOCALRELATION"]._serialized_end = 9232
- _globals["_CACHEDLOCALRELATION"]._serialized_start = 9234
- _globals["_CACHEDLOCALRELATION"]._serialized_end = 9306
- _globals["_CACHEDREMOTERELATION"]._serialized_start = 9308
- _globals["_CACHEDREMOTERELATION"]._serialized_end = 9363
- _globals["_SAMPLE"]._serialized_start = 9366
- _globals["_SAMPLE"]._serialized_end = 9639
- _globals["_RANGE"]._serialized_start = 9642
- _globals["_RANGE"]._serialized_end = 9787
- _globals["_SUBQUERYALIAS"]._serialized_start = 9789
- _globals["_SUBQUERYALIAS"]._serialized_end = 9903
- _globals["_REPARTITION"]._serialized_start = 9906
- _globals["_REPARTITION"]._serialized_end = 10048
- _globals["_SHOWSTRING"]._serialized_start = 10051
- _globals["_SHOWSTRING"]._serialized_end = 10193
- _globals["_HTMLSTRING"]._serialized_start = 10195
- _globals["_HTMLSTRING"]._serialized_end = 10309
- _globals["_STATSUMMARY"]._serialized_start = 10311
- _globals["_STATSUMMARY"]._serialized_end = 10403
- _globals["_STATDESCRIBE"]._serialized_start = 10405
- _globals["_STATDESCRIBE"]._serialized_end = 10486
- _globals["_STATCROSSTAB"]._serialized_start = 10488
- _globals["_STATCROSSTAB"]._serialized_end = 10589
- _globals["_STATCOV"]._serialized_start = 10591
- _globals["_STATCOV"]._serialized_end = 10687
- _globals["_STATCORR"]._serialized_start = 10690
- _globals["_STATCORR"]._serialized_end = 10827
- _globals["_STATAPPROXQUANTILE"]._serialized_start = 10830
- _globals["_STATAPPROXQUANTILE"]._serialized_end = 10994
- _globals["_STATFREQITEMS"]._serialized_start = 10996
- _globals["_STATFREQITEMS"]._serialized_end = 11121
- _globals["_STATSAMPLEBY"]._serialized_start = 11124
- _globals["_STATSAMPLEBY"]._serialized_end = 11433
- _globals["_STATSAMPLEBY_FRACTION"]._serialized_start = 11325
- _globals["_STATSAMPLEBY_FRACTION"]._serialized_end = 11424
- _globals["_NAFILL"]._serialized_start = 11436
- _globals["_NAFILL"]._serialized_end = 11570
- _globals["_NADROP"]._serialized_start = 11573
- _globals["_NADROP"]._serialized_end = 11707
- _globals["_NAREPLACE"]._serialized_start = 11710
- _globals["_NAREPLACE"]._serialized_end = 12006
- _globals["_NAREPLACE_REPLACEMENT"]._serialized_start = 11865
- _globals["_NAREPLACE_REPLACEMENT"]._serialized_end = 12006
- _globals["_TODF"]._serialized_start = 12008
- _globals["_TODF"]._serialized_end = 12096
- _globals["_WITHCOLUMNSRENAMED"]._serialized_start = 12099
- _globals["_WITHCOLUMNSRENAMED"]._serialized_end = 12481
- _globals["_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY"]._serialized_start =
12343
- _globals["_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY"]._serialized_end =
12410
- _globals["_WITHCOLUMNSRENAMED_RENAME"]._serialized_start = 12412
- _globals["_WITHCOLUMNSRENAMED_RENAME"]._serialized_end = 12481
- _globals["_WITHCOLUMNS"]._serialized_start = 12483
- _globals["_WITHCOLUMNS"]._serialized_end = 12602
- _globals["_WITHWATERMARK"]._serialized_start = 12605
- _globals["_WITHWATERMARK"]._serialized_end = 12739
- _globals["_HINT"]._serialized_start = 12742
- _globals["_HINT"]._serialized_end = 12874
- _globals["_UNPIVOT"]._serialized_start = 12877
- _globals["_UNPIVOT"]._serialized_end = 13204
- _globals["_UNPIVOT_VALUES"]._serialized_start = 13134
- _globals["_UNPIVOT_VALUES"]._serialized_end = 13193
- _globals["_TRANSPOSE"]._serialized_start = 13206
- _globals["_TRANSPOSE"]._serialized_end = 13328
- _globals["_UNRESOLVEDTABLEVALUEDFUNCTION"]._serialized_start = 13330
- _globals["_UNRESOLVEDTABLEVALUEDFUNCTION"]._serialized_end = 13455
- _globals["_TOSCHEMA"]._serialized_start = 13457
- _globals["_TOSCHEMA"]._serialized_end = 13563
- _globals["_REPARTITIONBYEXPRESSION"]._serialized_start = 13566
- _globals["_REPARTITIONBYEXPRESSION"]._serialized_end = 13769
- _globals["_MAPPARTITIONS"]._serialized_start = 13772
- _globals["_MAPPARTITIONS"]._serialized_end = 14004
- _globals["_GROUPMAP"]._serialized_start = 14007
- _globals["_GROUPMAP"]._serialized_end = 14857
- _globals["_TRANSFORMWITHSTATEINFO"]._serialized_start = 14860
- _globals["_TRANSFORMWITHSTATEINFO"]._serialized_end = 15083
- _globals["_COGROUPMAP"]._serialized_start = 15086
- _globals["_COGROUPMAP"]._serialized_end = 15612
- _globals["_APPLYINPANDASWITHSTATE"]._serialized_start = 15615
- _globals["_APPLYINPANDASWITHSTATE"]._serialized_end = 15972
- _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_start = 15975
- _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_end = 16219
- _globals["_PYTHONUDTF"]._serialized_start = 16222
- _globals["_PYTHONUDTF"]._serialized_end = 16399
- _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_start = 16402
- _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_end = 16553
- _globals["_PYTHONDATASOURCE"]._serialized_start = 16555
- _globals["_PYTHONDATASOURCE"]._serialized_end = 16630
- _globals["_COLLECTMETRICS"]._serialized_start = 16633
- _globals["_COLLECTMETRICS"]._serialized_end = 16769
- _globals["_PARSE"]._serialized_start = 16772
- _globals["_PARSE"]._serialized_end = 17160
- _globals["_PARSE_OPTIONSENTRY"]._serialized_start = 5856
- _globals["_PARSE_OPTIONSENTRY"]._serialized_end = 5914
- _globals["_PARSE_PARSEFORMAT"]._serialized_start = 17061
- _globals["_PARSE_PARSEFORMAT"]._serialized_end = 17149
- _globals["_ASOFJOIN"]._serialized_start = 17163
- _globals["_ASOFJOIN"]._serialized_end = 17638
- _globals["_LATERALJOIN"]._serialized_start = 17641
- _globals["_LATERALJOIN"]._serialized_end = 17871
+ _globals["_RELATION"]._serialized_end = 4076
+ _globals["_MLRELATION"]._serialized_start = 4079
+ _globals["_MLRELATION"]._serialized_end = 4563
+ _globals["_MLRELATION_TRANSFORM"]._serialized_start = 4291
+ _globals["_MLRELATION_TRANSFORM"]._serialized_end = 4526
+ _globals["_FETCH"]._serialized_start = 4566
+ _globals["_FETCH"]._serialized_end = 4897
+ _globals["_FETCH_METHOD"]._serialized_start = 4682
+ _globals["_FETCH_METHOD"]._serialized_end = 4897
+ _globals["_FETCH_METHOD_ARGS"]._serialized_start = 4770
+ _globals["_FETCH_METHOD_ARGS"]._serialized_end = 4897
+ _globals["_UNKNOWN"]._serialized_start = 4899
+ _globals["_UNKNOWN"]._serialized_end = 4908
+ _globals["_RELATIONCOMMON"]._serialized_start = 4911
+ _globals["_RELATIONCOMMON"]._serialized_end = 5053
+ _globals["_SQL"]._serialized_start = 5056
+ _globals["_SQL"]._serialized_end = 5534
+ _globals["_SQL_ARGSENTRY"]._serialized_start = 5350
+ _globals["_SQL_ARGSENTRY"]._serialized_end = 5440
+ _globals["_SQL_NAMEDARGUMENTSENTRY"]._serialized_start = 5442
+ _globals["_SQL_NAMEDARGUMENTSENTRY"]._serialized_end = 5534
+ _globals["_WITHRELATIONS"]._serialized_start = 5536
+ _globals["_WITHRELATIONS"]._serialized_end = 5653
+ _globals["_READ"]._serialized_start = 5656
+ _globals["_READ"]._serialized_end = 6319
+ _globals["_READ_NAMEDTABLE"]._serialized_start = 5834
+ _globals["_READ_NAMEDTABLE"]._serialized_end = 6026
+ _globals["_READ_NAMEDTABLE_OPTIONSENTRY"]._serialized_start = 5968
+ _globals["_READ_NAMEDTABLE_OPTIONSENTRY"]._serialized_end = 6026
+ _globals["_READ_DATASOURCE"]._serialized_start = 6029
+ _globals["_READ_DATASOURCE"]._serialized_end = 6306
+ _globals["_READ_DATASOURCE_OPTIONSENTRY"]._serialized_start = 5968
+ _globals["_READ_DATASOURCE_OPTIONSENTRY"]._serialized_end = 6026
+ _globals["_PROJECT"]._serialized_start = 6321
+ _globals["_PROJECT"]._serialized_end = 6438
+ _globals["_FILTER"]._serialized_start = 6440
+ _globals["_FILTER"]._serialized_end = 6552
+ _globals["_JOIN"]._serialized_start = 6555
+ _globals["_JOIN"]._serialized_end = 7216
+ _globals["_JOIN_JOINDATATYPE"]._serialized_start = 6894
+ _globals["_JOIN_JOINDATATYPE"]._serialized_end = 6986
+ _globals["_JOIN_JOINTYPE"]._serialized_start = 6989
+ _globals["_JOIN_JOINTYPE"]._serialized_end = 7197
+ _globals["_SETOPERATION"]._serialized_start = 7219
+ _globals["_SETOPERATION"]._serialized_end = 7698
+ _globals["_SETOPERATION_SETOPTYPE"]._serialized_start = 7535
+ _globals["_SETOPERATION_SETOPTYPE"]._serialized_end = 7649
+ _globals["_LIMIT"]._serialized_start = 7700
+ _globals["_LIMIT"]._serialized_end = 7776
+ _globals["_OFFSET"]._serialized_start = 7778
+ _globals["_OFFSET"]._serialized_end = 7857
+ _globals["_TAIL"]._serialized_start = 7859
+ _globals["_TAIL"]._serialized_end = 7934
+ _globals["_AGGREGATE"]._serialized_start = 7937
+ _globals["_AGGREGATE"]._serialized_end = 8703
+ _globals["_AGGREGATE_PIVOT"]._serialized_start = 8352
+ _globals["_AGGREGATE_PIVOT"]._serialized_end = 8463
+ _globals["_AGGREGATE_GROUPINGSETS"]._serialized_start = 8465
+ _globals["_AGGREGATE_GROUPINGSETS"]._serialized_end = 8541
+ _globals["_AGGREGATE_GROUPTYPE"]._serialized_start = 8544
+ _globals["_AGGREGATE_GROUPTYPE"]._serialized_end = 8703
+ _globals["_SORT"]._serialized_start = 8706
+ _globals["_SORT"]._serialized_end = 8866
+ _globals["_DROP"]._serialized_start = 8869
+ _globals["_DROP"]._serialized_end = 9010
+ _globals["_DEDUPLICATE"]._serialized_start = 9013
+ _globals["_DEDUPLICATE"]._serialized_end = 9253
+ _globals["_LOCALRELATION"]._serialized_start = 9255
+ _globals["_LOCALRELATION"]._serialized_end = 9344
+ _globals["_CACHEDLOCALRELATION"]._serialized_start = 9346
+ _globals["_CACHEDLOCALRELATION"]._serialized_end = 9418
+ _globals["_CHUNKEDCACHEDLOCALRELATION"]._serialized_start = 9420
+ _globals["_CHUNKEDCACHEDLOCALRELATION"]._serialized_end = 9532
+ _globals["_CACHEDREMOTERELATION"]._serialized_start = 9534
+ _globals["_CACHEDREMOTERELATION"]._serialized_end = 9589
+ _globals["_SAMPLE"]._serialized_start = 9592
+ _globals["_SAMPLE"]._serialized_end = 9865
+ _globals["_RANGE"]._serialized_start = 9868
+ _globals["_RANGE"]._serialized_end = 10013
+ _globals["_SUBQUERYALIAS"]._serialized_start = 10015
+ _globals["_SUBQUERYALIAS"]._serialized_end = 10129
+ _globals["_REPARTITION"]._serialized_start = 10132
+ _globals["_REPARTITION"]._serialized_end = 10274
+ _globals["_SHOWSTRING"]._serialized_start = 10277
+ _globals["_SHOWSTRING"]._serialized_end = 10419
+ _globals["_HTMLSTRING"]._serialized_start = 10421
+ _globals["_HTMLSTRING"]._serialized_end = 10535
+ _globals["_STATSUMMARY"]._serialized_start = 10537
+ _globals["_STATSUMMARY"]._serialized_end = 10629
+ _globals["_STATDESCRIBE"]._serialized_start = 10631
+ _globals["_STATDESCRIBE"]._serialized_end = 10712
+ _globals["_STATCROSSTAB"]._serialized_start = 10714
+ _globals["_STATCROSSTAB"]._serialized_end = 10815
+ _globals["_STATCOV"]._serialized_start = 10817
+ _globals["_STATCOV"]._serialized_end = 10913
+ _globals["_STATCORR"]._serialized_start = 10916
+ _globals["_STATCORR"]._serialized_end = 11053
+ _globals["_STATAPPROXQUANTILE"]._serialized_start = 11056
+ _globals["_STATAPPROXQUANTILE"]._serialized_end = 11220
+ _globals["_STATFREQITEMS"]._serialized_start = 11222
+ _globals["_STATFREQITEMS"]._serialized_end = 11347
+ _globals["_STATSAMPLEBY"]._serialized_start = 11350
+ _globals["_STATSAMPLEBY"]._serialized_end = 11659
+ _globals["_STATSAMPLEBY_FRACTION"]._serialized_start = 11551
+ _globals["_STATSAMPLEBY_FRACTION"]._serialized_end = 11650
+ _globals["_NAFILL"]._serialized_start = 11662
+ _globals["_NAFILL"]._serialized_end = 11796
+ _globals["_NADROP"]._serialized_start = 11799
+ _globals["_NADROP"]._serialized_end = 11933
+ _globals["_NAREPLACE"]._serialized_start = 11936
+ _globals["_NAREPLACE"]._serialized_end = 12232
+ _globals["_NAREPLACE_REPLACEMENT"]._serialized_start = 12091
+ _globals["_NAREPLACE_REPLACEMENT"]._serialized_end = 12232
+ _globals["_TODF"]._serialized_start = 12234
+ _globals["_TODF"]._serialized_end = 12322
+ _globals["_WITHCOLUMNSRENAMED"]._serialized_start = 12325
+ _globals["_WITHCOLUMNSRENAMED"]._serialized_end = 12707
+ _globals["_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY"]._serialized_start =
12569
+ _globals["_WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY"]._serialized_end =
12636
+ _globals["_WITHCOLUMNSRENAMED_RENAME"]._serialized_start = 12638
+ _globals["_WITHCOLUMNSRENAMED_RENAME"]._serialized_end = 12707
+ _globals["_WITHCOLUMNS"]._serialized_start = 12709
+ _globals["_WITHCOLUMNS"]._serialized_end = 12828
+ _globals["_WITHWATERMARK"]._serialized_start = 12831
+ _globals["_WITHWATERMARK"]._serialized_end = 12965
+ _globals["_HINT"]._serialized_start = 12968
+ _globals["_HINT"]._serialized_end = 13100
+ _globals["_UNPIVOT"]._serialized_start = 13103
+ _globals["_UNPIVOT"]._serialized_end = 13430
+ _globals["_UNPIVOT_VALUES"]._serialized_start = 13360
+ _globals["_UNPIVOT_VALUES"]._serialized_end = 13419
+ _globals["_TRANSPOSE"]._serialized_start = 13432
+ _globals["_TRANSPOSE"]._serialized_end = 13554
+ _globals["_UNRESOLVEDTABLEVALUEDFUNCTION"]._serialized_start = 13556
+ _globals["_UNRESOLVEDTABLEVALUEDFUNCTION"]._serialized_end = 13681
+ _globals["_TOSCHEMA"]._serialized_start = 13683
+ _globals["_TOSCHEMA"]._serialized_end = 13789
+ _globals["_REPARTITIONBYEXPRESSION"]._serialized_start = 13792
+ _globals["_REPARTITIONBYEXPRESSION"]._serialized_end = 13995
+ _globals["_MAPPARTITIONS"]._serialized_start = 13998
+ _globals["_MAPPARTITIONS"]._serialized_end = 14230
+ _globals["_GROUPMAP"]._serialized_start = 14233
+ _globals["_GROUPMAP"]._serialized_end = 15083
+ _globals["_TRANSFORMWITHSTATEINFO"]._serialized_start = 15086
+ _globals["_TRANSFORMWITHSTATEINFO"]._serialized_end = 15309
+ _globals["_COGROUPMAP"]._serialized_start = 15312
+ _globals["_COGROUPMAP"]._serialized_end = 15838
+ _globals["_APPLYINPANDASWITHSTATE"]._serialized_start = 15841
+ _globals["_APPLYINPANDASWITHSTATE"]._serialized_end = 16198
+ _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_start = 16201
+ _globals["_COMMONINLINEUSERDEFINEDTABLEFUNCTION"]._serialized_end = 16445
+ _globals["_PYTHONUDTF"]._serialized_start = 16448
+ _globals["_PYTHONUDTF"]._serialized_end = 16625
+ _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_start = 16628
+ _globals["_COMMONINLINEUSERDEFINEDDATASOURCE"]._serialized_end = 16779
+ _globals["_PYTHONDATASOURCE"]._serialized_start = 16781
+ _globals["_PYTHONDATASOURCE"]._serialized_end = 16856
+ _globals["_COLLECTMETRICS"]._serialized_start = 16859
+ _globals["_COLLECTMETRICS"]._serialized_end = 16995
+ _globals["_PARSE"]._serialized_start = 16998
+ _globals["_PARSE"]._serialized_end = 17386
+ _globals["_PARSE_OPTIONSENTRY"]._serialized_start = 5968
+ _globals["_PARSE_OPTIONSENTRY"]._serialized_end = 6026
+ _globals["_PARSE_PARSEFORMAT"]._serialized_start = 17287
+ _globals["_PARSE_PARSEFORMAT"]._serialized_end = 17375
+ _globals["_ASOFJOIN"]._serialized_start = 17389
+ _globals["_ASOFJOIN"]._serialized_end = 17864
+ _globals["_LATERALJOIN"]._serialized_start = 17867
+ _globals["_LATERALJOIN"]._serialized_end = 18097
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index e1eb7945c19f..c6f20c158a6c 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -108,6 +108,7 @@ class Relation(google.protobuf.message.Message):
TRANSPOSE_FIELD_NUMBER: builtins.int
UNRESOLVED_TABLE_VALUED_FUNCTION_FIELD_NUMBER: builtins.int
LATERAL_JOIN_FIELD_NUMBER: builtins.int
+ CHUNKED_CACHED_LOCAL_RELATION_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
@@ -216,6 +217,8 @@ class Relation(google.protobuf.message.Message):
@property
def lateral_join(self) -> global___LateralJoin: ...
@property
+ def chunked_cached_local_relation(self) ->
global___ChunkedCachedLocalRelation: ...
+ @property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
@@ -301,6 +304,7 @@ class Relation(google.protobuf.message.Message):
transpose: global___Transpose | None = ...,
unresolved_table_valued_function:
global___UnresolvedTableValuedFunction | None = ...,
lateral_join: global___LateralJoin | None = ...,
+ chunked_cached_local_relation: global___ChunkedCachedLocalRelation |
None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
@@ -334,6 +338,8 @@ class Relation(google.protobuf.message.Message):
b"cached_remote_relation",
"catalog",
b"catalog",
+ "chunked_cached_local_relation",
+ b"chunked_cached_local_relation",
"co_group_map",
b"co_group_map",
"collect_metrics",
@@ -459,6 +465,8 @@ class Relation(google.protobuf.message.Message):
b"cached_remote_relation",
"catalog",
b"catalog",
+ "chunked_cached_local_relation",
+ b"chunked_cached_local_relation",
"co_group_map",
b"co_group_map",
"collect_metrics",
@@ -614,6 +622,7 @@ class Relation(google.protobuf.message.Message):
"transpose",
"unresolved_table_valued_function",
"lateral_join",
+ "chunked_cached_local_relation",
"fill_na",
"drop_na",
"replace",
@@ -2084,7 +2093,9 @@ class LocalRelation(google.protobuf.message.Message):
global___LocalRelation = LocalRelation
class CachedLocalRelation(google.protobuf.message.Message):
- """A local relation that has been cached already."""
+ """A local relation that has been cached already.
+ CachedLocalRelation doesn't support LocalRelations of size over 2GB.
+ """
DESCRIPTOR: google.protobuf.descriptor.Descriptor
@@ -2100,6 +2111,51 @@ class
CachedLocalRelation(google.protobuf.message.Message):
global___CachedLocalRelation = CachedLocalRelation
+class ChunkedCachedLocalRelation(google.protobuf.message.Message):
+ """A local relation that has been cached already."""
+
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ DATAHASHES_FIELD_NUMBER: builtins.int
+ SCHEMAHASH_FIELD_NUMBER: builtins.int
+ @property
+ def dataHashes(
+ self,
+ ) ->
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+ """(Required) A list of sha-256 hashes for representing
LocalRelation.data.
+ Data is serialized in Arrow IPC streaming format, each batch is cached
on the server as
+ a separate artifact. Each hash represents one batch stored on the
server.
+ Hashes are hex-encoded strings (e.g., "a3b2c1d4...").
+ """
+ schemaHash: builtins.str
+ """(Optional) A sha-256 hash of the serialized LocalRelation.schema.
+ Scala clients always provide the schema, Python clients can omit it.
+ Hash is a hex-encoded string (e.g., "a3b2c1d4...").
+ """
+ def __init__(
+ self,
+ *,
+ dataHashes: collections.abc.Iterable[builtins.str] | None = ...,
+ schemaHash: builtins.str | None = ...,
+ ) -> None: ...
+ def HasField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_schemaHash", b"_schemaHash", "schemaHash", b"schemaHash"
+ ],
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "_schemaHash", b"_schemaHash", "dataHashes", b"dataHashes",
"schemaHash", b"schemaHash"
+ ],
+ ) -> None: ...
+ def WhichOneof(
+ self, oneof_group: typing_extensions.Literal["_schemaHash",
b"_schemaHash"]
+ ) -> typing_extensions.Literal["schemaHash"] | None: ...
+
+global___ChunkedCachedLocalRelation = ChunkedCachedLocalRelation
+
class CachedRemoteRelation(google.protobuf.message.Message):
"""Represents a remote relation that has been cached on server."""
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index f759137fac1d..2a678c95c925 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -61,7 +61,7 @@ from pyspark.sql.connect.plan import (
Range,
LocalRelation,
LogicalPlan,
- CachedLocalRelation,
+ ChunkedCachedLocalRelation,
CachedRelation,
CachedRemoteRelation,
SubqueryAlias,
@@ -535,6 +535,8 @@ class SparkSession:
"spark.sql.timestampType",
"spark.sql.session.timeZone",
"spark.sql.session.localRelationCacheThreshold",
+ "spark.sql.session.localRelationChunkSizeRows",
+ "spark.sql.session.localRelationChunkSizeBytes",
"spark.sql.execution.pandas.convertToArrowArraySafely",
"spark.sql.execution.pandas.inferPandasDictAsMap",
"spark.sql.pyspark.inferNestedDictAsStruct.enabled",
@@ -755,10 +757,21 @@ class SparkSession:
else:
local_relation = LocalRelation(_table)
- cache_threshold =
configs["spark.sql.session.localRelationCacheThreshold"]
+ # get_config_dict throws [SQL_CONF_NOT_FOUND] if the key is not found.
+ cache_threshold = int(
+ configs["spark.sql.session.localRelationCacheThreshold"] # type:
ignore[arg-type]
+ )
+ max_chunk_size_rows = int(
+ configs["spark.sql.session.localRelationChunkSizeRows"] # type:
ignore[arg-type]
+ )
+ max_chunk_size_bytes = int(
+ configs["spark.sql.session.localRelationChunkSizeBytes"] # type:
ignore[arg-type]
+ )
plan: LogicalPlan = local_relation
- if cache_threshold is not None and int(cache_threshold) <=
_table.nbytes:
- plan =
CachedLocalRelation(self._cache_local_relation(local_relation))
+ if cache_threshold <= _table.nbytes:
+ plan = self._cache_local_relation(
+ local_relation, max_chunk_size_rows, max_chunk_size_bytes
+ )
df = DataFrame(plan, self)
if _cols is not None and len(_cols) > 0:
@@ -1031,12 +1044,36 @@ class SparkSession:
addArtifact = addArtifacts
- def _cache_local_relation(self, local_relation: LocalRelation) -> str:
+ def _cache_local_relation(
+ self,
+ local_relation: LocalRelation,
+ max_chunk_size_rows: int,
+ max_chunk_size_bytes: int,
+ ) -> ChunkedCachedLocalRelation:
"""
Cache the local relation at the server side if it has not been cached
yet.
+
+ Should only be called on LocalRelations with _table set.
"""
- serialized = local_relation.serialize(self._client)
- return self._client.cache_artifact(serialized)
+ assert local_relation._table is not None
+ has_schema = local_relation._schema is not None
+
+ # Serialize table into chunks
+ data_chunks = local_relation._serialize_table_chunks(
+ max_chunk_size_rows, max_chunk_size_bytes
+ )
+ blobs = data_chunks.copy() # Start with data chunks
+
+ if has_schema:
+ blobs.append(local_relation._serialize_schema())
+
+ hashes = self._client.cache_artifacts(blobs)
+
+ # Extract data hashes and schema hash
+ data_hashes = hashes[: len(data_chunks)]
+ schema_hash = hashes[len(data_chunks)] if has_schema else None
+
+ return ChunkedCachedLocalRelation(data_hashes, schema_hash)
def copyFromLocalToFs(self, local_path: str, dest_path: str) -> None:
if urllib.parse.urlparse(dest_path).scheme:
diff --git a/python/pyspark/sql/tests/arrow/test_arrow.py
b/python/pyspark/sql/tests/arrow/test_arrow.py
index 819639c63a2c..be7dd2febc94 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow.py
@@ -420,6 +420,29 @@ class ArrowTestsMixin:
)
assert_frame_equal(pdf_ny, pdf_la_corrected)
+ def check_cached_local_relation_changing_values(self):
+ import random
+ import string
+
+ row_size = 1000
+ row_count = 64 * 1000
+ suffix = "abcdef"
+ str_value = (
+ "".join(random.choices(string.ascii_letters + string.digits,
k=row_size)) + suffix
+ )
+ data = [(i, str_value) for i in range(row_count)]
+
+ for _ in range(2):
+ df = self.spark.createDataFrame(data, ["col1", "col2"])
+ assert df.count() == row_count
+ assert not df.filter(df["col2"].endswith(suffix)).isEmpty()
+
+ def check_large_cached_local_relation_same_values(self):
+ data = [("C000000032", "R20", 0.2555)] * 500_000
+ pdf = pd.DataFrame(data=data, columns=["Contrat", "Recommandation",
"Distance"])
+ df = self.spark.createDataFrame(pdf)
+ df.collect()
+
def test_toArrow_keep_utc_timezone(self):
df = self.spark.createDataFrame(self.data, schema=self.schema)
diff --git a/python/pyspark/sql/tests/connect/arrow/test_parity_arrow.py
b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow.py
index fa8cf286b9bd..2cc089a7c0d5 100644
--- a/python/pyspark/sql/tests/connect/arrow/test_parity_arrow.py
+++ b/python/pyspark/sql/tests/connect/arrow/test_parity_arrow.py
@@ -78,6 +78,12 @@ class ArrowParityTests(ArrowTestsMixin,
ReusedConnectTestCase, PandasOnSparkTest
def test_toPandas_respect_session_timezone(self):
self.check_toPandas_respect_session_timezone(True)
+ def test_cached_local_relation_changing_values(self):
+ self.check_cached_local_relation_changing_values()
+
+ def test_large_cached_local_relation_same_values(self):
+ self.check_large_cached_local_relation_same_values()
+
def test_toPandas_with_array_type(self):
self.check_toPandas_with_array_type(True)
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
index 9a69c3d2488f..f715f8f9ed8c 100644
--- a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
+++ b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConf.scala
@@ -61,9 +61,12 @@ private[sql] object SqlApiConf {
val SESSION_LOCAL_TIMEZONE_KEY: String =
SqlApiConfHelper.SESSION_LOCAL_TIMEZONE_KEY
val ARROW_EXECUTION_USE_LARGE_VAR_TYPES: String =
SqlApiConfHelper.ARROW_EXECUTION_USE_LARGE_VAR_TYPES
- val LOCAL_RELATION_CACHE_THRESHOLD_KEY: String = {
+ val LOCAL_RELATION_CACHE_THRESHOLD_KEY: String =
SqlApiConfHelper.LOCAL_RELATION_CACHE_THRESHOLD_KEY
- }
+ val LOCAL_RELATION_CHUNK_SIZE_ROWS_KEY: String =
+ SqlApiConfHelper.LOCAL_RELATION_CHUNK_SIZE_ROWS_KEY
+ val LOCAL_RELATION_CHUNK_SIZE_BYTES_KEY: String =
+ SqlApiConfHelper.LOCAL_RELATION_CHUNK_SIZE_BYTES_KEY
val PARSER_DFA_CACHE_FLUSH_THRESHOLD_KEY: String =
SqlApiConfHelper.PARSER_DFA_CACHE_FLUSH_THRESHOLD_KEY
val PARSER_DFA_CACHE_FLUSH_RATIO_KEY: String =
diff --git
a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala
b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala
index 727620bd5bd0..b839caba3f54 100644
---
a/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala
+++
b/sql/api/src/main/scala/org/apache/spark/sql/internal/SqlApiConfHelper.scala
@@ -32,6 +32,9 @@ private[sql] object SqlApiConfHelper {
val CASE_SENSITIVE_KEY: String = "spark.sql.caseSensitive"
val SESSION_LOCAL_TIMEZONE_KEY: String = "spark.sql.session.timeZone"
val LOCAL_RELATION_CACHE_THRESHOLD_KEY: String =
"spark.sql.session.localRelationCacheThreshold"
+ val LOCAL_RELATION_CHUNK_SIZE_ROWS_KEY: String =
"spark.sql.session.localRelationChunkSizeRows"
+ val LOCAL_RELATION_CHUNK_SIZE_BYTES_KEY: String =
+ "spark.sql.session.localRelationChunkSizeBytes"
val ARROW_EXECUTION_USE_LARGE_VAR_TYPES =
"spark.sql.execution.arrow.useLargeVarTypes"
val PARSER_DFA_CACHE_FLUSH_THRESHOLD_KEY: String =
"spark.sql.parser.parserDfaCacheFlushThreshold"
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index dd0dbe36d69a..f1e4acd00a52 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -5919,7 +5919,47 @@ object SQLConf {
.version("3.5.0")
.intConf
.checkValue(_ >= 0, "The threshold of cached local relations must not be
negative")
- .createWithDefault(64 * 1024 * 1024)
+ .createWithDefault(1024 * 1024)
+
+ val LOCAL_RELATION_CHUNK_SIZE_ROWS =
+ buildConf(SqlApiConfHelper.LOCAL_RELATION_CHUNK_SIZE_ROWS_KEY)
+ .doc("The chunk size in number of rows when splitting
ChunkedCachedLocalRelation.data " +
+ "into batches. A new chunk is created when either " +
+ "spark.sql.session.localRelationChunkSizeBytes " +
+ "or spark.sql.session.localRelationChunkSizeRows is reached.")
+ .version("4.1.0")
+ .intConf
+ .checkValue(_ > 0, "The chunk size in number of rows must be positive")
+ .createWithDefault(10000)
+
+ val LOCAL_RELATION_CHUNK_SIZE_BYTES =
+ buildConf(SqlApiConfHelper.LOCAL_RELATION_CHUNK_SIZE_BYTES_KEY)
+ .doc("The chunk size in bytes when splitting
ChunkedCachedLocalRelation.data " +
+ "into batches. A new chunk is created when either " +
+ "spark.sql.session.localRelationChunkSizeBytes " +
+ "or spark.sql.session.localRelationChunkSizeRows is reached.")
+ .version("4.1.0")
+ .longConf
+ .checkValue(_ > 0, "The chunk size in bytes must be positive")
+ .createWithDefault(16 * 1024 * 1024L)
+
+ val LOCAL_RELATION_CHUNK_SIZE_LIMIT =
+ buildConf("spark.sql.session.localRelationChunkSizeLimit")
+ .internal()
+ .doc("Limit on how large a single chunk of a
ChunkedCachedLocalRelation.data " +
+ "can be in bytes. If the limit is exceeded, an exception is thrown.")
+ .version("4.1.0")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefaultString("2000MB")
+
+ val LOCAL_RELATION_SIZE_LIMIT =
+ buildConf("spark.sql.session.localRelationSizeLimit")
+ .internal()
+ .doc("Limit on how large ChunkedCachedLocalRelation.data can be in
bytes." +
+ "If the limit is exceeded, an exception is thrown.")
+ .version("4.1.0")
+ .bytesConf(ByteUnit.BYTE)
+ .createWithDefaultString("3GB")
val DECORRELATE_JOIN_PREDICATE_ENABLED =
buildConf("spark.sql.optimizer.decorrelateJoinPredicate.enabled")
@@ -7165,6 +7205,10 @@ class SQLConf extends Serializable with Logging with
SqlApiConf {
def rangeExchangeSampleSizePerPartition: Int =
getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION)
+ def localRelationChunkSizeLimit: Long =
getConf(LOCAL_RELATION_CHUNK_SIZE_LIMIT)
+
+ def localRelationSizeLimit: Long = getConf(LOCAL_RELATION_SIZE_LIMIT)
+
def arrowPySparkEnabled: Boolean = getConf(ARROW_PYSPARK_EXECUTION_ENABLED)
def arrowLocalRelationThreshold: Long =
getConf(ARROW_LOCAL_RELATION_THRESHOLD)
diff --git
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala
index 4c0073cad567..6678a11a80b0 100644
---
a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala
+++
b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/SparkSessionE2ESuite.scala
@@ -450,4 +450,30 @@ class SparkSessionE2ESuite extends ConnectFunSuite with
RemoteSparkSession {
Map("one" -> "1", "two" -> "2"))
assert(df.as(StringEncoder).collect().toSet == Set("one", "two"))
}
+
+ test("dataframes with cached local relations succeed - changing values") {
+ val rowSize = 1000
+ val rowCount = 64 * 1000
+ val suffix = "abcdef"
+ val str = scala.util.Random.alphanumeric.take(rowSize).mkString + suffix
+ val data = Seq.tabulate(rowCount)(i => (i, str))
+ for (_ <- 0 until 2) {
+ val df = spark.createDataFrame(data)
+ assert(df.count() === rowCount)
+ assert(!df.filter(df("_2").endsWith(suffix)).isEmpty)
+ }
+ }
+
+ test("dataframes with cached local relations succeed - same values") {
+ val rowSize = 1000
+ val rowCount = 64 * 1000
+ val suffix = "abcdef"
+ val str = scala.util.Random.alphanumeric.take(rowSize).mkString + suffix
+ val data = Seq.tabulate(rowCount)(_ => (0, str))
+ for (_ <- 0 until 2) {
+ val df = spark.createDataFrame(data)
+ assert(df.count() === rowCount)
+ assert(!df.filter(df("_2").endsWith(suffix)).isEmpty)
+ }
+ }
}
diff --git a/sql/connect/common/src/main/protobuf/spark/connect/relations.proto
b/sql/connect/common/src/main/protobuf/spark/connect/relations.proto
index ccb674e812dc..1583785e69fb 100644
--- a/sql/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/sql/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -80,6 +80,7 @@ message Relation {
Transpose transpose = 42;
UnresolvedTableValuedFunction unresolved_table_valued_function = 43;
LateralJoin lateral_join = 44;
+ ChunkedCachedLocalRelation chunked_cached_local_relation = 45;
// NA functions
NAFill fill_na = 90;
@@ -499,6 +500,7 @@ message LocalRelation {
}
// A local relation that has been cached already.
+// CachedLocalRelation doesn't support LocalRelations of size over 2GB.
message CachedLocalRelation {
// `userId` and `sessionId` fields are deleted since the server must always
use the active
// session/user rather than arbitrary values provided by the client. It is
never valid to access
@@ -510,6 +512,20 @@ message CachedLocalRelation {
string hash = 3;
}
+// A local relation that has been cached already.
+message ChunkedCachedLocalRelation {
+ // (Required) A list of sha-256 hashes for representing LocalRelation.data.
+ // Data is serialized in Arrow IPC streaming format, each batch is cached on
the server as
+ // a separate artifact. Each hash represents one batch stored on the server.
+ // Hashes are hex-encoded strings (e.g., "a3b2c1d4...").
+ repeated string dataHashes = 1;
+
+ // (Optional) A sha-256 hash of the serialized LocalRelation.schema.
+ // Scala clients always provide the schema, Python clients can omit it.
+ // Hash is a hex-encoded string (e.g., "a3b2c1d4...").
+ optional string schemaHash = 2;
+}
+
// Represents a remote relation that has been cached on server.
message CachedRemoteRelation {
// (Required) ID of the remote related (assigned by the service).
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
index f7869a8b4dd8..0d9d4e5d60f0 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/SparkSession.scala
@@ -31,6 +31,7 @@ import scala.reflect.runtime.universe.TypeTag
import scala.util.Try
import com.google.common.cache.{CacheBuilder, CacheLoader}
+import com.google.protobuf.ByteString
import io.grpc.ClientInterceptor
import org.apache.arrow.memory.RootAllocator
@@ -116,16 +117,40 @@ class SparkSession private[sql] (
private def createDataset[T](encoder: AgnosticEncoder[T], data:
Iterator[T]): Dataset[T] = {
newDataset(encoder) { builder =>
if (data.nonEmpty) {
- val arrowData =
- ArrowSerializer.serialize(data, encoder, allocator, timeZoneId,
largeVarTypes)
- if (arrowData.size() <=
conf.get(SqlApiConf.LOCAL_RELATION_CACHE_THRESHOLD_KEY).toInt) {
+ val threshold =
conf.get(SqlApiConf.LOCAL_RELATION_CACHE_THRESHOLD_KEY).toInt
+ val maxRecordsPerBatch =
conf.get(SqlApiConf.LOCAL_RELATION_CHUNK_SIZE_ROWS_KEY).toInt
+ val maxBatchSize =
conf.get(SqlApiConf.LOCAL_RELATION_CHUNK_SIZE_BYTES_KEY).toInt
+ // Serialize with chunking support
+ val it = ArrowSerializer.serialize(
+ data,
+ encoder,
+ allocator,
+ maxRecordsPerBatch = maxRecordsPerBatch,
+ maxBatchSize = maxBatchSize,
+ timeZoneId = timeZoneId,
+ largeVarTypes = largeVarTypes,
+ batchSizeCheckInterval = math.min(1024, maxRecordsPerBatch))
+
+ val chunks =
+ try {
+ it.toArray
+ } finally {
+ it.close()
+ }
+
+ // If we got multiple chunks or a single large chunk, use
ChunkedCachedLocalRelation
+ val totalSize = chunks.map(_.length).sum
+ if (chunks.length > 1 || totalSize > threshold) {
+ val (dataHashes, schemaHash) = client.cacheLocalRelation(chunks,
encoder.schema.json)
+ builder.getChunkedCachedLocalRelationBuilder
+ .setSchemaHash(schemaHash)
+ .addAllDataHashes(dataHashes.asJava)
+ } else {
+ // Small data, use LocalRelation directly
+ val arrowData = ByteString.copyFrom(chunks(0))
builder.getLocalRelationBuilder
.setSchema(encoder.schema.json)
.setData(arrowData)
- } else {
- val hash = client.cacheLocalRelation(arrowData, encoder.schema.json)
- builder.getCachedLocalRelationBuilder
- .setHash(hash)
}
} else {
builder.getLocalRelationBuilder
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
index 213cd1d2e867..44a2a7aa9a2f 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/ArtifactManager.scala
@@ -185,6 +185,38 @@ class ArtifactManager(
} else false
}
+ /**
+ * Batch check which artifacts are already cached on the server. Returns a
Set of hashes that
+ * are already cached.
+ */
+ private[client] def getCachedArtifacts(hashes: Seq[String]): Set[String] = {
+ if (hashes.isEmpty) {
+ return Set.empty
+ }
+
+ val artifactNames = hashes.map(hash => s"${Artifact.CACHE_PREFIX}/$hash")
+ val request = proto.ArtifactStatusesRequest
+ .newBuilder()
+ .setUserContext(clientConfig.userContext)
+ .setClientType(clientConfig.userAgent)
+ .setSessionId(sessionId)
+ .addAllNames(artifactNames.asJava)
+ .build()
+
+ val response = bstub.artifactStatus(request)
+ if (SparkStringUtils.isNotEmpty(response.getSessionId) &&
+ response.getSessionId != sessionId) {
+ throw new IllegalStateException(
+ s"Session ID mismatch: $sessionId != ${response.getSessionId}")
+ }
+
+ val statuses = response.getStatusesMap
+ hashes.filter { hash =>
+ val artifactName = s"${Artifact.CACHE_PREFIX}/$hash"
+ statuses.containsKey(artifactName) &&
statuses.get(artifactName).getExists
+ }.toSet
+ }
+
/**
* Cache the give blob at the session.
*/
@@ -196,6 +228,38 @@ class ArtifactManager(
hash
}
+ /**
+ * Cache the given blobs at the session.
+ *
+ * This method batches artifact status checks and uploads to minimize RPC
overhead. Returns the
+ * list of hashes corresponding to the input blobs.
+ */
+ def cacheArtifacts(blobs: Array[Array[Byte]]): Seq[String] = {
+ // Compute hashes for all blobs upfront
+ val hashes = blobs.map(sha256Hex).toSeq
+ val uniqueHashes = hashes.distinct
+
+ // Batch check which artifacts are already cached
+ val cachedHashes = getCachedArtifacts(uniqueHashes)
+
+ // Collect unique artifacts that need to be uploaded
+ val seenHashes = scala.collection.mutable.Set[String]()
+ val uniqueBlobsToUpload = scala.collection.mutable.ListBuffer[Artifact]()
+ for ((blob, hash) <- blobs.zip(hashes)) {
+ if (!cachedHashes.contains(hash) && !seenHashes.contains(hash)) {
+ uniqueBlobsToUpload += newCacheArtifact(hash, new
Artifact.InMemory(blob))
+ seenHashes.add(hash)
+ }
+ }
+
+ // Batch upload all missing artifacts in a single RPC call
+ if (uniqueBlobsToUpload.nonEmpty) {
+ addArtifacts(uniqueBlobsToUpload.toList)
+ }
+
+ hashes
+ }
+
/**
* Upload all class file artifacts from the local REPL(s) to the server.
*
diff --git
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 3c328681dd9a..fa32eba91eb2 100644
---
a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++
b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -25,7 +25,6 @@ import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.util.Properties
-import com.google.protobuf.ByteString
import io.grpc._
import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION}
@@ -404,16 +403,23 @@ private[sql] class SparkConnectClient(
}
/**
- * Cache the given local relation at the server, and return its key in the
remote cache.
+ * Cache the given local relation Arrow stream from a local file and return
its hashes. The file
+ * is streamed in chunks and does not need to fit in memory.
+ *
+ * This method batches artifact status checks and uploads to minimize RPC
overhead.
*/
- private[sql] def cacheLocalRelation(data: ByteString, schema: String):
String = {
- val localRelation = proto.Relation
- .newBuilder()
- .getLocalRelationBuilder
- .setSchema(schema)
- .setData(data)
- .build()
- artifactManager.cacheArtifact(localRelation.toByteArray)
+ private[sql] def cacheLocalRelation(
+ data: Array[Array[Byte]],
+ schema: String): (Seq[String], String) = {
+ val schemaBytes = schema.getBytes
+ val allBlobs = data :+ schemaBytes
+ val allHashes = artifactManager.cacheArtifacts(allBlobs)
+
+ // Last hash is the schema hash, rest are data hashes
+ val dataHashes = allHashes.dropRight(1)
+ val schemaHash = allHashes.last
+
+ (dataHashes, schemaHash)
}
/**
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/InvalidInputErrors.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/InvalidInputErrors.scala
index 0dd4192908b9..fcef696c88af 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/InvalidInputErrors.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/InvalidInputErrors.scala
@@ -73,12 +73,32 @@ object InvalidInputErrors {
s"Not found any cached local relation with the hash: " +
s"$hash in the session with sessionUUID $sessionUUID.")
+ def notFoundChunkedCachedLocalRelationBlock(
+ hash: String,
+ sessionUUID: String): InvalidPlanInput =
+ InvalidPlanInput(
+ s"Not found chunked cached local relation block with the hash: " +
+ s"$hash in the session with sessionUUID $sessionUUID.")
+
+ def localRelationSizeLimitExceeded(actualSize: Long, limit: Long):
InvalidPlanInput =
+ InvalidPlanInput(
+ s"Cached local relation size ($actualSize bytes) exceeds the limit
($limit bytes).")
+
+ def localRelationChunkSizeLimitExceeded(limit: Long): InvalidPlanInput =
+ InvalidPlanInput(s"One of cached local relation chunks exceeded the limit
of $limit bytes.")
+
def withColumnsRequireSingleNamePart(got: String): InvalidPlanInput =
InvalidPlanInput(s"WithColumns require column name only contains one name
part, but got $got")
def inputDataForLocalRelationNoSchema(): InvalidPlanInput =
InvalidPlanInput("Input data for LocalRelation does not produce a schema.")
+ def chunkedCachedLocalRelationWithoutData(): InvalidPlanInput =
+ InvalidPlanInput("ChunkedCachedLocalRelation should contain data.")
+
+ def chunkedCachedLocalRelationChunksWithDifferentSchema(): InvalidPlanInput =
+ InvalidPlanInput("ChunkedCachedLocalRelation data chunks have different
schema.")
+
def schemaRequiredForLocalRelation(): InvalidPlanInput =
InvalidPlanInput("Schema for LocalRelation is required when the input data
is not provided.")
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 114a20e051b4..a7a8f3506dea 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.connect.planner
import java.util.{HashMap, Properties, UUID}
+import scala.collection.immutable.ArraySeq
import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.util.Try
@@ -42,7 +43,7 @@ import org.apache.spark.internal.{Logging, LogKeys}
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile,
TaskResourceProfile, TaskResourceRequest}
import org.apache.spark.sql.{AnalysisException, Column, Encoders,
ForeachWriter, Row}
-import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier,
FunctionIdentifier, QueryPlanningTracker}
+import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier,
FunctionIdentifier, InternalRow, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry,
GlobalTempView, LocalTempView, MultiAlias, UnresolvedAlias,
UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer,
UnresolvedExtractValue, UnresolvedFunction, UnresolvedOrdinal,
UnresolvedPlanId, UnresolvedRegex, UnresolvedRelation, UnresolvedStar,
UnresolvedStarWithColumns, UnresolvedStarWithColumnsRenames,
UnresolvedSubqueryColumnAliases, UnresolvedTableValuedFunction,
UnresolvedTranspose}
import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder,
ExpressionEncoder, RowEncoder}
import
org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ProductEncoder,
RowEncoder => AgnosticRowEncoder, StringEncoder, UnboundRowEncoder}
@@ -195,6 +196,8 @@ class SparkConnectPlanner(
transformWithWatermark(rel.getWithWatermark)
case proto.Relation.RelTypeCase.CACHED_LOCAL_RELATION =>
transformCachedLocalRelation(rel.getCachedLocalRelation)
+ case proto.Relation.RelTypeCase.CHUNKED_CACHED_LOCAL_RELATION =>
+
transformChunkedCachedLocalRelation(rel.getChunkedCachedLocalRelation)
case proto.Relation.RelTypeCase.HINT => transformHint(rel.getHint)
case proto.Relation.RelTypeCase.UNPIVOT =>
transformUnpivot(rel.getUnpivot)
case proto.Relation.RelTypeCase.TRANSPOSE =>
transformTranspose(rel.getTranspose)
@@ -1482,25 +1485,128 @@ class SparkConnectPlanner(
rel.getSchema,
parseDatatypeString,
fallbackParser = DataType.fromJson)
- schema = schemaType match {
- case s: StructType => s
- case d => StructType(Seq(StructField("value", d)))
- }
+ schema = toStructTypeOrWrap(schemaType)
}
if (rel.hasData) {
val (rows, structType) =
ArrowConverters.fromIPCStream(rel.getData.toByteArray,
TaskContext.get())
- if (structType == null) {
- throw InvalidInputErrors.inputDataForLocalRelationNoSchema()
+ buildLocalRelationFromRows(rows, structType, Option(schema))
+ } else {
+ if (schema == null) {
+ throw InvalidInputErrors.schemaRequiredForLocalRelation()
}
- val attributes = DataTypeUtils.toAttributes(structType)
- val proj = UnsafeProjection.create(attributes, attributes)
- val data = rows.map(proj)
+ LocalRelation(schema)
+ }
+ }
- if (schema == null) {
- logical.LocalRelation(attributes, data.map(_.copy()).toSeq)
+ private def readChunkedCachedLocalRelationBlock(hash: String): Array[Byte] =
{
+ val blockManager = session.sparkContext.env.blockManager
+ val blockId = CacheId(sessionHolder.session.sessionUUID, hash)
+ val bytes = blockManager.getLocalBytes(blockId)
+ bytes
+ .map { blockData =>
+ try {
+ blockData.toInputStream().readAllBytes()
+ } finally {
+ blockManager.releaseLock(blockId)
+ }
+ }
+ .getOrElse {
+ throw InvalidInputErrors.notFoundChunkedCachedLocalRelationBlock(
+ blockId.hash,
+ blockId.sessionUUID)
+ }
+ }
+
+ private def getBlockSize(hash: String): Long = {
+ val blockManager = session.sparkContext.env.blockManager
+ val blockId = CacheId(sessionHolder.session.sessionUUID, hash)
+ blockManager.getStatus(blockId).map(status => status.memSize +
status.diskSize).getOrElse(0L)
+ }
+
+ private def transformChunkedCachedLocalRelation(
+ rel: proto.ChunkedCachedLocalRelation): LogicalPlan = {
+ if (rel.getDataHashesCount == 0) {
+ throw InvalidInputErrors.chunkedCachedLocalRelationWithoutData()
+ }
+ val dataHashes = rel.getDataHashesList.asScala
+ val allHashes = dataHashes ++ (
+ if (rel.hasSchemaHash) {
+ Seq(rel.getSchemaHash)
} else {
+ Seq.empty
+ }
+ )
+ val allSizes = allHashes.map(hash => getBlockSize(hash))
+ val totalSize = allSizes.sum
+
+ val relationSizeLimit = session.sessionState.conf.localRelationSizeLimit
+ val chunkSizeLimit = session.sessionState.conf.localRelationChunkSizeLimit
+ if (totalSize > relationSizeLimit) {
+ throw InvalidInputErrors.localRelationSizeLimitExceeded(totalSize,
relationSizeLimit)
+ }
+ if (allSizes.exists(_ > chunkSizeLimit)) {
+ throw
InvalidInputErrors.localRelationChunkSizeLimitExceeded(chunkSizeLimit)
+ }
+
+ var schema: StructType = null
+ if (rel.hasSchemaHash) {
+ val schemaBytes = readChunkedCachedLocalRelationBlock(rel.getSchemaHash)
+ val schemaString = new String(schemaBytes)
+ val schemaType = DataType.parseTypeWithFallback(
+ schemaString,
+ parseDatatypeString,
+ fallbackParser = DataType.fromJson)
+ schema = toStructTypeOrWrap(schemaType)
+ }
+
+ // Load and combine all batches
+ var combinedRows: Iterator[InternalRow] = Iterator.empty
+ var structType: StructType = null
+
+ for ((dataHash, batchIndex) <- dataHashes.zipWithIndex) {
+ val dataBytes = readChunkedCachedLocalRelationBlock(dataHash)
+ val (batchRows, batchStructType) =
+ ArrowConverters.fromIPCStream(dataBytes, TaskContext.get())
+
+ // For the first batch, set the schema; for subsequent batches, verify
compatibility
+ if (batchIndex == 0) {
+ structType = batchStructType
+ combinedRows = batchRows
+
+ } else {
+ if (batchStructType != structType) {
+ throw
InvalidInputErrors.chunkedCachedLocalRelationChunksWithDifferentSchema()
+ }
+ combinedRows = combinedRows ++ batchRows
+ }
+ }
+
+ buildLocalRelationFromRows(combinedRows, structType, Option(schema))
+ }
+
+ private def toStructTypeOrWrap(dt: DataType): StructType = dt match {
+ case s: StructType => s
+ case d => StructType(Seq(StructField("value", d)))
+ }
+
+ private def buildLocalRelationFromRows(
+ rows: Iterator[InternalRow],
+ structType: StructType,
+ schemaOpt: Option[StructType]): LogicalPlan = {
+ if (structType == null) {
+ throw InvalidInputErrors.inputDataForLocalRelationNoSchema()
+ }
+
+ val attributes = DataTypeUtils.toAttributes(structType)
+ val initialProjection = UnsafeProjection.create(attributes, attributes)
+ val data = rows.map(initialProjection)
+
+ schemaOpt match {
+ case None =>
+ logical.LocalRelation(attributes,
ArraySeq.unsafeWrapArray(data.map(_.copy()).toArray))
+ case Some(schema) =>
def normalize(dt: DataType): DataType = dt match {
case udt: UserDefinedType[_] => normalize(udt.sqlType)
case StructType(fields) =>
@@ -1532,12 +1638,6 @@ class SparkConnectPlanner(
logical.LocalRelation(
DataTypeUtils.toAttributes(schema),
data.map(proj).map(_.copy()).toSeq)
- }
- } else {
- if (schema == null) {
- throw InvalidInputErrors.schemaRequiredForLocalRelation()
- }
- LocalRelation(schema)
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]