This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 25ee62e19bb [SPARK-45927][PYTHON] Update path handling for Python data
source
25ee62e19bb is described below
commit 25ee62e19bb00610e79de3284987970d9732195c
Author: allisonwang-db <[email protected]>
AuthorDate: Mon Nov 20 15:06:17 2023 -0800
[SPARK-45927][PYTHON] Update path handling for Python data source
### What changes were proposed in this pull request?
This PR updates how to handle `path` values from the `load()` method.
It changes the DataSource class constructor and add `path` as a key-value
pair in the options field.
Also, this PR blocks loading multiple paths.
### Why are the changes needed?
To make the behavior consistent with the existing data source APIs.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Existing unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43809 from allisonwang-db/spark-45927-path.
Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
python/pyspark/sql/datasource.py | 17 ++--------
python/pyspark/sql/tests/test_python_datasource.py | 36 ++++++++--------------
python/pyspark/sql/worker/create_data_source.py | 15 ++-------
.../org/apache/spark/sql/DataFrameReader.scala | 6 ++--
.../execution/datasources/DataSourceManager.scala | 1 -
.../datasources/v2/DataSourceV2Utils.scala | 2 +-
.../python/UserDefinedPythonDataSource.scala | 11 ++-----
.../execution/python/PythonDataSourceSuite.scala | 16 +++++++---
8 files changed, 34 insertions(+), 70 deletions(-)
diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py
index c30a2c8689d..b380e8b534e 100644
--- a/python/pyspark/sql/datasource.py
+++ b/python/pyspark/sql/datasource.py
@@ -15,7 +15,7 @@
# limitations under the License.
#
from abc import ABC, abstractmethod
-from typing import final, Any, Dict, Iterator, List, Optional, Tuple, Type,
Union, TYPE_CHECKING
+from typing import final, Any, Dict, Iterator, List, Tuple, Type, Union,
TYPE_CHECKING
from pyspark import since
from pyspark.sql import Row
@@ -45,21 +45,12 @@ class DataSource(ABC):
"""
@final
- def __init__(
- self,
- paths: List[str],
- userSpecifiedSchema: Optional[StructType],
- options: Dict[str, "OptionalPrimitiveType"],
- ) -> None:
+ def __init__(self, options: Dict[str, "OptionalPrimitiveType"]) -> None:
"""
- Initializes the data source with user-provided information.
+ Initializes the data source with user-provided options.
Parameters
----------
- paths : list
- A list of paths to the data source.
- userSpecifiedSchema : StructType, optional
- The user-specified schema of the data source.
options : dict
A dictionary representing the options for this data source.
@@ -67,8 +58,6 @@ class DataSource(ABC):
-----
This method should not be overridden.
"""
- self.paths = paths
- self.userSpecifiedSchema = userSpecifiedSchema
self.options = options
@classmethod
diff --git a/python/pyspark/sql/tests/test_python_datasource.py
b/python/pyspark/sql/tests/test_python_datasource.py
index fe6a8417527..46b9fa642fd 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -30,7 +30,7 @@ class BasePythonDataSourceTestsMixin:
...
options = dict(a=1, b=2)
- ds = MyDataSource(paths=[], userSpecifiedSchema=None, options=options)
+ ds = MyDataSource(options=options)
self.assertEqual(ds.options, options)
self.assertEqual(ds.name(), "MyDataSource")
with self.assertRaises(NotImplementedError):
@@ -53,8 +53,7 @@ class BasePythonDataSourceTestsMixin:
class InMemDataSourceReader(DataSourceReader):
DEFAULT_NUM_PARTITIONS: int = 3
- def __init__(self, paths, options):
- self.paths = paths
+ def __init__(self, options):
self.options = options
def partitions(self):
@@ -76,7 +75,7 @@ class BasePythonDataSourceTestsMixin:
return "x INT, y STRING"
def reader(self, schema) -> "DataSourceReader":
- return InMemDataSourceReader(self.paths, self.options)
+ return InMemDataSourceReader(self.options)
self.spark.dataSource.register(InMemoryDataSource)
df = self.spark.read.format("memory").load()
@@ -91,14 +90,13 @@ class BasePythonDataSourceTestsMixin:
import json
class JsonDataSourceReader(DataSourceReader):
- def __init__(self, paths, options):
- self.paths = paths
+ def __init__(self, options):
self.options = options
- def partitions(self):
- return iter(self.paths)
-
- def read(self, path):
+ def read(self, partition):
+ path = self.options.get("path")
+ if path is None:
+ raise Exception("path is not specified")
with open(path, "r") as file:
for line in file.readlines():
if line.strip():
@@ -114,28 +112,18 @@ class BasePythonDataSourceTestsMixin:
return "name STRING, age INT"
def reader(self, schema) -> "DataSourceReader":
- return JsonDataSourceReader(self.paths, self.options)
+ return JsonDataSourceReader(self.options)
self.spark.dataSource.register(JsonDataSource)
path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json")
path2 = os.path.join(SPARK_HOME,
"python/test_support/sql/people1.json")
- df1 = self.spark.read.format("my-json").load(path1)
- self.assertEqual(df1.rdd.getNumPartitions(), 1)
assertDataFrameEqual(
- df1,
+ self.spark.read.format("my-json").load(path1),
[Row(name="Michael", age=None), Row(name="Andy", age=30),
Row(name="Justin", age=19)],
)
-
- df2 = self.spark.read.format("my-json").load([path1, path2])
- self.assertEqual(df2.rdd.getNumPartitions(), 2)
assertDataFrameEqual(
- df2,
- [
- Row(name="Michael", age=None),
- Row(name="Andy", age=30),
- Row(name="Justin", age=19),
- Row(name="Jonathan", age=None),
- ],
+ self.spark.read.format("my-json").load(path2),
+ [Row(name="Jonathan", age=None)],
)
diff --git a/python/pyspark/sql/worker/create_data_source.py
b/python/pyspark/sql/worker/create_data_source.py
index 6a9ef79b7c1..1ba4dc9e8a3 100644
--- a/python/pyspark/sql/worker/create_data_source.py
+++ b/python/pyspark/sql/worker/create_data_source.py
@@ -17,7 +17,7 @@
import inspect
import os
import sys
-from typing import IO, List
+from typing import IO
from pyspark.accumulators import _accumulatorRegistry
from pyspark.errors import PySparkAssertionError, PySparkRuntimeError,
PySparkTypeError
@@ -55,7 +55,6 @@ def main(infile: IO, outfile: IO) -> None:
The JVM sends the following information to this process:
- a `DataSource` class representing the data source to be created.
- a provider name in string.
- - a list of paths in string.
- an optional user-specified schema in json string.
- a dictionary of options in string.
@@ -107,12 +106,6 @@ def main(infile: IO, outfile: IO) -> None:
},
)
- # Receive the paths.
- num_paths = read_int(infile)
- paths: List[str] = []
- for _ in range(num_paths):
- paths.append(utf8_deserializer.loads(infile))
-
# Receive the user-specified schema
user_specified_schema = None
if read_bool(infile):
@@ -136,11 +129,7 @@ def main(infile: IO, outfile: IO) -> None:
# Instantiate a data source.
try:
- data_source = data_source_cls(
- paths=paths,
- userSpecifiedSchema=user_specified_schema, # type: ignore
- options=options,
- )
+ data_source = data_source_cls(options=options)
except Exception as e:
raise PySparkRuntimeError(
error_class="PYTHON_DATA_SOURCE_CREATE_ERROR",
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index ef447e8a801..7fadbbfac68 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -244,9 +244,9 @@ class DataFrameReader private[sql](sparkSession:
SparkSession) extends Logging {
private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = {
val builder =
sparkSession.sharedState.dataSourceManager.lookupDataSource(source)
- // Unless the legacy path option behavior is enabled, the extraOptions here
- // should not include "path" or "paths" as keys.
- val plan = builder(sparkSession, source, paths, userSpecifiedSchema,
extraOptions)
+ // Add `path` and `paths` options to the extra options if specified.
+ val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions,
paths: _*)
+ val plan = builder(sparkSession, source, userSpecifiedSchema,
optionsWithPath)
Dataset.ofRows(sparkSession, plan)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
index 72a9e6497ac..a8c9c892b8b 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
@@ -35,7 +35,6 @@ class DataSourceManager {
private type DataSourceBuilder = (
SparkSession, // Spark session
String, // provider name
- Seq[String], // paths
Option[StructType], // user specified schema
CaseInsensitiveMap[String] // options
) => LogicalPlan
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala
index c4e7bf23cac..3dde20ac44e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala
@@ -152,7 +152,7 @@ private[sql] object DataSourceV2Utils extends Logging {
}
private lazy val objectMapper = new ObjectMapper()
- private def getOptionsWithPaths(
+ def getOptionsWithPaths(
extraOptions: CaseInsensitiveMap[String],
paths: String*): CaseInsensitiveMap[String] = {
if (paths.isEmpty) {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
index 0e7eb056f43..7044ef65c63 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
@@ -42,12 +42,11 @@ case class UserDefinedPythonDataSource(dataSourceCls:
PythonFunction) {
def builder(
sparkSession: SparkSession,
provider: String,
- paths: Seq[String],
userSpecifiedSchema: Option[StructType],
options: CaseInsensitiveMap[String]): LogicalPlan = {
val runner = new UserDefinedPythonDataSourceRunner(
- dataSourceCls, provider, paths, userSpecifiedSchema, options)
+ dataSourceCls, provider, userSpecifiedSchema, options)
val result = runner.runInPython()
val pickledDataSourceInstance = result.dataSource
@@ -68,10 +67,9 @@ case class UserDefinedPythonDataSource(dataSourceCls:
PythonFunction) {
def apply(
sparkSession: SparkSession,
provider: String,
- paths: Seq[String] = Seq.empty,
userSpecifiedSchema: Option[StructType] = None,
options: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map.empty)):
DataFrame = {
- val plan = builder(sparkSession, provider, paths, userSpecifiedSchema,
options)
+ val plan = builder(sparkSession, provider, userSpecifiedSchema, options)
Dataset.ofRows(sparkSession, plan)
}
}
@@ -89,7 +87,6 @@ case class PythonDataSourceCreationResult(
class UserDefinedPythonDataSourceRunner(
dataSourceCls: PythonFunction,
provider: String,
- paths: Seq[String],
userSpecifiedSchema: Option[StructType],
options: CaseInsensitiveMap[String])
extends PythonPlannerRunner[PythonDataSourceCreationResult](dataSourceCls) {
@@ -103,10 +100,6 @@ class UserDefinedPythonDataSourceRunner(
// Send the provider name
PythonWorkerUtils.writeUTF(provider, dataOut)
- // Send the paths
- dataOut.writeInt(paths.length)
- paths.foreach(PythonWorkerUtils.writeUTF(_, dataOut))
-
// Send the user-specified schema, if provided
dataOut.writeBoolean(userSpecifiedSchema.isDefined)
userSpecifiedSchema.map(_.json).foreach(PythonWorkerUtils.writeUTF(_,
dataOut))
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index 22a1e5250cd..bd0b08cbec8 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -160,13 +160,20 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
val dataSourceScript =
s"""
|from pyspark.sql.datasource import DataSource, DataSourceReader
+ |import json
+ |
|class SimpleDataSourceReader(DataSourceReader):
- | def __init__(self, paths, options):
- | self.paths = paths
+ | def __init__(self, options):
| self.options = options
|
| def partitions(self):
- | return iter(self.paths)
+ | if "paths" in self.options:
+ | paths = json.loads(self.options["paths"])
+ | elif "path" in self.options:
+ | paths = [self.options["path"]]
+ | else:
+ | paths = []
+ | return paths
|
| def read(self, path):
| yield (path, 1)
@@ -180,11 +187,10 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
| return "id STRING, value INT"
|
| def reader(self, schema):
- | return SimpleDataSourceReader(self.paths, self.options)
+ | return SimpleDataSourceReader(self.options)
|""".stripMargin
val dataSource = createUserDefinedPythonDataSource(dataSourceName,
dataSourceScript)
spark.dataSource.registerPython("test", dataSource)
-
checkAnswer(spark.read.format("test").load(), Seq(Row(null, 1)))
checkAnswer(spark.read.format("test").load("1"), Seq(Row("1", 1)))
checkAnswer(spark.read.format("test").load("1", "2"), Seq(Row("1", 1),
Row("2", 1)))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]