This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch branch-3.4
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push:
new ed797bbdaeb [SPARK-42340][CONNECT][PYTHON][3.4] Implement Grouped Map
API
ed797bbdaeb is described below
commit ed797bbdaeb1e421ca1d14620f72257560896a1c
Author: Xinrong Meng <[email protected]>
AuthorDate: Tue Mar 21 16:42:43 2023 +0900
[SPARK-42340][CONNECT][PYTHON][3.4] Implement Grouped Map API
### What changes were proposed in this pull request?
Implement Grouped Map API:`GroupedData.applyInPandas` and
`GroupedData.apply`.
### Why are the changes needed?
Parity with vanilla PySpark.
### Does this PR introduce _any_ user-facing change?
Yes. `GroupedData.applyInPandas` and `GroupedData.apply` are supported now,
as shown below.
```sh
>>> df = spark.createDataFrame([(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2,
10.0)],("id", "v"))
>>> def normalize(pdf):
... v = pdf.v
... return pdf.assign(v=(v - v.mean()) / v.std())
...
>>> df.groupby("id").applyInPandas(normalize, schema="id long, v
double").show()
+---+-------------------+
| id| v|
+---+-------------------+
| 1|-0.7071067811865475|
| 1| 0.7071067811865475|
| 2|-0.8320502943378437|
| 2|-0.2773500981126146|
| 2| 1.1094003924504583|
+---+-------------------+
```
```sh
>>> pandas_udf("id long, v double", PandasUDFType.GROUPED_MAP)
... def normalize(pdf):
... v = pdf.v
... return pdf.assign(v=(v - v.mean()) / v.std())
...
>>> df.groupby("id").apply(normalize).show()
/Users/xinrong.meng/spark/python/pyspark/sql/connect/group.py:228:
UserWarning: It is preferred to use 'applyInPandas' over this API. This API
will be deprecated in the future releases. See SPARK-28264 for more details.
warnings.warn(
+---+-------------------+
| id| v|
+---+-------------------+
| 1|-0.7071067811865475|
| 1| 0.7071067811865475|
| 2|-0.8320502943378437|
| 2|-0.2773500981126146|
| 2| 1.1094003924504583|
+---+-------------------+
```
### How was this patch tested?
(Parity) Unit tests.
Closes #40486 from xinrong-meng/group_map3.4.
Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../main/protobuf/spark/connect/relations.proto | 12 +
.../sql/connect/planner/SparkConnectPlanner.scala | 14 ++
dev/sparktestsupport/modules.py | 1 +
python/pyspark/sql/connect/_typing.py | 10 +-
python/pyspark/sql/connect/group.py | 61 +++++-
python/pyspark/sql/connect/plan.py | 27 +++
python/pyspark/sql/connect/proto/relations_pb2.py | 242 +++++++++++----------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 51 +++++
python/pyspark/sql/pandas/group_ops.py | 6 +
.../sql/tests/connect/test_connect_basic.py | 2 -
.../connect/test_parity_pandas_grouped_map.py | 102 +++++++++
.../sql/tests/connect/test_parity_pandas_udf.py | 5 -
.../sql/tests/pandas/test_pandas_grouped_map.py | 6 +-
13 files changed, 411 insertions(+), 128 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index 69451e7b76e..aba965082ea 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -63,6 +63,7 @@ message Relation {
MapPartitions map_partitions = 28;
CollectMetrics collect_metrics = 29;
Parse parse = 30;
+ GroupMap group_map = 31;
// NA functions
NAFill fill_na = 90;
@@ -788,6 +789,17 @@ message MapPartitions {
CommonInlineUserDefinedFunction func = 2;
}
+message GroupMap {
+ // (Required) Input relation for Group Map API: apply, applyInPandas.
+ Relation input = 1;
+
+ // (Required) Expressions for grouping keys.
+ repeated Expression grouping_expressions = 2;
+
+ // (Required) Input user-defined function.
+ CommonInlineUserDefinedFunction func = 3;
+}
+
// Collect arbitrary (named) metrics from a dataset.
message CollectMetrics {
// (Required) The input relation.
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index ded5ffa78f9..41a867d3b9d 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -116,6 +116,8 @@ class SparkConnectPlanner(val session: SparkSession) {
transformRepartitionByExpression(rel.getRepartitionByExpression)
case proto.Relation.RelTypeCase.MAP_PARTITIONS =>
transformMapPartitions(rel.getMapPartitions)
+ case proto.Relation.RelTypeCase.GROUP_MAP =>
+ transformGroupMap(rel.getGroupMap)
case proto.Relation.RelTypeCase.COLLECT_METRICS =>
transformCollectMetrics(rel.getCollectMetrics)
case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse)
@@ -494,6 +496,18 @@ class SparkConnectPlanner(val session: SparkSession) {
}
}
+ private def transformGroupMap(rel: proto.GroupMap): LogicalPlan = {
+ val pythonUdf = transformPythonUDF(rel.getFunc)
+ val cols =
+ rel.getGroupingExpressionsList.asScala.toSeq.map(expr =>
Column(transformExpression(expr)))
+
+ Dataset
+ .ofRows(session, transformRelation(rel.getInput))
+ .groupBy(cols: _*)
+ .flatMapGroupsInPandas(pythonUdf)
+ .logicalPlan
+ }
+
private def transformWithColumnsRenamed(rel: proto.WithColumnsRenamed):
LogicalPlan = {
Dataset
.ofRows(session, transformRelation(rel.getInput))
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 3a9b7a9d0fa..5921391b227 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -537,6 +537,7 @@ pyspark_connect = Module(
"pyspark.sql.tests.connect.test_parity_pandas_udf",
"pyspark.sql.tests.connect.test_parity_pandas_map",
"pyspark.sql.tests.connect.test_parity_arrow_map",
+ "pyspark.sql.tests.connect.test_parity_pandas_grouped_map",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy,
pandas, and pyarrow and
diff --git a/python/pyspark/sql/connect/_typing.py
b/python/pyspark/sql/connect/_typing.py
index 6df3f15d87d..63aae5d2487 100644
--- a/python/pyspark/sql/connect/_typing.py
+++ b/python/pyspark/sql/connect/_typing.py
@@ -22,7 +22,8 @@ if sys.version_info >= (3, 8):
else:
from typing_extensions import Protocol
-from typing import Any, Callable, Iterable, Union, Optional
+from types import FunctionType
+from typing import Any, Callable, Iterable, Union, Optional, NewType
import datetime
import decimal
@@ -53,6 +54,13 @@ PandasMapIterFunction = Callable[[Iterable[DataFrameLike]],
Iterable[DataFrameLi
ArrowMapIterFunction = Callable[[Iterable[pyarrow.RecordBatch]],
Iterable[pyarrow.RecordBatch]]
+PandasGroupedMapFunction = Union[
+ Callable[[DataFrameLike], DataFrameLike],
+ Callable[[Any, DataFrameLike], DataFrameLike],
+]
+
+GroupedMapPandasUserDefinedFunction =
NewType("GroupedMapPandasUserDefinedFunction", FunctionType)
+
class UserDefinedFunctionLike(Protocol):
func: Callable[..., Any]
diff --git a/python/pyspark/sql/connect/group.py
b/python/pyspark/sql/connect/group.py
index e699ce7105a..a75a50501bd 100644
--- a/python/pyspark/sql/connect/group.py
+++ b/python/pyspark/sql/connect/group.py
@@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import warnings
+
from pyspark.sql.connect.utils import check_dependencies
check_dependencies(__name__)
@@ -30,6 +32,7 @@ from typing import (
cast,
)
+from pyspark.rdd import PythonEvalType
from pyspark.sql.group import GroupedData as PySparkGroupedData
from pyspark.sql.types import NumericType
@@ -38,8 +41,13 @@ from pyspark.sql.connect.column import Column
from pyspark.sql.connect.functions import _invoke_function, col, lit
if TYPE_CHECKING:
- from pyspark.sql.connect._typing import LiteralType
+ from pyspark.sql.connect._typing import (
+ LiteralType,
+ PandasGroupedMapFunction,
+ GroupedMapPandasUserDefinedFunction,
+ )
from pyspark.sql.connect.dataframe import DataFrame
+ from pyspark.sql.types import StructType
class GroupedData:
@@ -203,11 +211,54 @@ class GroupedData:
pivot.__doc__ = PySparkGroupedData.pivot.__doc__
- def apply(self, *args: Any, **kwargs: Any) -> None:
- raise NotImplementedError("apply() is not implemented.")
+ def apply(self, udf: "GroupedMapPandasUserDefinedFunction") -> "DataFrame":
+ # Columns are special because hasattr always return True
+ if (
+ isinstance(udf, Column)
+ or not hasattr(udf, "func")
+ or (
+ udf.evalType # type: ignore[attr-defined]
+ != PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF
+ )
+ ):
+ raise ValueError(
+ "Invalid udf: the udf argument must be a pandas_udf of type "
"GROUPED_MAP."
+ )
+
+ warnings.warn(
+ "It is preferred to use 'applyInPandas' over this "
+ "API. This API will be deprecated in the future releases. See
SPARK-28264 for "
+ "more details.",
+ UserWarning,
+ )
+
+ return self.applyInPandas(udf.func, schema=udf.returnType) # type:
ignore[attr-defined]
+
+ apply.__doc__ = PySparkGroupedData.apply.__doc__
+
+ def applyInPandas(
+ self, func: "PandasGroupedMapFunction", schema: Union["StructType",
str]
+ ) -> "DataFrame":
+ from pyspark.sql.connect.udf import UserDefinedFunction
+ from pyspark.sql.connect.dataframe import DataFrame
+
+ udf_obj = UserDefinedFunction(
+ func,
+ returnType=schema,
+ evalType=PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
+ )
+
+ return DataFrame.withPlan(
+ plan.GroupMap(
+ child=self._df._plan,
+ grouping_cols=self._grouping_cols,
+ function=udf_obj,
+ cols=self._df.columns,
+ ),
+ session=self._df._session,
+ )
- def applyInPandas(self, *args: Any, **kwargs: Any) -> None:
- raise NotImplementedError("applyInPandas() is not implemented.")
+ applyInPandas.__doc__ = PySparkGroupedData.applyInPandas.__doc__
def applyInPandasWithState(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError("applyInPandasWithState() is not
implemented.")
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index 9807c9722a6..dbfcfea7678 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -1923,6 +1923,33 @@ class MapPartitions(LogicalPlan):
return plan
+class GroupMap(LogicalPlan):
+ """Logical plan object for a Group Map API: apply, applyInPandas."""
+
+ def __init__(
+ self,
+ child: Optional["LogicalPlan"],
+ grouping_cols: Sequence[Column],
+ function: "UserDefinedFunction",
+ cols: List[str],
+ ):
+ assert isinstance(grouping_cols, list) and all(isinstance(c, Column)
for c in grouping_cols)
+
+ super().__init__(child)
+ self._grouping_cols = grouping_cols
+ self._func = function._build_common_inline_user_defined_function(*cols)
+
+ def plan(self, session: "SparkConnectClient") -> proto.Relation:
+ assert self._child is not None
+ plan = self._create_proto_relation()
+ plan.group_map.input.CopyFrom(self._child.plan(session))
+ plan.group_map.grouping_expressions.extend(
+ [c.to_plan(session) for c in self._grouping_cols]
+ )
+ plan.group_map.func.CopyFrom(self._func.to_plan_udf(session))
+ return plan
+
+
class CachedRelation(LogicalPlan):
def __init__(self, plan: proto.Relation) -> None:
super(CachedRelation, self).__init__(None)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py
b/python/pyspark/sql/connect/proto/relations_pb2.py
index 521a10f214c..aa6d39cd4f0 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as
spark_dot_connect_dot_catal
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xb8\x13\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xf0\x13\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
)
@@ -92,6 +92,7 @@ _UNPIVOT_VALUES = _UNPIVOT.nested_types_by_name["Values"]
_TOSCHEMA = DESCRIPTOR.message_types_by_name["ToSchema"]
_REPARTITIONBYEXPRESSION =
DESCRIPTOR.message_types_by_name["RepartitionByExpression"]
_MAPPARTITIONS = DESCRIPTOR.message_types_by_name["MapPartitions"]
+_GROUPMAP = DESCRIPTOR.message_types_by_name["GroupMap"]
_COLLECTMETRICS = DESCRIPTOR.message_types_by_name["CollectMetrics"]
_PARSE = DESCRIPTOR.message_types_by_name["Parse"]
_PARSE_OPTIONSENTRY = _PARSE.nested_types_by_name["OptionsEntry"]
@@ -640,6 +641,17 @@ MapPartitions = _reflection.GeneratedProtocolMessageType(
)
_sym_db.RegisterMessage(MapPartitions)
+GroupMap = _reflection.GeneratedProtocolMessageType(
+ "GroupMap",
+ (_message.Message,),
+ {
+ "DESCRIPTOR": _GROUPMAP,
+ "__module__": "spark.connect.relations_pb2"
+ # @@protoc_insertion_point(class_scope:spark.connect.GroupMap)
+ },
+)
+_sym_db.RegisterMessage(GroupMap)
+
CollectMetrics = _reflection.GeneratedProtocolMessageType(
"CollectMetrics",
(_message.Message,),
@@ -685,117 +697,119 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_PARSE_OPTIONSENTRY._options = None
_PARSE_OPTIONSENTRY._serialized_options = b"8\001"
_RELATION._serialized_start = 165
- _RELATION._serialized_end = 2653
- _UNKNOWN._serialized_start = 2655
- _UNKNOWN._serialized_end = 2664
- _RELATIONCOMMON._serialized_start = 2666
- _RELATIONCOMMON._serialized_end = 2757
- _SQL._serialized_start = 2760
- _SQL._serialized_end = 2894
- _SQL_ARGSENTRY._serialized_start = 2839
- _SQL_ARGSENTRY._serialized_end = 2894
- _READ._serialized_start = 2897
- _READ._serialized_end = 3393
- _READ_NAMEDTABLE._serialized_start = 3039
- _READ_NAMEDTABLE._serialized_end = 3100
- _READ_DATASOURCE._serialized_start = 3103
- _READ_DATASOURCE._serialized_end = 3380
- _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3300
- _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3358
- _PROJECT._serialized_start = 3395
- _PROJECT._serialized_end = 3512
- _FILTER._serialized_start = 3514
- _FILTER._serialized_end = 3626
- _JOIN._serialized_start = 3629
- _JOIN._serialized_end = 4100
- _JOIN_JOINTYPE._serialized_start = 3892
- _JOIN_JOINTYPE._serialized_end = 4100
- _SETOPERATION._serialized_start = 4103
- _SETOPERATION._serialized_end = 4582
- _SETOPERATION_SETOPTYPE._serialized_start = 4419
- _SETOPERATION_SETOPTYPE._serialized_end = 4533
- _LIMIT._serialized_start = 4584
- _LIMIT._serialized_end = 4660
- _OFFSET._serialized_start = 4662
- _OFFSET._serialized_end = 4741
- _TAIL._serialized_start = 4743
- _TAIL._serialized_end = 4818
- _AGGREGATE._serialized_start = 4821
- _AGGREGATE._serialized_end = 5403
- _AGGREGATE_PIVOT._serialized_start = 5160
- _AGGREGATE_PIVOT._serialized_end = 5271
- _AGGREGATE_GROUPTYPE._serialized_start = 5274
- _AGGREGATE_GROUPTYPE._serialized_end = 5403
- _SORT._serialized_start = 5406
- _SORT._serialized_end = 5566
- _DROP._serialized_start = 5569
- _DROP._serialized_end = 5710
- _DEDUPLICATE._serialized_start = 5713
- _DEDUPLICATE._serialized_end = 5884
- _LOCALRELATION._serialized_start = 5886
- _LOCALRELATION._serialized_end = 5975
- _SAMPLE._serialized_start = 5978
- _SAMPLE._serialized_end = 6251
- _RANGE._serialized_start = 6254
- _RANGE._serialized_end = 6399
- _SUBQUERYALIAS._serialized_start = 6401
- _SUBQUERYALIAS._serialized_end = 6515
- _REPARTITION._serialized_start = 6518
- _REPARTITION._serialized_end = 6660
- _SHOWSTRING._serialized_start = 6663
- _SHOWSTRING._serialized_end = 6805
- _STATSUMMARY._serialized_start = 6807
- _STATSUMMARY._serialized_end = 6899
- _STATDESCRIBE._serialized_start = 6901
- _STATDESCRIBE._serialized_end = 6982
- _STATCROSSTAB._serialized_start = 6984
- _STATCROSSTAB._serialized_end = 7085
- _STATCOV._serialized_start = 7087
- _STATCOV._serialized_end = 7183
- _STATCORR._serialized_start = 7186
- _STATCORR._serialized_end = 7323
- _STATAPPROXQUANTILE._serialized_start = 7326
- _STATAPPROXQUANTILE._serialized_end = 7490
- _STATFREQITEMS._serialized_start = 7492
- _STATFREQITEMS._serialized_end = 7617
- _STATSAMPLEBY._serialized_start = 7620
- _STATSAMPLEBY._serialized_end = 7929
- _STATSAMPLEBY_FRACTION._serialized_start = 7821
- _STATSAMPLEBY_FRACTION._serialized_end = 7920
- _NAFILL._serialized_start = 7932
- _NAFILL._serialized_end = 8066
- _NADROP._serialized_start = 8069
- _NADROP._serialized_end = 8203
- _NAREPLACE._serialized_start = 8206
- _NAREPLACE._serialized_end = 8502
- _NAREPLACE_REPLACEMENT._serialized_start = 8361
- _NAREPLACE_REPLACEMENT._serialized_end = 8502
- _TODF._serialized_start = 8504
- _TODF._serialized_end = 8592
- _WITHCOLUMNSRENAMED._serialized_start = 8595
- _WITHCOLUMNSRENAMED._serialized_end = 8834
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8767
- _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8834
- _WITHCOLUMNS._serialized_start = 8836
- _WITHCOLUMNS._serialized_end = 8955
- _HINT._serialized_start = 8958
- _HINT._serialized_end = 9090
- _UNPIVOT._serialized_start = 9093
- _UNPIVOT._serialized_end = 9420
- _UNPIVOT_VALUES._serialized_start = 9350
- _UNPIVOT_VALUES._serialized_end = 9409
- _TOSCHEMA._serialized_start = 9422
- _TOSCHEMA._serialized_end = 9528
- _REPARTITIONBYEXPRESSION._serialized_start = 9531
- _REPARTITIONBYEXPRESSION._serialized_end = 9734
- _MAPPARTITIONS._serialized_start = 9737
- _MAPPARTITIONS._serialized_end = 9867
- _COLLECTMETRICS._serialized_start = 9870
- _COLLECTMETRICS._serialized_end = 10006
- _PARSE._serialized_start = 10009
- _PARSE._serialized_end = 10397
- _PARSE_OPTIONSENTRY._serialized_start = 3300
- _PARSE_OPTIONSENTRY._serialized_end = 3358
- _PARSE_PARSEFORMAT._serialized_start = 10298
- _PARSE_PARSEFORMAT._serialized_end = 10386
+ _RELATION._serialized_end = 2709
+ _UNKNOWN._serialized_start = 2711
+ _UNKNOWN._serialized_end = 2720
+ _RELATIONCOMMON._serialized_start = 2722
+ _RELATIONCOMMON._serialized_end = 2813
+ _SQL._serialized_start = 2816
+ _SQL._serialized_end = 2950
+ _SQL_ARGSENTRY._serialized_start = 2895
+ _SQL_ARGSENTRY._serialized_end = 2950
+ _READ._serialized_start = 2953
+ _READ._serialized_end = 3449
+ _READ_NAMEDTABLE._serialized_start = 3095
+ _READ_NAMEDTABLE._serialized_end = 3156
+ _READ_DATASOURCE._serialized_start = 3159
+ _READ_DATASOURCE._serialized_end = 3436
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3356
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3414
+ _PROJECT._serialized_start = 3451
+ _PROJECT._serialized_end = 3568
+ _FILTER._serialized_start = 3570
+ _FILTER._serialized_end = 3682
+ _JOIN._serialized_start = 3685
+ _JOIN._serialized_end = 4156
+ _JOIN_JOINTYPE._serialized_start = 3948
+ _JOIN_JOINTYPE._serialized_end = 4156
+ _SETOPERATION._serialized_start = 4159
+ _SETOPERATION._serialized_end = 4638
+ _SETOPERATION_SETOPTYPE._serialized_start = 4475
+ _SETOPERATION_SETOPTYPE._serialized_end = 4589
+ _LIMIT._serialized_start = 4640
+ _LIMIT._serialized_end = 4716
+ _OFFSET._serialized_start = 4718
+ _OFFSET._serialized_end = 4797
+ _TAIL._serialized_start = 4799
+ _TAIL._serialized_end = 4874
+ _AGGREGATE._serialized_start = 4877
+ _AGGREGATE._serialized_end = 5459
+ _AGGREGATE_PIVOT._serialized_start = 5216
+ _AGGREGATE_PIVOT._serialized_end = 5327
+ _AGGREGATE_GROUPTYPE._serialized_start = 5330
+ _AGGREGATE_GROUPTYPE._serialized_end = 5459
+ _SORT._serialized_start = 5462
+ _SORT._serialized_end = 5622
+ _DROP._serialized_start = 5625
+ _DROP._serialized_end = 5766
+ _DEDUPLICATE._serialized_start = 5769
+ _DEDUPLICATE._serialized_end = 5940
+ _LOCALRELATION._serialized_start = 5942
+ _LOCALRELATION._serialized_end = 6031
+ _SAMPLE._serialized_start = 6034
+ _SAMPLE._serialized_end = 6307
+ _RANGE._serialized_start = 6310
+ _RANGE._serialized_end = 6455
+ _SUBQUERYALIAS._serialized_start = 6457
+ _SUBQUERYALIAS._serialized_end = 6571
+ _REPARTITION._serialized_start = 6574
+ _REPARTITION._serialized_end = 6716
+ _SHOWSTRING._serialized_start = 6719
+ _SHOWSTRING._serialized_end = 6861
+ _STATSUMMARY._serialized_start = 6863
+ _STATSUMMARY._serialized_end = 6955
+ _STATDESCRIBE._serialized_start = 6957
+ _STATDESCRIBE._serialized_end = 7038
+ _STATCROSSTAB._serialized_start = 7040
+ _STATCROSSTAB._serialized_end = 7141
+ _STATCOV._serialized_start = 7143
+ _STATCOV._serialized_end = 7239
+ _STATCORR._serialized_start = 7242
+ _STATCORR._serialized_end = 7379
+ _STATAPPROXQUANTILE._serialized_start = 7382
+ _STATAPPROXQUANTILE._serialized_end = 7546
+ _STATFREQITEMS._serialized_start = 7548
+ _STATFREQITEMS._serialized_end = 7673
+ _STATSAMPLEBY._serialized_start = 7676
+ _STATSAMPLEBY._serialized_end = 7985
+ _STATSAMPLEBY_FRACTION._serialized_start = 7877
+ _STATSAMPLEBY_FRACTION._serialized_end = 7976
+ _NAFILL._serialized_start = 7988
+ _NAFILL._serialized_end = 8122
+ _NADROP._serialized_start = 8125
+ _NADROP._serialized_end = 8259
+ _NAREPLACE._serialized_start = 8262
+ _NAREPLACE._serialized_end = 8558
+ _NAREPLACE_REPLACEMENT._serialized_start = 8417
+ _NAREPLACE_REPLACEMENT._serialized_end = 8558
+ _TODF._serialized_start = 8560
+ _TODF._serialized_end = 8648
+ _WITHCOLUMNSRENAMED._serialized_start = 8651
+ _WITHCOLUMNSRENAMED._serialized_end = 8890
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_start = 8823
+ _WITHCOLUMNSRENAMED_RENAMECOLUMNSMAPENTRY._serialized_end = 8890
+ _WITHCOLUMNS._serialized_start = 8892
+ _WITHCOLUMNS._serialized_end = 9011
+ _HINT._serialized_start = 9014
+ _HINT._serialized_end = 9146
+ _UNPIVOT._serialized_start = 9149
+ _UNPIVOT._serialized_end = 9476
+ _UNPIVOT_VALUES._serialized_start = 9406
+ _UNPIVOT_VALUES._serialized_end = 9465
+ _TOSCHEMA._serialized_start = 9478
+ _TOSCHEMA._serialized_end = 9584
+ _REPARTITIONBYEXPRESSION._serialized_start = 9587
+ _REPARTITIONBYEXPRESSION._serialized_end = 9790
+ _MAPPARTITIONS._serialized_start = 9793
+ _MAPPARTITIONS._serialized_end = 9923
+ _GROUPMAP._serialized_start = 9926
+ _GROUPMAP._serialized_end = 10129
+ _COLLECTMETRICS._serialized_start = 10132
+ _COLLECTMETRICS._serialized_end = 10268
+ _PARSE._serialized_start = 10271
+ _PARSE._serialized_end = 10659
+ _PARSE_OPTIONSENTRY._serialized_start = 3356
+ _PARSE_OPTIONSENTRY._serialized_end = 3414
+ _PARSE_PARSEFORMAT._serialized_start = 10560
+ _PARSE_PARSEFORMAT._serialized_end = 10648
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index ab1561996ef..6ae4a323f6f 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -92,6 +92,7 @@ class Relation(google.protobuf.message.Message):
MAP_PARTITIONS_FIELD_NUMBER: builtins.int
COLLECT_METRICS_FIELD_NUMBER: builtins.int
PARSE_FIELD_NUMBER: builtins.int
+ GROUP_MAP_FIELD_NUMBER: builtins.int
FILL_NA_FIELD_NUMBER: builtins.int
DROP_NA_FIELD_NUMBER: builtins.int
REPLACE_FIELD_NUMBER: builtins.int
@@ -167,6 +168,8 @@ class Relation(google.protobuf.message.Message):
@property
def parse(self) -> global___Parse: ...
@property
+ def group_map(self) -> global___GroupMap: ...
+ @property
def fill_na(self) -> global___NAFill:
"""NA functions"""
@property
@@ -233,6 +236,7 @@ class Relation(google.protobuf.message.Message):
map_partitions: global___MapPartitions | None = ...,
collect_metrics: global___CollectMetrics | None = ...,
parse: global___Parse | None = ...,
+ group_map: global___GroupMap | None = ...,
fill_na: global___NAFill | None = ...,
drop_na: global___NADrop | None = ...,
replace: global___NAReplace | None = ...,
@@ -283,6 +287,8 @@ class Relation(google.protobuf.message.Message):
b"filter",
"freq_items",
b"freq_items",
+ "group_map",
+ b"group_map",
"hint",
b"hint",
"join",
@@ -378,6 +384,8 @@ class Relation(google.protobuf.message.Message):
b"filter",
"freq_items",
b"freq_items",
+ "group_map",
+ b"group_map",
"hint",
b"hint",
"join",
@@ -470,6 +478,7 @@ class Relation(google.protobuf.message.Message):
"map_partitions",
"collect_metrics",
"parse",
+ "group_map",
"fill_na",
"drop_na",
"replace",
@@ -2733,6 +2742,48 @@ class MapPartitions(google.protobuf.message.Message):
global___MapPartitions = MapPartitions
+class GroupMap(google.protobuf.message.Message):
+ DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+ INPUT_FIELD_NUMBER: builtins.int
+ GROUPING_EXPRESSIONS_FIELD_NUMBER: builtins.int
+ FUNC_FIELD_NUMBER: builtins.int
+ @property
+ def input(self) -> global___Relation:
+ """(Required) Input relation for Group Map API: apply,
applyInPandas."""
+ @property
+ def grouping_expressions(
+ self,
+ ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+ pyspark.sql.connect.proto.expressions_pb2.Expression
+ ]:
+ """(Required) Expressions for grouping keys."""
+ @property
+ def func(self) ->
pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction:
+ """(Required) Input user-defined function."""
+ def __init__(
+ self,
+ *,
+ input: global___Relation | None = ...,
+ grouping_expressions: collections.abc.Iterable[
+ pyspark.sql.connect.proto.expressions_pb2.Expression
+ ]
+ | None = ...,
+ func:
pyspark.sql.connect.proto.expressions_pb2.CommonInlineUserDefinedFunction
+ | None = ...,
+ ) -> None: ...
+ def HasField(
+ self, field_name: typing_extensions.Literal["func", b"func", "input",
b"input"]
+ ) -> builtins.bool: ...
+ def ClearField(
+ self,
+ field_name: typing_extensions.Literal[
+ "func", b"func", "grouping_expressions", b"grouping_expressions",
"input", b"input"
+ ],
+ ) -> None: ...
+
+global___GroupMap = GroupMap
+
class CollectMetrics(google.protobuf.message.Message):
"""Collect arbitrary (named) metrics from a dataset."""
diff --git a/python/pyspark/sql/pandas/group_ops.py
b/python/pyspark/sql/pandas/group_ops.py
index bca96eaf205..f03aa35bb83 100644
--- a/python/pyspark/sql/pandas/group_ops.py
+++ b/python/pyspark/sql/pandas/group_ops.py
@@ -48,6 +48,9 @@ class PandasGroupedOpsMixin:
.. versionadded:: 2.3.0
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
+
Parameters
----------
udf : :func:`pyspark.sql.functions.pandas_udf`
@@ -128,6 +131,9 @@ class PandasGroupedOpsMixin:
.. versionadded:: 3.0.0
+ .. versionchanged:: 3.4.0
+ Support Spark Connect.
+
Parameters
----------
func : function
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index a8e161a42a6..491865ad9c9 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -2845,8 +2845,6 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
# SPARK-41927: Disable unsupported functions.
cg = self.connect.read.table(self.tbl_name).groupBy("id")
for f in (
- "apply",
- "applyInPandas",
"applyInPandasWithState",
"cogroup",
):
diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py
b/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py
new file mode 100644
index 00000000000..e4a0d2ad85e
--- /dev/null
+++ b/python/pyspark/sql/tests/connect/test_parity_pandas_grouped_map.py
@@ -0,0 +1,102 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import unittest
+
+from pyspark.sql.tests.pandas.test_pandas_grouped_map import
GroupedApplyInPandasTestsMixin
+from pyspark.testing.connectutils import ReusedConnectTestCase
+
+
+class GroupedApplyInPandasTests(GroupedApplyInPandasTestsMixin,
ReusedConnectTestCase):
+ # TODO(SPARK-42822): Fix ambiguous reference for case-insensitive grouping
column
+ @unittest.skip("Fails in Spark Connect, should enable.")
+ def test_case_insensitive_grouping_column(self):
+ super().test_case_insensitive_grouping_column()
+
+ # TODO(SPARK-42857): Support CreateDataFrame from Decimal128
+ @unittest.skip("Fails in Spark Connect, should enable.")
+ def test_supported_types(self):
+ super().test_supported_types()
+
+ @unittest.skip(
+ "Spark Connect does not support sc._jvm.org.apache.log4j but the test
depends on it."
+ )
+ def test_wrong_return_type(self):
+ super().test_wrong_return_type()
+
+ @unittest.skip(
+ "Spark Connect does not support sc._jvm.org.apache.log4j but the test
depends on it."
+ )
+ def test_wrong_args(self):
+ super().test_wrong_args()
+
+ @unittest.skip(
+ "Spark Connect does not support sc._jvm.org.apache.log4j but the test
depends on it."
+ )
+ def test_unsupported_types(self):
+ super().test_unsupported_types()
+
+ @unittest.skip(
+ "Spark Connect does not support sc._jvm.org.apache.log4j but the test
depends on it."
+ )
+ def test_register_grouped_map_udf(self):
+ super().test_register_grouped_map_udf()
+
+ @unittest.skip(
+ "Spark Connect does not support sc._jvm.org.apache.log4j but the test
depends on it."
+ )
+ def test_column_order(self):
+ super().test_column_order()
+
+ @unittest.skip(
+ "Spark Connect does not support sc._jvm.org.apache.log4j but the test
depends on it."
+ )
+ def test_apply_in_pandas_returning_no_column_names_and_wrong_amount(self):
+
super().test_apply_in_pandas_returning_no_column_names_and_wrong_amount()
+
+ @unittest.skip(
+ "Spark Connect does not support sc._jvm.org.apache.log4j but the test
depends on it."
+ )
+ def test_apply_in_pandas_not_returning_pandas_dataframe(self):
+ super().test_apply_in_pandas_not_returning_pandas_dataframe()
+
+ @unittest.skip("Spark Connect doesn't support RDD but the test depends on
it.")
+ def test_grouped_with_empty_partition(self):
+ super().test_grouped_with_empty_partition()
+
+ @unittest.skip(
+ "Spark Connect does not support sc._jvm.org.apache.log4j but the test
depends on it."
+ )
+ def
test_apply_in_pandas_returning_empty_dataframe_and_wrong_number_of_columns(self):
+
super().test_apply_in_pandas_returning_empty_dataframe_and_wrong_number_of_columns()
+
+ @unittest.skip(
+ "Spark Connect does not support sc._jvm.org.apache.log4j but the test
depends on it."
+ )
+ def test_apply_in_pandas_returning_wrong_number_of_columns(self):
+ super().test_apply_in_pandas_returning_wrong_number_of_columns()
+
+
+if __name__ == "__main__":
+ from pyspark.sql.tests.connect.test_parity_pandas_grouped_map import * #
noqa: F401
+
+ try:
+ import xmlrunner
+
+ testRunner = xmlrunner.XMLTestRunner(output="target/test-reports",
verbosity=2)
+ except ImportError:
+ testRunner = None
+ unittest.main(testRunner=testRunner, verbosity=2)
diff --git a/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py
b/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py
index 571ee74287e..d2eab7fa4f3 100644
--- a/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py
+++ b/python/pyspark/sql/tests/connect/test_parity_pandas_udf.py
@@ -66,11 +66,6 @@ class PandasUDFParityTests(PandasUDFTestsMixin,
ReusedConnectTestCase):
self.assertEqual(udf.returnType, UnparsedDataType("v double"))
self.assertEqual(udf.evalType, PandasUDFType.GROUPED_MAP)
- # TODO(SPARK-42340): implement GroupedData.applyInPandas
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_stopiteration_in_grouped_map(self):
- super().test_stopiteration_in_grouped_map()
-
if __name__ == "__main__":
import unittest
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index 5f103c97926..c2c97bc7149 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -73,7 +73,7 @@ if have_pyarrow:
not have_pandas or not have_pyarrow,
cast(str, pandas_requirement_message or pyarrow_requirement_message),
)
-class GroupedMapInPandasTests(ReusedSQLTestCase):
+class GroupedApplyInPandasTestsMixin:
@property
def data(self):
return (
@@ -740,6 +740,10 @@ class GroupedMapInPandasTests(ReusedSQLTestCase):
self.assertEqual(row.asDict(), Row(column=1, score=0.5).asDict())
+class GroupedApplyInPandasTests(GroupedApplyInPandasTestsMixin,
ReusedSQLTestCase):
+ pass
+
+
if __name__ == "__main__":
from pyspark.sql.tests.pandas.test_pandas_grouped_map import * # noqa:
F401
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]