This is an automated email from the ASF dual-hosted git repository.
weichenxu123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 7e363747b15e [SPARK-52192][ML][CONNECT] MLCache loading path check
7e363747b15e is described below
commit 7e363747b15ef0dfbe4c56ad196886de8197b203
Author: Weichen Xu <[email protected]>
AuthorDate: Mon May 19 14:03:10 2025 +0800
[SPARK-52192][ML][CONNECT] MLCache loading path check
### What changes were proposed in this pull request?
Add check for the MLCache loading path, to prevent it read files outside
the MLCache offloading storage path. This can prevent potential security issues.
### Why are the changes needed?
to prevent it read files outside the MLCache offloading storage path. This
can prevent potential security issues.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Manually.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #50923 from WeichenXu123/SPARK-52192.
Authored-by: Weichen Xu <[email protected]>
Signed-off-by: Weichen Xu <[email protected]>
---
.../org/apache/spark/sql/connect/ml/MLCache.scala | 23 ++++++++++++++++++++--
1 file changed, 21 insertions(+), 2 deletions(-)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
index bd832fb02854..ef1b17dc2221 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ml/MLCache.scala
@@ -27,6 +27,7 @@ import scala.collection.mutable
import com.google.common.cache.{CacheBuilder, RemovalNotification}
import org.apache.commons.io.FileUtils
+import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
import org.apache.spark.ml.Model
import org.apache.spark.ml.util.{ConnectHelper, MLWritable, Summary}
@@ -137,6 +138,7 @@ private[connect] class MLCache(sessionHolder:
SessionHolder) extends Logging {
cachedModel.put(objectId, CacheItem(obj, sizeBytes))
if (getMemoryControlEnabled) {
val savePath = offloadedModelsDir.resolve(objectId)
+ require(savePath.startsWith(offloadedModelsDir))
obj.asInstanceOf[MLWritable].write.saveToLocal(savePath.toString)
Files.writeString(savePath.resolve(modelClassNameFile),
obj.getClass.getName)
totalMLCacheInMemorySizeBytes.addAndGet(sizeBytes)
@@ -148,6 +150,18 @@ private[connect] class MLCache(sessionHolder:
SessionHolder) extends Logging {
objectId
}
+ private[spark] def verifyObjectId(refId: String): Unit = {
+ // Verify the `refId` is a valid UUID.
+ // This is for preventing client to send a malicious `refId` which might
+ // cause Spark Server security issue.
+ try {
+ UUID.fromString(refId)
+ } catch {
+ case _: IllegalArgumentException =>
+ throw SparkException.internalError(s"The MLCache key $refId is
invalid.")
+ }
+ }
+
/**
* Get the object by the key
* @param refId
@@ -159,9 +173,11 @@ private[connect] class MLCache(sessionHolder:
SessionHolder) extends Logging {
if (refId == helperID) {
helper
} else {
+ verifyObjectId(refId)
var obj: Object =
Option(cachedModel.get(refId)).map(_.obj).getOrElse(null)
if (obj == null && getMemoryControlEnabled) {
val loadPath = offloadedModelsDir.resolve(refId)
+ require(loadPath.startsWith(offloadedModelsDir))
if (Files.isDirectory(loadPath)) {
val className =
Files.readString(loadPath.resolve(modelClassNameFile))
obj = MLUtils.loadTransformer(
@@ -179,11 +195,14 @@ private[connect] class MLCache(sessionHolder:
SessionHolder) extends Logging {
}
def _removeModel(refId: String): Boolean = {
+ verifyObjectId(refId)
val removedModel = cachedModel.remove(refId)
val removedFromMem = removedModel != null
- val removedFromDisk = if (getMemoryControlEnabled) {
+ val removedFromDisk = if (removedModel != null && getMemoryControlEnabled)
{
totalMLCacheSizeBytes.addAndGet(-removedModel.sizeBytes)
- val offloadingPath = new File(offloadedModelsDir.resolve(refId).toString)
+ val removePath = offloadedModelsDir.resolve(refId)
+ require(removePath.startsWith(offloadedModelsDir))
+ val offloadingPath = new File(removePath.toString)
if (offloadingPath.exists()) {
FileUtils.deleteDirectory(offloadingPath)
true
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]