This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b0f9978ec08 [SPARK-45026][CONNECT] `spark.sql` should support
datatypes not compatible with arrow
b0f9978ec08 is described below
commit b0f9978ec08caa9302c7340951b5c2979315ca13
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Fri Sep 1 11:12:51 2023 +0800
[SPARK-45026][CONNECT] `spark.sql` should support datatypes not compatible
with arrow
### What changes were proposed in this pull request?
Move the arrow batch creation to the `isCommand` branch
### Why are the changes needed?
https://github.com/apache/spark/pull/42736 and
https://github.com/apache/spark/pull/42743 introduced the
`CalendarIntervalType` in Spark Connect Python Client, however, there is a
failure
```
spark.sql("SELECT make_interval(100, 11, 1, 1, 12, 30, 01.001001)")
...
pyspark.errors.exceptions.connect.UnsupportedOperationException:
[UNSUPPORTED_DATATYPE] Unsupported data type "INTERVAL".
```
The root causes is that `handleSqlCommand` always create an arrow batch
while `ArrowUtils` doesn't accept `CalendarIntervalType` now.
this PR mainly focus on enabling `schema` with datatypes not compatible
with arrow.
In the future, we should make `ArrowUtils` accept `CalendarIntervalType` to
make `collect/toPandas` works
### Does this PR introduce _any_ user-facing change?
yes
after this PR
```
In [1]: spark.sql("SELECT make_interval(100, 11, 1, 1, 12, 30, 01.001001)")
Out[1]: DataFrame[make_interval(100, 11, 1, 1, 12, 30, 1.001001): interval]
In [2]: spark.sql("SELECT make_interval(100, 11, 1, 1, 12, 30,
01.001001)").schema
Out[2]: StructType([StructField('make_interval(100, 11, 1, 1, 12, 30,
1.001001)', CalendarIntervalType(), True)])
```
### How was this patch tested?
enabled ut
### Was this patch authored or co-authored using generative AI tooling?
no
Closes #42754 from zhengruifeng/connect_sql_types.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../sql/connect/planner/SparkConnectPlanner.scala | 40 +++++++++++-----------
.../pyspark/sql/tests/connect/test_parity_types.py | 4 ---
2 files changed, 20 insertions(+), 24 deletions(-)
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index fbe877b4547..547b6a9fb40 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -2469,30 +2469,30 @@ class SparkConnectPlanner(val sessionHolder:
SessionHolder) extends Logging {
val maxBatchSize =
(SparkEnv.get.conf.get(CONNECT_GRPC_ARROW_MAX_BATCH_SIZE) * 0.7).toLong
val timeZoneId = session.sessionState.conf.sessionLocalTimeZone
- // Convert the data.
- val bytes = if (rows.isEmpty) {
- ArrowConverters.createEmptyArrowBatch(
- schema,
- timeZoneId,
- errorOnDuplicatedFieldNames = false)
- } else {
- val batches = ArrowConverters.toBatchWithSchemaIterator(
- rowIter = rows.iterator,
- schema = schema,
- maxRecordsPerBatch = -1,
- maxEstimatedBatchSize = maxBatchSize,
- timeZoneId = timeZoneId,
- errorOnDuplicatedFieldNames = false)
- assert(batches.hasNext)
- val bytes = batches.next()
- assert(!batches.hasNext, s"remaining batches: ${batches.size}")
- bytes
- }
-
// To avoid explicit handling of the result on the client, we build the
expected input
// of the relation on the server. The client has to simply forward the
result.
val result = SqlCommandResult.newBuilder()
if (isCommand) {
+ // Convert the data.
+ val bytes = if (rows.isEmpty) {
+ ArrowConverters.createEmptyArrowBatch(
+ schema,
+ timeZoneId,
+ errorOnDuplicatedFieldNames = false)
+ } else {
+ val batches = ArrowConverters.toBatchWithSchemaIterator(
+ rowIter = rows.iterator,
+ schema = schema,
+ maxRecordsPerBatch = -1,
+ maxEstimatedBatchSize = maxBatchSize,
+ timeZoneId = timeZoneId,
+ errorOnDuplicatedFieldNames = false)
+ assert(batches.hasNext)
+ val bytes = batches.next()
+ assert(!batches.hasNext, s"remaining batches: ${batches.size}")
+ bytes
+ }
+
result.setRelation(
proto.Relation
.newBuilder()
diff --git a/python/pyspark/sql/tests/connect/test_parity_types.py
b/python/pyspark/sql/tests/connect/test_parity_types.py
index 44171fd61a3..807c295fae2 100644
--- a/python/pyspark/sql/tests/connect/test_parity_types.py
+++ b/python/pyspark/sql/tests/connect/test_parity_types.py
@@ -86,10 +86,6 @@ class TypesParityTests(TypesTestsMixin,
ReusedConnectTestCase):
def test_udt(self):
super().test_udt()
- @unittest.skip("SPARK-45026: spark.sql should support datatypes not
compatible with arrow")
- def test_calendar_interval_type(self):
- super().test_calendar_interval_type()
-
if __name__ == "__main__":
import unittest
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]