This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 201535d3254 [SPARK-42041][SPARK-42013][CONNECT][PYTHON]
DataFrameReader should support list of paths
201535d3254 is described below
commit 201535d3254af1534bcb802da4d309bfae373d73
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Sat Jan 14 15:45:19 2023 +0900
[SPARK-42041][SPARK-42013][CONNECT][PYTHON] DataFrameReader should support
list of paths
### What changes were proposed in this pull request?
DataFrameReader should support list of paths
### Why are the changes needed?
for parity
### Does this PR introduce _any_ user-facing change?
yes
### How was this patch tested?
added and enabled UT
Closes #39553 from zhengruifeng/connect_io_paths.
Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../main/protobuf/spark/connect/relations.proto | 3 +
.../sql/connect/planner/SparkConnectPlanner.scala | 13 +-
python/pyspark/sql/connect/plan.py | 41 +++--
python/pyspark/sql/connect/proto/relations_pb2.py | 178 ++++++++++-----------
python/pyspark/sql/connect/proto/relations_pb2.pyi | 9 ++
python/pyspark/sql/connect/readwriter.py | 30 ++--
.../sql/tests/connect/test_connect_basic.py | 27 ++++
.../pyspark/sql/tests/connect/test_connect_plan.py | 5 +-
.../sql/tests/connect/test_parity_datasources.py | 5 -
9 files changed, 190 insertions(+), 121 deletions(-)
diff --git
a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
index f029273c48b..e283c522aa7 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/relations.proto
@@ -127,6 +127,9 @@ message Read {
// data source format. This options could be empty for valid data source
format.
// The map key is case insensitive.
map<string, string> options = 3;
+
+ // (Optional) A list of path for file-system backed data sources.
+ repeated string paths = 4;
}
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 512ac8efe2e..a888c55769b 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -645,11 +645,12 @@ class SparkConnectPlanner(session: SparkSession) {
}
private def transformReadRel(rel: proto.Read): LogicalPlan = {
- val baseRelation = rel.getReadTypeCase match {
+ rel.getReadTypeCase match {
case proto.Read.ReadTypeCase.NAMED_TABLE =>
val multipartIdentifier =
CatalystSqlParser.parseMultipartIdentifier(rel.getNamedTable.getUnparsedIdentifier)
UnresolvedRelation(multipartIdentifier)
+
case proto.Read.ReadTypeCase.DATA_SOURCE =>
if (rel.getDataSource.getFormat == "") {
throw InvalidPlanInput("DataSource requires a format")
@@ -668,10 +669,16 @@ class SparkConnectPlanner(session: SparkSession) {
case other => throw InvalidPlanInput(s"Invalid schema $other")
}
}
- reader.load().queryExecution.analyzed
+ if (rel.getDataSource.getPathsCount == 0) {
+ reader.load().queryExecution.analyzed
+ } else if (rel.getDataSource.getPathsCount == 1) {
+ reader.load(rel.getDataSource.getPaths(0)).queryExecution.analyzed
+ } else {
+ reader.load(rel.getDataSource.getPathsList.asScala.toSeq:
_*).queryExecution.analyzed
+ }
+
case _ => throw InvalidPlanInput("Does not support " +
rel.getReadTypeCase.name())
}
- baseRelation
}
private def transformFilter(rel: proto.Filter): LogicalPlan = {
diff --git a/python/pyspark/sql/connect/plan.py
b/python/pyspark/sql/connect/plan.py
index 90900ffe27c..4a2e76a8954 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -229,26 +229,41 @@ class DataSource(LogicalPlan):
def __init__(
self,
- format: str = "",
+ format: str,
schema: Optional[str] = None,
options: Optional[Mapping[str, str]] = None,
+ paths: Optional[List[str]] = None,
) -> None:
super().__init__(None)
- self.format = format
- self.schema = schema
- self.options = options
+
+ assert isinstance(format, str) and format != ""
+
+ assert schema is None or isinstance(schema, str)
+
+ if options is not None:
+ for k, v in options.items():
+ assert isinstance(k, str)
+ assert isinstance(v, str)
+
+ if paths is not None:
+ assert isinstance(paths, list)
+ assert all(isinstance(path, str) for path in paths)
+
+ self._format = format
+ self._schema = schema
+ self._options = options
+ self._paths = paths
def plan(self, session: "SparkConnectClient") -> proto.Relation:
plan = proto.Relation()
- if self.format is not None:
- plan.read.data_source.format = self.format
- if self.schema is not None:
- plan.read.data_source.schema = self.schema
- if self.options is not None:
- for k in self.options.keys():
- v = self.options.get(k)
- if v is not None:
- plan.read.data_source.options[k] = v
+ plan.read.data_source.format = self._format
+ if self._schema is not None:
+ plan.read.data_source.schema = self._schema
+ if self._options is not None and len(self._options) > 0:
+ for k, v in self._options.items():
+ plan.read.data_source.options[k] = v
+ if self._paths is not None and len(self._paths) > 0:
+ plan.read.data_source.paths.extend(self._paths)
return plan
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.py
b/python/pyspark/sql/connect/proto/relations_pb2.py
index cb2b0347293..6180adc894d 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.py
+++ b/python/pyspark/sql/connect/proto/relations_pb2.py
@@ -36,7 +36,7 @@ from pyspark.sql.connect.proto import catalog_pb2 as
spark_dot_connect_dot_catal
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xed\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
+
b'\n\x1dspark/connect/relations.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1fspark/connect/expressions.proto\x1a\x19spark/connect/types.proto\x1a\x1bspark/connect/catalog.proto"\xed\x12\n\x08Relation\x12\x35\n\x06\x63ommon\x18\x01
\x01(\x0b\x32\x1d.spark.connect.RelationCommonR\x06\x63ommon\x12)\n\x04read\x18\x02
\x01(\x0b\x32\x13.spark.connect.ReadH\x00R\x04read\x12\x32\n\x07project\x18\x03
\x01(\x0b\x32\x16.spark.connect.ProjectH\x00R\x07project\x12/\n\x06\x66il [...]
)
@@ -619,95 +619,95 @@ if _descriptor._USE_C_DESCRIPTORS == False:
_SQL._serialized_start = 2642
_SQL._serialized_end = 2669
_READ._serialized_start = 2672
- _READ._serialized_end = 3098
+ _READ._serialized_end = 3120
_READ_NAMEDTABLE._serialized_start = 2814
_READ_NAMEDTABLE._serialized_end = 2875
_READ_DATASOURCE._serialized_start = 2878
- _READ_DATASOURCE._serialized_end = 3085
- _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3016
- _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3074
- _PROJECT._serialized_start = 3100
- _PROJECT._serialized_end = 3217
- _FILTER._serialized_start = 3219
- _FILTER._serialized_end = 3331
- _JOIN._serialized_start = 3334
- _JOIN._serialized_end = 3805
- _JOIN_JOINTYPE._serialized_start = 3597
- _JOIN_JOINTYPE._serialized_end = 3805
- _SETOPERATION._serialized_start = 3808
- _SETOPERATION._serialized_end = 4204
- _SETOPERATION_SETOPTYPE._serialized_start = 4067
- _SETOPERATION_SETOPTYPE._serialized_end = 4181
- _LIMIT._serialized_start = 4206
- _LIMIT._serialized_end = 4282
- _OFFSET._serialized_start = 4284
- _OFFSET._serialized_end = 4363
- _TAIL._serialized_start = 4365
- _TAIL._serialized_end = 4440
- _AGGREGATE._serialized_start = 4443
- _AGGREGATE._serialized_end = 5025
- _AGGREGATE_PIVOT._serialized_start = 4782
- _AGGREGATE_PIVOT._serialized_end = 4893
- _AGGREGATE_GROUPTYPE._serialized_start = 4896
- _AGGREGATE_GROUPTYPE._serialized_end = 5025
- _SORT._serialized_start = 5028
- _SORT._serialized_end = 5188
- _DROP._serialized_start = 5190
- _DROP._serialized_end = 5290
- _DEDUPLICATE._serialized_start = 5293
- _DEDUPLICATE._serialized_end = 5464
- _LOCALRELATION._serialized_start = 5466
- _LOCALRELATION._serialized_end = 5555
- _SAMPLE._serialized_start = 5558
- _SAMPLE._serialized_end = 5831
- _RANGE._serialized_start = 5834
- _RANGE._serialized_end = 5979
- _SUBQUERYALIAS._serialized_start = 5981
- _SUBQUERYALIAS._serialized_end = 6095
- _REPARTITION._serialized_start = 6098
- _REPARTITION._serialized_end = 6240
- _SHOWSTRING._serialized_start = 6243
- _SHOWSTRING._serialized_end = 6385
- _STATSUMMARY._serialized_start = 6387
- _STATSUMMARY._serialized_end = 6479
- _STATDESCRIBE._serialized_start = 6481
- _STATDESCRIBE._serialized_end = 6562
- _STATCROSSTAB._serialized_start = 6564
- _STATCROSSTAB._serialized_end = 6665
- _STATCOV._serialized_start = 6667
- _STATCOV._serialized_end = 6763
- _STATCORR._serialized_start = 6766
- _STATCORR._serialized_end = 6903
- _STATAPPROXQUANTILE._serialized_start = 6906
- _STATAPPROXQUANTILE._serialized_end = 7070
- _STATFREQITEMS._serialized_start = 7072
- _STATFREQITEMS._serialized_end = 7197
- _STATSAMPLEBY._serialized_start = 7200
- _STATSAMPLEBY._serialized_end = 7509
- _STATSAMPLEBY_FRACTION._serialized_start = 7401
- _STATSAMPLEBY_FRACTION._serialized_end = 7500
- _NAFILL._serialized_start = 7512
- _NAFILL._serialized_end = 7646
- _NADROP._serialized_start = 7649
- _NADROP._serialized_end = 7783
- _NAREPLACE._serialized_start = 7786
- _NAREPLACE._serialized_end = 8082
- _NAREPLACE_REPLACEMENT._serialized_start = 7941
- _NAREPLACE_REPLACEMENT._serialized_end = 8082
- _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 8084
- _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 8198
- _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 8201
- _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 8460
- _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start =
8393
- _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 8460
- _WITHCOLUMNS._serialized_start = 8462
- _WITHCOLUMNS._serialized_end = 8581
- _HINT._serialized_start = 8584
- _HINT._serialized_end = 8716
- _UNPIVOT._serialized_start = 8719
- _UNPIVOT._serialized_end = 8965
- _TOSCHEMA._serialized_start = 8967
- _TOSCHEMA._serialized_end = 9073
- _REPARTITIONBYEXPRESSION._serialized_start = 9076
- _REPARTITIONBYEXPRESSION._serialized_end = 9279
+ _READ_DATASOURCE._serialized_end = 3107
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_start = 3038
+ _READ_DATASOURCE_OPTIONSENTRY._serialized_end = 3096
+ _PROJECT._serialized_start = 3122
+ _PROJECT._serialized_end = 3239
+ _FILTER._serialized_start = 3241
+ _FILTER._serialized_end = 3353
+ _JOIN._serialized_start = 3356
+ _JOIN._serialized_end = 3827
+ _JOIN_JOINTYPE._serialized_start = 3619
+ _JOIN_JOINTYPE._serialized_end = 3827
+ _SETOPERATION._serialized_start = 3830
+ _SETOPERATION._serialized_end = 4226
+ _SETOPERATION_SETOPTYPE._serialized_start = 4089
+ _SETOPERATION_SETOPTYPE._serialized_end = 4203
+ _LIMIT._serialized_start = 4228
+ _LIMIT._serialized_end = 4304
+ _OFFSET._serialized_start = 4306
+ _OFFSET._serialized_end = 4385
+ _TAIL._serialized_start = 4387
+ _TAIL._serialized_end = 4462
+ _AGGREGATE._serialized_start = 4465
+ _AGGREGATE._serialized_end = 5047
+ _AGGREGATE_PIVOT._serialized_start = 4804
+ _AGGREGATE_PIVOT._serialized_end = 4915
+ _AGGREGATE_GROUPTYPE._serialized_start = 4918
+ _AGGREGATE_GROUPTYPE._serialized_end = 5047
+ _SORT._serialized_start = 5050
+ _SORT._serialized_end = 5210
+ _DROP._serialized_start = 5212
+ _DROP._serialized_end = 5312
+ _DEDUPLICATE._serialized_start = 5315
+ _DEDUPLICATE._serialized_end = 5486
+ _LOCALRELATION._serialized_start = 5488
+ _LOCALRELATION._serialized_end = 5577
+ _SAMPLE._serialized_start = 5580
+ _SAMPLE._serialized_end = 5853
+ _RANGE._serialized_start = 5856
+ _RANGE._serialized_end = 6001
+ _SUBQUERYALIAS._serialized_start = 6003
+ _SUBQUERYALIAS._serialized_end = 6117
+ _REPARTITION._serialized_start = 6120
+ _REPARTITION._serialized_end = 6262
+ _SHOWSTRING._serialized_start = 6265
+ _SHOWSTRING._serialized_end = 6407
+ _STATSUMMARY._serialized_start = 6409
+ _STATSUMMARY._serialized_end = 6501
+ _STATDESCRIBE._serialized_start = 6503
+ _STATDESCRIBE._serialized_end = 6584
+ _STATCROSSTAB._serialized_start = 6586
+ _STATCROSSTAB._serialized_end = 6687
+ _STATCOV._serialized_start = 6689
+ _STATCOV._serialized_end = 6785
+ _STATCORR._serialized_start = 6788
+ _STATCORR._serialized_end = 6925
+ _STATAPPROXQUANTILE._serialized_start = 6928
+ _STATAPPROXQUANTILE._serialized_end = 7092
+ _STATFREQITEMS._serialized_start = 7094
+ _STATFREQITEMS._serialized_end = 7219
+ _STATSAMPLEBY._serialized_start = 7222
+ _STATSAMPLEBY._serialized_end = 7531
+ _STATSAMPLEBY_FRACTION._serialized_start = 7423
+ _STATSAMPLEBY_FRACTION._serialized_end = 7522
+ _NAFILL._serialized_start = 7534
+ _NAFILL._serialized_end = 7668
+ _NADROP._serialized_start = 7671
+ _NADROP._serialized_end = 7805
+ _NAREPLACE._serialized_start = 7808
+ _NAREPLACE._serialized_end = 8104
+ _NAREPLACE_REPLACEMENT._serialized_start = 7963
+ _NAREPLACE_REPLACEMENT._serialized_end = 8104
+ _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_start = 8106
+ _RENAMECOLUMNSBYSAMELENGTHNAMES._serialized_end = 8220
+ _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_start = 8223
+ _RENAMECOLUMNSBYNAMETONAMEMAP._serialized_end = 8482
+ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_start =
8415
+ _RENAMECOLUMNSBYNAMETONAMEMAP_RENAMECOLUMNSMAPENTRY._serialized_end = 8482
+ _WITHCOLUMNS._serialized_start = 8484
+ _WITHCOLUMNS._serialized_end = 8603
+ _HINT._serialized_start = 8606
+ _HINT._serialized_end = 8738
+ _UNPIVOT._serialized_start = 8741
+ _UNPIVOT._serialized_end = 8987
+ _TOSCHEMA._serialized_start = 8989
+ _TOSCHEMA._serialized_end = 9095
+ _REPARTITIONBYEXPRESSION._serialized_start = 9098
+ _REPARTITIONBYEXPRESSION._serialized_end = 9301
# @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/relations_pb2.pyi
b/python/pyspark/sql/connect/proto/relations_pb2.pyi
index 04512a4c891..b80a8a5b1bb 100644
--- a/python/pyspark/sql/connect/proto/relations_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/relations_pb2.pyi
@@ -554,6 +554,7 @@ class Read(google.protobuf.message.Message):
FORMAT_FIELD_NUMBER: builtins.int
SCHEMA_FIELD_NUMBER: builtins.int
OPTIONS_FIELD_NUMBER: builtins.int
+ PATHS_FIELD_NUMBER: builtins.int
format: builtins.str
"""(Required) Supported formats include: parquet, orc, text, json,
parquet, csv, avro."""
schema: builtins.str
@@ -569,12 +570,18 @@ class Read(google.protobuf.message.Message):
data source format. This options could be empty for valid data
source format.
The map key is case insensitive.
"""
+ @property
+ def paths(
+ self,
+ ) ->
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+ """(Optional) A list of path for file-system backed data
sources."""
def __init__(
self,
*,
format: builtins.str = ...,
schema: builtins.str | None = ...,
options: collections.abc.Mapping[builtins.str, builtins.str] |
None = ...,
+ paths: collections.abc.Iterable[builtins.str] | None = ...,
) -> None: ...
def HasField(
self, field_name: typing_extensions.Literal["_schema", b"_schema",
"schema", b"schema"]
@@ -588,6 +595,8 @@ class Read(google.protobuf.message.Message):
b"format",
"options",
b"options",
+ "paths",
+ b"paths",
"schema",
b"schema",
],
diff --git a/python/pyspark/sql/connect/readwriter.py
b/python/pyspark/sql/connect/readwriter.py
index 62a082cc90a..f172b4ecd39 100644
--- a/python/pyspark/sql/connect/readwriter.py
+++ b/python/pyspark/sql/connect/readwriter.py
@@ -33,6 +33,7 @@ if TYPE_CHECKING:
from pyspark.sql.connect._typing import OptionalPrimitiveType
from pyspark.sql.connect.session import SparkSession
+__all__ = ["DataFrameReader", "DataFrameWriter"]
PathOrPaths = Union[str, List[str]]
TupleOrListOfString = Union[List[str], Tuple[str, ...]]
@@ -95,7 +96,7 @@ class DataFrameReader(OptionUtils):
def load(
self,
- path: Optional[str] = None,
+ path: Optional[PathOrPaths] = None,
format: Optional[str] = None,
schema: Optional[Union[StructType, str]] = None,
**options: "OptionalPrimitiveType",
@@ -105,10 +106,17 @@ class DataFrameReader(OptionUtils):
if schema is not None:
self.schema(schema)
self.options(**options)
- if path is not None:
- self.option("path", path)
- plan = DataSource(format=self._format, schema=self._schema,
options=self._options)
+ paths = path
+ if isinstance(path, str):
+ paths = [path]
+
+ plan = DataSource(
+ format=self._format,
+ schema=self._schema,
+ options=self._options,
+ paths=paths, # type: ignore[arg-type]
+ )
return self._df(plan)
load.__doc__ = PySparkDataFrameReader.load.__doc__
@@ -125,7 +133,7 @@ class DataFrameReader(OptionUtils):
def json(
self,
- path: str,
+ path: PathOrPaths,
schema: Optional[Union[StructType, str]] = None,
primitivesAsString: Optional[Union[bool, str]] = None,
prefersDecimal: Optional[Union[bool, str]] = None,
@@ -176,11 +184,13 @@ class DataFrameReader(OptionUtils):
modifiedAfter=modifiedAfter,
allowNonNumericNumbers=allowNonNumericNumbers,
)
+ if isinstance(path, str):
+ path = [path]
return self.load(path=path, format="json", schema=schema)
json.__doc__ = PySparkDataFrameReader.json.__doc__
- def parquet(self, path: str, **options: "OptionalPrimitiveType") ->
"DataFrame":
+ def parquet(self, *paths: str, **options: "OptionalPrimitiveType") ->
"DataFrame":
mergeSchema = options.get("mergeSchema", None)
pathGlobFilter = options.get("pathGlobFilter", None)
modifiedBefore = options.get("modifiedBefore", None)
@@ -198,13 +208,13 @@ class DataFrameReader(OptionUtils):
int96RebaseMode=int96RebaseMode,
)
- return self.load(path=path, format="parquet")
+ return self.load(path=list(paths), format="parquet")
parquet.__doc__ = PySparkDataFrameReader.parquet.__doc__
def text(
self,
- path: str,
+ paths: PathOrPaths,
wholetext: Optional[bool] = None,
lineSep: Optional[str] = None,
pathGlobFilter: Optional[Union[bool, str]] = None,
@@ -221,7 +231,9 @@ class DataFrameReader(OptionUtils):
modifiedAfter=modifiedAfter,
)
- return self.load(path=path, format="text")
+ if isinstance(paths, str):
+ paths = [paths]
+ return self.load(path=paths, format="text")
text.__doc__ = PySparkDataFrameReader.text.__doc__
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index 3aca5830481..2ea986f3540 100644
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -256,6 +256,33 @@ class SparkConnectBasicTests(SparkConnectSQLTestCase):
# Read the text file as a DataFrame.
self.assert_eq(self.connect.read.text(d).toPandas(),
self.spark.read.text(d).toPandas())
+ def test_multi_paths(self):
+ # SPARK-42041: DataFrameReader should support list of paths
+
+ with tempfile.TemporaryDirectory() as d:
+ text_files = []
+ for i in range(0, 3):
+ text_file = f"{d}/text-{i}.text"
+ shutil.copyfile("python/test_support/sql/text-test.txt",
text_file)
+ text_files.append(text_file)
+
+ self.assertEqual(
+ self.connect.read.text(text_files).collect(),
+ self.spark.read.text(text_files).collect(),
+ )
+
+ with tempfile.TemporaryDirectory() as d:
+ json_files = []
+ for i in range(0, 5):
+ json_file = f"{d}/json-{i}.json"
+ shutil.copyfile("python/test_support/sql/people.json",
json_file)
+ json_files.append(json_file)
+
+ self.assertEqual(
+ self.connect.read.json(json_files).collect(),
+ self.spark.read.json(json_files).collect(),
+ )
+
def test_join_condition_column_list_columns(self):
left_connect_df = self.connect.read.table(self.tbl_name)
right_connect_df = self.connect.read.table(self.tbl_name2)
diff --git a/python/pyspark/sql/tests/connect/test_connect_plan.py
b/python/pyspark/sql/tests/connect/test_connect_plan.py
index 3dbe088aa76..e698566c2bd 100644
--- a/python/pyspark/sql/tests/connect/test_connect_plan.py
+++ b/python/pyspark/sql/tests/connect/test_connect_plan.py
@@ -507,10 +507,11 @@ class SparkConnectPlanTests(PlanOnlyTestFixture):
data_source = plan.root.read.data_source
self.assertEqual(data_source.format, "text")
self.assertEqual(data_source.schema, "id INT")
- self.assertEqual(len(data_source.options), 3)
- self.assertEqual(data_source.options.get("path"), "test_path")
+ self.assertEqual(len(data_source.options), 2)
self.assertEqual(data_source.options.get("op1"), "opv")
self.assertEqual(data_source.options.get("op2"), "opv2")
+ self.assertEqual(len(data_source.paths), 1)
+ self.assertEqual(data_source.paths[0], "test_path")
def test_simple_udf(self):
u = udf(lambda x: "Martin", StringType())
diff --git a/python/pyspark/sql/tests/connect/test_parity_datasources.py
b/python/pyspark/sql/tests/connect/test_parity_datasources.py
index 83a9c4414e9..db1bba8de10 100644
--- a/python/pyspark/sql/tests/connect/test_parity_datasources.py
+++ b/python/pyspark/sql/tests/connect/test_parity_datasources.py
@@ -46,11 +46,6 @@ class DataSourcesParityTests(DataSourcesTestsMixin,
ReusedConnectTestCase):
def test_read_multiple_orc_file(self):
super().test_read_multiple_orc_file()
- # TODO(SPARK-42013): Implement DataFrameReader.text to take multiple paths
- @unittest.skip("Fails in Spark Connect, should enable.")
- def test_read_text_file_list(self):
- super().test_read_text_file_list()
-
if __name__ == "__main__":
import unittest
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]