This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3c97a4c5a1c [SPARK-42085][CONNECT][PYTHON] Make `from_arrow_schema` 
support nested types
3c97a4c5a1c is described below

commit 3c97a4c5a1c1ff698b094bf4c7fc1a17f94f1148
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Mon Jan 16 15:59:42 2023 +0800

    [SPARK-42085][CONNECT][PYTHON] Make `from_arrow_schema` support nested types
    
    ### What changes were proposed in this pull request?
    Make `from_arrow_schema` support nested types
    
    ### Why are the changes needed?
    do not need to get the schema by `self.schema` due to `from_arrow_schema` 
not supporting some nested types
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    existing UT
    
    Closes #39594 from zhengruifeng/connect_collect_arrow_schema.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/connect/dataframe.py | 37 ++----------------
 python/pyspark/sql/connect/types.py     | 66 +++++++++++++++++++++++++++++++++
 2 files changed, 70 insertions(+), 33 deletions(-)

diff --git a/python/pyspark/sql/connect/dataframe.py 
b/python/pyspark/sql/connect/dataframe.py
index f41eb145612..d9d286b9fb5 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -39,20 +39,12 @@ from collections.abc import Iterable
 
 from pyspark import _NoValue
 from pyspark._globals import _NoValueType
-from pyspark.sql.types import (
-    Row,
-    StructType,
-    ArrayType,
-    MapType,
-    TimestampType,
-    TimestampNTZType,
-)
+from pyspark.sql.types import Row, StructType
 from pyspark.sql.dataframe import (
     DataFrame as PySparkDataFrame,
     DataFrameNaFunctions as PySparkDataFrameNaFunctions,
     DataFrameStatFunctions as PySparkDataFrameStatFunctions,
 )
-from pyspark.sql.pandas.types import from_arrow_schema
 
 import pyspark.sql.connect.plan as plan
 from pyspark.sql.connect.group import GroupedData
@@ -66,6 +58,8 @@ from pyspark.sql.connect.functions import (
     lit,
     expr as sql_expression,
 )
+from pyspark.sql.connect.types import from_arrow_schema
+
 
 if TYPE_CHECKING:
     from pyspark.sql.connect._typing import (
@@ -1248,30 +1242,7 @@ class DataFrame:
         query = self._plan.to_proto(self._session.client)
         table = self._session.client.to_table(query)
 
-        # We first try the inferred schema from PyArrow Table instead of 
always fetching
-        # the Connect Dataframe schema by 'self.schema', for two reasons:
-        # 1, the schema maybe quietly simple, then we can save an RPC;
-        # 2, if we always invoke 'self.schema' here, all catalog functions 
based on
-        # 'dataframe.collect' will be invoked twice (1, collect data, 2, fetch 
schema),
-        # and then some of them (e.g. "CREATE DATABASE") fail due to the 
second invocation.
-
-        schema: Optional[StructType] = None
-        try:
-            schema = from_arrow_schema(table.schema)
-        except Exception:
-            # may fail due to 'from_arrow_schema' not supporting nested struct
-            schema = None
-
-        if schema is None:
-            schema = self.schema
-        else:
-            if any(
-                isinstance(
-                    f.dataType, (StructType, ArrayType, MapType, 
TimestampType, TimestampNTZType)
-                )
-                for f in schema.fields
-            ):
-                schema = self.schema
+        schema = from_arrow_schema(table.schema)
 
         assert schema is not None and isinstance(schema, StructType)
 
diff --git a/python/pyspark/sql/connect/types.py 
b/python/pyspark/sql/connect/types.py
index 8c23ed03ef3..a77585fd6d6 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -248,3 +248,69 @@ def to_arrow_schema(schema: StructType) -> "pa.Schema":
         for field in schema
     ]
     return pa.schema(fields)
+
+
+def from_arrow_type(at: "pa.DataType", prefer_timestamp_ntz: bool = False) -> 
DataType:
+    """Convert pyarrow type to Spark data type.
+
+    This function refers to 'pyspark.sql.pandas.types.from_arrow_type' but 
relax the restriction,
+    e.g. it supports nested StructType, Array of TimestampType. However, Arrow 
DictionaryType is
+    not allowed.
+    """
+    import pyarrow.types as types
+
+    spark_type: DataType
+    if types.is_boolean(at):
+        spark_type = BooleanType()
+    elif types.is_int8(at):
+        spark_type = ByteType()
+    elif types.is_int16(at):
+        spark_type = ShortType()
+    elif types.is_int32(at):
+        spark_type = IntegerType()
+    elif types.is_int64(at):
+        spark_type = LongType()
+    elif types.is_float32(at):
+        spark_type = FloatType()
+    elif types.is_float64(at):
+        spark_type = DoubleType()
+    elif types.is_decimal(at):
+        spark_type = DecimalType(precision=at.precision, scale=at.scale)
+    elif types.is_string(at):
+        spark_type = StringType()
+    elif types.is_binary(at):
+        spark_type = BinaryType()
+    elif types.is_date32(at):
+        spark_type = DateType()
+    elif types.is_timestamp(at) and prefer_timestamp_ntz and at.tz is None:
+        spark_type = TimestampNTZType()
+    elif types.is_timestamp(at):
+        spark_type = TimestampType()
+    elif types.is_duration(at):
+        spark_type = DayTimeIntervalType()
+    elif types.is_list(at):
+        spark_type = ArrayType(from_arrow_type(at.value_type))
+    elif types.is_map(at):
+        spark_type = MapType(from_arrow_type(at.key_type), 
from_arrow_type(at.item_type))
+    elif types.is_struct(at):
+        return StructType(
+            [
+                StructField(field.name, from_arrow_type(field.type), 
nullable=field.nullable)
+                for field in at
+            ]
+        )
+    elif types.is_null(at):
+        spark_type = NullType()
+    else:
+        raise TypeError("Unsupported type in conversion from Arrow: " + 
str(at))
+    return spark_type
+
+
+def from_arrow_schema(arrow_schema: "pa.Schema") -> StructType:
+    """Convert schema from Arrow to Spark."""
+    return StructType(
+        [
+            StructField(field.name, from_arrow_type(field.type), 
nullable=field.nullable)
+            for field in arrow_schema
+        ]
+    )


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to