This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new a1b0da200b27 [SPARK-45597][PYTHON][SQL] Support creating table using a
Python data source in SQL (DSv2 exec)
a1b0da200b27 is described below
commit a1b0da200b271214e9d6b3170308509d7d514c7f
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Fri Dec 15 11:04:32 2023 -0800
[SPARK-45597][PYTHON][SQL] Support creating table using a Python data
source in SQL (DSv2 exec)
### What changes were proposed in this pull request?
This PR is same as https://github.com/apache/spark/pull/44233 but does not
use `V1Table` but the original DSv2 interface by reusing UDTF execution code.
### Why are the changes needed?
In order for Python Data Source to be able to be used in all other place
including SparkR, Scala together.
### Does this PR introduce _any_ user-facing change?
Yes. Users can register their Python Data Source, and use them in SQL,
SparkR, etc.
### How was this patch tested?
Unittests were added, and manually tested.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #44269
Closes #44233
Closes #43784
Closes #44305 from HyukjinKwon/SPARK-45597-3.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../plans/logical/pythonLogicalOperators.scala | 40 +----
.../org/apache/spark/sql/DataFrameReader.scala | 48 +----
.../apache/spark/sql/DataSourceRegistration.scala | 2 +-
.../scala/org/apache/spark/sql/SparkSession.scala | 2 +-
.../spark/sql/execution/SparkOptimizer.scala | 8 +-
.../spark/sql/execution/SparkStrategies.scala | 2 -
.../apache/spark/sql/execution/command/ddl.scala | 6 +-
.../spark/sql/execution/command/tables.scala | 7 +-
.../sql/execution/datasources/DataSource.scala | 35 +++-
.../execution/datasources/DataSourceManager.scala | 24 +--
.../datasources/PlanPythonDataSourceScan.scala | 89 ----------
.../ApplyInPandasWithStatePythonRunner.scala | 6 +-
.../execution/python/ArrowEvalPythonUDTFExec.scala | 2 +-
.../sql/execution/python/ArrowPythonRunner.scala | 6 +-
.../execution/python/ArrowPythonUDTFRunner.scala | 2 +-
.../python/CoGroupedArrowPythonRunner.scala | 6 +-
.../python/FlatMapGroupsInPythonExec.scala | 2 +-
.../python/MapInBatchEvaluatorFactory.scala | 2 +-
.../sql/execution/python/MapInBatchExec.scala | 2 +-
.../sql/execution/python/PythonArrowInput.scala | 4 +-
.../sql/execution/python/PythonArrowOutput.scala | 6 +-
.../python/PythonDataSourcePartitionsExec.scala | 80 ---------
.../python/UserDefinedPythonDataSource.scala | 195 ++++++++++++++++++---
.../spark/sql/streaming/DataStreamReader.scala | 5 +-
.../spark/sql/streaming/DataStreamWriter.scala | 2 +-
.../execution/python/PythonDataSourceSuite.scala | 66 ++++---
26 files changed, 290 insertions(+), 359 deletions(-)
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
index fb8b06eb41bc..f5930c5272a2 100644
---
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
+++
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala
@@ -17,13 +17,11 @@
package org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.api.python.PythonFunction
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet,
Expression, PythonUDF, PythonUDTF}
import org.apache.spark.sql.catalyst.trees.TreePattern._
-import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode}
-import org.apache.spark.sql.types.{BinaryType, StructType}
+import org.apache.spark.sql.types.StructType
/**
* FlatMap groups using a udf: pandas.Dataframe -> pandas.DataFrame.
@@ -103,42 +101,6 @@ case class PythonMapInArrow(
copy(child = newChild)
}
-/**
- * Represents a Python data source.
- */
-case class PythonDataSource(
- dataSource: PythonFunction,
- outputSchema: StructType,
- override val output: Seq[Attribute]) extends LeafNode {
- require(output.forall(_.resolved),
- "Unresolved attributes found when constructing PythonDataSource.")
- override protected def stringArgs: Iterator[Any] = {
- Iterator(output)
- }
- final override val nodePatterns: Seq[TreePattern] = Seq(PYTHON_DATA_SOURCE)
-}
-
-/**
- * Represents a list of Python data source partitions.
- */
-case class PythonDataSourcePartitions(
- output: Seq[Attribute],
- partitions: Seq[Array[Byte]]) extends LeafNode {
- override protected def stringArgs: Iterator[Any] = {
- if (partitions.isEmpty) {
- Iterator("<empty>", output)
- } else {
- Iterator(output)
- }
- }
-}
-
-object PythonDataSourcePartitions {
- def schema: StructType = new StructType().add("partition", BinaryType)
-
- def getOutputAttrs: Seq[Attribute] = toAttributes(schema)
-}
-
/**
* Flatmap cogroups using a udf: pandas.Dataframe, pandas.Dataframe ->
pandas.Dataframe
* This is used by DataFrame.groupby().cogroup().apply().
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index c29ffb329072..9992d8cbba07 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -17,12 +17,11 @@
package org.apache.spark.sql
-import java.util.{Locale, Properties, ServiceConfigurationError}
+import java.util.{Locale, Properties}
import scala.jdk.CollectionConverters._
-import scala.util.{Failure, Success, Try}
-import org.apache.spark.{Partition, SparkClassNotFoundException,
SparkThrowable}
+import org.apache.spark.Partition
import org.apache.spark.annotation.Stable
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.internal.Logging
@@ -209,45 +208,10 @@ class DataFrameReader private[sql](sparkSession:
SparkSession) extends Logging {
throw QueryCompilationErrors.pathOptionNotSetCorrectlyWhenReadingError()
}
- val isUserDefinedDataSource =
- sparkSession.sessionState.dataSourceManager.dataSourceExists(source)
-
- Try(DataSource.lookupDataSourceV2(source, sparkSession.sessionState.conf))
match {
- case Success(providerOpt) =>
- // The source can be successfully loaded as either a V1 or a V2 data
source.
- // Check if it is also a user-defined data source.
- if (isUserDefinedDataSource) {
- throw QueryCompilationErrors.foundMultipleDataSources(source)
- }
- providerOpt.flatMap { provider =>
- DataSourceV2Utils.loadV2Source(
- sparkSession, provider, userSpecifiedSchema, extraOptions, source,
paths: _*)
- }.getOrElse(loadV1Source(paths: _*))
- case Failure(exception) =>
- // Exceptions are thrown while trying to load the data source as a V1
or V2 data source.
- // For the following not found exceptions, if the user-defined data
source is defined,
- // we can instead return the user-defined data source.
- val isNotFoundError = exception match {
- case _: NoClassDefFoundError | _: SparkClassNotFoundException => true
- case e: SparkThrowable => e.getErrorClass == "DATA_SOURCE_NOT_FOUND"
- case e: ServiceConfigurationError =>
e.getCause.isInstanceOf[NoClassDefFoundError]
- case _ => false
- }
- if (isNotFoundError && isUserDefinedDataSource) {
- loadUserDefinedDataSource(paths)
- } else {
- // Throw the original exception.
- throw exception
- }
- }
- }
-
- private def loadUserDefinedDataSource(paths: Seq[String]): DataFrame = {
- val builder =
sparkSession.sessionState.dataSourceManager.lookupDataSource(source)
- // Add `path` and `paths` options to the extra options if specified.
- val optionsWithPath = DataSourceV2Utils.getOptionsWithPaths(extraOptions,
paths: _*)
- val plan = builder(sparkSession, source, userSpecifiedSchema,
optionsWithPath)
- Dataset.ofRows(sparkSession, plan)
+ DataSource.lookupDataSourceV2(source,
sparkSession.sessionState.conf).flatMap { provider =>
+ DataSourceV2Utils.loadV2Source(sparkSession, provider,
userSpecifiedSchema, extraOptions,
+ source, paths: _*)
+ }.getOrElse(loadV1Source(paths: _*))
}
private def loadV1Source(paths: String*) = {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala
b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala
index 15d26418984b..936286eb0da5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataSourceRegistration.scala
@@ -43,6 +43,6 @@ private[sql] class DataSourceRegistration private[sql]
(dataSourceManager: DataS
| pythonExec: ${dataSource.dataSourceCls.pythonExec}
""".stripMargin)
- dataSourceManager.registerDataSource(name, dataSource.builder)
+ dataSourceManager.registerDataSource(name, dataSource)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 15eeca87dcf6..44a4d82c1dac 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -780,7 +780,7 @@ class SparkSession private(
DataSource.lookupDataSource(runner, sessionState.conf) match {
case source if classOf[ExternalCommandRunner].isAssignableFrom(source) =>
Dataset.ofRows(self, ExternalCommandExecutor(
- source.getDeclaredConstructor().newInstance()
+ DataSource.newDataSourceInstance(runner, source)
.asInstanceOf[ExternalCommandRunner], command, options))
case _ =>
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 00328910f5b6..70a35ea91153 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.connector.catalog.CatalogManager
-import org.apache.spark.sql.execution.datasources.{PlanPythonDataSourceScan,
PruneFileSourcePartitions, SchemaPruning, V1Writes}
+import org.apache.spark.sql.execution.datasources.{PruneFileSourcePartitions,
SchemaPruning, V1Writes}
import
org.apache.spark.sql.execution.datasources.v2.{GroupBasedRowLevelOperationScanPlanning,
OptimizeMetadataOnlyDeleteFromTable, V2ScanPartitioningAndOrdering,
V2ScanRelationPushDown, V2Writes}
import
org.apache.spark.sql.execution.dynamicpruning.{CleanupDynamicPruningFilters,
PartitionPruning, RowLevelOperationRuntimeGroupFiltering}
import
org.apache.spark.sql.execution.python.{ExtractGroupingPythonUDFFromAggregate,
ExtractPythonUDFFromAggregate, ExtractPythonUDFs, ExtractPythonUDTFs}
@@ -42,8 +42,7 @@ class SparkOptimizer(
V2ScanRelationPushDown :+
V2ScanPartitioningAndOrdering :+
V2Writes :+
- PruneFileSourcePartitions :+
- PlanPythonDataSourceScan
+ PruneFileSourcePartitions
override def preCBORules: Seq[Rule[LogicalPlan]] =
OptimizeMetadataOnlyDeleteFromTable :: Nil
@@ -102,8 +101,7 @@ class SparkOptimizer(
V2ScanRelationPushDown.ruleName :+
V2ScanPartitioningAndOrdering.ruleName :+
V2Writes.ruleName :+
- ReplaceCTERefWithRepartition.ruleName :+
- PlanPythonDataSourceScan.ruleName
+ ReplaceCTERefWithRepartition.ruleName
/**
* Optimization batches that are executed before the regular optimization
batches (also before
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 2d24f997d105..35070ac1d562 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -753,8 +753,6 @@ abstract class SparkStrategies extends
QueryPlanner[SparkPlan] {
case ArrowEvalPythonUDTF(udtf, requiredChildOutput, resultAttrs, child,
evalType) =>
ArrowEvalPythonUDTFExec(
udtf, requiredChildOutput, resultAttrs, planLater(child), evalType)
:: Nil
- case PythonDataSourcePartitions(output, partitions) =>
- PythonDataSourcePartitionsExec(output, partitions) :: Nil
case _ =>
Nil
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index dc1c5b3fd580..199c8728a5c9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -45,7 +45,7 @@ import
org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM
import org.apache.spark.sql.connector.catalog.SupportsNamespaces._
import org.apache.spark.sql.errors.QueryCompilationErrors
import
org.apache.spark.sql.errors.QueryExecutionErrors.hiveTableWithAnsiIntervalsError
-import org.apache.spark.sql.execution.datasources.{DataSource,
DataSourceUtils, FileFormat, HadoopFsRelation, LogicalRelation}
+import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
import org.apache.spark.sql.types._
@@ -1025,7 +1025,9 @@ object DDLUtils extends Logging {
def checkDataColNames(provider: String, schema: StructType): Unit = {
val source = try {
- DataSource.lookupDataSource(provider,
SQLConf.get).getConstructor().newInstance()
+ DataSource.newDataSourceInstance(
+ provider,
+ DataSource.lookupDataSource(provider, SQLConf.get))
} catch {
case e: Throwable =>
logError(s"Failed to find data source: $provider when check data
column names.", e)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index 2f8fca7cfd73..9771ee08b258 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -35,7 +35,7 @@ import
org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString,
quoteIfNeeded, CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils,
ResolveDefaultColumns}
+import org.apache.spark.sql.catalyst.util._
import
org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY
import
org.apache.spark.sql.connector.catalog.CatalogV2Implicits.TableIdentifierHelper
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
@@ -264,8 +264,9 @@ case class AlterTableAddColumnsCommand(
}
if (DDLUtils.isDatasourceTable(catalogTable)) {
- DataSource.lookupDataSource(catalogTable.provider.get, conf).
- getConstructor().newInstance() match {
+ DataSource.newDataSourceInstance(
+ catalogTable.provider.get,
+ DataSource.lookupDataSource(catalogTable.provider.get, conf)) match {
// For datasource table, this command can only support the following
File format.
// TextFileFormat only default to one column "value"
// Hive type is already considered as hive serde table, so the logic
will not
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 71b6d4b886b4..9612d8ff24f5 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -44,6 +44,7 @@ import
org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2
import org.apache.spark.sql.execution.datasources.xml.XmlFileFormat
+import org.apache.spark.sql.execution.python.PythonTableProvider
import org.apache.spark.sql.execution.streaming._
import org.apache.spark.sql.execution.streaming.sources.{RateStreamProvider,
TextSocketSourceProvider}
import org.apache.spark.sql.internal.SQLConf
@@ -105,13 +106,14 @@ case class DataSource(
// [[FileDataSourceV2]] will still be used if we call the load()/save()
method in
// [[DataFrameReader]]/[[DataFrameWriter]], since they use method
`lookupDataSource`
// instead of `providingClass`.
- cls.getDeclaredConstructor().newInstance() match {
+ DataSource.newDataSourceInstance(className, cls) match {
case f: FileDataSourceV2 => f.fallbackFileFormat
case _ => cls
}
}
- private[sql] def providingInstance(): Any =
providingClass.getConstructor().newInstance()
+ private[sql] def providingInstance(): Any =
+ DataSource.newDataSourceInstance(className, providingClass)
private def newHadoopConfiguration(): Configuration =
sparkSession.sessionState.newHadoopConfWithOptions(options)
@@ -622,6 +624,15 @@ object DataSource extends Logging {
"org.apache.spark.sql.sources.HadoopFsRelationProvider",
"org.apache.spark.Logging")
+ /** Create the instance of the datasource */
+ def newDataSourceInstance(provider: String, providingClass: Class[_]): Any =
{
+ providingClass match {
+ case cls if classOf[PythonTableProvider].isAssignableFrom(cls) =>
+ cls.getDeclaredConstructor(classOf[String]).newInstance(provider)
+ case cls => cls.getDeclaredConstructor().newInstance()
+ }
+ }
+
/** Given a provider name, look up the data source class definition. */
def lookupDataSource(provider: String, conf: SQLConf): Class[_] = {
val provider1 = backwardCompatibilityMap.getOrElse(provider, provider)
match {
@@ -649,6 +660,9 @@ object DataSource extends Logging {
// Found the data source using fully qualified path
dataSource
case Failure(error) =>
+ // TODO(SPARK-45600): should be session-based.
+ val isUserDefinedDataSource =
SparkSession.getActiveSession.exists(
+ _.sessionState.dataSourceManager.dataSourceExists(provider))
if (provider1.startsWith("org.apache.spark.sql.hive.orc")) {
throw QueryCompilationErrors.orcNotUsedWithHiveEnabledError()
} else if (provider1.toLowerCase(Locale.ROOT) == "avro" ||
@@ -657,6 +671,8 @@ object DataSource extends Logging {
throw
QueryCompilationErrors.failedToFindAvroDataSourceError(provider1)
} else if (provider1.toLowerCase(Locale.ROOT) == "kafka") {
throw
QueryCompilationErrors.failedToFindKafkaDataSourceError(provider1)
+ } else if (isUserDefinedDataSource) {
+ classOf[PythonTableProvider]
} else {
throw
QueryExecutionErrors.dataSourceNotFoundError(provider1, error)
}
@@ -673,6 +689,14 @@ object DataSource extends Logging {
}
case head :: Nil =>
// there is exactly one registered alias
+ // TODO(SPARK-45600): should be session-based.
+ val isUserDefinedDataSource = SparkSession.getActiveSession.exists(
+ _.sessionState.dataSourceManager.dataSourceExists(provider))
+ // The source can be successfully loaded as either a V1 or a V2 data
source.
+ // Check if it is also a user-defined data source.
+ if (isUserDefinedDataSource) {
+ throw QueryCompilationErrors.foundMultipleDataSources(provider)
+ }
head.getClass
case sources =>
// There are multiple registered aliases for the input. If there is
single datasource
@@ -708,9 +732,9 @@ object DataSource extends Logging {
def lookupDataSourceV2(provider: String, conf: SQLConf):
Option[TableProvider] = {
val useV1Sources =
conf.getConf(SQLConf.USE_V1_SOURCE_LIST).toLowerCase(Locale.ROOT)
.split(",").map(_.trim)
- val cls = lookupDataSource(provider, conf)
+ val providingClass = lookupDataSource(provider, conf)
val instance = try {
- cls.getDeclaredConstructor().newInstance()
+ newDataSourceInstance(provider, providingClass)
} catch {
// Throw the original error from the data source implementation.
case e: java.lang.reflect.InvocationTargetException => throw e.getCause
@@ -718,7 +742,8 @@ object DataSource extends Logging {
instance match {
case d: DataSourceRegister if useV1Sources.contains(d.shortName()) =>
None
case t: TableProvider
- if
!useV1Sources.contains(cls.getCanonicalName.toLowerCase(Locale.ROOT)) =>
+ if !useV1Sources.contains(
+ providingClass.getCanonicalName.toLowerCase(Locale.ROOT)) =>
Some(t)
case _ => None
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
index 1cdc3d9cb69e..e6c4749df60a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceManager.scala
@@ -21,26 +21,18 @@ import java.util.Locale
import java.util.concurrent.ConcurrentHashMap
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
import org.apache.spark.sql.errors.QueryCompilationErrors
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.execution.python.UserDefinedPythonDataSource
+
/**
* A manager for user-defined data sources. It is used to register and lookup
data sources by
* their short names or fully qualified names.
*/
class DataSourceManager extends Logging {
-
- private type DataSourceBuilder = (
- SparkSession, // Spark session
- String, // provider name
- Option[StructType], // user specified schema
- CaseInsensitiveMap[String] // options
- ) => LogicalPlan
-
- private val dataSourceBuilders = new ConcurrentHashMap[String,
DataSourceBuilder]()
+ // TODO(SPARK-45917): Statically load Python Data Source so idempotently
Python
+ // Data Sources can be loaded even when the Driver is restarted.
+ private val dataSourceBuilders = new ConcurrentHashMap[String,
UserDefinedPythonDataSource]()
private def normalize(name: String): String = name.toLowerCase(Locale.ROOT)
@@ -48,9 +40,9 @@ class DataSourceManager extends Logging {
* Register a data source builder for the given provider.
* Note that the provider name is case-insensitive.
*/
- def registerDataSource(name: String, builder: DataSourceBuilder): Unit = {
+ def registerDataSource(name: String, source: UserDefinedPythonDataSource):
Unit = {
val normalizedName = normalize(name)
- val previousValue = dataSourceBuilders.put(normalizedName, builder)
+ val previousValue = dataSourceBuilders.put(normalizedName, source)
if (previousValue != null) {
logWarning(f"The data source $name replaced a previously registered data
source.")
}
@@ -60,7 +52,7 @@ class DataSourceManager extends Logging {
* Returns a data source builder for the given provider and throw an
exception if
* it does not exist.
*/
- def lookupDataSource(name: String): DataSourceBuilder = {
+ def lookupDataSource(name: String): UserDefinedPythonDataSource = {
if (dataSourceExists(name)) {
dataSourceBuilders.get(normalize(name))
} else {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala
deleted file mode 100644
index 7ffd61a4a266..000000000000
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PlanPythonDataSourceScan.scala
+++ /dev/null
@@ -1,89 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.datasources
-
-import org.apache.spark.api.python.{PythonEvalType, PythonFunction,
SimplePythonFunction}
-import org.apache.spark.sql.catalyst.expressions.PythonUDF
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project,
PythonDataSource, PythonDataSourcePartitions, PythonMapInArrow}
-import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.TreePattern.PYTHON_DATA_SOURCE
-import
org.apache.spark.sql.execution.python.UserDefinedPythonDataSourceReadRunner
-import org.apache.spark.util.ArrayImplicits._
-
-/**
- * A logical rule to plan reads from a Python data source.
- *
- * This rule creates a Python process and invokes the `DataSource.reader`
method to create an
- * instance of the user-defined data source reader, generates partitions if
any, and returns
- * the information back to JVM (this rule) to construct the logical plan for
Python data source.
- *
- * For example, prior to applying this rule, the plan might look like:
- *
- * PythonDataSource(dataSource, schema, output)
- *
- * Here, `dataSource` is a serialized Python function that contains an
instance of the DataSource
- * class. Post this rule, the plan is transformed into:
- *
- * Project [output]
- * +- PythonMapInArrow [read_from_data_source, ...]
- * +- PythonDataSourcePartitions [partition_bytes]
- *
- * The PythonDataSourcePartitions contains a list of serialized partition
values for the data
- * source. The `DataSourceReader.read` method will be planned as a MapInArrow
operator that
- * accepts a partition value and yields the scanning output.
- */
-object PlanPythonDataSourceScan extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning(
- _.containsPattern(PYTHON_DATA_SOURCE)) {
- case ds @ PythonDataSource(dataSource: PythonFunction, schema, _) =>
- val inputSchema = PythonDataSourcePartitions.schema
-
- val info = new UserDefinedPythonDataSourceReadRunner(
- dataSource, inputSchema, schema).runInPython()
-
- val readerFunc = SimplePythonFunction(
- command = info.func.toImmutableArraySeq,
- envVars = dataSource.envVars,
- pythonIncludes = dataSource.pythonIncludes,
- pythonExec = dataSource.pythonExec,
- pythonVer = dataSource.pythonVer,
- broadcastVars = dataSource.broadcastVars,
- accumulator = dataSource.accumulator)
-
- val partitionPlan = PythonDataSourcePartitions(
- PythonDataSourcePartitions.getOutputAttrs, info.partitions)
-
- val pythonUDF = PythonUDF(
- name = "read_from_data_source",
- func = readerFunc,
- dataType = schema,
- children = partitionPlan.output,
- evalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
- udfDeterministic = false)
-
- // Construct the plan.
- val plan = PythonMapInArrow(
- pythonUDF,
- ds.output,
- partitionPlan,
- isBarrier = false)
-
- // Project out partition values.
- Project(ds.output, plan)
- }
-}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
index cfe01f85cbe7..936ab866f5bf 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ApplyInPandasWithStatePythonRunner.scala
@@ -61,12 +61,14 @@ class ApplyInPandasWithStatePythonRunner(
keySchema: StructType,
outputSchema: StructType,
stateValueSchema: StructType,
- val pythonMetrics: Map[String, SQLMetric],
+ pyMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String])
extends BasePythonRunner[InType, OutType](funcs, evalType, argOffsets,
jobArtifactUUID)
with PythonArrowInput[InType]
with PythonArrowOutput[OutType] {
+ override val pythonMetrics: Option[Map[String, SQLMetric]] = Some(pyMetrics)
+
override val pythonExec: String =
SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
funcs.head.funcs.head.pythonExec)
@@ -149,7 +151,7 @@ class ApplyInPandasWithStatePythonRunner(
pandasWriter.finalizeGroup()
val deltaData = dataOut.size() - startData
- pythonMetrics("pythonDataSent") += deltaData
+ pythonMetrics.foreach(_("pythonDataSent") += deltaData)
true
} else {
pandasWriter.finalizeData()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
index 9e210bf5241b..2503deae7d5a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonUDTFExec.scala
@@ -70,7 +70,7 @@ case class ArrowEvalPythonUDTFExec(
sessionLocalTimeZone,
largeVarTypes,
pythonRunnerConf,
- pythonMetrics,
+ Some(pythonMetrics),
jobArtifactUUID).compute(batchIter, context.partitionId(), context)
columnarBatchIter.map { batch =>
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
index a9eaf79c9db0..5dcb79cc2b91 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala
@@ -35,7 +35,7 @@ abstract class BaseArrowPythonRunner(
_timeZoneId: String,
protected override val largeVarTypes: Boolean,
protected override val workerConf: Map[String, String],
- val pythonMetrics: Map[String, SQLMetric],
+ override val pythonMetrics: Option[Map[String, SQLMetric]],
jobArtifactUUID: Option[String])
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
funcs, evalType, argOffsets, jobArtifactUUID)
@@ -74,7 +74,7 @@ class ArrowPythonRunner(
_timeZoneId: String,
largeVarTypes: Boolean,
workerConf: Map[String, String],
- pythonMetrics: Map[String, SQLMetric],
+ pythonMetrics: Option[Map[String, SQLMetric]],
jobArtifactUUID: Option[String])
extends BaseArrowPythonRunner(
funcs, evalType, argOffsets, _schema, _timeZoneId, largeVarTypes,
workerConf,
@@ -100,7 +100,7 @@ class ArrowPythonWithNamedArgumentRunner(
jobArtifactUUID: Option[String])
extends BaseArrowPythonRunner(
funcs, evalType, argMetas.map(_.map(_.offset)), _schema, _timeZoneId,
largeVarTypes, workerConf,
- pythonMetrics, jobArtifactUUID) {
+ Some(pythonMetrics), jobArtifactUUID) {
override protected def writeUDF(dataOut: DataOutputStream): Unit =
PythonUDFRunner.writeUDFs(dataOut, funcs, argMetas)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
index 87d1ccb25776..df2e89128124 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonUDTFRunner.scala
@@ -39,7 +39,7 @@ class ArrowPythonUDTFRunner(
protected override val timeZoneId: String,
protected override val largeVarTypes: Boolean,
protected override val workerConf: Map[String, String],
- val pythonMetrics: Map[String, SQLMetric],
+ override val pythonMetrics: Option[Map[String, SQLMetric]],
jobArtifactUUID: Option[String])
extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch](
Seq(ChainedPythonFunctions(Seq(udtf.func))), evalType,
Array(argMetas.map(_.offset)),
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
index eb56298bfbee..70bd1ce82e2e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/CoGroupedArrowPythonRunner.scala
@@ -46,13 +46,15 @@ class CoGroupedArrowPythonRunner(
rightSchema: StructType,
timeZoneId: String,
conf: Map[String, String],
- val pythonMetrics: Map[String, SQLMetric],
+ pyMetrics: Map[String, SQLMetric],
jobArtifactUUID: Option[String])
extends BasePythonRunner[
(Iterator[InternalRow], Iterator[InternalRow]), ColumnarBatch](
funcs, evalType, argOffsets, jobArtifactUUID)
with BasicPythonArrowOutput {
+ override val pythonMetrics: Option[Map[String, SQLMetric]] = Some(pyMetrics)
+
override val pythonExec: String =
SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
funcs.head.funcs.head.pythonExec)
@@ -93,7 +95,7 @@ class CoGroupedArrowPythonRunner(
writeGroup(nextRight, rightSchema, dataOut, "right")
val deltaData = dataOut.size() - startData
- pythonMetrics("pythonDataSent") += deltaData
+ pythonMetrics.foreach(_("pythonDataSent") += deltaData)
true
} else {
dataOut.writeInt(0)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPythonExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPythonExec.scala
index 0c18206a825a..e5a00e2cc8ea 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPythonExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPythonExec.scala
@@ -88,7 +88,7 @@ trait FlatMapGroupsInPythonExec extends SparkPlan with
UnaryExecNode with Python
sessionLocalTimeZone,
largeVarTypes,
pythonRunnerConf,
- pythonMetrics,
+ Some(pythonMetrics),
jobArtifactUUID)
executePython(data, output, runner)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
index 316c543ea807..00990ee46ea5 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchEvaluatorFactory.scala
@@ -36,7 +36,7 @@ class MapInBatchEvaluatorFactory(
sessionLocalTimeZone: String,
largeVarTypes: Boolean,
pythonRunnerConf: Map[String, String],
- pythonMetrics: Map[String, SQLMetric],
+ pythonMetrics: Option[Map[String, SQLMetric]],
jobArtifactUUID: Option[String])
extends PartitionEvaluatorFactory[InternalRow, InternalRow] {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
index 8db389f02667..6db6c96b426a 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/MapInBatchExec.scala
@@ -57,7 +57,7 @@ trait MapInBatchExec extends UnaryExecNode with
PythonSQLMetrics {
conf.sessionLocalTimeZone,
conf.arrowUseLargeVarTypes,
pythonRunnerConf,
- pythonMetrics,
+ Some(pythonMetrics),
jobArtifactUUID)
if (isBarrier) {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
index 1e075cab9224..6d0f31f35ff7 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala
@@ -46,7 +46,7 @@ private[python] trait PythonArrowInput[IN] { self:
BasePythonRunner[IN, _] =>
protected val largeVarTypes: Boolean
- protected def pythonMetrics: Map[String, SQLMetric]
+ protected def pythonMetrics: Option[Map[String, SQLMetric]]
protected def writeNextInputToArrowStream(
root: VectorSchemaRoot,
@@ -132,7 +132,7 @@ private[python] trait BasicPythonArrowInput extends
PythonArrowInput[Iterator[In
writer.writeBatch()
arrowWriter.reset()
val deltaData = dataOut.size() - startData
- pythonMetrics("pythonDataSent") += deltaData
+ pythonMetrics.foreach(_("pythonDataSent") += deltaData)
true
} else {
super[PythonArrowInput].close()
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
index 90922d89ad10..82e8e7aa4f64 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowOutput.scala
@@ -37,7 +37,7 @@ import org.apache.spark.sql.vectorized.{ArrowColumnVector,
ColumnarBatch, Column
*/
private[python] trait PythonArrowOutput[OUT <: AnyRef] { self:
BasePythonRunner[_, OUT] =>
- protected def pythonMetrics: Map[String, SQLMetric]
+ protected def pythonMetrics: Option[Map[String, SQLMetric]]
protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { }
@@ -91,8 +91,8 @@ private[python] trait PythonArrowOutput[OUT <: AnyRef] {
self: BasePythonRunner[
val rowCount = root.getRowCount
batch.setNumRows(root.getRowCount)
val bytesReadEnd = reader.bytesRead()
- pythonMetrics("pythonNumRowsReceived") += rowCount
- pythonMetrics("pythonDataReceived") += bytesReadEnd -
bytesReadStart
+ pythonMetrics.foreach(_("pythonNumRowsReceived") += rowCount)
+ pythonMetrics.foreach(_("pythonDataReceived") += bytesReadEnd -
bytesReadStart)
deserializeColumnarBatch(batch, schema)
} else {
reader.close(false)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonDataSourcePartitionsExec.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonDataSourcePartitionsExec.scala
deleted file mode 100644
index 8f1595cfdd71..000000000000
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonDataSourcePartitionsExec.scala
+++ /dev/null
@@ -1,80 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.execution.python
-
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.{InputRDDCodegen, LeafExecNode,
SQLExecution}
-import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.util.ArrayImplicits._
-
-/**
- * A physical plan node for scanning data from a list of data source partition
values.
- *
- * It creates a RDD with number of partitions equal to size of the partition
value list and
- * each partition contains a single row with a serialized partition value.
- */
-case class PythonDataSourcePartitionsExec(
- output: Seq[Attribute],
- partitions: Seq[Array[Byte]]) extends LeafExecNode with InputRDDCodegen {
-
- override lazy val metrics = Map(
- "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output
rows"))
-
- @transient private lazy val unsafeRows: Array[InternalRow] = {
- if (partitions.isEmpty) {
- Array.empty
- } else {
- val proj = UnsafeProjection.create(output, output)
- partitions.map(p => proj(InternalRow(p)).copy()).toArray
- }
- }
-
- @transient private lazy val rdd: RDD[InternalRow] = {
- val numPartitions = partitions.size
- if (numPartitions == 0) {
- sparkContext.emptyRDD
- } else {
- sparkContext.parallelize(unsafeRows.toImmutableArraySeq, numPartitions)
- }
- }
-
- override def inputRDD: RDD[InternalRow] = rdd
-
- override protected val createUnsafeProjection: Boolean = false
-
- protected override def doExecute(): RDD[InternalRow] = {
- longMetric("numOutputRows").add(partitions.size)
- sendDriverMetrics()
- rdd
- }
-
- override protected def stringArgs: Iterator[Any] = {
- if (partitions.isEmpty) {
- Iterator("<empty>", output)
- } else {
- Iterator(output)
- }
- }
-
- private def sendDriverMetrics(): Unit = {
- val executionId =
sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
- SQLMetrics.postDriverMetricUpdates(sparkContext, executionId,
metrics.values.toSeq)
- }
-}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
index 2c8e1b942727..7c850d1e2890 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
@@ -20,58 +20,199 @@ package org.apache.spark.sql.execution.python
import java.io.{DataInputStream, DataOutputStream}
import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
import net.razorvine.pickle.Pickler
-import org.apache.spark.api.python.{PythonFunction, PythonWorkerUtils,
SimplePythonFunction, SpecialLengths}
-import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan,
PythonDataSource}
+import org.apache.spark.JobArtifactSet
+import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType,
PythonFunction, PythonWorkerUtils, SimplePythonFunction, SpecialLengths}
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.PythonUDF
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.connector.catalog.{SupportsRead, Table,
TableCapability, TableProvider}
+import org.apache.spark.sql.connector.catalog.TableCapability.BATCH_READ
+import org.apache.spark.sql.connector.expressions.Transform
+import org.apache.spark.sql.connector.read.{Batch, InputPartition,
PartitionReader, PartitionReaderFactory, Scan, ScanBuilder}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types.{BinaryType, DataType, StructType}
+import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.util.ArrayImplicits._
+/**
+ * Data Source V2 wrapper for Python Data Source.
+ */
+class PythonTableProvider(shortName: String) extends TableProvider {
+ private var dataSourceInPython: PythonDataSourceCreationResult = _
+ private[this] val jobArtifactUUID =
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+ private lazy val source: UserDefinedPythonDataSource =
+
SparkSession.active.sessionState.dataSourceManager.lookupDataSource(shortName)
+ override def inferSchema(options: CaseInsensitiveStringMap): StructType = {
+ if (dataSourceInPython == null) {
+ dataSourceInPython = source.createDataSourceInPython(shortName, options,
None)
+ }
+ dataSourceInPython.schema
+ }
+
+ override def getTable(
+ schema: StructType,
+ partitioning: Array[Transform],
+ properties: java.util.Map[String, String]): Table = {
+ val outputSchema = schema
+ new Table with SupportsRead {
+ override def name(): String = shortName
+
+ override def capabilities(): java.util.Set[TableCapability] =
java.util.EnumSet.of(
+ BATCH_READ)
+
+ override def newScanBuilder(options: CaseInsensitiveStringMap):
ScanBuilder = {
+ new ScanBuilder with Batch with Scan {
+
+ private lazy val infoInPython: PythonDataSourceReadInfo = {
+ if (dataSourceInPython == null) {
+ dataSourceInPython = source
+ .createDataSourceInPython(shortName, options,
Some(outputSchema))
+ }
+ source.createReadInfoInPython(dataSourceInPython, outputSchema)
+ }
+
+ override def build(): Scan = this
+
+ override def toBatch: Batch = this
+
+ override def readSchema(): StructType = outputSchema
+
+ override def planInputPartitions(): Array[InputPartition] =
+ infoInPython.partitions.zipWithIndex.map(p =>
PythonInputPartition(p._2, p._1)).toArray
+
+ override def createReaderFactory(): PartitionReaderFactory = {
+ val readerFunc = infoInPython.func
+ new PythonPartitionReaderFactory(
+ source, readerFunc, outputSchema, jobArtifactUUID)
+ }
+ }
+ }
+
+ override def schema(): StructType = outputSchema
+ }
+ }
+
+ override def supportsExternalMetadata(): Boolean = true
+}
+
+case class PythonInputPartition(index: Int, pickedPartition: Array[Byte])
extends InputPartition
+
+class PythonPartitionReaderFactory(
+ source: UserDefinedPythonDataSource,
+ pickledReadFunc: Array[Byte],
+ outputSchema: StructType,
+ jobArtifactUUID: Option[String])
+ extends PartitionReaderFactory {
+
+ override def createReader(partition: InputPartition):
PartitionReader[InternalRow] = {
+ new PartitionReader[InternalRow] {
+ private val outputIter = source.createPartitionReadIteratorInPython(
+ partition.asInstanceOf[PythonInputPartition],
+ pickledReadFunc,
+ outputSchema,
+ jobArtifactUUID)
+
+ override def next(): Boolean = outputIter.hasNext
+
+ override def get(): InternalRow = outputIter.next()
+
+ override def close(): Unit = {}
+ }
+ }
+}
+
/**
* A user-defined Python data source. This is used by the Python API.
+ * Defines the interation between Python and JVM.
*
* @param dataSourceCls The Python data source class.
*/
case class UserDefinedPythonDataSource(dataSourceCls: PythonFunction) {
- def builder(
- sparkSession: SparkSession,
- provider: String,
- userSpecifiedSchema: Option[StructType],
- options: CaseInsensitiveMap[String]): LogicalPlan = {
+ private val inputSchema: StructType = new StructType().add("partition",
BinaryType)
+
+ /**
+ * (Driver-side) Run Python process, and get the pickled Python Data Source
+ * instance and its schema.
+ */
+ def createDataSourceInPython(
+ shortName: String,
+ options: CaseInsensitiveStringMap,
+ userSpecifiedSchema: Option[StructType]): PythonDataSourceCreationResult
= {
+ new UserDefinedPythonDataSourceRunner(
+ dataSourceCls,
+ shortName,
+ userSpecifiedSchema,
+
CaseInsensitiveMap(options.asCaseSensitiveMap().asScala.toMap)).runInPython()
+ }
- val runner = new UserDefinedPythonDataSourceRunner(
- dataSourceCls, provider, userSpecifiedSchema, options)
+ /**
+ * (Driver-side) Run Python process, and get the partition read functions,
and
+ * partition information.
+ */
+ def createReadInfoInPython(
+ pythonResult: PythonDataSourceCreationResult,
+ outputSchema: StructType): PythonDataSourceReadInfo = {
+ new UserDefinedPythonDataSourceReadRunner(
+ createPythonFunction(
+ pythonResult.dataSource), inputSchema, outputSchema).runInPython()
+ }
- val result = runner.runInPython()
- val pickledDataSourceInstance = result.dataSource
+ /**
+ * (Executor-side) Create an iterator that reads the input partitions.
+ */
+ def createPartitionReadIteratorInPython(
+ partition: PythonInputPartition,
+ pickledReadFunc: Array[Byte],
+ outputSchema: StructType,
+ jobArtifactUUID: Option[String]): Iterator[InternalRow] = {
+ val readerFunc = createPythonFunction(pickledReadFunc)
+
+ val pythonEvalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF
+
+ val pythonUDF = PythonUDF(
+ name = "read_from_data_source",
+ func = readerFunc,
+ dataType = outputSchema,
+ children = toAttributes(inputSchema),
+ evalType = pythonEvalType,
+ udfDeterministic = false)
+
+ val conf = SQLConf.get
+
+ val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
+ val evaluatorFactory = new MapInBatchEvaluatorFactory(
+ toAttributes(outputSchema),
+ Seq(ChainedPythonFunctions(Seq(pythonUDF.func))),
+ inputSchema,
+ conf.arrowMaxRecordsPerBatch,
+ pythonEvalType,
+ conf.sessionLocalTimeZone,
+ conf.arrowUseLargeVarTypes,
+ pythonRunnerConf,
+ None,
+ jobArtifactUUID)
+
+ evaluatorFactory.createEvaluator().eval(
+ partition.index, Iterator.single(InternalRow(partition.pickedPartition)))
+ }
- val dataSource = SimplePythonFunction(
- command = pickledDataSourceInstance.toImmutableArraySeq,
+ private def createPythonFunction(pickledFunc: Array[Byte]): PythonFunction =
{
+ SimplePythonFunction(
+ command = pickledFunc.toImmutableArraySeq,
envVars = dataSourceCls.envVars,
pythonIncludes = dataSourceCls.pythonIncludes,
pythonExec = dataSourceCls.pythonExec,
pythonVer = dataSourceCls.pythonVer,
broadcastVars = dataSourceCls.broadcastVars,
accumulator = dataSourceCls.accumulator)
- val schema = result.schema
-
- PythonDataSource(dataSource, schema, output = toAttributes(schema))
- }
-
- def apply(
- sparkSession: SparkSession,
- provider: String,
- userSpecifiedSchema: Option[StructType] = None,
- options: CaseInsensitiveMap[String] = CaseInsensitiveMap(Map.empty)):
DataFrame = {
- val plan = builder(sparkSession, provider, userSpecifiedSchema, options)
- Dataset.ofRows(sparkSession, plan)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index 1a69678c2f54..c93ca632d3c7 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -156,8 +156,9 @@ final class DataStreamReader private[sql](sparkSession:
SparkSession) extends Lo
extraOptions + ("path" -> path.get)
}
- val ds = DataSource.lookupDataSource(source,
sparkSession.sessionState.conf).
- getConstructor().newInstance()
+ val ds = DataSource.newDataSourceInstance(
+ source,
+ DataSource.lookupDataSource(source, sparkSession.sessionState.conf))
// We need to generate the V1 data source so we can pass it to the V2
relation as a shim.
// We can't be sure at this point whether we'll actually want to use V2,
since we don't know the
// writer or whether the query is continuous.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 95aa2f8c7a4e..7202f69ab1bf 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -382,7 +382,7 @@ final class DataStreamWriter[T] private[sql](ds:
Dataset[T]) {
}
val sink = if (classOf[TableProvider].isAssignableFrom(cls) &&
!useV1Source) {
- val provider =
cls.getConstructor().newInstance().asInstanceOf[TableProvider]
+ val provider = DataSource.newDataSourceInstance(source,
cls).asInstanceOf[TableProvider]
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
source = provider, conf = df.sparkSession.sessionState.conf)
val finalOptions = sessionOptions.filter { case (k, _) =>
!optionsWithPath.contains(k) } ++
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index 6bc9166117f2..53a54abf8392 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -18,8 +18,7 @@
package org.apache.spark.sql.execution.python
import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils,
QueryTest, Row}
-import
org.apache.spark.sql.catalyst.plans.logical.{PythonDataSourcePartitions,
PythonMapInArrow}
-import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap
+import org.apache.spark.sql.execution.datasources.v2.DataSourceV2ScanRelation
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType
@@ -53,12 +52,13 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
val schema = StructType.fromDDL("id INT, partition INT")
val dataSource = createUserDefinedPythonDataSource(
name = dataSourceName, pythonScript = dataSourceScript)
- val df = dataSource.apply(
- spark, provider = dataSourceName, userSpecifiedSchema = Some(schema))
+ spark.dataSource.registerPython(dataSourceName, dataSource)
+ val df = spark.read.format(dataSourceName).schema(schema).load()
assert(df.rdd.getNumPartitions == 2)
val plan = df.queryExecution.optimizedPlan
plan match {
- case PythonMapInArrow(_, _, _: PythonDataSourcePartitions, _) =>
+ case s: DataSourceV2ScanRelation
+ if s.relation.table.getClass.toString.contains("PythonTable") =>
case _ => fail(s"Plan did not match the expected pattern. Actual
plan:\n$plan")
}
checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0),
Row(2, 1)))
@@ -79,7 +79,8 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
| return SimpleDataSourceReader()
|""".stripMargin
val dataSource = createUserDefinedPythonDataSource(dataSourceName,
dataSourceScript)
- val df = dataSource(spark, provider = dataSourceName)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
+ val df = spark.read.format(dataSourceName).load()
checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0),
Row(2, 1)))
}
@@ -102,7 +103,8 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
| return SimpleDataSourceReader()
|""".stripMargin
val dataSource = createUserDefinedPythonDataSource(dataSourceName,
dataSourceScript)
- val df = dataSource(spark, provider = dataSourceName)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
+ val df = spark.read.format(dataSourceName).load()
checkAnswer(df, Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0),
Row(2, 1)))
}
@@ -121,8 +123,9 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
| return SimpleDataSourceReader()
|""".stripMargin
val dataSource = createUserDefinedPythonDataSource(dataSourceName,
dataSourceScript)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
checkError(
- exception = intercept[AnalysisException](dataSource(spark, provider =
dataSourceName)),
+ exception =
intercept[AnalysisException](spark.read.format(dataSourceName).load()),
errorClass = "INVALID_SCHEMA.NON_STRUCT_TYPE",
parameters = Map("inputSchema" -> "INT", "dataType" -> "\"INT\""))
}
@@ -145,9 +148,8 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
val dataSource = createUserDefinedPythonDataSource(dataSourceName,
dataSourceScript)
spark.dataSource.registerPython(dataSourceName, dataSource)
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
- val ds1 =
spark.sessionState.dataSourceManager.lookupDataSource(dataSourceName)
checkAnswer(
- ds1(spark, dataSourceName, None, CaseInsensitiveMap(Map.empty)),
+ spark.read.format(dataSourceName).load(),
Seq(Row(0, 0), Row(0, 1), Row(1, 0), Row(1, 1), Row(2, 0), Row(2, 1)))
// Should be able to override an already registered data source.
@@ -168,10 +170,8 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
val newDataSource = createUserDefinedPythonDataSource(dataSourceName,
newScript)
spark.dataSource.registerPython(dataSourceName, newDataSource)
assert(spark.sessionState.dataSourceManager.dataSourceExists(dataSourceName))
-
- val ds2 =
spark.sessionState.dataSourceManager.lookupDataSource(dataSourceName)
checkAnswer(
- ds2(spark, dataSourceName, None, CaseInsensitiveMap(Map.empty)),
+ spark.read.format(dataSourceName).load(),
Seq(Row(0)))
}
@@ -195,12 +195,12 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
| paths = []
| return [InputPartition(p) for p in paths]
|
- | def read(self, path):
- | if path is not None:
- | assert isinstance(path, InputPartition)
- | yield (path.value, 1)
+ | def read(self, part):
+ | if part is not None:
+ | assert isinstance(part, InputPartition)
+ | yield (part.value, 1)
| else:
- | yield (path, 1)
+ | yield (part, 1)
|
|class $dataSourceName(DataSource):
| @classmethod
@@ -218,6 +218,12 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
checkAnswer(spark.read.format("test").load(), Seq(Row(null, 1)))
checkAnswer(spark.read.format("test").load("1"), Seq(Row("1", 1)))
checkAnswer(spark.read.format("test").load("1", "2"), Seq(Row("1", 1),
Row("2", 1)))
+
+ withTable("tblA") {
+ sql("CREATE TABLE tblA USING test")
+ // The path will be the actual temp path.
+ checkAnswer(spark.table("tblA").selectExpr("value"), Seq(Row(1)))
+ }
}
test("reader not implemented") {
@@ -231,8 +237,9 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
val schema = StructType.fromDDL("id INT, partition INT")
val dataSource = createUserDefinedPythonDataSource(
name = dataSourceName, pythonScript = dataSourceScript)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
val err = intercept[AnalysisException] {
- dataSource(spark, dataSourceName, userSpecifiedSchema =
Some(schema)).collect()
+ spark.read.format(dataSourceName).schema(schema).load().collect()
}
assert(err.getErrorClass == "PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON")
assert(err.getMessage.contains("PYTHON_DATA_SOURCE_METHOD_NOT_IMPLEMENTED"))
@@ -250,8 +257,9 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
val schema = StructType.fromDDL("id INT, partition INT")
val dataSource = createUserDefinedPythonDataSource(
name = dataSourceName, pythonScript = dataSourceScript)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
val err = intercept[AnalysisException] {
- dataSource(spark, dataSourceName, userSpecifiedSchema =
Some(schema)).collect()
+ spark.read.format(dataSourceName).schema(schema).load().collect()
}
assert(err.getErrorClass == "PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON")
assert(err.getMessage.contains("PYTHON_DATA_SOURCE_CREATE_ERROR"))
@@ -269,8 +277,9 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
val schema = StructType.fromDDL("id INT, partition INT")
val dataSource = createUserDefinedPythonDataSource(
name = dataSourceName, pythonScript = dataSourceScript)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
val err = intercept[AnalysisException] {
- dataSource(spark, dataSourceName, userSpecifiedSchema =
Some(schema)).collect()
+ spark.read.format(dataSourceName).schema(schema).load().collect()
}
assert(err.getErrorClass == "PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON")
assert(err.getMessage.contains("PYTHON_DATA_SOURCE_TYPE_MISMATCH"))
@@ -278,7 +287,7 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
}
test("data source read with custom partitions") {
- assume(shouldTestPythonUDFs)
+ assume(shouldTestPandasUDFs)
val dataSourceScript =
s"""
|from pyspark.sql.datasource import DataSource, DataSourceReader,
InputPartition
@@ -304,12 +313,13 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
| return SimpleDataSourceReader()
|""".stripMargin
val dataSource = createUserDefinedPythonDataSource(dataSourceName,
dataSourceScript)
- val df = dataSource(spark, provider = dataSourceName)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
+ val df = spark.read.format(dataSourceName).load()
checkAnswer(df, Seq(Row(1), Row(3)))
}
test("data source read with empty partitions") {
- assume(shouldTestPythonUDFs)
+ assume(shouldTestPandasUDFs)
val dataSourceScript =
s"""
|from pyspark.sql.datasource import DataSource, DataSourceReader
@@ -331,12 +341,13 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
| return SimpleDataSourceReader()
|""".stripMargin
val dataSource = createUserDefinedPythonDataSource(dataSourceName,
dataSourceScript)
- val df = dataSource(spark, provider = dataSourceName)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
+ val df = spark.read.format(dataSourceName).load()
checkAnswer(df, Row("success"))
}
test("data source read with invalid partitions") {
- assume(shouldTestPythonUDFs)
+ assume(shouldTestPandasUDFs)
val reader1 =
s"""
|class SimpleDataSourceReader(DataSourceReader):
@@ -378,8 +389,9 @@ class PythonDataSourceSuite extends QueryTest with
SharedSparkSession {
| return SimpleDataSourceReader()
|""".stripMargin
val dataSource = createUserDefinedPythonDataSource(dataSourceName,
dataSourceScript)
+ spark.dataSource.registerPython(dataSourceName, dataSource)
val err = intercept[AnalysisException](
- dataSource(spark, provider = dataSourceName).collect())
+ spark.read.format(dataSourceName).load().collect())
assert(err.getErrorClass ==
"PYTHON_DATA_SOURCE_FAILED_TO_PLAN_IN_PYTHON")
assert(err.getMessage.contains("PYTHON_DATA_SOURCE_CREATE_ERROR"))
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]