This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 731b89d5914 [SPARK-41833][SPARK-41881][SPARK-41815][CONNECT][PYTHON]
Make `DataFrame.collect` handle None/NaN/Array/Binary porperly
731b89d5914 is described below
commit 731b89d59143adb8a4ab3d16dd9f0e08c799abf2
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Thu Jan 5 08:52:08 2023 +0900
[SPARK-41833][SPARK-41881][SPARK-41815][CONNECT][PYTHON] Make
`DataFrame.collect` handle None/NaN/Array/Binary porperly
### What changes were proposed in this pull request?
Existing `DataFrame.collect` directly collect coming Arrow batches into a
Pandas DataFrame, and then convert each series into a Row, which is problematic
since it can not correctly handle None/NaN/Arrays/Binary/etc.
This PR refactor `DataFrame.collect` by directly building rows from the raw
Arrow Table, in order to support:
1, None/NaN values;
2, ArrayType
3, BinaryType
### Why are the changes needed?
To be consistent with PySpark
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
enabled doctests
Closes #39386 from zhengruifeng/connect_fix_41833.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
python/pyspark/sql/connect/client.py | 54 ++++++++++++++++++---------------
python/pyspark/sql/connect/column.py | 2 --
python/pyspark/sql/connect/dataframe.py | 22 +++++++++++---
python/pyspark/sql/connect/functions.py | 31 +++++--------------
4 files changed, 55 insertions(+), 54 deletions(-)
diff --git a/python/pyspark/sql/connect/client.py
b/python/pyspark/sql/connect/client.py
index e78c4de0f70..832b5648676 100644
--- a/python/pyspark/sql/connect/client.py
+++ b/python/pyspark/sql/connect/client.py
@@ -21,12 +21,13 @@ import urllib.parse
import uuid
from typing import Iterable, Optional, Any, Union, List, Tuple, Dict,
NoReturn, cast
+import pandas as pd
+import pyarrow as pa
+
import google.protobuf.message
from grpc_status import rpc_status
import grpc
-import pandas
from google.protobuf import text_format
-import pyarrow as pa
from google.rpc import error_details_pb2
import pyspark.sql.connect.proto as pb2
@@ -406,11 +407,22 @@ class SparkConnectClient(object):
for x in metrics.metrics
]
- def to_pandas(self, plan: pb2.Plan) -> "pandas.DataFrame":
+ def to_table(self, plan: pb2.Plan) -> "pa.Table":
+ logger.info(f"Executing plan {self._proto_to_string(plan)}")
+ req = self._execute_plan_request_with_metadata()
+ req.plan.CopyFrom(plan)
+ table, _ = self._execute_and_fetch(req)
+ return table
+
+ def to_pandas(self, plan: pb2.Plan) -> "pd.DataFrame":
logger.info(f"Executing plan {self._proto_to_string(plan)}")
req = self._execute_plan_request_with_metadata()
req.plan.CopyFrom(plan)
- return self._execute_and_fetch(req)
+ table, metrics = self._execute_and_fetch(req)
+ pdf = table.to_pandas()
+ if len(metrics) > 0:
+ pdf.attrs["metrics"] = metrics
+ return pdf
def _proto_schema_to_pyspark_schema(self, schema: pb2.DataType) ->
DataType:
return types.proto_schema_to_pyspark_data_type(schema)
@@ -521,10 +533,6 @@ class SparkConnectClient(object):
except grpc.RpcError as rpc_error:
self._handle_error(rpc_error)
- def _process_batch(self, arrow_batch: pb2.ExecutePlanResponse.ArrowBatch)
-> "pandas.DataFrame":
- with pa.ipc.open_stream(arrow_batch.data) as rd:
- return rd.read_pandas()
-
def _execute(self, req: pb2.ExecutePlanRequest) -> None:
"""
Execute the passed request `req` and drop all results.
@@ -546,12 +554,14 @@ class SparkConnectClient(object):
except grpc.RpcError as rpc_error:
self._handle_error(rpc_error)
- def _execute_and_fetch(self, req: pb2.ExecutePlanRequest) ->
"pandas.DataFrame":
+ def _execute_and_fetch(
+ self, req: pb2.ExecutePlanRequest
+ ) -> Tuple["pa.Table", List[PlanMetrics]]:
logger.info("ExecuteAndFetch")
- import pandas as pd
m: Optional[pb2.ExecutePlanResponse.Metrics] = None
- result_dfs = []
+
+ batches: List[pa.RecordBatch] = []
try:
for b in self._stub.ExecutePlan(req,
metadata=self._builder.metadata()):
@@ -567,25 +577,21 @@ class SparkConnectClient(object):
f"Received arrow batch rows={b.arrow_batch.row_count} "
f"size={len(b.arrow_batch.data)}"
)
- pb = self._process_batch(b.arrow_batch)
- result_dfs.append(pb)
+
+ with pa.ipc.open_stream(b.arrow_batch.data) as reader:
+ for batch in reader:
+ assert isinstance(batch, pa.RecordBatch)
+ batches.append(batch)
except grpc.RpcError as rpc_error:
self._handle_error(rpc_error)
- assert len(result_dfs) > 0
+ assert len(batches) > 0
- df = pd.concat(result_dfs)
+ table = pa.Table.from_batches(batches=batches)
- # pd.concat generates non-consecutive index like:
- # Int64Index([0, 1, 0, 1, 2, 0, 1, 0, 1, 2], dtype='int64')
- # set it to RangeIndex to be consistent with pyspark
- n = len(df)
- df.set_index(pd.RangeIndex(start=0, stop=n, step=1), inplace=True)
+ metrics: List[PlanMetrics] = self._build_metrics(m) if m is not None
else []
- # Attach the metrics to the DataFrame attributes.
- if m is not None:
- df.attrs["metrics"] = self._build_metrics(m)
- return df
+ return table, metrics
def _handle_error(self, rpc_error: grpc.RpcError) -> NoReturn:
"""
diff --git a/python/pyspark/sql/connect/column.py
b/python/pyspark/sql/connect/column.py
index 6fda15e084a..4d0b3de322d 100644
--- a/python/pyspark/sql/connect/column.py
+++ b/python/pyspark/sql/connect/column.py
@@ -448,8 +448,6 @@ def _test() -> None:
del pyspark.sql.connect.column.Column.dropFields.__doc__
# TODO(SPARK-41772): Enable
pyspark.sql.connect.column.Column.withField doctest
del pyspark.sql.connect.column.Column.withField.__doc__
- # TODO(SPARK-41815): Column.isNull returns nan instead of None
- del pyspark.sql.connect.column.Column.isNull.__doc__
# TODO(SPARK-41746): SparkSession.createDataFrame does not support
nested datatypes
del pyspark.sql.connect.column.Column.getField.__doc__
diff --git a/python/pyspark/sql/connect/dataframe.py
b/python/pyspark/sql/connect/dataframe.py
index b9d613870ab..fdb75d377b7 100644
--- a/python/pyspark/sql/connect/dataframe.py
+++ b/python/pyspark/sql/connect/dataframe.py
@@ -1016,11 +1016,23 @@ class DataFrame:
return ""
def collect(self) -> List[Row]:
- pdf = self.toPandas()
- if pdf is not None:
- return list(pdf.apply(lambda row: Row(**row), axis=1))
- else:
- return []
+ if self._plan is None:
+ raise Exception("Cannot collect on empty plan.")
+ if self._session is None:
+ raise Exception("Cannot collect on empty session.")
+ query = self._plan.to_proto(self._session.client)
+ table = self._session.client.to_table(query)
+
+ rows: List[Row] = []
+ for row in table.to_pylist():
+ _dict = {}
+ for k, v in row.items():
+ if isinstance(v, bytes):
+ _dict[k] = bytearray(v)
+ else:
+ _dict[k] = v
+ rows.append(Row(**_dict))
+ return rows
collect.__doc__ = PySparkDataFrame.collect.__doc__
diff --git a/python/pyspark/sql/connect/functions.py
b/python/pyspark/sql/connect/functions.py
index 77c7db2d808..965a9a5331e 100644
--- a/python/pyspark/sql/connect/functions.py
+++ b/python/pyspark/sql/connect/functions.py
@@ -1872,7 +1872,7 @@ translate.__doc__ = pysparkfuncs.translate.__doc__
# Date/Timestamp functions
-# TODO(SPARK-41283): Resolve dtypes inconsistencies for:
+# TODO(SPARK-41455): Resolve dtypes inconsistencies for:
# to_timestamp, from_utc_timestamp, to_utc_timestamp,
# timestamp_seconds, current_timestamp, date_trunc
@@ -2347,33 +2347,18 @@ def _test() -> None:
# Spark Connect does not support Spark Context but the test depends on
that.
del pyspark.sql.connect.functions.monotonically_increasing_id.__doc__
- # TODO(SPARK-41833): fix collect() output
- del pyspark.sql.connect.functions.array.__doc__
- del pyspark.sql.connect.functions.array_distinct.__doc__
- del pyspark.sql.connect.functions.array_except.__doc__
- del pyspark.sql.connect.functions.array_intersect.__doc__
- del pyspark.sql.connect.functions.array_remove.__doc__
- del pyspark.sql.connect.functions.array_repeat.__doc__
- del pyspark.sql.connect.functions.array_sort.__doc__
- del pyspark.sql.connect.functions.array_union.__doc__
- del pyspark.sql.connect.functions.collect_list.__doc__
- del pyspark.sql.connect.functions.collect_set.__doc__
- del pyspark.sql.connect.functions.concat.__doc__
+ # TODO(SPARK-41880): Function `from_json` should support non-literal
expression
+ # TODO(SPARK-41879): `DataFrame.collect` should support nested types
+ del pyspark.sql.connect.functions.struct.__doc__
del pyspark.sql.connect.functions.create_map.__doc__
- del pyspark.sql.connect.functions.date_trunc.__doc__
- del pyspark.sql.connect.functions.from_utc_timestamp.__doc__
del pyspark.sql.connect.functions.from_csv.__doc__
del pyspark.sql.connect.functions.from_json.__doc__
- del pyspark.sql.connect.functions.isnull.__doc__
- del pyspark.sql.connect.functions.reverse.__doc__
- del pyspark.sql.connect.functions.sequence.__doc__
- del pyspark.sql.connect.functions.slice.__doc__
- del pyspark.sql.connect.functions.sort_array.__doc__
- del pyspark.sql.connect.functions.split.__doc__
- del pyspark.sql.connect.functions.struct.__doc__
+
+ # TODO(SPARK-41455): Resolve dtypes inconsistencies of date/timestamp
functions
del pyspark.sql.connect.functions.to_timestamp.__doc__
del pyspark.sql.connect.functions.to_utc_timestamp.__doc__
- del pyspark.sql.connect.functions.unhex.__doc__
+ del pyspark.sql.connect.functions.date_trunc.__doc__
+ del pyspark.sql.connect.functions.from_utc_timestamp.__doc__
# TODO(SPARK-41825): Dataframe.show formatting int as double
del pyspark.sql.connect.functions.coalesce.__doc__
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]