This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 9d93b7112a31 [SPARK-45639][SQL][PYTHON] Support loading Python data
sources in DataFrameReader
9d93b7112a31 is described below
commit 9d93b7112a31965447a34301889f90d14578e628
Author: allisonwang-db <[email protected]>
AuthorDate: Wed Nov 8 09:23:12 2023 -0800
[SPARK-45639][SQL][PYTHON] Support loading Python data sources in
DataFrameReader
### What changes were proposed in this pull request?
This PR supports `spark.read.format(...).load()` for Python data sources.
After this PR, users can use a Python data source directly like this:
```python
from pyspark.sql.datasource import DataSource, DataSourceReader
class MyReader(DataSourceReader):
def read(self, partition):
yield (0, 1)
class MyDataSource(DataSource):
classmethod
def name(cls):
return "my-source"
def schema(self):
return "id INT, value INT"
def reader(self, schema):
return MyReader()
spark.dataSource.register(MyDataSource)
df = spark.read.format("my-source").load()
df.show()
+---+-----+
| id|value|
+---+-----+
| 0| 1|
+---+-----+
```
### Why are the changes needed?
To support Python data sources.
### Does this PR introduce _any_ user-facing change?
Yes. After this PR, users can load a custom Python data source using
`spark.read.format(...).load()`.
### How was this patch tested?
New unit tests.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43630 from allisonwang-db/spark-45639-ds-lookup.
Authored-by: allisonwang-db <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../src/main/resources/error/error-classes.json | 12 +++
dev/sparktestsupport/modules.py | 1 +
docs/sql-error-conditions.md | 12 +++
python/pyspark/sql/session.py | 4 +
python/pyspark/sql/tests/test_python_datasource.py | 97 ++++++++++++++++++++--
python/pyspark/sql/worker/create_data_source.py | 16 +++-
.../spark/sql/errors/QueryCompilationErrors.scala | 12 +++
.../org/apache/spark/sql/DataFrameReader.scala | 48 +++++++++--
.../execution/datasources/DataSourceManager.scala | 31 ++++++-
.../python/UserDefinedPythonDataSource.scala | 15 ++--
.../execution/python/PythonDataSourceSuite.scala | 35 ++++++++
11 files changed, 255 insertions(+), 28 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-classes.json
b/common/utils/src/main/resources/error/error-classes.json
index db46ee8ca208..c38171c3d9e6 100644
--- a/common/utils/src/main/resources/error/error-classes.json
+++ b/common/utils/src/main/resources/error/error-classes.json
@@ -850,6 +850,12 @@
],
"sqlState" : "42710"
},
+ "DATA_SOURCE_NOT_EXIST" : {
+ "message" : [
+ "Data source '<provider>' not found. Please make sure the data source is
registered."
+ ],
+ "sqlState" : "42704"
+ },
"DATA_SOURCE_NOT_FOUND" : {
"message" : [
"Failed to find the data source: <provider>. Please find packages at
`https://spark.apache.org/third-party-projects.html`."
@@ -1095,6 +1101,12 @@
],
"sqlState" : "42809"
},
+ "FOUND_MULTIPLE_DATA_SOURCES" : {
+ "message" : [
+ "Detected multiple data sources with the name '<provider>'. Please check
the data source isn't simultaneously registered and located in the classpath."
+ ],
+ "sqlState" : "42710"
+ },
"GENERATED_COLUMN_WITH_DEFAULT_VALUE" : {
"message" : [
"A column cannot have both a default value and a generation expression
but column <colName> has default value: (<defaultValue>) and generation
expression: (<genExpr>)."
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index 95c9069a8313..01757ba28dd2 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -511,6 +511,7 @@ pyspark_sql = Module(
"pyspark.sql.tests.pandas.test_pandas_udf_window",
"pyspark.sql.tests.pandas.test_converter",
"pyspark.sql.tests.test_pandas_sqlmetrics",
+ "pyspark.sql.tests.test_python_datasource",
"pyspark.sql.tests.test_readwriter",
"pyspark.sql.tests.test_serde",
"pyspark.sql.tests.test_session",
diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md
index 7b0bc8ceb2b5..8a5faa15dc9c 100644
--- a/docs/sql-error-conditions.md
+++ b/docs/sql-error-conditions.md
@@ -454,6 +454,12 @@ DataType `<type>` requires a length parameter, for example
`<type>`(10). Please
Data source '`<provider>`' already exists in the registry. Please use a
different name for the new data source.
+### DATA_SOURCE_NOT_EXIST
+
+[SQLSTATE:
42704](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
+
+Data source '`<provider>`' not found. Please make sure the data source is
registered.
+
### DATA_SOURCE_NOT_FOUND
[SQLSTATE:
42K02](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
@@ -669,6 +675,12 @@ No such struct field `<fieldName>` in `<fields>`.
The operation `<statement>` is not allowed on the `<objectType>`:
`<objectName>`.
+### FOUND_MULTIPLE_DATA_SOURCES
+
+[SQLSTATE:
42710](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
+
+Detected multiple data sources with the name '`<provider>`'. Please check the
data source isn't simultaneously registered and located in the classpath.
+
### GENERATED_COLUMN_WITH_DEFAULT_VALUE
[SQLSTATE:
42623](sql-error-conditions-sqlstates.html#class-42-syntax-error-or-access-rule-violation)
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index 4ab7281d7ac8..85aff09aa3df 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -884,6 +884,10 @@ class SparkSession(SparkConversionMixin):
Returns
-------
:class:`DataSourceRegistration`
+
+ Notes
+ -----
+ This feature is experimental and unstable.
"""
from pyspark.sql.datasource import DataSourceRegistration
diff --git a/python/pyspark/sql/tests/test_python_datasource.py
b/python/pyspark/sql/tests/test_python_datasource.py
index b429d73fb7d7..fe6a84175274 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -14,10 +14,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import os
import unittest
from pyspark.sql.datasource import DataSource, DataSourceReader
+from pyspark.sql.types import Row
+from pyspark.testing import assertDataFrameEqual
from pyspark.testing.sqlutils import ReusedSQLTestCase
+from pyspark.testing.utils import SPARK_HOME
class BasePythonDataSourceTestsMixin:
@@ -45,16 +49,93 @@ class BasePythonDataSourceTestsMixin:
self.assertEqual(list(reader.partitions()), [None])
self.assertEqual(list(reader.read(None)), [(None,)])
- def test_register_data_source(self):
- class MyDataSource(DataSource):
- ...
+ def test_in_memory_data_source(self):
+ class InMemDataSourceReader(DataSourceReader):
+ DEFAULT_NUM_PARTITIONS: int = 3
+
+ def __init__(self, paths, options):
+ self.paths = paths
+ self.options = options
+
+ def partitions(self):
+ if "num_partitions" in self.options:
+ num_partitions = int(self.options["num_partitions"])
+ else:
+ num_partitions = self.DEFAULT_NUM_PARTITIONS
+ return range(num_partitions)
+
+ def read(self, partition):
+ yield partition, str(partition)
+
+ class InMemoryDataSource(DataSource):
+ @classmethod
+ def name(cls):
+ return "memory"
+
+ def schema(self):
+ return "x INT, y STRING"
+
+ def reader(self, schema) -> "DataSourceReader":
+ return InMemDataSourceReader(self.paths, self.options)
+
+ self.spark.dataSource.register(InMemoryDataSource)
+ df = self.spark.read.format("memory").load()
+ self.assertEqual(df.rdd.getNumPartitions(), 3)
+ assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1"), Row(x=2,
y="2")])
- self.spark.dataSource.register(MyDataSource)
+ df = self.spark.read.format("memory").option("num_partitions",
2).load()
+ assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")])
+ self.assertEqual(df.rdd.getNumPartitions(), 2)
+
+ def test_custom_json_data_source(self):
+ import json
+
+ class JsonDataSourceReader(DataSourceReader):
+ def __init__(self, paths, options):
+ self.paths = paths
+ self.options = options
+
+ def partitions(self):
+ return iter(self.paths)
+
+ def read(self, path):
+ with open(path, "r") as file:
+ for line in file.readlines():
+ if line.strip():
+ data = json.loads(line)
+ yield data.get("name"), data.get("age")
+
+ class JsonDataSource(DataSource):
+ @classmethod
+ def name(cls):
+ return "my-json"
+
+ def schema(self):
+ return "name STRING, age INT"
+
+ def reader(self, schema) -> "DataSourceReader":
+ return JsonDataSourceReader(self.paths, self.options)
+
+ self.spark.dataSource.register(JsonDataSource)
+ path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json")
+ path2 = os.path.join(SPARK_HOME,
"python/test_support/sql/people1.json")
+ df1 = self.spark.read.format("my-json").load(path1)
+ self.assertEqual(df1.rdd.getNumPartitions(), 1)
+ assertDataFrameEqual(
+ df1,
+ [Row(name="Michael", age=None), Row(name="Andy", age=30),
Row(name="Justin", age=19)],
+ )
- self.assertTrue(
- self.spark._jsparkSession.sharedState()
- .dataSourceRegistry()
- .dataSourceExists("MyDataSource")
+ df2 = self.spark.read.format("my-json").load([path1, path2])
+ self.assertEqual(df2.rdd.getNumPartitions(), 2)
+ assertDataFrameEqual(
+ df2,
+ [
+ Row(name="Michael", age=None),
+ Row(name="Andy", age=30),
+ Row(name="Justin", age=19),
+ Row(name="Jonathan", age=None),
+ ],
)
diff --git a/python/pyspark/sql/worker/create_data_source.py
b/python/pyspark/sql/worker/create_data_source.py
index ea56d2cc7522..6a9ef79b7c18 100644
--- a/python/pyspark/sql/worker/create_data_source.py
+++ b/python/pyspark/sql/worker/create_data_source.py
@@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-
+import inspect
import os
import sys
from typing import IO, List
from pyspark.accumulators import _accumulatorRegistry
-from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
+from pyspark.errors import PySparkAssertionError, PySparkRuntimeError,
PySparkTypeError
from pyspark.java_gateway import local_connect_and_auth
from pyspark.serializers import (
read_bool,
@@ -84,8 +84,20 @@ def main(infile: IO, outfile: IO) -> None:
},
)
+ # Check the name method is a class method.
+ if not inspect.ismethod(data_source_cls.name):
+ raise PySparkTypeError(
+ error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
+ message_parameters={
+ "expected": "'name()' method to be a classmethod",
+ "actual": f"'{type(data_source_cls.name).__name__}'",
+ },
+ )
+
# Receive the provider name.
provider = utf8_deserializer.loads(infile)
+
+ # Check if the provider name matches the data source's name.
if provider.lower() != data_source_cls.name().lower():
raise PySparkAssertionError(
error_class="PYTHON_DATA_SOURCE_TYPE_MISMATCH",
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
index 1925eddd2ce2..0c5dcb1ead01 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala
@@ -3805,4 +3805,16 @@ private[sql] object QueryCompilationErrors extends
QueryErrorsBase with Compilat
errorClass = "DATA_SOURCE_ALREADY_EXISTS",
messageParameters = Map("provider" -> name))
}
+
+ def dataSourceDoesNotExist(name: String): Throwable = {
+ new AnalysisException(
+ errorClass = "DATA_SOURCE_NOT_EXIST",
+ messageParameters = Map("provider" -> name))
+ }
+
+ def foundMultipleDataSources(provider: String): Throwable = {
+ new AnalysisException(
+ errorClass = "FOUND_MULTIPLE_DATA_SOURCES",
+ messageParameters = Map("provider" -> provider))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 9992d8cbba07..ef447e8a8010 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -17,11 +17,12 @@
package org.apache.spark.sql
-import java.util.{Locale, Properties}
+import java.util.{Locale, Properties, ServiceConfigurationError}
import scala.jdk.CollectionConverters._
+import scala.util.{Failure, Success, Try}
-import org.apache.spark.Partition
+import org.apache.spark.{Partition, SparkClassNotFoundException,
SparkThrowable}
import org.apache.spark.annotation.Stable
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
@@ -208,10 +209,45 @@ class DataFrameReader private[sql](sparkSession:
SparkSession) extends Logging {
throw QueryCompilationErrors.pathOptionNotSetCorrectlyWhenReadingError()
}
- DataSource.lookupDataSourceV2(source,
sparkSession.sessionState.conf).flatMap { provider =>
- DataSourceV2Utils.loadV2Source(sparkSession, provider,
userSpecifiedSchema, extraOptions,
- source, paths: _*)
- }.getOrElse(loadV1Source(paths: _*))
+ val isUserDefinedDataSource =
+ sparkSession.sharedState.dataSourceManager.dataSourceExists(source)
+
+ Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf))
match {
+ case Success(providerOpt) =>
+ // The source can be successfully loaded as either a V1 or a V2 data
source.
+ // Check if it is also a user-defined data source.
+ if (isUserDefinedDataSource) {
+ throw QueryCompilationErrors.foundMultipleDataSources(source)
+ }
+ providerOpt.flatMap { provider =>
+ DataSourceV2Utils.loadV2Source(
+ sparkSession, provider, userSpecifiedSchema, extraOptions, source,
paths: _*)
+ }.getOrElse(loadV1Source(paths: _*))
+ case Failure(exception) =>
+ // Exceptions are thrown while trying to load the data source as a V1
or V2 data source.
+ // For the following not found exceptions, if the user-defined data
source is defined,
+ // we can instead return the user-defined data source.
+ val isNotFoundError = exception match {
+ case _: NoClassDefFoundError | _: SparkClassNotFoundException => true
+ case e: SparkThrowable => e.getErrorClass == "DATA_SOURCE_NOT_FOUND"
+ case e: ServiceConfigurationError =>
e.getCause.isInstanceOf[NoClassDefFoundError]
+ case _ => false
+ }
+ if (isNotFoundError && isUserDefinedDataSource) {
+ loadUserDefinedDataSource(paths)
+ } else {
+ // Throw the original exception.
+ throw exception
+ }
+ }
+ }
+
+ private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = {
+ val builder =
sparkSession.sharedState.dataSourceManager.lookupDataSource(source)
+ // Unless the legacy path option behavior is enabled, the extraOptions here
+ // should not include "path" or "paths" as keys.
+ val plan = builder(sparkSession, source, paths, userSpecifiedSchema,
extraOptions)
+ Dataset.ofRows(sparkSession, plan)
}
private def loadV1Source(paths: String*) = {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
index 283ca2ac62ed..72a9e6497aca 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
@@ -22,10 +22,14 @@ import java.util.concurrent.ConcurrentHashMap
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.util.CaseInsensitiveStringMap
+/**
+ * A manager for user-defined data sources. It is used to register and lookup
data sources by
+ * their short names or fully qualified names.
+ */
class DataSourceManager {
private type DataSourceBuilder = (
@@ -33,22 +37,41 @@ class DataSourceManager {
String, // provider name
Seq[String], // paths
Option[StructType], // user specified schema
- CaseInsensitiveStringMap // options
+ CaseInsensitiveMap[String] // options
) => LogicalPlan
private val dataSourceBuilders = new ConcurrentHashMap[String,
DataSourceBuilder]()
private def normalize(name: String): String = name.toLowerCase(Locale.ROOT)
+ /**
+ * Register a data source builder for the given provider.
+ * Note that the provider name is case-insensitive.
+ */
def registerDataSource(name: String, builder: DataSourceBuilder): Unit = {
val normalizedName = normalize(name)
if (dataSourceBuilders.containsKey(normalizedName)) {
throw QueryCompilationErrors.dataSourceAlreadyExists(name)
}
- // TODO(SPARK-45639): check if the data source is a DSv1 or DSv2 using
loadDataSource.
dataSourceBuilders.put(normalizedName, builder)
}
- def dataSourceExists(name: String): Boolean =
+ /**
+ * Returns a data source builder for the given provider and throw an
exception if
+ * it does not exist.
+ */
+ def lookupDataSource(name: String): DataSourceBuilder = {
+ if (dataSourceExists(name)) {
+ dataSourceBuilders.get(normalize(name))
+ } else {
+ throw QueryCompilationErrors.dataSourceDoesNotExist(name)
+ }
+ }
+
+ /**
+ * Checks if a data source with the specified name exists (case-insensitive).
+ */
+ def dataSourceExists(name: String): Boolean = {
dataSourceBuilders.containsKey(normalize(name))
+ }
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
index dbff8eefcd5f..703c1e10ce26 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
@@ -20,7 +20,6 @@ package org.apache.spark.sql.execution.python
import java.io.{DataInputStream, DataOutputStream}
import scala.collection.mutable.ArrayBuffer
-import scala.jdk.CollectionConverters._
import net.razorvine.pickle.Pickler
@@ -28,9 +27,9 @@ import org.apache.spark.api.python.{PythonFunction,
PythonWorkerUtils, SimplePyt
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan,
PythonDataSource}
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
+import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{DataType, StructType}
-import org.apache.spark.sql.util.CaseInsensitiveStringMap
/**
* A user-defined Python data source. This is used by the Python API.
@@ -44,7 +43,7 @@ case class UserDefinedPythonDataSource(dataSourceCls:
PythonFunction) {
provider: String,
paths: Seq[String],
userSpecifiedSchema: Option[StructType],
- options: CaseInsensitiveStringMap): LogicalPlan = {
+ options: CaseInsensitiveMap[String]): LogicalPlan = {
val runner = new UserDefinedPythonDataSourceRunner(
dataSourceCls, provider, paths, userSpecifiedSchema, options)
@@ -70,7 +69,7 @@ case class UserDefinedPythonDataSource(dataSourceCls:
PythonFunction) {
provider: String,
paths: Seq[String] = Seq.empty,
userSpecifiedSchema: Option[StructType] = None,
- options: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty):
DataFrame = {
+ options: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map.empty)):
DataFrame = {
val plan = builder(sparkSession, provider, paths, userSpecifiedSchema,
options)
Dataset.ofRows(sparkSession, plan)
}
@@ -91,7 +90,7 @@ class UserDefinedPythonDataSourceRunner(
provider: String,
paths: Seq[String],
userSpecifiedSchema: Option[StructType],
- options: CaseInsensitiveStringMap)
+ options: CaseInsensitiveMap[String])
extends PythonPlannerRunner[PythonDataSourceCreationResult](dataSourceCls) {
override val workerModule = "pyspark.sql.worker.create_data_source"
@@ -113,9 +112,9 @@ class UserDefinedPythonDataSourceRunner(
// Send the options
dataOut.writeInt(options.size)
- options.entrySet.asScala.foreach { e =>
- PythonWorkerUtils.writeUTF(e.getKey, dataOut)
- PythonWorkerUtils.writeUTF(e.getValue, dataOut)
+ options.iterator.foreach { case (key, value) =>
+ PythonWorkerUtils.writeUTF(key, dataOut)
+ PythonWorkerUtils.writeUTF(value, dataOut)
}
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index 6c749c2c9b67..22a1e5250cd9 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -155,6 +155,41 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
parameters = Map("provider" -> dataSourceName))
}
+ test("load data source") {
+ assume(shouldTestPythonUDFs)
+ val dataSourceScript =
+ s"""
+ |from pyspark.sql.datasource import DataSource, DataSourceReader
+ |class SimpleDataSourceReader(DataSourceReader):
+ | def __init__(self, paths, options):
+ | self.paths = paths
+ | self.options = options
+ |
+ | def partitions(self):
+ | return iter(self.paths)
+ |
+ | def read(self, path):
+ | yield (path, 1)
+ |
+ |class $dataSourceName(DataSource):
+ | @classmethod
+ | def name(cls) -> str:
+ | return "test"
+ |
+ | def schema(self) -> str:
+ | return "id STRING, value INT"
+ |
+ | def reader(self, schema):
+ | return SimpleDataSourceReader(self.paths, self.options)
+ |""".stripMargin
+ val dataSource = createUserDefinedPythonDataSource(dataSourceName,
dataSourceScript)
+ spark.dataSource.registerPython("test", dataSource)
+
+ checkAnswer(spark.read.format("test").load(), Seq(Row(null, 1)))
+ checkAnswer(spark.read.format("test").load("1"), Seq(Row("1", 1)))
+ checkAnswer(spark.read.format("test").load("1", "2"), Seq(Row("1", 1),
Row("2", 1)))
+ }
+
test("reader not implemented") {
assume(shouldTestPythonUDFs)
val dataSourceScript =
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]