This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 25ee62e19bb [SPARK-45927][PYTHON] Update path handling for Python data 
source
25ee62e19bb is described below

commit 25ee62e19bb00610e79de3284987970d9732195c
Author: allisonwang-db <[email protected]>
AuthorDate: Mon Nov 20 15:06:17 2023 -0800

    [SPARK-45927][PYTHON] Update path handling for Python data source
    
    ### What changes were proposed in this pull request?
    
    This PR updates how to handle `path` values from the `load()` method.
    It changes the DataSource class constructor and add `path` as a key-value 
pair in the options field.
    
    Also, this PR blocks loading multiple paths.
    
    ### Why are the changes needed?
    
    To make the behavior consistent with the existing data source APIs.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Existing unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #43809 from allisonwang-db/spark-45927-path.
    
    Authored-by: allisonwang-db <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 python/pyspark/sql/datasource.py                   | 17 ++--------
 python/pyspark/sql/tests/test_python_datasource.py | 36 ++++++++--------------
 python/pyspark/sql/worker/create_data_source.py    | 15 ++-------
 .../org/apache/spark/sql/DataFrameReader.scala     |  6 ++--
 .../execution/datasources/DataSourceManager.scala  |  1 -
 .../datasources/v2/DataSourceV2Utils.scala         |  2 +-
 .../python/UserDefinedPythonDataSource.scala       | 11 ++-----
 .../execution/python/PythonDataSourceSuite.scala   | 16 +++++++---
 8 files changed, 34 insertions(+), 70 deletions(-)

diff --git a/python/pyspark/sql/datasource.py b/python/pyspark/sql/datasource.py
index c30a2c8689d..b380e8b534e 100644
--- a/python/pyspark/sql/datasource.py
+++ b/python/pyspark/sql/datasource.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 #
 from abc import ABC, abstractmethod
-from typing import final, Any, Dict, Iterator, List, Optional, Tuple, Type, 
Union, TYPE_CHECKING
+from typing import final, Any, Dict, Iterator, List, Tuple, Type, Union, 
TYPE_CHECKING
 
 from pyspark import since
 from pyspark.sql import Row
@@ -45,21 +45,12 @@ class DataSource(ABC):
     """
 
     @final
-    def __init__(
-        self,
-        paths: List[str],
-        userSpecifiedSchema: Optional[StructType],
-        options: Dict[str, "OptionalPrimitiveType"],
-    ) -> None:
+    def __init__(self, options: Dict[str, "OptionalPrimitiveType"]) -> None:
         """
-        Initializes the data source with user-provided information.
+        Initializes the data source with user-provided options.
 
         Parameters
         ----------
-        paths : list
-            A list of paths to the data source.
-        userSpecifiedSchema : StructType, optional
-            The user-specified schema of the data source.
         options : dict
             A dictionary representing the options for this data source.
 
@@ -67,8 +58,6 @@ class DataSource(ABC):
         -----
         This method should not be overridden.
         """
-        self.paths = paths
-        self.userSpecifiedSchema = userSpecifiedSchema
         self.options = options
 
     @classmethod
diff --git a/python/pyspark/sql/tests/test_python_datasource.py 
b/python/pyspark/sql/tests/test_python_datasource.py
index fe6a8417527..46b9fa642fd 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -30,7 +30,7 @@ class BasePythonDataSourceTestsMixin:
             ...
 
         options = dict(a=1, b=2)
-        ds = MyDataSource(paths=[], userSpecifiedSchema=None, options=options)
+        ds = MyDataSource(options=options)
         self.assertEqual(ds.options, options)
         self.assertEqual(ds.name(), "MyDataSource")
         with self.assertRaises(NotImplementedError):
@@ -53,8 +53,7 @@ class BasePythonDataSourceTestsMixin:
         class InMemDataSourceReader(DataSourceReader):
             DEFAULT_NUM_PARTITIONS: int = 3
 
-            def __init__(self, paths, options):
-                self.paths = paths
+            def __init__(self, options):
                 self.options = options
 
             def partitions(self):
@@ -76,7 +75,7 @@ class BasePythonDataSourceTestsMixin:
                 return "x INT, y STRING"
 
             def reader(self, schema) -> "DataSourceReader":
-                return InMemDataSourceReader(self.paths, self.options)
+                return InMemDataSourceReader(self.options)
 
         self.spark.dataSource.register(InMemoryDataSource)
         df = self.spark.read.format("memory").load()
@@ -91,14 +90,13 @@ class BasePythonDataSourceTestsMixin:
         import json
 
         class JsonDataSourceReader(DataSourceReader):
-            def __init__(self, paths, options):
-                self.paths = paths
+            def __init__(self, options):
                 self.options = options
 
-            def partitions(self):
-                return iter(self.paths)
-
-            def read(self, path):
+            def read(self, partition):
+                path = self.options.get("path")
+                if path is None:
+                    raise Exception("path is not specified")
                 with open(path, "r") as file:
                     for line in file.readlines():
                         if line.strip():
@@ -114,28 +112,18 @@ class BasePythonDataSourceTestsMixin:
                 return "name STRING, age INT"
 
             def reader(self, schema) -> "DataSourceReader":
-                return JsonDataSourceReader(self.paths, self.options)
+                return JsonDataSourceReader(self.options)
 
         self.spark.dataSource.register(JsonDataSource)
         path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json")
         path2 = os.path.join(SPARK_HOME, 
"python/test_support/sql/people1.json")
-        df1 = self.spark.read.format("my-json").load(path1)
-        self.assertEqual(df1.rdd.getNumPartitions(), 1)
         assertDataFrameEqual(
-            df1,
+            self.spark.read.format("my-json").load(path1),
             [Row(name="Michael", age=None), Row(name="Andy", age=30), 
Row(name="Justin", age=19)],
         )
-
-        df2 = self.spark.read.format("my-json").load([path1, path2])
-        self.assertEqual(df2.rdd.getNumPartitions(), 2)
         assertDataFrameEqual(
-            df2,
-            [
-                Row(name="Michael", age=None),
-                Row(name="Andy", age=30),
-                Row(name="Justin", age=19),
-                Row(name="Jonathan", age=None),
-            ],
+            self.spark.read.format("my-json").load(path2),
+            [Row(name="Jonathan", age=None)],
         )
 
 
diff --git a/python/pyspark/sql/worker/create_data_source.py 
b/python/pyspark/sql/worker/create_data_source.py
index 6a9ef79b7c1..1ba4dc9e8a3 100644
--- a/python/pyspark/sql/worker/create_data_source.py
+++ b/python/pyspark/sql/worker/create_data_source.py
@@ -17,7 +17,7 @@
 import inspect
 import os
 import sys
-from typing import IO, List
+from typing import IO
 
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.errors import PySparkAssertionError, PySparkRuntimeError, 
PySparkTypeError
@@ -55,7 +55,6 @@ def main(infile: IO, outfile: IO) -> None:
     The JVM sends the following information to this process:
     - a `DataSource` class representing the data source to be created.
     - a provider name in string.
-    - a list of paths in string.
     - an optional user-specified schema in json string.
     - a dictionary of options in string.
 
@@ -107,12 +106,6 @@ def main(infile: IO, outfile: IO) -> None:
                 },
             )
 
-        # Receive the paths.
-        num_paths = read_int(infile)
-        paths: List[str] = []
-        for _ in range(num_paths):
-            paths.append(utf8_deserializer.loads(infile))
-
         # Receive the user-specified schema
         user_specified_schema = None
         if read_bool(infile):
@@ -136,11 +129,7 @@ def main(infile: IO, outfile: IO) -> None:
 
         # Instantiate a data source.
         try:
-            data_source = data_source_cls(
-                paths=paths,
-                userSpecifiedSchema=user_specified_schema,  # type: ignore
-                options=options,
-            )
+            data_source = data_source_cls(options=options)
         except Exception as e:
             raise PySparkRuntimeError(
                 error_class="PYTHON_DATA_SOURCE_CREATE_ERROR",
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index ef447e8a801..7fadbbfac68 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -244,9 +244,9 @@ class DataFrameReader private[sql](sparkSession: 
SparkSession) extends Logging {
 
   private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = {
     val builder = 
sparkSession.sharedState.dataSourceManager.lookupDataSource(source)
-    // Unless the legacy path option behavior is enabled, the extraOptions here
-    // should not include "path" or "paths" as keys.
-    val plan = builder(sparkSession, source, paths, userSpecifiedSchema, 
extraOptions)
+    // Add `path` and `paths` options to the extra options if specified.
+    val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions, 
paths: _*)
+    val plan = builder(sparkSession, source, userSpecifiedSchema, 
optionsWithPath)
     Dataset.ofRows(sparkSession, plan)
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
index 72a9e6497ac..a8c9c892b8b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
@@ -35,7 +35,6 @@ class DataSourceManager {
   private type DataSourceBuilder = (
     SparkSession,  // Spark session
     String,  // provider name
-    Seq[String],  // paths
     Option[StructType],  // user specified schema
     CaseInsensitiveMap[String]  // options
   ) => LogicalPlan
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala
index c4e7bf23cac..3dde20ac44e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala
@@ -152,7 +152,7 @@ private[sql] object DataSourceV2Utils extends Logging {
   }
 
   private lazy val objectMapper = new ObjectMapper()
-  private def getOptionsWithPaths(
+  def getOptionsWithPaths(
       extraOptions: CaseInsensitiveMap[String],
       paths: String*): CaseInsensitiveMap[String] = {
     if (paths.isEmpty) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
index 0e7eb056f43..7044ef65c63 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
@@ -42,12 +42,11 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
   def builder(
       sparkSession: SparkSession,
       provider: String,
-      paths: Seq[String],
       userSpecifiedSchema: Option[StructType],
       options: CaseInsensitiveMap[String]): LogicalPlan = {
 
     val runner = new UserDefinedPythonDataSourceRunner(
-      dataSourceCls, provider, paths, userSpecifiedSchema, options)
+      dataSourceCls, provider, userSpecifiedSchema, options)
 
     val result = runner.runInPython()
     val pickledDataSourceInstance = result.dataSource
@@ -68,10 +67,9 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
   def apply(
       sparkSession: SparkSession,
       provider: String,
-      paths: Seq[String] = Seq.empty,
       userSpecifiedSchema: Option[StructType] = None,
       options: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map.empty)): 
DataFrame = {
-    val plan = builder(sparkSession, provider, paths, userSpecifiedSchema, 
options)
+    val plan = builder(sparkSession, provider, userSpecifiedSchema, options)
     Dataset.ofRows(sparkSession, plan)
   }
 }
@@ -89,7 +87,6 @@ case class PythonDataSourceCreationResult(
 class UserDefinedPythonDataSourceRunner(
     dataSourceCls: PythonFunction,
     provider: String,
-    paths: Seq[String],
     userSpecifiedSchema: Option[StructType],
     options: CaseInsensitiveMap[String])
   extends PythonPlannerRunner[PythonDataSourceCreationResult](dataSourceCls) {
@@ -103,10 +100,6 @@ class UserDefinedPythonDataSourceRunner(
     // Send the provider name
     PythonWorkerUtils.writeUTF(provider, dataOut)
 
-    // Send the paths
-    dataOut.writeInt(paths.length)
-    paths.foreach(PythonWorkerUtils.writeUTF(_, dataOut))
-
     // Send the user-specified schema, if provided
     dataOut.writeBoolean(userSpecifiedSchema.isDefined)
     userSpecifiedSchema.map(_.json).foreach(PythonWorkerUtils.writeUTF(_, 
dataOut))
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index 22a1e5250cd..bd0b08cbec8 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -160,13 +160,20 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
     val dataSourceScript =
       s"""
          |from pyspark.sql.datasource import DataSource, DataSourceReader
+         |import json
+         |
          |class SimpleDataSourceReader(DataSourceReader):
-         |    def __init__(self, paths, options):
-         |        self.paths = paths
+         |    def __init__(self, options):
          |        self.options = options
          |
          |    def partitions(self):
-         |        return iter(self.paths)
+         |        if "paths" in self.options:
+         |            paths = json.loads(self.options["paths"])
+         |        elif "path" in self.options:
+         |            paths = [self.options["path"]]
+         |        else:
+         |            paths = []
+         |        return paths
          |
          |    def read(self, path):
          |        yield (path, 1)
@@ -180,11 +187,10 @@ class PythonDataSourceSuite extends QueryTest with 
SharedSparkSession {
          |        return "id STRING, value INT"
          |
          |    def reader(self, schema):
-         |        return SimpleDataSourceReader(self.paths, self.options)
+         |        return SimpleDataSourceReader(self.options)
          |""".stripMargin
     val dataSource = createUserDefinedPythonDataSource(dataSourceName, 
dataSourceScript)
     spark.dataSource.registerPython("test", dataSource)
-
     checkAnswer(spark.read.format("test").load(), Seq(Row(null, 1)))
     checkAnswer(spark.read.format("test").load("1"), Seq(Row("1", 1)))
     checkAnswer(spark.read.format("test").load("1", "2"), Seq(Row("1", 1), 
Row("2", 1)))


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to