This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.4 by this push:
     new 2e4238cd011 [SPARK-42679][CONNECT][PYTHON] createDataFrame doesn't 
work with non-nullable schema
2e4238cd011 is described below

commit 2e4238cd0112a5103b0c34037ba0a8201f5287bc
Author: panbingkun <[email protected]>
AuthorDate: Mon Mar 13 11:10:49 2023 +0800

    [SPARK-42679][CONNECT][PYTHON] createDataFrame doesn't work with 
non-nullable schema
    
    ### What changes were proposed in this pull request?
    
    Fixes `spark.createDataFrame` to apply the given schema to work with 
non-nullable data types.
    
    ### Why are the changes needed?
    
    Currently `spark.createDataFrame` won't work with non-nullable schema as 
below:
    
    ```py
    >>> from pyspark.sql.types import *
    >>> schema_false = StructType([StructField("id", IntegerType(), False)])
    >>> spark.createDataFrame([[1]], schema=schema_false)
    Traceback (most recent call last):
    ...
    pyspark.errors.exceptions.connect.AnalysisException: 
[NULLABLE_COLUMN_OR_FIELD] Column or field `id` is nullable while it's required 
to be non-nullable.
    ```
    
    whereas it works fine with nullable schema:
    
    ```py
    >>> from pyspark.sql.types import *
    >>> schema_false = StructType([StructField("id", IntegerType(), False)])
    >>> spark.createDataFrame([[1]], schema=schema_false)
    DataFrame[id: int]
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    `spark.createDataFrame` with non-nullable schema will work.
    
    ### How was this patch tested?
    
    Added related tests.
    
    Closes #40382 from ueshin/issues/SPARK-42679/non-nullable.
    
    Lead-authored-by: panbingkun <[email protected]>
    Co-authored-by: Takuya UESHIN <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
    (cherry picked from commit fa5ca7fe87290c81ccd2ba214c8478beefe0c5ec)
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 python/pyspark/sql/connect/session.py              | 62 +++++++++++++---------
 python/pyspark/sql/connect/types.py                |  7 ++-
 .../sql/tests/connect/test_connect_basic.py        | 59 ++++++++++++++++++++
 3 files changed, 101 insertions(+), 27 deletions(-)

diff --git a/python/pyspark/sql/connect/session.py 
b/python/pyspark/sql/connect/session.py
index 9d9af112da4..5e7c8361d80 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -53,7 +53,7 @@ from pyspark.sql.connect.dataframe import DataFrame
 from pyspark.sql.connect.plan import SQL, Range, LocalRelation, CachedRelation
 from pyspark.sql.connect.readwriter import DataFrameReader
 from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
-from pyspark.sql.pandas.types import to_arrow_type, _get_local_timezone
+from pyspark.sql.pandas.types import to_arrow_schema, to_arrow_type, 
_get_local_timezone
 from pyspark.sql.session import classproperty, SparkSession as PySparkSession
 from pyspark.sql.types import (
     _infer_schema,
@@ -241,8 +241,10 @@ class SparkSession:
                 _num_cols = len(_cols)
 
             # Determine arrow types to coerce data when creating batches
+            arrow_schema: Optional[pa.Schema] = None
             if isinstance(schema, StructType):
-                arrow_types = [to_arrow_type(f.dataType) for f in 
schema.fields]
+                arrow_schema = to_arrow_schema(schema)
+                arrow_types = [field.type for field in arrow_schema]
                 _cols = [str(x) if not isinstance(x, str) else x for x in 
schema.fieldNames()]
             elif isinstance(schema, DataType):
                 raise ValueError("Single data type %s is not supported with 
Arrow" % str(schema))
@@ -267,6 +269,10 @@ class SparkSession:
                 [ser._create_batch([(c, t) for (_, c), t in zip(data.items(), 
arrow_types)])]
             )
 
+            if isinstance(schema, StructType):
+                assert arrow_schema is not None
+                _table = _table.rename_columns(schema.names).cast(arrow_schema)
+
         elif isinstance(data, np.ndarray):
             if data.ndim not in [1, 2]:
                 raise ValueError("NumPy array input should be of 1 or 2 
dimensions.")
@@ -311,29 +317,35 @@ class SparkSession:
                 # we need to convert it to [[1], [2], [3]] to be able to infer 
schema.
                 _data = [[d] for d in _data]
 
-            _inferred_schema = self._inferSchemaFromList(_data, _cols)
-
-            if _cols is not None and cast(int, _num_cols) < len(_cols):
-                _num_cols = len(_cols)
-
-            if _has_nulltype(_inferred_schema):
-                # For cases like createDataFrame([("Alice", None, 80.1)], 
schema)
-                # we can not infer the schema from the data itself.
-                warnings.warn("failed to infer the schema from data")
-                if _schema is None and _schema_str is not None:
-                    _parsed = self.client._analyze(
-                        method="ddl_parse", ddl_string=_schema_str
-                    ).parsed
-                    if isinstance(_parsed, StructType):
-                        _schema = _parsed
-                    elif isinstance(_parsed, DataType):
-                        _schema = StructType().add("value", _parsed)
-                if _schema is None or not isinstance(_schema, StructType):
-                    raise ValueError(
-                        "Some of types cannot be determined after inferring, "
-                        "a StructType Schema is required in this case"
-                    )
-                _inferred_schema = _schema
+            if _schema is not None:
+                if isinstance(_schema, StructType):
+                    _inferred_schema = _schema
+                else:
+                    _inferred_schema = StructType().add("value", _schema)
+            else:
+                _inferred_schema = self._inferSchemaFromList(_data, _cols)
+
+                if _cols is not None and cast(int, _num_cols) < len(_cols):
+                    _num_cols = len(_cols)
+
+                if _has_nulltype(_inferred_schema):
+                    # For cases like createDataFrame([("Alice", None, 80.1)], 
schema)
+                    # we can not infer the schema from the data itself.
+                    warnings.warn("failed to infer the schema from data")
+                    if _schema is None and _schema_str is not None:
+                        _parsed = self.client._analyze(
+                            method="ddl_parse", ddl_string=_schema_str
+                        ).parsed
+                        if isinstance(_parsed, StructType):
+                            _schema = _parsed
+                        elif isinstance(_parsed, DataType):
+                            _schema = StructType().add("value", _parsed)
+                    if _schema is None or not isinstance(_schema, StructType):
+                        raise ValueError(
+                            "Some of types cannot be determined after 
inferring, "
+                            "a StructType Schema is required in this case"
+                        )
+                    _inferred_schema = _schema
 
             from pyspark.sql.connect.conversion import 
LocalDataToArrowConversion
 
diff --git a/python/pyspark/sql/connect/types.py 
b/python/pyspark/sql/connect/types.py
index d3c0fbc0272..b5145d91c76 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -309,9 +309,12 @@ def to_arrow_type(dt: DataType) -> "pa.DataType":
     elif type(dt) == DayTimeIntervalType:
         arrow_type = pa.duration("us")
     elif type(dt) == ArrayType:
-        arrow_type = pa.list_(to_arrow_type(dt.elementType))
+        field = pa.field("element", to_arrow_type(dt.elementType), 
nullable=dt.containsNull)
+        arrow_type = pa.list_(field)
     elif type(dt) == MapType:
-        arrow_type = pa.map_(to_arrow_type(dt.keyType), 
to_arrow_type(dt.valueType))
+        key_field = pa.field("key", to_arrow_type(dt.keyType), nullable=False)
+        value_field = pa.field("value", to_arrow_type(dt.valueType), 
nullable=dt.valueContainsNull)
+        arrow_type = pa.map_(key_field, value_field)
     elif type(dt) == StructType:
         fields = [
             pa.field(field.name, to_arrow_type(field.dataType), 
nullable=field.nullable)
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index fc5031bd91a..cd6890a630b 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -2889,6 +2889,65 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
             self.connect.sql("show functions").collect(), self.spark.sql("show 
functions").collect()
         )
 
+    def test_schema_has_nullable(self):
+        schema_false = StructType().add("id", IntegerType(), False)
+        cdf1 = self.connect.createDataFrame([[1]], schema=schema_false)
+        sdf1 = self.spark.createDataFrame([[1]], schema=schema_false)
+        self.assertEqual(cdf1.schema, sdf1.schema)
+        self.assertEqual(cdf1.collect(), sdf1.collect())
+
+        schema_true = StructType().add("id", IntegerType(), True)
+        cdf2 = self.connect.createDataFrame([[1]], schema=schema_true)
+        sdf2 = self.spark.createDataFrame([[1]], schema=schema_true)
+        self.assertEqual(cdf2.schema, sdf2.schema)
+        self.assertEqual(cdf2.collect(), sdf2.collect())
+
+        pdf1 = cdf1.toPandas()
+        cdf3 = self.connect.createDataFrame(pdf1, cdf1.schema)
+        sdf3 = self.spark.createDataFrame(pdf1, sdf1.schema)
+        self.assertEqual(cdf3.schema, sdf3.schema)
+        self.assertEqual(cdf3.collect(), sdf3.collect())
+
+        pdf2 = cdf2.toPandas()
+        cdf4 = self.connect.createDataFrame(pdf2, cdf2.schema)
+        sdf4 = self.spark.createDataFrame(pdf2, sdf2.schema)
+        self.assertEqual(cdf4.schema, sdf4.schema)
+        self.assertEqual(cdf4.collect(), sdf4.collect())
+
+    def test_array_has_nullable(self):
+        schema_array_false = StructType().add("arr", ArrayType(IntegerType(), 
False))
+        cdf1 = self.connect.createDataFrame([Row([1, 2]), Row([3])], 
schema=schema_array_false)
+        sdf1 = self.spark.createDataFrame([Row([1, 2]), Row([3])], 
schema=schema_array_false)
+        self.assertEqual(cdf1.schema, sdf1.schema)
+        self.assertEqual(cdf1.collect(), sdf1.collect())
+
+        schema_array_true = StructType().add("arr", ArrayType(IntegerType(), 
True))
+        cdf2 = self.connect.createDataFrame([Row([1, None]), Row([3])], 
schema=schema_array_true)
+        sdf2 = self.spark.createDataFrame([Row([1, None]), Row([3])], 
schema=schema_array_true)
+        self.assertEqual(cdf2.schema, sdf2.schema)
+        self.assertEqual(cdf2.collect(), sdf2.collect())
+
+    def test_map_has_nullable(self):
+        schema_map_false = StructType().add("map", MapType(StringType(), 
IntegerType(), False))
+        cdf1 = self.connect.createDataFrame(
+            [Row({"a": 1, "b": 2}), Row({"a": 3})], schema=schema_map_false
+        )
+        sdf1 = self.spark.createDataFrame(
+            [Row({"a": 1, "b": 2}), Row({"a": 3})], schema=schema_map_false
+        )
+        self.assertEqual(cdf1.schema, sdf1.schema)
+        self.assertEqual(cdf1.collect(), sdf1.collect())
+
+        schema_map_true = StructType().add("map", MapType(StringType(), 
IntegerType(), True))
+        cdf2 = self.connect.createDataFrame(
+            [Row({"a": 1, "b": None}), Row({"a": 3})], schema=schema_map_true
+        )
+        sdf2 = self.spark.createDataFrame(
+            [Row({"a": 1, "b": None}), Row({"a": 3})], schema=schema_map_true
+        )
+        self.assertEqual(cdf2.schema, sdf2.schema)
+        self.assertEqual(cdf2.collect(), sdf2.collect())
+
 
 @unittest.skipIf(not should_test_connect, connect_requirement_message)
 class ClientTests(unittest.TestCase):


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to