This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new 2e4238cd011 [SPARK-42679][CONNECT][PYTHON] createDataFrame doesn't
work with non-nullable schema
2e4238cd011 is described below
commit 2e4238cd0112a5103b0c34037ba0a8201f5287bc
Author: panbingkun <[email protected]>
AuthorDate: Mon Mar 13 11:10:49 2023 +0800
[SPARK-42679][CONNECT][PYTHON] createDataFrame doesn't work with
non-nullable schema
### What changes were proposed in this pull request?
Fixes `spark.createDataFrame` to apply the given schema to work with
non-nullable data types.
### Why are the changes needed?
Currently `spark.createDataFrame` won't work with non-nullable schema as
below:
```py
>>> from pyspark.sql.types import *
>>> schema_false = StructType([StructField("id", IntegerType(), False)])
>>> spark.createDataFrame([[1]], schema=schema_false)
Traceback (most recent call last):
...
pyspark.errors.exceptions.connect.AnalysisException:
[NULLABLE_COLUMN_OR_FIELD] Column or field `id` is nullable while it's required
to be non-nullable.
```
whereas it works fine with nullable schema:
```py
>>> from pyspark.sql.types import *
>>> schema_false = StructType([StructField("id", IntegerType(), False)])
>>> spark.createDataFrame([[1]], schema=schema_false)
DataFrame[id: int]
```
### Does this PR introduce _any_ user-facing change?
`spark.createDataFrame` with non-nullable schema will work.
### How was this patch tested?
Added related tests.
Closes #40382 from ueshin/issues/SPARK-42679/non-nullable.
Lead-authored-by: panbingkun <[email protected]>
Co-authored-by: Takuya UESHIN <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
(cherry picked from commit fa5ca7fe87290c81ccd2ba214c8478beefe0c5ec)
Signed-off-by: Ruifeng Zheng <[email protected]>
---
python/pyspark/sql/connect/session.py | 62 +++++++++++++---------
python/pyspark/sql/connect/types.py | 7 ++-
.../sql/tests/connect/test_connect_basic.py | 59 ++++++++++++++++++++
3 files changed, 101 insertions(+), 27 deletions(-)
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index 9d9af112da4..5e7c8361d80 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -53,7 +53,7 @@ from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.plan import SQL, Range, LocalRelation, CachedRelation
from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
-from pyspark.sql.pandas.types import to_arrow_type, _get_local_timezone
+from pyspark.sql.pandas.types import to_arrow_schema, to_arrow_type,
_get_local_timezone
from pyspark.sql.session import classproperty, SparkSession as PySparkSession
from pyspark.sql.types import (
_infer_schema,
@@ -241,8 +241,10 @@ class SparkSession:
_num_cols = len(_cols)
# Determine arrow types to coerce data when creating batches
+ arrow_schema: Optional[pa.Schema] = None
if isinstance(schema, StructType):
- arrow_types = [to_arrow_type(f.dataType) for f in
schema.fields]
+ arrow_schema = to_arrow_schema(schema)
+ arrow_types = [field.type for field in arrow_schema]
_cols = [str(x) if not isinstance(x, str) else x for x in
schema.fieldNames()]
elif isinstance(schema, DataType):
raise ValueError("Single data type %s is not supported with
Arrow" % str(schema))
@@ -267,6 +269,10 @@ class SparkSession:
[ser._create_batch([(c, t) for (_, c), t in zip(data.items(),
arrow_types)])]
)
+ if isinstance(schema, StructType):
+ assert arrow_schema is not None
+ _table = _table.rename_columns(schema.names).cast(arrow_schema)
+
elif isinstance(data, np.ndarray):
if data.ndim not in [1, 2]:
raise ValueError("NumPy array input should be of 1 or 2
dimensions.")
@@ -311,29 +317,35 @@ class SparkSession:
# we need to convert it to [[1], [2], [3]] to be able to infer
schema.
_data = [[d] for d in _data]
- _inferred_schema = self._inferSchemaFromList(_data, _cols)
-
- if _cols is not None and cast(int, _num_cols) < len(_cols):
- _num_cols = len(_cols)
-
- if _has_nulltype(_inferred_schema):
- # For cases like createDataFrame([("Alice", None, 80.1)],
schema)
- # we can not infer the schema from the data itself.
- warnings.warn("failed to infer the schema from data")
- if _schema is None and _schema_str is not None:
- _parsed = self.client._analyze(
- method="ddl_parse", ddl_string=_schema_str
- ).parsed
- if isinstance(_parsed, StructType):
- _schema = _parsed
- elif isinstance(_parsed, DataType):
- _schema = StructType().add("value", _parsed)
- if _schema is None or not isinstance(_schema, StructType):
- raise ValueError(
- "Some of types cannot be determined after inferring, "
- "a StructType Schema is required in this case"
- )
- _inferred_schema = _schema
+ if _schema is not None:
+ if isinstance(_schema, StructType):
+ _inferred_schema = _schema
+ else:
+ _inferred_schema = StructType().add("value", _schema)
+ else:
+ _inferred_schema = self._inferSchemaFromList(_data, _cols)
+
+ if _cols is not None and cast(int, _num_cols) < len(_cols):
+ _num_cols = len(_cols)
+
+ if _has_nulltype(_inferred_schema):
+ # For cases like createDataFrame([("Alice", None, 80.1)],
schema)
+ # we can not infer the schema from the data itself.
+ warnings.warn("failed to infer the schema from data")
+ if _schema is None and _schema_str is not None:
+ _parsed = self.client._analyze(
+ method="ddl_parse", ddl_string=_schema_str
+ ).parsed
+ if isinstance(_parsed, StructType):
+ _schema = _parsed
+ elif isinstance(_parsed, DataType):
+ _schema = StructType().add("value", _parsed)
+ if _schema is None or not isinstance(_schema, StructType):
+ raise ValueError(
+ "Some of types cannot be determined after
inferring, "
+ "a StructType Schema is required in this case"
+ )
+ _inferred_schema = _schema
from pyspark.sql.connect.conversion import
LocalDataToArrowConversion
diff --git a/python/pyspark/sql/connect/types.py
b/python/pyspark/sql/connect/types.py
index d3c0fbc0272..b5145d91c76 100644
--- a/python/pyspark/sql/connect/types.py
+++ b/python/pyspark/sql/connect/types.py
@@ -309,9 +309,12 @@ def to_arrow_type(dt: DataType) -> "pa.DataType":
elif type(dt) == DayTimeIntervalType:
arrow_type = pa.duration("us")
elif type(dt) == ArrayType:
- arrow_type = pa.list_(to_arrow_type(dt.elementType))
+ field = pa.field("element", to_arrow_type(dt.elementType),
nullable=dt.containsNull)
+ arrow_type = pa.list_(field)
elif type(dt) == MapType:
- arrow_type = pa.map_(to_arrow_type(dt.keyType),
to_arrow_type(dt.valueType))
+ key_field = pa.field("key", to_arrow_type(dt.keyType), nullable=False)
+ value_field = pa.field("value", to_arrow_type(dt.valueType),
nullable=dt.valueContainsNull)
+ arrow_type = pa.map_(key_field, value_field)
elif type(dt) == StructType:
fields = [
pa.field(field.name, to_arrow_type(field.dataType),
nullable=field.nullable)
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index fc5031bd91a..cd6890a630b 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -2889,6 +2889,65 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
self.connect.sql("show functions").collect(), self.spark.sql("show
functions").collect()
)
+ def test_schema_has_nullable(self):
+ schema_false = StructType().add("id", IntegerType(), False)
+ cdf1 = self.connect.createDataFrame([[1]], schema=schema_false)
+ sdf1 = self.spark.createDataFrame([[1]], schema=schema_false)
+ self.assertEqual(cdf1.schema, sdf1.schema)
+ self.assertEqual(cdf1.collect(), sdf1.collect())
+
+ schema_true = StructType().add("id", IntegerType(), True)
+ cdf2 = self.connect.createDataFrame([[1]], schema=schema_true)
+ sdf2 = self.spark.createDataFrame([[1]], schema=schema_true)
+ self.assertEqual(cdf2.schema, sdf2.schema)
+ self.assertEqual(cdf2.collect(), sdf2.collect())
+
+ pdf1 = cdf1.toPandas()
+ cdf3 = self.connect.createDataFrame(pdf1, cdf1.schema)
+ sdf3 = self.spark.createDataFrame(pdf1, sdf1.schema)
+ self.assertEqual(cdf3.schema, sdf3.schema)
+ self.assertEqual(cdf3.collect(), sdf3.collect())
+
+ pdf2 = cdf2.toPandas()
+ cdf4 = self.connect.createDataFrame(pdf2, cdf2.schema)
+ sdf4 = self.spark.createDataFrame(pdf2, sdf2.schema)
+ self.assertEqual(cdf4.schema, sdf4.schema)
+ self.assertEqual(cdf4.collect(), sdf4.collect())
+
+ def test_array_has_nullable(self):
+ schema_array_false = StructType().add("arr", ArrayType(IntegerType(),
False))
+ cdf1 = self.connect.createDataFrame([Row([1, 2]), Row([3])],
schema=schema_array_false)
+ sdf1 = self.spark.createDataFrame([Row([1, 2]), Row([3])],
schema=schema_array_false)
+ self.assertEqual(cdf1.schema, sdf1.schema)
+ self.assertEqual(cdf1.collect(), sdf1.collect())
+
+ schema_array_true = StructType().add("arr", ArrayType(IntegerType(),
True))
+ cdf2 = self.connect.createDataFrame([Row([1, None]), Row([3])],
schema=schema_array_true)
+ sdf2 = self.spark.createDataFrame([Row([1, None]), Row([3])],
schema=schema_array_true)
+ self.assertEqual(cdf2.schema, sdf2.schema)
+ self.assertEqual(cdf2.collect(), sdf2.collect())
+
+ def test_map_has_nullable(self):
+ schema_map_false = StructType().add("map", MapType(StringType(),
IntegerType(), False))
+ cdf1 = self.connect.createDataFrame(
+ [Row({"a": 1, "b": 2}), Row({"a": 3})], schema=schema_map_false
+ )
+ sdf1 = self.spark.createDataFrame(
+ [Row({"a": 1, "b": 2}), Row({"a": 3})], schema=schema_map_false
+ )
+ self.assertEqual(cdf1.schema, sdf1.schema)
+ self.assertEqual(cdf1.collect(), sdf1.collect())
+
+ schema_map_true = StructType().add("map", MapType(StringType(),
IntegerType(), True))
+ cdf2 = self.connect.createDataFrame(
+ [Row({"a": 1, "b": None}), Row({"a": 3})], schema=schema_map_true
+ )
+ sdf2 = self.spark.createDataFrame(
+ [Row({"a": 1, "b": None}), Row({"a": 3})], schema=schema_map_true
+ )
+ self.assertEqual(cdf2.schema, sdf2.schema)
+ self.assertEqual(cdf2.collect(), sdf2.collect())
+
@unittest.skipIf(not should_test_connect, connect_requirement_message)
class ClientTests(unittest.TestCase):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]