This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 62fc27d79d5c [SPARK-46423][PYTHON][SQL] Make the Python Data Source
instance at DataSource.lookupDataSourceV2
62fc27d79d5c is described below
commit 62fc27d79d5ccce476671a7a664272c718024617
Author: Hyukjin Kwon <[email protected]>
AuthorDate: Fri Dec 15 15:18:20 2023 -0800
[SPARK-46423][PYTHON][SQL] Make the Python Data Source instance at
DataSource.lookupDataSourceV2
### What changes were proposed in this pull request?
This PR is a kind of a followup of
https://github.com/apache/spark/pull/44305 that proposes to create Python Data
Source instance at `DataSource.lookupDataSourceV2`
### Why are the changes needed?
Semantically the instance has to be ready at
`DataSource.lookupDataSourceV2` level instead of after that. It's more
consistent as well.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Existing tests should cover.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #44374 from HyukjinKwon/SPARK-46423.
Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../scala/org/apache/spark/sql/SparkSession.scala | 2 +-
.../apache/spark/sql/execution/command/ddl.scala | 6 ++----
.../spark/sql/execution/command/tables.scala | 7 +++---
.../sql/execution/datasources/DataSource.scala | 25 ++++++++--------------
.../python/UserDefinedPythonDataSource.scala | 11 +++++++++-
.../spark/sql/streaming/DataStreamReader.scala | 5 ++---
.../spark/sql/streaming/DataStreamWriter.scala | 2 +-
7 files changed, 28 insertions(+), 30 deletions(-)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 44a4d82c1dac..15eeca87dcf6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -780,7 +780,7 @@ class SparkSession private(
DataSource.lookupDataSource(runner, sessionState.conf) match {
case source if classOf[ExternalCommandRunner].isAssignableFrom(source) =>
Dataset.ofRows(self, ExternalCommandExecutor(
- DataSource.newDataSourceInstance(runner, source)
+ source.getDeclaredConstructor().newInstance()
.asInstanceOf[ExternalCommandRunner], command, options))
case _ =>
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index 199c8728a5c9..dc1c5b3fd580 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -45,7 +45,7 @@ import
org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAM
import org.apache.spark.sql.connector.catalog.SupportsNamespaces._
import org.apache.spark.sql.errors.QueryCompilationErrors
import
org.apache.spark.sql.errors.QueryExecutionErrors.hiveTableWithAnsiIntervalsError
-import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.datasources.{DataSource,
DataSourceUtils, FileFormat, HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2
import org.apache.spark.sql.internal.{HiveSerDe, SQLConf}
import org.apache.spark.sql.types._
@@ -1025,9 +1025,7 @@ object DDLUtils extends Logging {
def checkDataColNames(provider: String, schema: StructType): Unit = {
val source = try {
- DataSource.newDataSourceInstance(
- provider,
- DataSource.lookupDataSource(provider, SQLConf.get))
+ DataSource.lookupDataSource(provider,
SQLConf.get).getConstructor().newInstance()
} catch {
case e: Throwable =>
logError(s"Failed to find data source: $provider when check data
column names.", e)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index 9771ee08b258..2f8fca7cfd73 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -35,7 +35,7 @@ import
org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString,
quoteIfNeeded, CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils,
ResolveDefaultColumns}
import
org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY
import
org.apache.spark.sql.connector.catalog.CatalogV2Implicits.TableIdentifierHelper
import org.apache.spark.sql.errors.{QueryCompilationErrors,
QueryExecutionErrors}
@@ -264,9 +264,8 @@ case class AlterTableAddColumnsCommand(
}
if (DDLUtils.isDatasourceTable(catalogTable)) {
- DataSource.newDataSourceInstance(
- catalogTable.provider.get,
- DataSource.lookupDataSource(catalogTable.provider.get, conf)) match {
+ DataSource.lookupDataSource(catalogTable.provider.get, conf).
+ getConstructor().newInstance() match {
// For datasource table, this command can only support the following
File format.
// TextFileFormat only default to one column "value"
// Hive type is already considered as hive serde table, so the logic
will not
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 9612d8ff24f5..efec44658d51 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -106,14 +106,13 @@ case class DataSource(
// [[FileDataSourceV2]] will still be used if we call the load()/save()
method in
// [[DataFrameReader]]/[[DataFrameWriter]], since they use method
`lookupDataSource`
// instead of `providingClass`.
- DataSource.newDataSourceInstance(className, cls) match {
+ cls.getDeclaredConstructor().newInstance() match {
case f: FileDataSourceV2 => f.fallbackFileFormat
case _ => cls
}
}
- private[sql] def providingInstance(): Any =
- DataSource.newDataSourceInstance(className, providingClass)
+ private[sql] def providingInstance(): Any =
providingClass.getConstructor().newInstance()
private def newHadoopConfiguration(): Configuration =
sparkSession.sessionState.newHadoopConfWithOptions(options)
@@ -624,15 +623,6 @@ object DataSource extends Logging {
"org.apache.spark.sql.sources.HadoopFsRelationProvider",
"org.apache.spark.Logging")
- /** Create the instance of the datasource */
- def newDataSourceInstance(provider: String, providingClass: Class[_]): Any =
{
- providingClass match {
- case cls if classOf[PythonTableProvider].isAssignableFrom(cls) =>
- cls.getDeclaredConstructor(classOf[String]).newInstance(provider)
- case cls => cls.getDeclaredConstructor().newInstance()
- }
- }
-
/** Given a provider name, look up the data source class definition. */
def lookupDataSource(provider: String, conf: SQLConf): Class[_] = {
val provider1 = backwardCompatibilityMap.getOrElse(provider, provider)
match {
@@ -732,9 +722,9 @@ object DataSource extends Logging {
def lookupDataSourceV2(provider: String, conf: SQLConf):
Option[TableProvider] = {
val useV1Sources =
conf.getConf(SQLConf.USE_V1_SOURCE_LIST).toLowerCase(Locale.ROOT)
.split(",").map(_.trim)
- val providingClass = lookupDataSource(provider, conf)
+ val cls = lookupDataSource(provider, conf)
val instance = try {
- newDataSourceInstance(provider, providingClass)
+ cls.getDeclaredConstructor().newInstance()
} catch {
// Throw the original error from the data source implementation.
case e: java.lang.reflect.InvocationTargetException => throw e.getCause
@@ -742,8 +732,11 @@ object DataSource extends Logging {
instance match {
case d: DataSourceRegister if useV1Sources.contains(d.shortName()) =>
None
case t: TableProvider
- if !useV1Sources.contains(
- providingClass.getCanonicalName.toLowerCase(Locale.ROOT)) =>
+ if
!useV1Sources.contains(cls.getCanonicalName.toLowerCase(Locale.ROOT)) =>
+ t match {
+ case p: PythonTableProvider => p.setShortName(provider)
+ case _ =>
+ }
Some(t)
case _ => None
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
index 7c850d1e2890..5e978a900884 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonDataSource.scala
@@ -44,7 +44,16 @@ import org.apache.spark.util.ArrayImplicits._
/**
* Data Source V2 wrapper for Python Data Source.
*/
-class PythonTableProvider(shortName: String) extends TableProvider {
+class PythonTableProvider extends TableProvider {
+ private var name: String = _
+ def setShortName(str: String): Unit = {
+ assert(name == null)
+ name = str
+ }
+ private def shortName: String = {
+ assert(name != null)
+ name
+ }
private var dataSourceInPython: PythonDataSourceCreationResult = _
private[this] val jobArtifactUUID =
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
private lazy val source: UserDefinedPythonDataSource =
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index c93ca632d3c7..1a69678c2f54 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -156,9 +156,8 @@ final class DataStreamReader private[sql](sparkSession:
SparkSession) extends Lo
extraOptions + ("path" -> path.get)
}
- val ds = DataSource.newDataSourceInstance(
- source,
- DataSource.lookupDataSource(source, sparkSession.sessionState.conf))
+ val ds = DataSource.lookupDataSource(source,
sparkSession.sessionState.conf).
+ getConstructor().newInstance()
// We need to generate the V1 data source so we can pass it to the V2
relation as a shim.
// We can't be sure at this point whether we'll actually want to use V2,
since we don't know the
// writer or whether the query is continuous.
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index 7202f69ab1bf..95aa2f8c7a4e 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -382,7 +382,7 @@ final class DataStreamWriter[T] private[sql](ds:
Dataset[T]) {
}
val sink = if (classOf[TableProvider].isAssignableFrom(cls) &&
!useV1Source) {
- val provider = DataSource.newDataSourceInstance(source,
cls).asInstanceOf[TableProvider]
+ val provider =
cls.getConstructor().newInstance().asInstanceOf[TableProvider]
val sessionOptions = DataSourceV2Utils.extractSessionConfigs(
source = provider, conf = df.sparkSession.sessionState.conf)
val finalOptions = sessionOptions.filter { case (k, _) =>
!optionsWithPath.contains(k) } ++
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]