This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3c97a4c5a1c [SPARK-42085][CONNECT][PYTHON] Make `from_arrow_schema`
support nested types
3c97a4c5a1c is described below
commit 3c97a4c5a1c1ff698b094bf4c7fc1a17f94f1148
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Jan 16 15:59:42 2023 +0800
[SPARK-42085][CONNECT][PYTHON] Make `from_arrow_schema` support nested types
### What changes were proposed in this pull request?
Make `from_arrow_schema` support nested types
### Why are the changes needed?
do not need to get the schema by `self.schema` due to `from_arrow_schema`
not supporting some nested types
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
existing UT
Closes #39594 from zhengruifeng/connect_collect_arrow_schema.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/connect/dataframe.py | 37 ++----------------
python/pyspark/sql/connect/types.py | 66 +++++++++++++++++++++++++++++++++
2 files changed, 70 insertions(+), 33 deletions(-)
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index f41eb145612..d9d286b9fb5 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -39,20 +39,12 @@ from collections.abc import Iterable
from pyspark import _NoValue
from pyspark._globals import _NoValueType
-from pyspark.sql.types import (
- Row,
- StructType,
- ArrayType,
- MapType,
- TimestampType,
- TimestampNTZType,
-)
+from pyspark.sql.types import Row, StructType
from pyspark.sql.dataframe import (
DataFrame as PySparkDataFrame,
DataFrameNaFunctions as PySparkDataFrameNaFunctions,
DataFrameStatFunctions as PySparkDataFrameStatFunctions,
)
-from pyspark.sql.pandas.types import from_arrow_schema
import pyspark.sql.connect.plan as plan
from pyspark.sql.connect.group import GroupedData
@@ -66,6 +58,8 @@ from pyspark.sql.connect.functions import (
lit,
expr as sql_expression,
)
+from pyspark.sql.connect.types import from_arrow_schema
+
if TYPE_CHECKING:
from pyspark.sql.connect._typing import (
@@ -1248,30 +1242,7 @@ class DataFrame:
query = self._plan.to_proto(self._session.client)
table = self._session.client.to_table(query)
- # We first try the inferred schema from PyArrow Table instead of
always fetching
- # the Connect Dataframe schema by 'self.schema', for two reasons:
- # 1, the schema maybe quietly simple, then we can save an RPC;
- # 2, if we always invoke 'self.schema' here, all catalog functions
based on
- # 'dataframe.collect' will be invoked twice (1, collect data, 2, fetch
schema),
- # and then some of them (e.g. "CREATE DATABASE") fail due to the
second invocation.
-
- schema: Optional[StructType] = None
- try:
- schema = from_arrow_schema(table.schema)
- except Exception:
- # may fail due to 'from_arrow_schema' not supporting nested struct
- schema = None
-
- if schema is None:
- schema = self.schema
- else:
- if any(
- isinstance(
- f.dataType, (StructType, ArrayType, MapType,
TimestampType, TimestampNTZType)
- )
- for f in schema.fields
- ):
- schema = self.schema
+ schema = from_arrow_schema(table.schema)
assert schema is not None and isinstance(schema, StructType)
diff --git a/python/pyspark/sql/connect/types.py
b/python/pyspark/sql/connect/types.py
index 8c23ed03ef3..a77585fd6d6 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -248,3 +248,69 @@ def to_arrow_schema(schema: StructType) -> "pa.Schema":
for field in schema
]
return pa.schema(fields)
+
+
+def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) ->
DataType:
+ """Convert pyarrow type to Spark data type.
+
+ This function refers to 'pyspark.sql.pandas.types.from_arrow_type' but
relax the restriction,
+ e.g. it supports nested StructType, Array of TimestampType. However, Arrow
DictionaryType is
+ not allowed.
+ """
+ import pyarrow.types as types
+
+ spark_type: DataType
+ if types.is_boolean(at):
+ spark_type = BooleanType()
+ elif types.is_int8(at):
+ spark_type = ByteType()
+ elif types.is_int16(at):
+ spark_type = ShortType()
+ elif types.is_int32(at):
+ spark_type = IntegerType()
+ elif types.is_int64(at):
+ spark_type = LongType()
+ elif types.is_float32(at):
+ spark_type = FloatType()
+ elif types.is_float64(at):
+ spark_type = DoubleType()
+ elif types.is_decimal(at):
+ spark_type = DecimalType(precision=at.precision, scale=at.scale)
+ elif types.is_string(at):
+ spark_type = StringType()
+ elif types.is_binary(at):
+ spark_type = BinaryType()
+ elif types.is_date32(at):
+ spark_type = DateType()
+ elif types.is_timestamp(at) and prefer_timestamp_ntz and at.tz is None:
+ spark_type = TimestampNTZType()
+ elif types.is_timestamp(at):
+ spark_type = TimestampType()
+ elif types.is_duration(at):
+ spark_type = DayTimeIntervalType()
+ elif types.is_list(at):
+ spark_type = ArrayType(from_arrow_type(at.value_type))
+ elif types.is_map(at):
+ spark_type = MapType(from_arrow_type(at.key_type),
from_arrow_type(at.item_type))
+ elif types.is_struct(at):
+ return StructType(
+ [
+ StructField(field.name, from_arrow_type(field.type),
nullable=field.nullable)
+ for field in at
+ ]
+ )
+ elif types.is_null(at):
+ spark_type = NullType()
+ else:
+ raise TypeError("Unsupported type in conversion from Arrow: " +
str(at))
+ return spark_type
+
+
+def from_arrow_schema(arrow_schema: "pa.Schema") -> StructType:
+ """Convert schema from Arrow to Spark."""
+ return StructType(
+ [
+ StructField(field.name, from_arrow_type(field.type),
nullable=field.nullable)
+ for field in arrow_schema
+ ]
+ )
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]