This is an automated email from the ASF dual-hosted git repository.
gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 434aa30d234 [SPARK-45856] Move ArtifactManager from Spark Connect into
SparkSession (sql/core)
434aa30d234 is described below
commit 434aa30d23499ec66a12f1056c0b82000ff6b683
Author: vicennial <[email protected]>
AuthorDate: Tue Nov 21 10:39:20 2023 +0900
[SPARK-45856] Move ArtifactManager from Spark Connect into SparkSession
(sql/core)
### What changes were proposed in this pull request?
The significant changes in this PR include:
- `SparkConnectArtifactManager` is renamed to `ArtifactManager` and moved
out of Spark Connect and into `sql/core` (available in `SparkSession` through
`SessionState`) along with all corresponding tests and confs.
- While `ArtifactManager` in part of SparkSession, we keep the legacy
behaviour for non-connect spark while utilising the ArtifactManager in connect
pathways
- This is done by exposing a new method `withResources` in the artifact
manager that sets the context class loader (for driver-side operations) and
propagates the `JobArtifactState` such that the resources reach the executor.
- Spark Connect pathways utilise this method through the
`SessionHolder#withActive`
- When `withResources` is not used, neither the custom context
classloader nor the `JobArtifactState` is propagated and hence, non Spark
Connect pathways remain with legacy behaviour.
### Why are the changes needed?
The `ArtifactManager` that currently lies in the connect package can be
moved into the wider sql/core package (e.g SparkSession) to expand the scope.
This is possible because the `ArtifactManager` is tied solely to the
`SparkSession#sessionUUID` and hence can be cleanly detached from Spark Connect
and be made generally available.
### Does this PR introduce _any_ user-facing change?
No. Existing behaviour is kept intact for both non-connect and connect
spark.
### How was this patch tested?
Existing tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43735 from vicennial/SPARK-45856.
Lead-authored-by: vicennial <[email protected]>
Co-authored-by: Venkata Sai Akhil Gudesa <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
---
.../CheckConnectJvmClientCompatibility.scala | 12 +-
.../apache/spark/sql/connect/config/Connect.scala | 26 ++-
.../sql/connect/planner/SparkConnectPlanner.scala | 8 +-
.../spark/sql/connect/service/SessionHolder.scala | 22 +--
.../service/SparkConnectAddArtifactsHandler.scala | 7 +-
.../SparkConnectArtifactStatusesHandler.scala | 2 +-
.../service/SparkConnectSessionHolderSuite.scala | 2 +-
.../scala/org/apache/spark/storage/BlockId.scala | 4 +-
.../org/apache/spark/storage/BlockManager.scala | 6 +-
project/MimaExcludes.scala | 10 +
.../tests/connect/test_connect_classification.py | 2 +-
.../ml/tests/connect/test_connect_pipeline.py | 2 +-
.../ml/tests/connect/test_connect_tuning.py | 2 +-
.../sql/tests/connect/client/test_artifact.py | 2 +-
.../org/apache/spark/sql/internal/SQLConf.scala | 21 ++-
.../scala/org/apache/spark/sql/SparkSession.scala | 11 ++
.../spark/sql/artifact/ArtifactManager.scala | 209 ++++++++-------------
.../spark/sql}/artifact/util/ArtifactUtils.scala | 6 +-
.../sql/internal/BaseSessionStateBuilder.scala | 10 +-
.../apache/spark/sql/internal/SessionState.scala | 6 +-
.../src/test/resources/artifact-tests/Hello.class | Bin 0 -> 5671 bytes
.../resources/artifact-tests/smallClassFile.class | Bin 0 -> 424 bytes
.../src/test/resources/artifact-tests/udf_noA.jar | Bin 0 -> 5545 bytes
.../spark/sql}/artifact/ArtifactManagerSuite.scala | 125 +++++-------
.../spark/sql}/artifact/StubClassLoaderSuite.scala | 9 +-
25 files changed, 234 insertions(+), 270 deletions(-)
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 5178013e455..a9b6f102a51 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -288,7 +288,17 @@ object CheckConnectJvmClientCompatibility {
// SQLImplicits
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits.rddToDatasetHolder"),
-
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits._sqlContext"))
+
ProblemFilters.exclude[Problem]("org.apache.spark.sql.SQLImplicits._sqlContext"),
+
+ // Artifact Manager
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.sql.artifact.ArtifactManager"),
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.sql.artifact.ArtifactManager$"),
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.sql.artifact.util.ArtifactUtils"),
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.sql.artifact.util.ArtifactUtils$"))
checkMiMaCompatibility(clientJar, sqlJar, includedRules, excludeRules)
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index 1a5944676f5..f7aa98af2fa 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -20,6 +20,7 @@ import java.util.concurrent.TimeUnit
import org.apache.spark.network.util.ByteUnit
import org.apache.spark.sql.connect.common.config.ConnectCommon
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.buildConf
object Connect {
@@ -206,20 +207,6 @@ object Connect {
.intConf
.createWithDefault(1024)
- val CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL =
- buildStaticConf("spark.connect.copyFromLocalToFs.allowDestLocal")
- .internal()
- .doc("""
- |Allow `spark.copyFromLocalToFs` destination to be local file
system
- | path on spark driver node when
- |`spark.connect.copyFromLocalToFs.allowDestLocal` is true.
- |This will allow user to overwrite arbitrary file on spark
- |driver node we should only enable it for testing purpose.
- |""".stripMargin)
- .version("3.5.0")
- .booleanConf
- .createWithDefault(false)
-
val CONNECT_UI_STATEMENT_LIMIT =
buildStaticConf("spark.sql.connect.ui.retainedStatements")
.doc("The number of statements kept in the Spark Connect UI history.")
@@ -227,6 +214,17 @@ object Connect {
.intConf
.createWithDefault(200)
+ val CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL =
+ buildStaticConf("spark.connect.copyFromLocalToFs.allowDestLocal")
+ .internal()
+ .doc(s"""
+ |(Deprecated since Spark 4.0, please set
+ |'${SQLConf.ARTIFACT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL.key}'
instead.
+ |""".stripMargin)
+ .version("3.5.0")
+ .booleanConf
+ .createWithDefault(false)
+
val CONNECT_UI_SESSION_LIMIT =
buildStaticConf("spark.sql.connect.ui.retainedSessions")
.doc("The number of client sessions kept in the Spark Connect UI history.")
.version("3.5.0")
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
index 373ae0f90c6..4a0aa7e5589 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala
@@ -942,7 +942,7 @@ class SparkConnectPlanner(
command = fun.getCommand.toByteArray.toImmutableArraySeq,
// Empty environment variables
envVars = Maps.newHashMap(),
- pythonIncludes =
sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava,
+ pythonIncludes = sessionHolder.artifactManager.getPythonIncludes.asJava,
pythonExec = pythonExec,
pythonVer = fun.getPythonVer,
// Empty broadcast variables
@@ -996,7 +996,7 @@ class SparkConnectPlanner(
private def transformCachedLocalRelation(rel: proto.CachedLocalRelation):
LogicalPlan = {
val blockManager = session.sparkContext.env.blockManager
- val blockId = CacheId(sessionHolder.userId, sessionHolder.sessionId,
rel.getHash)
+ val blockId = CacheId(sessionHolder.session.sessionUUID, rel.getHash)
val bytes = blockManager.getLocalBytes(blockId)
bytes
.map { blockData =>
@@ -1014,7 +1014,7 @@ class SparkConnectPlanner(
.getOrElse {
throw InvalidPlanInput(
s"Not found any cached local relation with the hash: ${blockId.hash}
in " +
- s"the session ${blockId.sessionId} for the user id
${blockId.userId}.")
+ s"the session with sessionUUID ${blockId.sessionUUID}.")
}
}
@@ -1633,7 +1633,7 @@ class SparkConnectPlanner(
command = fun.getCommand.toByteArray.toImmutableArraySeq,
// Empty environment variables
envVars = Maps.newHashMap(),
- pythonIncludes =
sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava,
+ pythonIncludes = sessionHolder.artifactManager.getPythonIncludes.asJava,
pythonExec = pythonExec,
pythonVer = fun.getPythonVer,
// Empty broadcast variables
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index 0c55e30ba50..fd7c10d5c40 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -27,18 +27,16 @@ import scala.jdk.CollectionConverters._
import com.google.common.base.Ticker
import com.google.common.cache.CacheBuilder
-import org.apache.spark.{JobArtifactSet, SparkException, SparkSQLException}
+import org.apache.spark.{SparkException, SparkSQLException}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener
import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper
import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE,
ERROR_CACHE_TIMEOUT_SEC}
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.util.SystemClock
-import org.apache.spark.util.Utils
// Unique key identifying session by combination of user, and session id
case class SessionKey(userId: String, sessionId: String)
@@ -166,7 +164,7 @@ case class SessionHolder(userId: String, sessionId: String,
session: SparkSessio
interruptedIds.toSeq
}
- private[connect] lazy val artifactManager = new
SparkConnectArtifactManager(this)
+ private[connect] def artifactManager = session.artifactManager
/**
* Add an artifact to this SparkConnect session.
@@ -238,27 +236,13 @@ case class SessionHolder(userId: String, sessionId:
String, session: SparkSessio
eventManager.postClosed()
}
- /**
- * Execute a block of code using this session's classloader.
- * @param f
- * @tparam T
- */
- def withContextClassLoader[T](f: => T): T = {
- // Needed for deserializing and evaluating the UDF on the driver
- Utils.withContextClassLoader(classloader) {
- JobArtifactSet.withActiveJobArtifactState(artifactManager.state) {
- f
- }
- }
- }
-
/**
* Execute a block of code with this session as the active SparkConnect
session.
* @param f
* @tparam T
*/
def withSession[T](f: SparkSession => T): T = {
- withContextClassLoader {
+ artifactManager.withResources {
session.withActive {
f(session)
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
index 636054198fb..e664e07dce1 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAddArtifactsHandler.scala
@@ -30,8 +30,8 @@ import io.grpc.stub.StreamObserver
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{AddArtifactsRequest,
AddArtifactsResponse}
import org.apache.spark.connect.proto.AddArtifactsResponse.ArtifactSummary
-import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
-import org.apache.spark.sql.connect.artifact.util.ArtifactUtils
+import org.apache.spark.sql.artifact.ArtifactManager
+import org.apache.spark.sql.artifact.util.ArtifactUtils
import org.apache.spark.util.Utils
/**
@@ -101,8 +101,7 @@ class SparkConnectAddArtifactsHandler(val responseObserver:
StreamObserver[AddAr
// We do not store artifacts that fail the CRC. The failure is reported
in the artifact
// summary and it is up to the client to decide whether to retry sending
the artifact.
if (artifact.getCrcStatus.contains(true)) {
- if (artifact.path.startsWith(
- SparkConnectArtifactManager.forwardToFSPrefix + File.separator)) {
+ if (artifact.path.startsWith(ArtifactManager.forwardToFSPrefix +
File.separator)) {
holder.artifactManager.uploadArtifactToFs(artifact.path,
artifact.stagedPath)
} else {
addStagedArtifactToArtifactManager(artifact)
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectArtifactStatusesHandler.scala
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectArtifactStatusesHandler.scala
index 325832ac07e..78def077f2d 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectArtifactStatusesHandler.scala
+++
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectArtifactStatusesHandler.scala
@@ -33,7 +33,7 @@ class SparkConnectArtifactStatusesHandler(
.getOrCreateIsolatedSession(userId, sessionId)
.session
val blockManager = session.sparkContext.env.blockManager
- blockManager.getStatus(CacheId(userId, sessionId, hash)).isDefined
+ blockManager.getStatus(CacheId(session.sessionUUID, hash)).isDefined
}
def handle(request: proto.ArtifactStatusesRequest): Unit = {
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
index 2eaa8c8383e..bb51b0a7982 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
+++
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
@@ -163,7 +163,7 @@ class SparkConnectSessionHolderSuite extends
SharedSparkSession {
SimplePythonFunction(
command = fcn(sparkPythonPath).toImmutableArraySeq,
envVars = mutable.Map("PYTHONPATH" -> sparkPythonPath).asJava,
- pythonIncludes =
sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava,
+ pythonIncludes = sessionHolder.artifactManager.getPythonIncludes.asJava,
pythonExec = IntegratedUDFTestUtils.pythonExec,
pythonVer = IntegratedUDFTestUtils.pythonVer,
broadcastVars = Lists.newArrayList(),
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 456b4edf938..585d9a886b4 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -190,8 +190,8 @@ class UnrecognizedBlockId(name: String)
extends SparkException(s"Failed to parse $name into a block ID")
@DeveloperApi
-case class CacheId(userId: String, sessionId: String, hash: String) extends
BlockId {
- override def name: String = s"cache_${userId}_${sessionId}_$hash"
+case class CacheId(sessionUUID: String, hash: String) extends BlockId {
+ override def name: String = s"cache_${sessionUUID}_$hash"
}
@DeveloperApi
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index e64c33382dc..8c22f8473e6 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -2057,10 +2057,10 @@ private[spark] class BlockManager(
*
* @return The number of blocks removed.
*/
- def removeCache(userId: String, sessionId: String): Int = {
- logDebug(s"Removing cache of user id = $userId in the session $sessionId")
+ def removeCache(sessionUUID: String): Int = {
+ logDebug(s"Removing cache of spark session with UUID: $sessionUUID")
val blocksToRemove = blockInfoManager.entries.map(_._1).collect {
- case cid: CacheId if cid.userId == userId && cid.sessionId == sessionId
=> cid
+ case cid: CacheId if cid.sessionUUID == sessionUUID => cid
}
blocksToRemove.foreach { blockId => removeBlock(blockId) }
blocksToRemove.size
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index d080b16fdc5..46321229087 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -90,6 +90,16 @@ object MimaExcludes {
// SPARK-43299: Convert StreamingQueryException in Scala Client
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.streaming.StreamingQueryException"),
+ // SPARK-45856: Move ArtifactManager from Spark Connect into SparkSession
(sql/core)
+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.apply"),
+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.userId"),
+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.sessionId"),
+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.copy"),
+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.copy$default$3"),
+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.this"),
+
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.storage.CacheId$"),
+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.storage.CacheId.apply"),
+
(problem: Problem) => problem match {
case MissingClassProblem(cls) =>
!cls.fullName.startsWith("org.sparkproject.jpmml") &&
!cls.fullName.startsWith("org.sparkproject.dmg.pmml")
diff --git a/python/pyspark/ml/tests/connect/test_connect_classification.py
b/python/pyspark/ml/tests/connect/test_connect_classification.py
index ccf7c346be7..1f811c774cb 100644
--- a/python/pyspark/ml/tests/connect/test_connect_classification.py
+++ b/python/pyspark/ml/tests/connect/test_connect_classification.py
@@ -38,7 +38,7 @@ class ClassificationTestsOnConnect(ClassificationTestsMixin,
unittest.TestCase):
def setUp(self) -> None:
self.spark = (
SparkSession.builder.remote("local[2]")
- .config("spark.connect.copyFromLocalToFs.allowDestLocal", "true")
+ .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal",
"true")
.getOrCreate()
)
diff --git a/python/pyspark/ml/tests/connect/test_connect_pipeline.py
b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
index 6925b2482f2..6a895e89239 100644
--- a/python/pyspark/ml/tests/connect/test_connect_pipeline.py
+++ b/python/pyspark/ml/tests/connect/test_connect_pipeline.py
@@ -30,7 +30,7 @@ class PipelineTestsOnConnect(PipelineTestsMixin,
unittest.TestCase):
def setUp(self) -> None:
self.spark = (
SparkSession.builder.remote("local[2]")
- .config("spark.connect.copyFromLocalToFs.allowDestLocal", "true")
+ .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal",
"true")
.getOrCreate()
)
diff --git a/python/pyspark/ml/tests/connect/test_connect_tuning.py
b/python/pyspark/ml/tests/connect/test_connect_tuning.py
index d7dbb00b5e1..7b10d91da06 100644
--- a/python/pyspark/ml/tests/connect/test_connect_tuning.py
+++ b/python/pyspark/ml/tests/connect/test_connect_tuning.py
@@ -30,7 +30,7 @@ class CrossValidatorTestsOnConnect(CrossValidatorTestsMixin,
unittest.TestCase):
def setUp(self) -> None:
self.spark = (
SparkSession.builder.remote("local[2]")
- .config("spark.connect.copyFromLocalToFs.allowDestLocal", "true")
+ .config("spark.sql.artifact.copyFromLocalToFs.allowDestLocal",
"true")
.getOrCreate()
)
diff --git a/python/pyspark/sql/tests/connect/client/test_artifact.py
b/python/pyspark/sql/tests/connect/client/test_artifact.py
index 7e9f9dbbf56..7fde0958e38 100644
--- a/python/pyspark/sql/tests/connect/client/test_artifact.py
+++ b/python/pyspark/sql/tests/connect/client/test_artifact.py
@@ -183,7 +183,7 @@ class ArtifactTests(ReusedConnectTestCase,
ArtifactTestsMixin):
@classmethod
def conf(cls):
conf = super().conf()
- conf.set("spark.connect.copyFromLocalToFs.allowDestLocal", "true")
+ conf.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
return conf
def test_basic_requests(self):
diff --git
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 2d67a8428d2..6a8e1f92fc5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -4550,6 +4550,23 @@ object SQLConf {
.booleanConf
.createWithDefault(false)
+ // Deprecate "spark.connect.copyFromLocalToFs.allowDestLocal" in favor of
this config. This is
+ // currently optional because we don't want to break existing users who are
using the old config.
+ // If this config is set, then we override the deprecated config.
+ val ARTIFACT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL =
+ buildConf("spark.sql.artifact.copyFromLocalToFs.allowDestLocal")
+ .internal()
+ .doc("""
+ |Allow `spark.copyFromLocalToFs` destination to be local file
system
+ | path on spark driver node when
+ |`spark.sql.artifact.copyFromLocalToFs.allowDestLocal` is true.
+ |This will allow user to overwrite arbitrary file on spark
+ |driver node we should only enable it for testing purpose.
+ |""".stripMargin)
+ .version("4.0.0")
+ .booleanConf
+ .createOptional
+
val LEGACY_RETAIN_FRACTION_DIGITS_FIRST =
buildConf("spark.sql.legacy.decimal.retainFractionDigitsOnTruncate")
.internal()
@@ -4617,7 +4634,9 @@ object SQLConf {
DeprecatedConfig(COALESCE_PARTITIONS_MIN_PARTITION_NUM.key, "3.2",
s"Use '${COALESCE_PARTITIONS_MIN_PARTITION_SIZE.key}' instead."),
DeprecatedConfig(ESCAPED_STRING_LITERALS.key, "4.0",
- "Use raw string literals with the `r` prefix instead. ")
+ "Use raw string literals with the `r` prefix instead. "),
+ DeprecatedConfig("spark.connect.copyFromLocalToFs.allowDestLocal", "4.0",
+ s"Use '${ARTIFACT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL.key}'
instead.")
)
Map(configs.map { cfg => cfg.key -> cfg } : _*)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 779015ee13e..5eba9e59c17 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -33,6 +33,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.{ConfigEntry,
EXECUTOR_ALLOW_SPARK_CONTEXT}
import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
+import org.apache.spark.sql.artifact.ArtifactManager
import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery,
PosParameterizedQuery, UnresolvedRelation}
@@ -243,6 +244,16 @@ class SparkSession private(
@Unstable
def streams: StreamingQueryManager = sessionState.streamingQueryManager
+ /**
+ * Returns an `ArtifactManager` that supports adding, managing and using
session-scoped artifacts
+ * (jars, classfiles, etc).
+ *
+ * @since 4.0.0
+ */
+ @Experimental
+ @Unstable
+ private[sql] def artifactManager: ArtifactManager =
sessionState.artifactManager
+
/**
* Start a new session with isolated SQL configurations, temporary tables,
registered
* functions are isolated, but sharing the underlying `SparkContext` and
cached data.
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
similarity index 58%
rename from
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
rename to
sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
index ba36b708e83..69a5fd86074 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/artifact/ArtifactManager.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.connect.artifact
+package org.apache.spark.sql.artifact
import java.io.File
import java.net.{URI, URL, URLClassLoader}
@@ -26,62 +26,85 @@ import javax.ws.rs.core.UriBuilder
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag
-import io.grpc.Status
import org.apache.commons.io.{FilenameUtils, FileUtils}
import org.apache.hadoop.fs.{LocalFileSystem, Path => FSPath}
-import org.apache.spark.{JobArtifactSet, JobArtifactState, SparkContext,
SparkEnv}
+import org.apache.spark.{JobArtifactSet, JobArtifactState, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.{CONNECT_SCALA_UDF_STUB_PREFIXES,
EXECUTOR_USER_CLASS_PATH_FIRST}
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.connect.artifact.util.ArtifactUtils
-import
org.apache.spark.sql.connect.config.Connect.CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL
-import org.apache.spark.sql.connect.service.SessionHolder
+import org.apache.spark.sql.artifact.util.ArtifactUtils
+import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.storage.{CacheId, StorageLevel}
import org.apache.spark.util.{ChildFirstURLClassLoader, StubClassLoader, Utils}
/**
- * The Artifact Manager for the [[SparkConnectService]].
- *
* This class handles the storage of artifacts as well as preparing the
artifacts for use.
*
- * Artifacts belonging to different [[SparkSession]]s are segregated and
isolated from each other
- * with the help of the `sessionUUID`.
+ * Artifacts belonging to different SparkSessions are isolated from each other
with the help of the
+ * `sessionUUID`.
*
- * Jars and classfile artifacts are stored under "jars" and "classes"
sub-directories respectively
- * while other types of artifacts are stored under the root directory for that
particular
- * [[SparkSession]].
+ * Jars and classfile artifacts are stored under "jars", "classes" and
"pyfiles" sub-directories
+ * respectively while other types of artifacts are stored under the root
directory for that
+ * particular SparkSession.
*
- * @param sessionHolder
- * The object used to hold the Spark Connect session state.
+ * @param session The object used to hold the Spark Connect session state.
*/
-class SparkConnectArtifactManager(sessionHolder: SessionHolder) extends
Logging {
- import SparkConnectArtifactManager._
+class ArtifactManager(session: SparkSession) extends Logging {
+ import ArtifactManager._
+
+ // The base directory where all artifacts are stored.
+ protected def artifactRootPath: Path = artifactRootDirectory
+
+ private[artifact] lazy val artifactRootURI: String = SparkEnv
+ .get
+ .rpcEnv
+ .fileServer
+ .addDirectoryIfAbsent(ARTIFACT_DIRECTORY_PREFIX, artifactRootPath.toFile)
// The base directory/URI where all artifacts are stored for this
`sessionUUID`.
- val (artifactPath, artifactURI): (Path, String) =
- getArtifactDirectoryAndUriForSession(sessionHolder)
+ protected[artifact] val (artifactPath, artifactURI): (Path, String) =
+ (ArtifactUtils.concatenatePaths(artifactRootPath, session.sessionUUID),
+ s"$artifactRootURI/${session.sessionUUID}")
+
// The base directory/URI where all class file artifacts are stored for this
`sessionUUID`.
- val (classDir, classURI): (Path, String) =
getClassfileDirectoryAndUriForSession(sessionHolder)
- val state: JobArtifactState =
- JobArtifactState(sessionHolder.serverSessionId, Option(classURI))
+ protected[artifact] val (classDir, classURI): (Path, String) =
+ (ArtifactUtils.concatenatePaths(artifactPath, "classes"),
s"$artifactURI/classes/")
+
+ protected[artifact] val state: JobArtifactState =
+ JobArtifactState(session.sessionUUID, Option(classURI))
+
+ def withResources[T](f: => T): T = {
+ Utils.withContextClassLoader(classloader) {
+ JobArtifactSet.withActiveJobArtifactState(state) {
+ f
+ }
+ }
+ }
- private val jarsList = new CopyOnWriteArrayList[Path]
- private val pythonIncludeList = new CopyOnWriteArrayList[String]
+ protected val jarsList = new CopyOnWriteArrayList[Path]
+ protected val pythonIncludeList = new CopyOnWriteArrayList[String]
/**
- * Get the URLs of all jar artifacts added through the
[[SparkConnectService]].
- *
- * @return
+ * Get the URLs of all jar artifacts.
*/
- def getSparkConnectAddedJars: Seq[URL] =
jarsList.asScala.map(_.toUri.toURL).toSeq
+ def getAddedJars: Seq[URL] = jarsList.asScala.map(_.toUri.toURL).toSeq
/**
- * Get the py-file names added through the [[SparkConnectService]].
+ * Get the py-file names added to this SparkSession.
*
* @return
*/
- def getSparkConnectPythonIncludes: Seq[String] =
pythonIncludeList.asScala.toSeq
+ def getPythonIncludes: Seq[String] = pythonIncludeList.asScala.toSeq
+
+ protected def moveFile(source: Path, target: Path, allowOverwrite: Boolean =
false): Unit = {
+ Files.createDirectories(target.getParent)
+ if (allowOverwrite) {
+ Files.move(source, target, StandardCopyOption.REPLACE_EXISTING)
+ } else {
+ Files.move(source, target)
+ }
+ }
/**
* Add and prepare a staged artifact (i.e an artifact that has been rebuilt
locally from bytes
@@ -91,7 +114,7 @@ class SparkConnectArtifactManager(sessionHolder:
SessionHolder) extends Logging
* @param serverLocalStagingPath
* @param fragment
*/
- private[connect] def addArtifact(
+ def addArtifact(
remoteRelativePath: Path,
serverLocalStagingPath: Path,
fragment: Option[String]): Unit =
JobArtifactSet.withActiveJobArtifactState(state) {
@@ -99,10 +122,9 @@ class SparkConnectArtifactManager(sessionHolder:
SessionHolder) extends Logging
if (remoteRelativePath.startsWith(s"cache${File.separator}")) {
val tmpFile = serverLocalStagingPath.toFile
Utils.tryWithSafeFinallyAndFailureCallbacks {
- val blockManager = sessionHolder.session.sparkContext.env.blockManager
+ val blockManager = session.sparkContext.env.blockManager
val blockId = CacheId(
- userId = sessionHolder.userId,
- sessionId = sessionHolder.sessionId,
+ sessionUUID = session.sessionUUID,
hash =
remoteRelativePath.toString.stripPrefix(s"cache${File.separator}"))
val updater = blockManager.TempFileBasedBlockStoreUpdater(
blockId = blockId,
@@ -118,15 +140,12 @@ class SparkConnectArtifactManager(sessionHolder:
SessionHolder) extends Logging
val target = ArtifactUtils.concatenatePaths(
classDir,
remoteRelativePath.toString.stripPrefix(s"classes${File.separator}"))
- Files.createDirectories(target.getParent)
// Allow overwriting class files to capture updates to classes.
// This is required because the client currently sends all the class
files in each class file
// transfer.
- Files.move(serverLocalStagingPath, target,
StandardCopyOption.REPLACE_EXISTING)
+ moveFile(serverLocalStagingPath, target, allowOverwrite = true)
} else {
val target = ArtifactUtils.concatenatePaths(artifactPath,
remoteRelativePath)
- Files.createDirectories(target.getParent)
-
// Disallow overwriting with modified version
if (Files.exists(target)) {
// makes the query idempotent
@@ -134,22 +153,20 @@ class SparkConnectArtifactManager(sessionHolder:
SessionHolder) extends Logging
return
}
- throw Status.ALREADY_EXISTS
- .withDescription(s"Duplicate Artifact: $remoteRelativePath. " +
+ throw new RuntimeException(s"Duplicate Artifact: $remoteRelativePath.
" +
"Artifacts cannot be overwritten.")
- .asRuntimeException()
}
- Files.move(serverLocalStagingPath, target)
+ moveFile(serverLocalStagingPath, target)
// This URI is for Spark file server that starts with "spark://".
val uri = s"$artifactURI/${Utils.encodeRelativeUnixPathToURIRawPath(
FilenameUtils.separatorsToUnix(remoteRelativePath.toString))}"
if (remoteRelativePath.startsWith(s"jars${File.separator}")) {
- sessionHolder.session.sparkContext.addJar(uri)
+ session.sparkContext.addJar(uri)
jarsList.add(target)
} else if (remoteRelativePath.startsWith(s"pyfiles${File.separator}")) {
- sessionHolder.session.sparkContext.addFile(uri)
+ session.sparkContext.addFile(uri)
val stringRemotePath = remoteRelativePath.toString
if (stringRemotePath.endsWith(".zip") || stringRemotePath.endsWith(
".egg") || stringRemotePath.endsWith(".jar")) {
@@ -158,9 +175,9 @@ class SparkConnectArtifactManager(sessionHolder:
SessionHolder) extends Logging
} else if (remoteRelativePath.startsWith(s"archives${File.separator}")) {
val canonicalUri =
fragment.map(UriBuilder.fromUri(new
URI(uri)).fragment).getOrElse(new URI(uri))
- sessionHolder.session.sparkContext.addArchive(canonicalUri.toString)
+ session.sparkContext.addArchive(canonicalUri.toString)
} else if (remoteRelativePath.startsWith(s"files${File.separator}")) {
- sessionHolder.session.sparkContext.addFile(uri)
+ session.sparkContext.addFile(uri)
}
}
}
@@ -169,7 +186,7 @@ class SparkConnectArtifactManager(sessionHolder:
SessionHolder) extends Logging
* Returns a [[ClassLoader]] for session-specific jar/class file resources.
*/
def classloader: ClassLoader = {
- val urls = getSparkConnectAddedJars :+ classDir.toUri.toURL
+ val urls = getAddedJars :+ classDir.toUri.toURL
val prefixes = SparkEnv.get.conf.get(CONNECT_SCALA_UDF_STUB_PREFIXES)
val userClasspathFirst =
SparkEnv.get.conf.get(EXECUTOR_USER_CLASS_PATH_FIRST)
val loader = if (prefixes.nonEmpty) {
@@ -208,35 +225,34 @@ class SparkConnectArtifactManager(sessionHolder:
SessionHolder) extends Logging
}
/**
- * Cleans up all resources specific to this `sessionHolder`.
+ * Cleans up all resources specific to this `session`.
*/
- private[connect] def cleanUpResources(): Unit = {
+ private[sql] def cleanUpResources(): Unit = {
logDebug(
- s"Cleaning up resources for session with userId: ${sessionHolder.userId}
and " +
- s"sessionId: ${sessionHolder.sessionId}")
+ s"Cleaning up resources for session with sessionUUID
${session.sessionUUID}")
// Clean up added files
val fileserver = SparkEnv.get.rpcEnv.fileServer
- val sparkContext = sessionHolder.session.sparkContext
+ val sparkContext = session.sparkContext
sparkContext.addedFiles.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeFile))
sparkContext.addedArchives.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeFile))
sparkContext.addedJars.remove(state.uuid).foreach(_.keys.foreach(fileserver.removeJar))
// Clean up cached relations
val blockManager = sparkContext.env.blockManager
- blockManager.removeCache(sessionHolder.userId, sessionHolder.sessionId)
+ blockManager.removeCache(session.sessionUUID)
// Clean up artifacts folder
FileUtils.deleteDirectory(artifactPath.toFile)
}
- private[connect] def uploadArtifactToFs(
+ def uploadArtifactToFs(
remoteRelativePath: Path,
serverLocalStagingPath: Path): Unit = {
- val hadoopConf = sessionHolder.session.sparkContext.hadoopConfiguration
+ val hadoopConf = session.sparkContext.hadoopConfiguration
assert(
remoteRelativePath.startsWith(
- SparkConnectArtifactManager.forwardToFSPrefix + File.separator))
+ ArtifactManager.forwardToFSPrefix + File.separator))
val destFSPath = new FSPath(
Paths
.get("/")
@@ -246,14 +262,17 @@ class SparkConnectArtifactManager(sessionHolder:
SessionHolder) extends Logging
val fs = destFSPath.getFileSystem(hadoopConf)
if (fs.isInstanceOf[LocalFileSystem]) {
val allowDestLocalConf =
- SparkEnv.get.conf.get(CONNECT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL)
+
session.conf.get(SQLConf.ARTIFACT_COPY_FROM_LOCAL_TO_FS_ALLOW_DEST_LOCAL)
+ .getOrElse(
+
session.conf.get("spark.connect.copyFromLocalToFs.allowDestLocal").contains("true"))
+
if (!allowDestLocalConf) {
// To avoid security issue, by default,
// we don't support uploading file to local file system
// destination path, otherwise user is able to overwrite arbitrary file
// on spark driver node.
// We can temporarily allow the behavior by setting spark config
- // `spark.connect.copyFromLocalToFs.allowDestLocal`
+ // `spark.sql.artifact.copyFromLocalToFs.allowDestLocal`
// to `true` when starting spark driver, we should only enable it for
testing
// purpose.
throw new UnsupportedOperationException(
@@ -264,80 +283,12 @@ class SparkConnectArtifactManager(sessionHolder:
SessionHolder) extends Logging
}
}
-object SparkConnectArtifactManager extends Logging {
+object ArtifactManager extends Logging {
val forwardToFSPrefix = "forward_to_fs"
- private var currentArtifactRootUri: String = _
- private var lastKnownSparkContextInstance: SparkContext = _
-
- private val ARTIFACT_DIRECTORY_PREFIX = "artifacts"
+ val ARTIFACT_DIRECTORY_PREFIX = "artifacts"
- // The base directory where all artifacts are stored.
- private[spark] lazy val artifactRootPath = {
+ private[artifact] lazy val artifactRootDirectory =
Utils.createTempDir(ARTIFACT_DIRECTORY_PREFIX).toPath
- }
-
- private[spark] def getArtifactDirectoryAndUriForSession(session:
SparkSession): (Path, String) =
- (
- ArtifactUtils.concatenatePaths(artifactRootPath, session.sessionUUID),
- s"$artifactRootURI/${session.sessionUUID}")
-
- private[spark] def getArtifactDirectoryAndUriForSession(
- sessionHolder: SessionHolder): (Path, String) =
- getArtifactDirectoryAndUriForSession(sessionHolder.session)
-
- private[spark] def getClassfileDirectoryAndUriForSession(
- session: SparkSession): (Path, String) = {
- val (artDir, artUri) = getArtifactDirectoryAndUriForSession(session)
- (ArtifactUtils.concatenatePaths(artDir, "classes"), s"$artUri/classes/")
- }
-
- private[spark] def getClassfileDirectoryAndUriForSession(
- sessionHolder: SessionHolder): (Path, String) =
- getClassfileDirectoryAndUriForSession(sessionHolder.session)
-
- /**
- * Updates the URI for the artifact directory.
- *
- * This is required if the SparkContext is restarted.
- *
- * Note: This logic is solely to handle testing where a [[SparkContext]] may
be restarted
- * several times in a single JVM lifetime. In a general Spark cluster, the
[[SparkContext]] is
- * not expected to be restarted at any point in time.
- */
- private def refreshArtifactUri(sc: SparkContext): Unit = synchronized {
- // If a competing thread had updated the URI, we do not need to refresh
the URI again.
- if (sc eq lastKnownSparkContextInstance) {
- return
- }
- val oldArtifactUri = currentArtifactRootUri
- currentArtifactRootUri = SparkEnv.get.rpcEnv.fileServer
- .addDirectoryIfAbsent(ARTIFACT_DIRECTORY_PREFIX, artifactRootPath.toFile)
- lastKnownSparkContextInstance = sc
- logDebug(s"Artifact URI updated from $oldArtifactUri to
$currentArtifactRootUri")
- }
-
- /**
- * Checks if the URI for the artifact directory needs to be updated. This is
required in cases
- * where SparkContext is restarted as the old URI would no longer be valid.
- *
- * Note: This logic is solely to handle testing where a [[SparkContext]] may
be restarted
- * several times in a single JVM lifetime. In a general Spark cluster, the
[[SparkContext]] is
- * not expected to be restarted at any point in time.
- */
- private def updateUriIfRequired(): Unit = {
- SparkContext.getActive.foreach { sc =>
- if (lastKnownSparkContextInstance == null || (sc ne
lastKnownSparkContextInstance)) {
- logDebug("Refreshing artifact URI due to SparkContext
(re)initialisation!")
- refreshArtifactUri(sc)
- }
- }
- }
-
- private[connect] def artifactRootURI: String = {
- updateUriIfRequired()
- require(currentArtifactRootUri != null)
- currentArtifactRootUri
- }
}
diff --git
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/util/ArtifactUtils.scala
b/sql/core/src/main/scala/org/apache/spark/sql/artifact/util/ArtifactUtils.scala
similarity index 88%
rename from
connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/util/ArtifactUtils.scala
rename to
sql/core/src/main/scala/org/apache/spark/sql/artifact/util/ArtifactUtils.scala
index ab1c0f81659..f16d01501d7 100644
---
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/util/ArtifactUtils.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/artifact/util/ArtifactUtils.scala
@@ -15,13 +15,13 @@
* limitations under the License.
*/
-package org.apache.spark.sql.connect.artifact.util
+package org.apache.spark.sql.artifact.util
import java.nio.file.{Path, Paths}
object ArtifactUtils {
- private[connect] def concatenatePaths(basePath: Path, otherPath: Path): Path
= {
+ private[sql] def concatenatePaths(basePath: Path, otherPath: Path): Path = {
require(!otherPath.isAbsolute)
// We avoid using the `.resolve()` method here to ensure that we're
concatenating the two
// paths.
@@ -37,7 +37,7 @@ object ArtifactUtils {
normalizedPath
}
- private[connect] def concatenatePaths(basePath: Path, otherPath: String):
Path = {
+ private[sql] def concatenatePaths(basePath: Path, otherPath: String): Path =
{
concatenatePaths(basePath, Paths.get(otherPath))
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 1d496b027ef..630e1202f6d 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.internal
import org.apache.spark.annotation.Unstable
import org.apache.spark.sql.{ExperimentalMethods, SparkSession,
UDFRegistration, _}
+import org.apache.spark.sql.artifact.ArtifactManager
import org.apache.spark.sql.catalyst.analysis.{Analyzer,
EvalSubqueriesForTimeTravel, FunctionRegistry, ReplaceCharWithVarchar,
ResolveSessionCatalog, TableFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder,
SessionCatalog}
import org.apache.spark.sql.catalyst.expressions.Expression
@@ -349,6 +350,12 @@ abstract class BaseSessionStateBuilder(
new ExecutionListenerManager(session, conf, loadExtensions = true))
}
+ /**
+ * Resource manager that handles the storage of artifacts as well as
preparing the artifacts for
+ * use.
+ */
+ protected def artifactManager: ArtifactManager = new ArtifactManager(session)
+
/**
* Function used to make clones of the session state.
*/
@@ -381,7 +388,8 @@ abstract class BaseSessionStateBuilder(
createClone,
columnarRules,
adaptiveRulesHolder,
- planNormalizationRules)
+ planNormalizationRules,
+ () => artifactManager)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index 177a25b45fc..adf3e0cb6ca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.Unstable
import org.apache.spark.sql._
+import org.apache.spark.sql.artifact.ArtifactManager
import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry,
TableFunctionRegistry}
import org.apache.spark.sql.catalyst.catalog._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
@@ -84,7 +85,8 @@ private[sql] class SessionState(
createClone: (SparkSession, SessionState) => SessionState,
val columnarRules: Seq[ColumnarRule],
val adaptiveRulesHolder: AdaptiveRulesHolder,
- val planNormalizationRules: Seq[Rule[LogicalPlan]]) {
+ val planNormalizationRules: Seq[Rule[LogicalPlan]],
+ val artifactManagerBuilder: () => ArtifactManager) {
// The following fields are lazy to avoid creating the Hive client when
creating SessionState.
lazy val catalog: SessionCatalog = catalogBuilder()
@@ -99,6 +101,8 @@ private[sql] class SessionState(
// when connecting to ThriftServer.
lazy val streamingQueryManager: StreamingQueryManager =
streamingQueryManagerBuilder()
+ lazy val artifactManager: ArtifactManager = artifactManagerBuilder()
+
def catalogManager: CatalogManager = analyzer.catalogManager
def newHadoopConf(): Configuration = SessionState.newHadoopConf(
diff --git a/sql/core/src/test/resources/artifact-tests/Hello.class
b/sql/core/src/test/resources/artifact-tests/Hello.class
new file mode 100644
index 00000000000..56725764de2
Binary files /dev/null and
b/sql/core/src/test/resources/artifact-tests/Hello.class differ
diff --git a/sql/core/src/test/resources/artifact-tests/smallClassFile.class
b/sql/core/src/test/resources/artifact-tests/smallClassFile.class
new file mode 100755
index 00000000000..e796030e471
Binary files /dev/null and
b/sql/core/src/test/resources/artifact-tests/smallClassFile.class differ
diff --git a/sql/core/src/test/resources/artifact-tests/udf_noA.jar
b/sql/core/src/test/resources/artifact-tests/udf_noA.jar
new file mode 100644
index 00000000000..4d8c423ab6d
Binary files /dev/null and
b/sql/core/src/test/resources/artifact-tests/udf_noA.jar differ
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
similarity index 68%
rename from
connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
rename to
sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
index 0c095384de8..263006100be 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/ArtifactManagerSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/artifact/ArtifactManagerSuite.scala
@@ -14,37 +14,31 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.connect.artifact
+package org.apache.spark.sql.artifact
+import java.io.File
import java.nio.charset.StandardCharsets
import java.nio.file.{Files, Paths}
-import java.util.UUID
-import io.grpc.StatusRuntimeException
import org.apache.commons.io.FileUtils
-import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext,
SparkException, SparkFunSuite}
-import org.apache.spark.sql.connect.ResourceHelper
-import org.apache.spark.sql.connect.service.{SessionHolder,
SparkConnectService}
+import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.storage.CacheId
import org.apache.spark.util.Utils
-class ArtifactManagerSuite extends SharedSparkSession with ResourceHelper {
+class ArtifactManagerSuite extends SharedSparkSession {
override protected def sparkConf: SparkConf = {
val conf = super.sparkConf
- conf
- .set("spark.plugins", "org.apache.spark.sql.connect.SparkConnectPlugin")
- .set("spark.connect.copyFromLocalToFs.allowDestLocal", "true")
+ conf.set("spark.sql.artifact.copyFromLocalToFs.allowDestLocal", "true")
}
- private val artifactPath = commonResourcePath.resolve("artifact-tests")
- private lazy val sessionHolder: SessionHolder = {
- SessionHolder("test", UUID.randomUUID().toString, spark)
- }
- private lazy val artifactManager = new
SparkConnectArtifactManager(sessionHolder)
+ private val artifactPath = new
File("src/test/resources/artifact-tests").toPath
+
+ private lazy val artifactManager = spark.artifactManager
private def sessionUUID: String = spark.sessionUUID
@@ -61,7 +55,7 @@ class ArtifactManagerSuite extends SharedSparkSession with
ResourceHelper {
assert(stagingPath.toFile.exists())
artifactManager.addArtifact(remotePath, stagingPath, None)
- val movedClassFile = SparkConnectArtifactManager.artifactRootPath
+ val movedClassFile = ArtifactManager.artifactRootDirectory
.resolve(s"$sessionUUID/classes/smallClassFile.class")
.toFile
assert(movedClassFile.exists())
@@ -75,7 +69,7 @@ class ArtifactManagerSuite extends SharedSparkSession with
ResourceHelper {
assert(stagingPath.toFile.exists())
artifactManager.addArtifact(remotePath, stagingPath, None)
- val movedClassFile = SparkConnectArtifactManager.artifactRootPath
+ val movedClassFile = ArtifactManager.artifactRootDirectory
.resolve(s"$sessionUUID/classes/Hello.class")
.toFile
assert(movedClassFile.exists())
@@ -98,16 +92,14 @@ class ArtifactManagerSuite extends SharedSparkSession with
ResourceHelper {
val remotePath = Paths.get("classes/Hello.class")
assert(stagingPath.toFile.exists())
- val sessionHolder =
- SparkConnectService.getOrCreateIsolatedSession("c1",
UUID.randomUUID.toString())
- sessionHolder.addArtifact(remotePath, stagingPath, None)
+ artifactManager.addArtifact(remotePath, stagingPath, None)
- val movedClassFile = SparkConnectArtifactManager.artifactRootPath
- .resolve(s"${sessionHolder.session.sessionUUID}/classes/Hello.class")
+ val movedClassFile = ArtifactManager.artifactRootDirectory
+ .resolve(s"${spark.sessionUUID}/classes/Hello.class")
.toFile
assert(movedClassFile.exists())
- val classLoader = sessionHolder.classloader
+ val classLoader = spark.artifactManager.classloader
val instance = classLoader
.loadClass("Hello")
.getDeclaredConstructor(classOf[String])
@@ -115,8 +107,8 @@ class ArtifactManagerSuite extends SharedSparkSession with
ResourceHelper {
.asInstanceOf[String => String]
val udf = org.apache.spark.sql.functions.udf(instance)
- sessionHolder.withSession { session =>
- session.range(10).select(udf(col("id").cast("string"))).collect()
+ spark.artifactManager.withResources {
+ spark.range(10).select(udf(col("id").cast("string"))).collect()
}
}
@@ -125,9 +117,8 @@ class ArtifactManagerSuite extends SharedSparkSession with
ResourceHelper {
val stagingPath = path.toPath
Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8))
val remotePath = Paths.get("cache/abc")
- val session = sessionHolder
val blockManager = spark.sparkContext.env.blockManager
- val blockId = CacheId(session.userId, session.sessionId, "abc")
+ val blockId = CacheId(spark.sessionUUID, "abc")
try {
artifactManager.addArtifact(remotePath, stagingPath, None)
val bytes = blockManager.getLocalBytes(blockId)
@@ -136,7 +127,7 @@ class ArtifactManagerSuite extends SharedSparkSession with
ResourceHelper {
assert(readback === "test")
} finally {
blockManager.releaseLock(blockId)
- blockManager.removeCache(session.userId, session.sessionId)
+ blockManager.removeCache(spark.sessionUUID)
}
}
}
@@ -147,7 +138,7 @@ class ArtifactManagerSuite extends SharedSparkSession with
ResourceHelper {
Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8))
val remotePath = Paths.get("pyfiles/abc.zip")
artifactManager.addArtifact(remotePath, stagingPath, None)
- assert(artifactManager.getSparkConnectPythonIncludes == Seq("abc.zip"))
+ assert(artifactManager.getPythonIncludes == Seq("abc.zip"))
}
}
@@ -167,7 +158,7 @@ class ArtifactManagerSuite extends SharedSparkSession with
ResourceHelper {
withTempPath { path =>
Files.write(path.toPath, "updated file".getBytes(StandardCharsets.UTF_8))
- assertThrows[StatusRuntimeException] {
+ assertThrows[RuntimeException] {
artifactManager.addArtifact(remotePath, path.toPath, None)
}
}
@@ -193,9 +184,8 @@ class ArtifactManagerSuite extends SharedSparkSession with
ResourceHelper {
val stagingPath = path.toPath
Files.write(path.toPath, "test".getBytes(StandardCharsets.UTF_8))
val remotePath = Paths.get("cache/abc")
- val session = sessionHolder
val blockManager = spark.sparkContext.env.blockManager
- val blockId = CacheId(session.userId, session.sessionId, "abc")
+ val blockId = CacheId(spark.sessionUUID, "abc")
// Setup artifact dir
val copyDir = Utils.createTempDir().toPath
FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile)
@@ -209,7 +199,7 @@ class ArtifactManagerSuite extends SharedSparkSession with
ResourceHelper {
val bytes = blockManager.getLocalBytes(blockId)
assert(bytes.isDefined)
blockManager.releaseLock(blockId)
- val expectedPath = SparkConnectArtifactManager.artifactRootPath
+ val expectedPath = ArtifactManager.artifactRootDirectory
.resolve(s"$sessionUUID/classes/smallClassFile.class")
assert(expectedPath.toFile.exists())
@@ -226,32 +216,30 @@ class ArtifactManagerSuite extends SharedSparkSession
with ResourceHelper {
case throwable: Throwable => throw throwable
} finally {
FileUtils.deleteDirectory(copyDir.toFile)
- blockManager.removeCache(session.userId, session.sessionId)
+ blockManager.removeCache(spark.sessionUUID)
}
}
}
}
test("Classloaders for spark sessions are isolated") {
- // use same sessionId - different users should still make it isolated.
- val sessionId = UUID.randomUUID.toString()
- val holder1 = SparkConnectService.getOrCreateIsolatedSession("c1",
sessionId)
- val holder2 = SparkConnectService.getOrCreateIsolatedSession("c2",
sessionId)
- val holder3 = SparkConnectService.getOrCreateIsolatedSession("c3",
sessionId)
+ val session1 = spark.newSession()
+ val session2 = spark.newSession()
+ val session3 = spark.newSession()
- def addHelloClass(holder: SessionHolder): Unit = {
+ def addHelloClass(session: SparkSession): Unit = {
val copyDir = Utils.createTempDir().toPath
FileUtils.copyDirectory(artifactPath.toFile, copyDir.toFile)
val stagingPath = copyDir.resolve("Hello.class")
val remotePath = Paths.get("classes/Hello.class")
assert(stagingPath.toFile.exists())
- holder.addArtifact(remotePath, stagingPath, None)
+ session.artifactManager.addArtifact(remotePath, stagingPath, None)
}
// Add the "Hello" classfile for the first user
- addHelloClass(holder1)
+ addHelloClass(session1)
- val classLoader1 = holder1.classloader
+ val classLoader1 = session1.artifactManager.classloader
val instance1 = classLoader1
.loadClass("Hello")
.getDeclaredConstructor(classOf[String])
@@ -259,13 +247,13 @@ class ArtifactManagerSuite extends SharedSparkSession
with ResourceHelper {
.asInstanceOf[String => String]
val udf1 = org.apache.spark.sql.functions.udf(instance1)
- holder1.withSession { session =>
- val result =
session.range(10).select(udf1(col("id").cast("string"))).collect()
- assert(result.forall(_.getString(0).contains("Talon")))
+ session1.artifactManager.withResources {
+ val result1 =
session1.range(10).select(udf1(col("id").cast("string"))).collect()
+ assert(result1.forall(_.getString(0).contains("Talon")))
}
assertThrows[ClassNotFoundException] {
- val classLoader2 = holder2.classloader
+ val classLoader2 = session2.artifactManager.classloader
val instance2 = classLoader2
.loadClass("Hello")
.getDeclaredConstructor(classOf[String])
@@ -274,17 +262,19 @@ class ArtifactManagerSuite extends SharedSparkSession
with ResourceHelper {
}
// Add the "Hello" classfile for the third user
- addHelloClass(holder3)
- val instance3 = holder3.classloader
+ addHelloClass(session3)
+
+ val classLoader3 = session3.artifactManager.classloader
+ val instance3 = classLoader3
.loadClass("Hello")
.getDeclaredConstructor(classOf[String])
.newInstance("Ahri")
.asInstanceOf[String => String]
val udf3 = org.apache.spark.sql.functions.udf(instance3)
- holder3.withSession { session =>
- val result =
session.range(10).select(udf3(col("id").cast("string"))).collect()
- assert(result.forall(_.getString(0).contains("Ahri")))
+ session3.artifactManager.withResources {
+ val result3 =
session3.range(10).select(udf3(col("id").cast("string"))).collect()
+ assert(result3.forall(_.getString(0).contains("Ahri")))
}
}
@@ -294,36 +284,13 @@ class ArtifactManagerSuite extends SharedSparkSession
with ResourceHelper {
val stagingPath = copyDir.resolve("Hello.class")
val remotePath = Paths.get("classes/Hello.class")
- val holder =
- SparkConnectService.getOrCreateIsolatedSession("c1",
UUID.randomUUID.toString)
- holder.addArtifact(remotePath, stagingPath, None)
+ artifactManager.addArtifact(remotePath, stagingPath, None)
- val sessionDirectory =
-
SparkConnectArtifactManager.getArtifactDirectoryAndUriForSession(holder)._1.toFile
+ val sessionDirectory = artifactManager.artifactPath.toFile
assert(sessionDirectory.exists())
- holder.artifactManager.cleanUpResources()
+ artifactManager.cleanUpResources()
assert(!sessionDirectory.exists())
- assert(SparkConnectArtifactManager.artifactRootPath.toFile.exists())
- }
-}
-
-class ArtifactUriSuite extends SparkFunSuite with LocalSparkContext {
-
- private def createSparkContext(): Unit = {
- resetSparkContext()
- sc = new SparkContext("local[4]", "test", new SparkConf())
-
- }
- override def beforeEach(): Unit = {
- super.beforeEach()
- createSparkContext()
- }
-
- test("Artifact URI is reset when SparkContext is restarted") {
- val oldUri = SparkConnectArtifactManager.artifactRootURI
- createSparkContext()
- val newUri = SparkConnectArtifactManager.artifactRootURI
- assert(newUri != oldUri)
+ assert(ArtifactManager.artifactRootDirectory.toFile.exists())
}
}
diff --git
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/artifact/StubClassLoaderSuite.scala
similarity index 94%
rename from
connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala
rename to
sql/core/src/test/scala/org/apache/spark/sql/artifact/StubClassLoaderSuite.scala
index bde9a71fa17..c1a0cc27400 100644
---
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/artifact/StubClassLoaderSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/artifact/StubClassLoaderSuite.scala
@@ -14,7 +14,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-package org.apache.spark.sql.connect.artifact
+package org.apache.spark.sql.artifact
import java.io.File
@@ -23,8 +23,11 @@ import org.apache.spark.util.{ChildFirstURLClassLoader,
StubClassLoader}
class StubClassLoaderSuite extends SparkFunSuite {
- // See src/test/resources/StubClassDummyUdf for how the UDFs and jars are
created.
- private val udfNoAJar = new
File("src/test/resources/udf_noA.jar").toURI.toURL
+ // TODO: Modify JAR to remove references to connect.
+ // See connector/client/jvm/src/test/resources/StubClassDummyUdf for how the
UDFs and jars are
+ // created.
+ private val udfNoAJar = new File(
+ "src/test/resources/artifact-tests/udf_noA.jar").toURI.toURL
private val classDummyUdf =
"org.apache.spark.sql.connect.client.StubClassDummyUdf"
private val classA = "org.apache.spark.sql.connect.client.A"
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]