This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new b689c2ad756 [SPARK-42028][CONNECT][PYTHON][FOLLOW-UP] Uses the same
logic with PySpark, and reeanbles skipped test
b689c2ad756 is described below
commit b689c2ad756cc00dfc0f71f142e771dd367bcf4a
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Fri Jan 13 15:20:22 2023 +0900
[SPARK-42028][CONNECT][PYTHON][FOLLOW-UP] Uses the same logic with PySpark,
and reeanbles skipped test
### What changes were proposed in this pull request?
This PR is a followup of https://github.com/apache/spark/pull/39469 that
uses the same logic with PySpark:
https://github.com/apache/spark/blob/baa6fa9b148467bfc83e6c2d22ea9fd9fa5b4564/python/pyspark/sql/pandas/conversion.py#L546-L631
and reeanbles skipped test
`test_create_dataframe_from_pandas_with_timestamp`.
This PR fixes a bug together by doing this. Nave datetime was inferred as
`TimestampNTZType` before but now it is inferred as `TimestampType` that is
matched with the regular PySpark.
### Why are the changes needed?
To deduplicate the changes in the future, and maintainability.
### Does this PR introduce _any_ user-facing change?
No to end users.
It matches the behaviour to the existing PySpark's `createDataFrame(pdf)`.
### How was this patch tested?
Reenabled a skipped test, and existing test added in the previous PR should
cover this.
Closes #39544 from HyukjinKwon/SPARK-42028-followup.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/session.py | 112 ++++++++++-----------
.../sql/tests/connect/test_parity_dataframe.py | 7 +-
python/pyspark/sql/tests/test_dataframe.py | 12 +--
3 files changed, 60 insertions(+), 71 deletions(-)
diff --git a/python/pyspark/sql/connect/session.py
b/python/pyspark/sql/connect/session.py
index 76073fc2717..6aec28d70a8 100644
--- a/python/pyspark/sql/connect/session.py
+++ b/python/pyspark/sql/connect/session.py
@@ -16,17 +16,39 @@
#
import os
import warnings
-from distutils.version import LooseVersion
-from threading import RLock
from collections.abc import Sized
+from distutils.version import LooseVersion
from functools import reduce
+from threading import RLock
+from typing import (
+ Optional,
+ Any,
+ Union,
+ Dict,
+ List,
+ Tuple,
+ cast,
+ overload,
+ Iterable,
+ TYPE_CHECKING,
+)
import numpy as np
import pandas as pd
import pyarrow as pa
+from pandas.api.types import ( # type: ignore[attr-defined]
+ is_datetime64_dtype,
+ is_datetime64tz_dtype,
+)
from pyspark import SparkContext, SparkConf, __version__
from pyspark.java_gateway import launch_gateway
+from pyspark.sql.connect.client import SparkConnectClient
+from pyspark.sql.connect.dataframe import DataFrame
+from pyspark.sql.connect.plan import SQL, Range, LocalRelation
+from pyspark.sql.connect.readwriter import DataFrameReader
+from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
+from pyspark.sql.pandas.types import to_arrow_type, _get_local_timezone
from pyspark.sql.session import classproperty, SparkSession as PySparkSession
from pyspark.sql.types import (
_infer_schema,
@@ -36,28 +58,10 @@ from pyspark.sql.types import (
DataType,
StructType,
AtomicType,
+ TimestampType,
)
from pyspark.sql.utils import to_str
-from pyspark.sql.connect.client import SparkConnectClient
-from pyspark.sql.connect.dataframe import DataFrame
-from pyspark.sql.connect.plan import SQL, Range, LocalRelation
-from pyspark.sql.connect.readwriter import DataFrameReader
-
-from typing import (
- Optional,
- Any,
- Union,
- Dict,
- List,
- Tuple,
- cast,
- overload,
- Iterable,
- TYPE_CHECKING,
-)
-
-
if TYPE_CHECKING:
from pyspark.sql.connect._typing import OptionalPrimitiveType
from pyspark.sql.connect.catalog import Catalog
@@ -221,47 +225,37 @@ class SparkSession:
_inferred_schema: Optional[StructType] = None
if isinstance(data, pd.DataFrame):
- from pandas.api.types import ( # type: ignore[attr-defined]
- is_datetime64_dtype,
- is_datetime64tz_dtype,
- )
- from pyspark.sql.pandas.types import (
- _check_series_convert_timestamps_internal,
- _get_local_timezone,
+ # Logic was borrowed from `_create_from_pandas_with_arrow` in
+ # `pyspark.sql.pandas.conversion.py`. Should ideally deduplicate
the logics.
+
+ # If no schema supplied by user then get the names of columns only
+ if schema is None:
+ _cols = [str(x) if not isinstance(x, str) else x for x in
data.columns]
+
+ # Determine arrow types to coerce data when creating batches
+ if isinstance(schema, StructType):
+ arrow_types = [to_arrow_type(f.dataType) for f in
schema.fields]
+ _cols = [str(x) if not isinstance(x, str) else x for x in
schema.fieldNames()]
+ elif isinstance(schema, DataType):
+ raise ValueError("Single data type %s is not supported with
Arrow" % str(schema))
+ else:
+ # Any timestamps must be coerced to be compatible with Spark
+ arrow_types = [
+ to_arrow_type(TimestampType())
+ if is_datetime64_dtype(t) or is_datetime64tz_dtype(t)
+ else None
+ for t in data.dtypes
+ ]
+
+ ser = ArrowStreamPandasSerializer(
+ _get_local_timezone(), # 'spark.session.timezone' should be
respected
+ False, #
'spark.sql.execution.pandas.convertToArrowArraySafely' should be respected
+ True,
)
- # First, check if we need to create a copy of the input data to
adjust
- # the timestamps.
- input_data = data
- has_timestamp_data = any(
- [is_datetime64_dtype(data[c]) or
is_datetime64tz_dtype(data[c]) for c in data]
+ _table = pa.Table.from_batches(
+ [ser._create_batch([(c, t) for (_, c), t in zip(data.items(),
arrow_types)])]
)
- if has_timestamp_data:
- input_data = data.copy()
- # We need double conversions for the truncation, first
truncate to microseconds.
- for col in input_data:
- if is_datetime64tz_dtype(input_data[col].dtype):
- input_data[col] =
_check_series_convert_timestamps_internal(
- input_data[col], _get_local_timezone()
- ).astype("datetime64[us, UTC]")
- elif is_datetime64_dtype(input_data[col].dtype):
- input_data[col] =
input_data[col].astype("datetime64[us]")
-
- # Create a new schema and change the types to the truncated
microseconds.
- pd_schema = pa.Schema.from_pandas(input_data)
- new_schema = pa.schema([])
- for x in range(len(pd_schema.types)):
- f = pd_schema.field(x)
- # TODO(SPARK-42027) Add support for struct types.
- if isinstance(f.type, pa.TimestampType) and f.type.unit ==
"ns":
- tmp = f.with_type(pa.timestamp("us"))
- new_schema = new_schema.append(tmp)
- else:
- new_schema = new_schema.append(f)
- new_schema = new_schema.with_metadata(pd_schema.metadata)
- _table = pa.Table.from_pandas(input_data, schema=new_schema)
- else:
- _table = pa.Table.from_pandas(data)
elif isinstance(data, np.ndarray):
if data.ndim not in [1, 2]:
diff --git a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
index 61fd72b6bbf..c722f4693e4 100644
--- a/python/pyspark/sql/tests/connect/test_parity_dataframe.py
+++ b/python/pyspark/sql/tests/connect/test_parity_dataframe.py
@@ -37,16 +37,11 @@ class DataFrameParityTests(DataFrameTestsMixin,
ReusedConnectTestCase):
def test_create_dataframe_from_pandas_with_day_time_interval(self):
super().test_create_dataframe_from_pandas_with_day_time_interval()
- # TODO(SPARK-41842): Support data type Timestamp(NANOSECOND, null)
+ # TODO(SPARK-41834): Implement SparkSession.conf
@unittest.skip("Fails in Spark Connect, should enable.")
def test_create_dataframe_from_pandas_with_dst(self):
super().test_create_dataframe_from_pandas_with_dst()
- # TODO(SPARK-41842): Support data type Timestamp(NANOSECOND, null)
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_create_dataframe_from_pandas_with_timestamp(self):
- super().test_create_dataframe_from_pandas_with_timestamp()
-
# TODO(SPARK-41855): createDataFrame doesn't handle None/NaN properly
@unittest.skip("Fails in Spark Connect, should enable.")
def test_create_nan_decimal_dataframe(self):
diff --git a/python/pyspark/sql/tests/test_dataframe.py
b/python/pyspark/sql/tests/test_dataframe.py
index c67f43ecb64..4ab16f231c0 100644
--- a/python/pyspark/sql/tests/test_dataframe.py
+++ b/python/pyspark/sql/tests/test_dataframe.py
@@ -1246,15 +1246,15 @@ class DataFrameTestsMixin:
)
# test types are inferred correctly without specifying schema
df = self.spark.createDataFrame(pdf)
- self.assertTrue(isinstance(df.schema["ts"].dataType, TimestampType))
- self.assertTrue(isinstance(df.schema["d"].dataType, DateType))
+ self.assertIsInstance(df.schema["ts"].dataType, TimestampType)
+ self.assertIsInstance(df.schema["d"].dataType, DateType)
# test with schema will accept pdf as input
df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp")
- self.assertTrue(isinstance(df.schema["ts"].dataType, TimestampType))
- self.assertTrue(isinstance(df.schema["d"].dataType, DateType))
+ self.assertIsInstance(df.schema["ts"].dataType, TimestampType)
+ self.assertIsInstance(df.schema["d"].dataType, DateType)
df = self.spark.createDataFrame(pdf, schema="d date, ts timestamp_ntz")
- self.assertTrue(isinstance(df.schema["ts"].dataType, TimestampNTZType))
- self.assertTrue(isinstance(df.schema["d"].dataType, DateType))
+ self.assertIsInstance(df.schema["ts"].dataType, TimestampNTZType)
+ self.assertIsInstance(df.schema["d"].dataType, DateType)
@unittest.skipIf(have_pandas, "Required Pandas was found.")
def test_create_dataframe_required_pandas_not_found(self):
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]