This is an automated email from the ASF dual-hosted git repository.
hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 3e2b146eb81 [SPARK-45136][CONNECT] Enhance ClosureCleaner with
Ammonite support
3e2b146eb81 is described below
commit 3e2b146eb81d9a5727f07b58f7bb1760a71a8697
Author: Vsevolod Stepanov <[email protected]>
AuthorDate: Wed Oct 25 21:35:07 2023 -0400
[SPARK-45136][CONNECT] Enhance ClosureCleaner with Ammonite support
### What changes were proposed in this pull request?
This PR enhances existing ClosureCleaner implementation to support cleaning
closures defined in Ammonite. Please refer to [this
gist](https://gist.github.com/vsevolodstep-db/b8e4d676745d6e2d047ecac291e5254c)
to get more context on how Ammonite code wrapping works and what problems I'm
trying to solve here.
Overall, it contains these logical changes in `ClosureCleaner`:
1. Making it recognize and clean closures defined in Ammonite (previously
it was checking if capturing class name starts with `$line` and ends with
`$iw`, which is native Scala REPL specific thing
2. Making it clean closures if they are defined inside a user class in a
REPL (see corner case 1 in the gist)
3. Making it clean nested closures properly for Ammonite REPL (see corner
case 2 in the gist)
4. Making it transitively follow other Ammonite commands that are captured
by the target closure.
Please note that `cleanTransitively` option of `ClosureCleaner.clean()`
method refers to following references transitively within enclosing command
object, but it doesn't follow other command objects.
As we need `ClosureCleaner` to be available in Spark Connect, I also moved
the implementation to `common-utils` module. This brings a new
`xbean-asm9-shaded` which is fairly small.
Also, this PR moves `checkSerializable` check from `ClosureCleaner` to
`SparkClosureCleaner`, as it is specific to Spark core
The important changes affect `ClosureCleaner` only. They should not affect
existing codepath for normal Scala closures / closures defined in a native
Scala REPL and cover only closures defined in Ammonite.
Also, this PR modifies SparkConnect's `UserDefinedFunction` to actually
use `ClosureCleaner` and clean closures in SparkConnect
### Why are the changes needed?
To properly support closures defined in Ammonite, reduce UDF payload size
and avoid possible `NonSerializable` exceptions. This includes:
- lambda capturing outer command object, leading in a circular dependency
- lambda capturing other command objects transitively, exploding payload
size
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Existing tests.
New tests in `ReplE2ESuite` covering various scenarios using SparkConnect +
Ammonite REPL to make sure closures are actually cleaned.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #42995 from vsevolodstep-db/SPARK-45136/closure-cleaner.
Authored-by: Vsevolod Stepanov <[email protected]>
Signed-off-by: Herman van Hovell <[email protected]>
---
common/utils/pom.xml | 4 +
.../org/apache/spark/util/ClosureCleaner.scala | 636 ++++++++++++++-------
.../org/apache/spark/util/SparkStreamUtils.scala | 109 ++++
.../sql/expressions/UserDefinedFunction.scala | 10 +-
.../spark/sql/application/ReplE2ESuite.scala | 143 +++++
.../CheckConnectJvmClientCompatibility.scala | 8 +
core/pom.xml | 4 -
.../main/scala/org/apache/spark/SparkContext.scala | 2 +-
.../apache/spark/util/SparkClosureCleaner.scala | 49 ++
.../main/scala/org/apache/spark/util/Utils.scala | 85 +--
.../apache/spark/util/ClosureCleanerSuite.scala | 2 +-
.../apache/spark/util/ClosureCleanerSuite2.scala | 4 +-
project/MimaExcludes.scala | 4 +-
.../catalyst/encoders/ExpressionEncoderSuite.scala | 4 +-
.../org/apache/spark/streaming/StateSpec.scala | 6 +-
15 files changed, 756 insertions(+), 314 deletions(-)
diff --git a/common/utils/pom.xml b/common/utils/pom.xml
index 37d1ea48d97..44cb30a19ff 100644
--- a/common/utils/pom.xml
+++ b/common/utils/pom.xml
@@ -39,6 +39,10 @@
<groupId>org.apache.spark</groupId>
<artifactId>spark-tags_${scala.binary.version}</artifactId>
</dependency>
+ <dependency>
+ <groupId>org.apache.xbean</groupId>
+ <artifactId>xbean-asm9-shaded</artifactId>
+ </dependency>
<dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-databind</artifactId>
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
b/common/utils/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
similarity index 61%
rename from core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
rename to common/utils/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index 29fb0206f90..ffa2f0e60b2 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/common/utils/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -21,7 +21,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream}
import java.lang.invoke.{MethodHandleInfo, SerializedLambda}
import java.lang.reflect.{Field, Modifier}
-import scala.collection.mutable.{Map, Set, Stack}
+import scala.collection.mutable.{Map, Queue, Set, Stack}
import scala.jdk.CollectionConverters._
import org.apache.commons.lang3.ClassUtils
@@ -29,14 +29,13 @@ import org.apache.xbean.asm9.{ClassReader, ClassVisitor,
Handle, MethodVisitor,
import org.apache.xbean.asm9.Opcodes._
import org.apache.xbean.asm9.tree.{ClassNode, MethodNode}
-import org.apache.spark.{SparkEnv, SparkException}
+import org.apache.spark.SparkException
import org.apache.spark.internal.Logging
/**
* A cleaner that renders closures serializable if they can be done so safely.
*/
private[spark] object ClosureCleaner extends Logging {
-
// Get an ASM class reader for a given class from the JAR that loaded it
private[util] def getClassReader(cls: Class[_]): ClassReader = {
// Copy data over, before delegating to ClassReader - else we can run out
of open file handles.
@@ -46,11 +45,18 @@ private[spark] object ClosureCleaner extends Logging {
null
} else {
val baos = new ByteArrayOutputStream(128)
- Utils.copyStream(resourceStream, baos, true)
+
+ SparkStreamUtils.copyStream(resourceStream, baos, closeStreams = true)
new ClassReader(new ByteArrayInputStream(baos.toByteArray))
}
}
+ private[util] def isAmmoniteCommandOrHelper(clazz: Class[_]): Boolean =
clazz.getName.matches(
+ """^ammonite\.\$sess\.cmd[0-9]*(\$Helper\$?)?""")
+
+ private[util] def isDefinedInAmmonite(clazz: Class[_]): Boolean =
clazz.getName.matches(
+ """^ammonite\.\$sess\.cmd[0-9]*.*""")
+
// Check whether a class represents a Scala closure
private def isClosure(cls: Class[_]): Boolean = {
cls.getName.contains("$anonfun$")
@@ -146,23 +152,6 @@ private[spark] object ClosureCleaner extends Logging {
clone
}
- /**
- * Clean the given closure in place.
- *
- * More specifically, this renders the given closure serializable as long as
it does not
- * explicitly reference unserializable objects.
- *
- * @param closure the closure to clean
- * @param checkSerializable whether to verify that the closure is
serializable after cleaning
- * @param cleanTransitively whether to clean enclosing closures transitively
- */
- def clean(
- closure: AnyRef,
- checkSerializable: Boolean = true,
- cleanTransitively: Boolean = true): Unit = {
- clean(closure, checkSerializable, cleanTransitively, Map.empty)
- }
-
/**
* Helper method to clean the given closure in place.
*
@@ -198,18 +187,15 @@ private[spark] object ClosureCleaner extends Logging {
* pointer of a cloned scope "one" and set it the parent of scope "two",
such that scope "two"
* no longer references SomethingNotSerializable transitively.
*
- * @param func the starting closure to clean
- * @param checkSerializable whether to verify that the closure is
serializable after cleaning
+ * @param func the starting closure to clean
* @param cleanTransitively whether to clean enclosing closures transitively
- * @param accessedFields a map from a class to a set of its fields that are
accessed by
- * the starting closure
+ * @param accessedFields a map from a class to a set of its fields that
are accessed by
+ * the starting closure
*/
- private def clean(
+ private[spark] def clean(
func: AnyRef,
- checkSerializable: Boolean,
cleanTransitively: Boolean,
- accessedFields: Map[Class[_], Set[String]]): Unit = {
-
+ accessedFields: Map[Class[_], Set[String]]): Boolean = {
// indylambda check. Most likely to be the case with 2.12, 2.13
// so we check first
// non LMF-closures should be less frequent from now on
@@ -217,131 +203,18 @@ private[spark] object ClosureCleaner extends Logging {
if (!isClosure(func.getClass) && maybeIndylambdaProxy.isEmpty) {
logDebug(s"Expected a closure; got ${func.getClass.getName}")
- return
+ return false
}
// TODO: clean all inner closures first. This requires us to find the
inner objects.
// TODO: cache outerClasses / innerClasses / accessedFields
if (func == null) {
- return
+ return false
}
if (maybeIndylambdaProxy.isEmpty) {
- logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++")
-
- // A list of classes that represents closures enclosed in the given one
- val innerClasses = getInnerClosureClasses(func)
-
- // A list of enclosing objects and their respective classes, from
innermost to outermost
- // An outer object at a given index is of type outer class at the same
index
- val (outerClasses, outerObjects) = getOuterClassesAndObjects(func)
-
- // For logging purposes only
- val declaredFields = func.getClass.getDeclaredFields
- val declaredMethods = func.getClass.getDeclaredMethods
-
- if (log.isDebugEnabled) {
- logDebug(s" + declared fields: ${declaredFields.size}")
- declaredFields.foreach { f => logDebug(s" $f") }
- logDebug(s" + declared methods: ${declaredMethods.size}")
- declaredMethods.foreach { m => logDebug(s" $m") }
- logDebug(s" + inner classes: ${innerClasses.size}")
- innerClasses.foreach { c => logDebug(s" ${c.getName}") }
- logDebug(s" + outer classes: ${outerClasses.size}" )
- outerClasses.foreach { c => logDebug(s" ${c.getName}") }
- }
-
- // Fail fast if we detect return statements in closures
- getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0)
-
- // If accessed fields is not populated yet, we assume that
- // the closure we are trying to clean is the starting one
- if (accessedFields.isEmpty) {
- logDebug(" + populating accessed fields because this is the starting
closure")
- // Initialize accessed fields with the outer classes first
- // This step is needed to associate the fields to the correct classes
later
- initAccessedFields(accessedFields, outerClasses)
-
- // Populate accessed fields by visiting all fields and methods
accessed by this and
- // all of its inner closures. If transitive cleaning is enabled, this
may recursively
- // visits methods that belong to other classes in search of
transitively referenced fields.
- for (cls <- func.getClass :: innerClasses) {
- getClassReader(cls).accept(new FieldAccessFinder(accessedFields,
cleanTransitively), 0)
- }
- }
-
- logDebug(s" + fields accessed by starting closure:
${accessedFields.size} classes")
- accessedFields.foreach { f => logDebug(" " + f) }
-
- // List of outer (class, object) pairs, ordered from outermost to
innermost
- // Note that all outer objects but the outermost one (first one in this
list) must be closures
- var outerPairs: List[(Class[_], AnyRef)] =
outerClasses.zip(outerObjects).reverse
- var parent: AnyRef = null
- if (outerPairs.nonEmpty) {
- val outermostClass = outerPairs.head._1
- val outermostObject = outerPairs.head._2
-
- if (isClosure(outermostClass)) {
- logDebug(s" + outermost object is a closure, so we clone it:
${outermostClass}")
- } else if (outermostClass.getName.startsWith("$line")) {
- // SPARK-14558: if the outermost object is a REPL line object, we
should clone
- // and clean it as it may carry a lot of unnecessary information,
- // e.g. hadoop conf, spark conf, etc.
- logDebug(s" + outermost object is a REPL line object, so we clone
it:" +
- s" ${outermostClass}")
- } else {
- // The closure is ultimately nested inside a class; keep the object
of that
- // class without cloning it since we don't want to clone the user's
objects.
- // Note that we still need to keep around the outermost object
itself because
- // we need it to clone its child closure later (see below).
- logDebug(s" + outermost object is not a closure or REPL line
object," +
- s" so do not clone it: ${outermostClass}")
- parent = outermostObject // e.g. SparkContext
- outerPairs = outerPairs.tail
- }
- } else {
- logDebug(" + there are no enclosing objects!")
- }
-
- // Clone the closure objects themselves, nulling out any fields that are
not
- // used in the closure we're working on or any of its inner closures.
- for ((cls, obj) <- outerPairs) {
- logDebug(s" + cloning instance of class ${cls.getName}")
- // We null out these unused references by cloning each object and then
filling in all
- // required fields from the original object. We need the parent here
because the Java
- // language specification requires the first constructor parameter of
any closure to be
- // its enclosing object.
- val clone = cloneAndSetFields(parent, obj, cls, accessedFields)
-
- // If transitive cleaning is enabled, we recursively clean any
enclosing closure using
- // the already populated accessed fields map of the starting closure
- if (cleanTransitively && isClosure(clone.getClass)) {
- logDebug(s" + cleaning cloned closure recursively (${cls.getName})")
- // No need to check serializable here for the outer closures because
we're
- // only interested in the serializability of the starting closure
- clean(clone, checkSerializable = false, cleanTransitively,
accessedFields)
- }
- parent = clone
- }
-
- // Update the parent pointer ($outer) of this closure
- if (parent != null) {
- val field = func.getClass.getDeclaredField("$outer")
- field.setAccessible(true)
- // If the starting closure doesn't actually need our enclosing object,
then just null it out
- if (accessedFields.contains(func.getClass) &&
- !accessedFields(func.getClass).contains("$outer")) {
- logDebug(s" + the starting closure doesn't actually need $parent, so
we null it out")
- field.set(func, null)
- } else {
- // Update this closure's parent pointer to point to our enclosing
object,
- // which could either be a cloned closure or the original user object
- field.set(func, parent)
- }
- }
-
- logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned
+++")
+ cleanNonIndyLambdaClosure(func, cleanTransitively, accessedFields)
} else {
val lambdaProxy = maybeIndylambdaProxy.get
val implMethodName = lambdaProxy.getImplMethodName
@@ -359,62 +232,339 @@ private[spark] object ClosureCleaner extends Logging {
val capturingClassReader = getClassReader(capturingClass)
capturingClassReader.accept(new
ReturnStatementFinder(Option(implMethodName)), 0)
- val isClosureDeclaredInScalaRepl =
capturingClassName.startsWith("$line") &&
- capturingClassName.endsWith("$iw")
- val outerThisOpt = if (lambdaProxy.getCapturedArgCount > 0) {
- Option(lambdaProxy.getCapturedArg(0))
+ val outerThis = if (lambdaProxy.getCapturedArgCount > 0) {
+ // only need to clean when there is an enclosing non-null "this"
captured by the closure
+ Option(lambdaProxy.getCapturedArg(0)).getOrElse(return false)
} else {
- None
+ return false
}
- // only need to clean when there is an enclosing "this" captured by the
closure, and it
- // should be something cleanable, i.e. a Scala REPL line object
- val needsCleaning = isClosureDeclaredInScalaRepl &&
- outerThisOpt.isDefined && outerThisOpt.get.getClass.getName ==
capturingClassName
-
- if (needsCleaning) {
- // indylambda closures do not reference enclosing closures via an
`$outer` chain, so no
- // transitive cleaning on the `$outer` chain is needed.
- // Thus clean() shouldn't be recursively called with a non-empty
accessedFields.
- assert(accessedFields.isEmpty)
-
- initAccessedFields(accessedFields, Seq(capturingClass))
- IndylambdaScalaClosures.findAccessedFields(
- lambdaProxy, classLoader, accessedFields, cleanTransitively)
-
- logDebug(s" + fields accessed by starting closure:
${accessedFields.size} classes")
- accessedFields.foreach { f => logDebug(" " + f) }
-
- if (accessedFields(capturingClass).size <
capturingClass.getDeclaredFields.length) {
- // clone and clean the enclosing `this` only when there are fields
to null out
-
- val outerThis = outerThisOpt.get
-
- logDebug(s" + cloning instance of REPL class $capturingClassName")
- val clonedOuterThis = cloneAndSetFields(
- parent = null, outerThis, capturingClass, accessedFields)
-
- val outerField = func.getClass.getDeclaredField("arg$1")
- // SPARK-37072: When Java 17 is used and `outerField` is read-only,
- // the content of `outerField` cannot be set by reflect api directly.
- // But we can remove the `final` modifier of `outerField` before set
value
- // and reset the modifier after set value.
- val modifiersField = getFinalModifiersFieldForJava17(outerField)
- modifiersField
- .foreach(m => m.setInt(outerField, outerField.getModifiers &
~Modifier.FINAL))
- outerField.setAccessible(true)
- outerField.set(func, clonedOuterThis)
- modifiersField
- .foreach(m => m.setInt(outerField, outerField.getModifiers |
Modifier.FINAL))
+ // clean only if enclosing "this" is something cleanable, i.e. a Scala
REPL line object or
+ // Ammonite command helper object.
+ // For Ammonite closures, we do not care about actual capturing class
name,
+ // as closure needs to be cleaned if it captures Ammonite command helper
object
+ if (isDefinedInAmmonite(outerThis.getClass)) {
+ // If outerThis is a lambda, we have to clean that instead
+ IndylambdaScalaClosures.getSerializationProxy(outerThis).foreach { _ =>
+ return clean(outerThis, cleanTransitively, accessedFields)
+ }
+ cleanupAmmoniteReplClosure(func, lambdaProxy, outerThis,
cleanTransitively)
+ } else {
+ val isClosureDeclaredInScalaRepl =
capturingClassName.startsWith("$line") &&
+ capturingClassName.endsWith("$iw")
+ if (isClosureDeclaredInScalaRepl && outerThis.getClass.getName ==
capturingClassName) {
+ assert(accessedFields.isEmpty)
+ cleanupScalaReplClosure(func, lambdaProxy, outerThis,
cleanTransitively)
}
}
logDebug(s" +++ indylambda closure ($implMethodName) is now cleaned +++")
}
- if (checkSerializable) {
- ensureSerializable(func)
+ true
+ }
+
+ /**
+ * Cleans non-indylambda closure in place
+ *
+ * @param func the starting closure to clean
+ * @param cleanTransitively whether to clean enclosing closures transitively
+ * @param accessedFields a map from a class to a set of its fields that
are accessed by
+ * the starting closure
+ */
+ private def cleanNonIndyLambdaClosure(
+ func: AnyRef,
+ cleanTransitively: Boolean,
+ accessedFields: Map[Class[_], Set[String]]): Unit = {
+ logDebug(s"+++ Cleaning closure $func (${func.getClass.getName}) +++")
+
+ // A list of classes that represents closures enclosed in the given one
+ val innerClasses = getInnerClosureClasses(func)
+
+ // A list of enclosing objects and their respective classes, from
innermost to outermost
+ // An outer object at a given index is of type outer class at the same
index
+ val (outerClasses, outerObjects) = getOuterClassesAndObjects(func)
+
+ // For logging purposes only
+ val declaredFields = func.getClass.getDeclaredFields
+ val declaredMethods = func.getClass.getDeclaredMethods
+
+ if (log.isDebugEnabled) {
+ logDebug(s" + declared fields: ${declaredFields.size}")
+ declaredFields.foreach { f => logDebug(s" $f") }
+ logDebug(s" + declared methods: ${declaredMethods.size}")
+ declaredMethods.foreach { m => logDebug(s" $m") }
+ logDebug(s" + inner classes: ${innerClasses.size}")
+ innerClasses.foreach { c => logDebug(s" ${c.getName}") }
+ logDebug(s" + outer classes: ${outerClasses.size}")
+ outerClasses.foreach { c => logDebug(s" ${c.getName}") }
}
+
+ // Fail fast if we detect return statements in closures
+ getClassReader(func.getClass).accept(new ReturnStatementFinder(), 0)
+
+ // If accessed fields is not populated yet, we assume that
+ // the closure we are trying to clean is the starting one
+ if (accessedFields.isEmpty) {
+ logDebug(" + populating accessed fields because this is the starting
closure")
+ // Initialize accessed fields with the outer classes first
+ // This step is needed to associate the fields to the correct classes
later
+ initAccessedFields(accessedFields, outerClasses)
+
+ // Populate accessed fields by visiting all fields and methods accessed
by this and
+ // all of its inner closures. If transitive cleaning is enabled, this
may recursively
+ // visits methods that belong to other classes in search of transitively
referenced fields.
+ for (cls <- func.getClass :: innerClasses) {
+ getClassReader(cls).accept(new FieldAccessFinder(accessedFields,
cleanTransitively), 0)
+ }
+ }
+
+ logDebug(s" + fields accessed by starting closure: ${accessedFields.size}
classes")
+ accessedFields.foreach { f => logDebug(" " + f) }
+
+ // List of outer (class, object) pairs, ordered from outermost to innermost
+ // Note that all outer objects but the outermost one (first one in this
list) must be closures
+ var outerPairs: List[(Class[_], AnyRef)] =
outerClasses.zip(outerObjects).reverse
+ var parent: AnyRef = null
+ if (outerPairs.nonEmpty) {
+ val outermostClass = outerPairs.head._1
+ val outermostObject = outerPairs.head._2
+
+ if (isClosure(outermostClass)) {
+ logDebug(s" + outermost object is a closure, so we clone it:
${outermostClass}")
+ } else if (outermostClass.getName.startsWith("$line")) {
+ // SPARK-14558: if the outermost object is a REPL line object, we
should clone
+ // and clean it as it may carry a lot of unnecessary information,
+ // e.g. hadoop conf, spark conf, etc.
+ logDebug(s" + outermost object is a REPL line object, so we clone it:"
+
+ s" ${outermostClass}")
+ } else {
+ // The closure is ultimately nested inside a class; keep the object of
that
+ // class without cloning it since we don't want to clone the user's
objects.
+ // Note that we still need to keep around the outermost object itself
because
+ // we need it to clone its child closure later (see below).
+ logDebug(s" + outermost object is not a closure or REPL line object," +
+ s" so do not clone it: ${outermostClass}")
+ parent = outermostObject // e.g. SparkContext
+ outerPairs = outerPairs.tail
+ }
+ } else {
+ logDebug(" + there are no enclosing objects!")
+ }
+
+ // Clone the closure objects themselves, nulling out any fields that are
not
+ // used in the closure we're working on or any of its inner closures.
+ for ((cls, obj) <- outerPairs) {
+ logDebug(s" + cloning instance of class ${cls.getName}")
+ // We null out these unused references by cloning each object and then
filling in all
+ // required fields from the original object. We need the parent here
because the Java
+ // language specification requires the first constructor parameter of
any closure to be
+ // its enclosing object.
+ val clone = cloneAndSetFields(parent, obj, cls, accessedFields)
+
+ // If transitive cleaning is enabled, we recursively clean any enclosing
closure using
+ // the already populated accessed fields map of the starting closure
+ if (cleanTransitively && isClosure(clone.getClass)) {
+ logDebug(s" + cleaning cloned closure recursively (${cls.getName})")
+ clean(clone, cleanTransitively, accessedFields)
+ }
+ parent = clone
+ }
+
+ // Update the parent pointer ($outer) of this closure
+ if (parent != null) {
+ val field = func.getClass.getDeclaredField("$outer")
+ field.setAccessible(true)
+ // If the starting closure doesn't actually need our enclosing object,
then just null it out
+ if (accessedFields.contains(func.getClass) &&
+ !accessedFields(func.getClass).contains("$outer")) {
+ logDebug(s" + the starting closure doesn't actually need $parent, so
we null it out")
+ field.set(func, null)
+ } else {
+ // Update this closure's parent pointer to point to our enclosing
object,
+ // which could either be a cloned closure or the original user object
+ field.set(func, parent)
+ }
+ }
+
+ logDebug(s" +++ closure $func (${func.getClass.getName}) is now cleaned
+++")
+ }
+
+ /**
+ * Null out fields of enclosing class which are not actually accessed by a
closure
+ * @param func the starting closure to clean
+ * @param lambdaProxy starting closure proxy
+ * @param outerThis lambda enclosing class
+ * @param cleanTransitively whether to clean enclosing closures transitively
+ */
+ private def cleanupScalaReplClosure(
+ func: AnyRef,
+ lambdaProxy: SerializedLambda,
+ outerThis: AnyRef,
+ cleanTransitively: Boolean): Unit = {
+
+ val capturingClass = outerThis.getClass
+ val accessedFields: Map[Class[_], Set[String]] = Map.empty
+ initAccessedFields(accessedFields, Seq(capturingClass))
+
+ IndylambdaScalaClosures.findAccessedFields(
+ lambdaProxy,
+ func.getClass.getClassLoader,
+ accessedFields,
+ Map.empty,
+ Map.empty,
+ cleanTransitively)
+
+ logDebug(s" + fields accessed by starting closure: ${accessedFields.size}
classes")
+ accessedFields.foreach { f => logDebug(" " + f) }
+
+ if (accessedFields(capturingClass).size <
capturingClass.getDeclaredFields.length) {
+ // clone and clean the enclosing `this` only when there are fields to
null out
+ logDebug(s" + cloning instance of REPL class ${capturingClass.getName}")
+ val clonedOuterThis = cloneAndSetFields(
+ parent = null, outerThis, capturingClass, accessedFields)
+
+ val outerField = func.getClass.getDeclaredField("arg$1")
+ // SPARK-37072: When Java 17 is used and `outerField` is read-only,
+ // the content of `outerField` cannot be set by reflect api directly.
+ // But we can remove the `final` modifier of `outerField` before set
value
+ // and reset the modifier after set value.
+ setFieldAndIgnoreModifiers(func, outerField, clonedOuterThis)
+ }
+ }
+
+
+ /**
+ * Cleans up Ammonite closures and nulls out fields captured from cmd &
cmd$Helper objects
+ * but not actually accessed by the closure. To achieve this, it does:
+ * 1. Identify all accessed Ammonite cmd & cmd$Helper objects
+ * 2. Clone all accessed cmdX objects
+ * 3. Clone all accessed cmdX$Helper objects and set their $outer field to
the cmdX clone
+ * 4. Iterate over these clones and set all other accessed fields to
+ * - a clone, if the field refers to an Ammonite object
+ * - a previous value otherwise
+ * 5. In case if capturing object is an inner class of Ammonite cmd$Helper
object, clone & update
+ * this capturing object as well
+ *
+ * As a result:
+ * - For all accessed cmdX objects all their references to cmdY$Helper
objects are
+ * either nulled out or updated to cmdY clone
+ * - For cmdX$Helper objects it means that variables defined in this
command are
+ * nulled out if not accessed
+ * - lambda enclosing class is cleaned up as it's done for normal Scala
closures
+ *
+ * @param func the starting closure to clean
+ * @param lambdaProxy starting closure proxy
+ * @param outerThis lambda enclosing class
+ * @param cleanTransitively whether to clean enclosing closures transitively
+ */
+ private def cleanupAmmoniteReplClosure(
+ func: AnyRef,
+ lambdaProxy: SerializedLambda,
+ outerThis: AnyRef,
+ cleanTransitively: Boolean): Unit = {
+
+ val accessedFields: Map[Class[_], Set[String]] = Map.empty
+ initAccessedFields(accessedFields, Seq(outerThis.getClass))
+
+ // Ammonite generates 3 classes for a command number X:
+ // - cmdX class containing all dependencies needed to execute the command
+ // (i.e. previous command helpers)
+ // - cmdX$Helper - inner class of cmdX - containing the user code. It
pulls
+ // required dependencies (i.e. variables defined in other commands) from
outer command
+ // - cmdX companion object holding an instance of cmdX and cmdX$Helper
classes.
+ // Here, we care only about command objects and their helpers, companion
objects are
+ // not captured by closure
+
+ // instances of cmdX and cmdX$Helper
+ val ammCmdInstances: Map[Class[_], AnyRef] = Map.empty
+ // fields accessed in those commands
+ val accessedAmmCmdFields: Map[Class[_], Set[String]] = Map.empty
+ // outer class may be either Ammonite cmd / cmd$Helper class or an inner
class
+ // defined in a user code. We need to clean up Ammonite classes only
+ if (isAmmoniteCommandOrHelper(outerThis.getClass)) {
+ ammCmdInstances(outerThis.getClass) = outerThis
+ accessedAmmCmdFields(outerThis.getClass) = Set.empty
+ }
+
+ IndylambdaScalaClosures.findAccessedFields(
+ lambdaProxy,
+ func.getClass.getClassLoader,
+ accessedFields,
+ accessedAmmCmdFields,
+ ammCmdInstances,
+ cleanTransitively)
+
+ logTrace(s" + command fields accessed by starting closure: " +
+ s"${accessedAmmCmdFields.size} classes")
+ accessedAmmCmdFields.foreach { f => logTrace(" " + f) }
+
+ val cmdClones = Map[Class[_], AnyRef]()
+ for ((cmdClass, _) <- ammCmdInstances if
!cmdClass.getName.contains("Helper")) {
+ logDebug(s" + Cloning instance of Ammonite command class
${cmdClass.getName}")
+ cmdClones(cmdClass) = instantiateClass(cmdClass, enclosingObject = null)
+ }
+ for ((cmdHelperClass, cmdHelperInstance) <- ammCmdInstances
+ if cmdHelperClass.getName.contains("Helper")) {
+ val cmdHelperOuter = cmdHelperClass.getDeclaredFields
+ .find(_.getName == "$outer")
+ .map { field =>
+ field.setAccessible(true)
+ field.get(cmdHelperInstance)
+ }
+ val outerClone = cmdHelperOuter.flatMap(o =>
cmdClones.get(o.getClass)).orNull
+ logDebug(s" + Cloning instance of Ammonite command helper class
${cmdHelperClass.getName}")
+ cmdClones(cmdHelperClass) =
+ instantiateClass(cmdHelperClass, enclosingObject = outerClone)
+ }
+
+ // set accessed fields
+ for ((_, cmdClone) <- cmdClones) {
+ val cmdClass = cmdClone.getClass
+ val accessedFields = accessedAmmCmdFields(cmdClass)
+ for (field <- cmdClone.getClass.getDeclaredFields
+ // outer fields were initialized during clone construction
+ if accessedFields.contains(field.getName) && field.getName !=
"$outer") {
+ // get command clone if exists, otherwise use an original field value
+ val value = cmdClones.getOrElse(field.getType, {
+ field.setAccessible(true)
+ field.get(ammCmdInstances(cmdClass))
+ })
+ setFieldAndIgnoreModifiers(cmdClone, field, value)
+ }
+ }
+
+ val outerThisClone = if (!isAmmoniteCommandOrHelper(outerThis.getClass)) {
+ // if outer class is not Ammonite helper / command object then is was
not cloned
+ // in the code above. We still need to clone it and update accessed
fields
+ logDebug(s" + Cloning instance of lambda capturing class
${outerThis.getClass.getName}")
+ val clone = cloneAndSetFields(parent = null, outerThis,
outerThis.getClass, accessedFields)
+ // making sure that the code below will update references to Ammonite
objects if they exist
+ for (field <- outerThis.getClass.getDeclaredFields) {
+ field.setAccessible(true)
+ cmdClones.get(field.getType).foreach { value =>
+ setFieldAndIgnoreModifiers(clone, field, value)
+ }
+ }
+ clone
+ } else {
+ cmdClones(outerThis.getClass)
+ }
+
+ val outerField = func.getClass.getDeclaredField("arg$1")
+ // update lambda capturing class reference
+ setFieldAndIgnoreModifiers(func, outerField, outerThisClone)
+ }
+
+ private def setFieldAndIgnoreModifiers(obj: AnyRef, field: Field, value:
AnyRef): Unit = {
+ val modifiersField = getFinalModifiersFieldForJava17(field)
+ modifiersField
+ .foreach(m => m.setInt(field, field.getModifiers & ~Modifier.FINAL))
+ field.setAccessible(true)
+ field.set(obj, value)
+
+ modifiersField
+ .foreach(m => m.setInt(field, field.getModifiers | Modifier.FINAL))
}
/**
@@ -434,19 +584,7 @@ private[spark] object ClosureCleaner extends Logging {
} else None
}
- private def ensureSerializable(func: AnyRef): Unit = {
- try {
- if (SparkEnv.get != null) {
- SparkEnv.get.closureSerializer.newInstance().serialize(func)
- }
- } catch {
- case ex: Exception => throw new SparkException("Task not serializable",
ex)
- }
- }
-
- private def instantiateClass(
- cls: Class[_],
- enclosingObject: AnyRef): AnyRef = {
+ private def instantiateClass(cls: Class[_], enclosingObject: AnyRef): AnyRef
= {
// Use reflection to instantiate object without calling constructor
val rf = sun.reflect.ReflectionFactory.getReflectionFactory()
val parentCtor = classOf[java.lang.Object].getDeclaredConstructor()
@@ -561,6 +699,9 @@ private[spark] object IndylambdaScalaClosures extends
Logging {
* same for all three combined, so they can be fused together easily while
maintaining the same
* ordering as the existing implementation.
*
+ * It also visits transitively Ammonite cmd and cmd%Helper objects it
encounters
+ * and populates accessed fields for them to be able to clean up these as
well
+ *
* Precondition: this function expects the `accessedFields` to be populated
with all known
* outer classes and their super classes to be in the map as
keys, e.g.
* initializing via ClosureCleaner.initAccessedFields.
@@ -630,6 +771,8 @@ private[spark] object IndylambdaScalaClosures extends
Logging {
lambdaProxy: SerializedLambda,
lambdaClassLoader: ClassLoader,
accessedFields: Map[Class[_], Set[String]],
+ accessedAmmCmdFields: Map[Class[_], Set[String]],
+ ammCmdInstances: Map[Class[_], AnyRef],
findTransitively: Boolean): Unit = {
// We may need to visit the same class multiple times for different
methods on it, and we'll
@@ -642,15 +785,30 @@ private[spark] object IndylambdaScalaClosures extends
Logging {
// scalastyle:off classforname
val clazz = Class.forName(classExternalName, false, lambdaClassLoader)
// scalastyle:on classforname
- val classNode = new ClassNode()
- val classReader = ClosureCleaner.getClassReader(clazz)
- classReader.accept(classNode, 0)
- for (m <- classNode.methods.asScala) {
- methodNodeById(MethodIdentifier(clazz, m.name, m.desc)) = m
+ def getClassNode(clazz: Class[_]): ClassNode = {
+ val classNode = new ClassNode()
+ val classReader = ClosureCleaner.getClassReader(clazz)
+ classReader.accept(classNode, 0)
+ classNode
}
- (clazz, classNode)
+ var curClazz = clazz
+ // we need to add superclass methods as well
+ // e.g. consider the following closure:
+ // object Enclosing {
+ // val closure = () => getClass.getName
+ // }
+ // To scan this closure properly, we need to add Object.getClass method
+ // to methodNodeById map
+ while (curClazz != null) {
+ for (m <- getClassNode(curClazz).methods.asScala) {
+ methodNodeById(MethodIdentifier(clazz, m.name, m.desc)) = m
+ }
+ curClazz = curClazz.getSuperclass
+ }
+
+ (clazz, getClassNode(clazz))
})
classInfo
}
@@ -674,21 +832,55 @@ private[spark] object IndylambdaScalaClosures extends
Logging {
// to better find and track field accesses.
val trackedClassInternalNames = Set[String](implClassInternalName)
- // Depth-first search for inner closures and track the fields that were
accessed in them.
+ // Breadth-first search for inner closures and track the fields that were
accessed in them.
// Start from the lambda body's implementation method, follow method
invocations
val visited = Set.empty[MethodIdentifier[_]]
- val stack = Stack[MethodIdentifier[_]](implMethodId)
+ // Depth-first search will not work there. To make
addAmmoniteCommandFieldsToTracking to work
+ // we need to process objects in order they appear in the reference tree.
+ // E.g. if there was a reference chain a -> b -> c, then DFS will process
these nodes in order
+ // a -> c -> b. However, to initialize ammCmdInstances(c.getClass) we need
to process node b
+ // first.
+ val queue = Queue[MethodIdentifier[_]](implMethodId)
def pushIfNotVisited(methodId: MethodIdentifier[_]): Unit = {
if (!visited.contains(methodId)) {
- stack.push(methodId)
+ queue.enqueue(methodId)
}
}
- while (!stack.isEmpty) {
- val currentId = stack.pop()
+ def addAmmoniteCommandFieldsToTracking(currentClass: Class[_]): Unit = {
+ // get an instance of currentClass. It can be either lambda enclosing
this
+ // or another already processed Ammonite object
+ val currentInstance = if (currentClass ==
lambdaProxy.getCapturedArg(0).getClass) {
+ Some(lambdaProxy.getCapturedArg(0))
+ } else {
+ // This key exists if we encountered a non-null reference to
`currentClass` before
+ // as we're processing nodes with a breadth-first search (see comment
above)
+ ammCmdInstances.get(currentClass)
+ }
+ currentInstance.foreach { cmdInstance =>
+ // track only cmdX and cmdX$Helper objects generated by Ammonite
+ for (otherCmdField <- cmdInstance.getClass.getDeclaredFields
+ if
ClosureCleaner.isAmmoniteCommandOrHelper(otherCmdField.getType)) {
+ otherCmdField.setAccessible(true)
+ val otherCmdHelperRef = otherCmdField.get(cmdInstance)
+ val otherCmdClass = otherCmdField.getType
+ // Ammonite is clever enough to sometimes nullify references to
unused commands.
+ // Ignoring these references for simplicity
+ if (otherCmdHelperRef != null &&
!ammCmdInstances.contains(otherCmdClass)) {
+ logTrace(s" started tracking ${otherCmdClass.getName}
Ammonite object")
+ ammCmdInstances(otherCmdClass) = otherCmdHelperRef
+ accessedAmmCmdFields(otherCmdClass) = Set()
+ }
+ }
+ }
+ }
+
+ while (queue.nonEmpty) {
+ val currentId = queue.dequeue()
visited += currentId
val currentClass = currentId.cls
+ addAmmoniteCommandFieldsToTracking(currentClass)
val currentMethodNode = methodNodeById(currentId)
logTrace(s" scanning
${currentId.cls.getName}.${currentId.name}${currentId.desc}")
currentMethodNode.accept(new MethodVisitor(ASM9) {
@@ -704,6 +896,10 @@ private[spark] object IndylambdaScalaClosures extends
Logging {
logTrace(s" found field access $name on $ownerExternalName")
accessedFields(cl) += name
}
+ for (cl <- accessedAmmCmdFields.keys if cl.getName ==
ownerExternalName) {
+ logTrace(s" found Ammonite command field access $name on
$ownerExternalName")
+ accessedAmmCmdFields(cl) += name
+ }
}
}
@@ -714,6 +910,10 @@ private[spark] object IndylambdaScalaClosures extends
Logging {
logTrace(s" found intra class call to
$ownerExternalName.$name$desc")
// could be invoking a helper method or a field accessor method,
just follow it.
pushIfNotVisited(MethodIdentifier(currentClass, name, desc))
+ } else if (owner.startsWith("ammonite/$sess/cmd")) {
+ // we're inside Ammonite command / command helper object, track
all calls from here
+ val classInfo = getOrUpdateClassInfo(owner)
+ pushIfNotVisited(MethodIdentifier(classInfo._1, name, desc))
} else if (isInnerClassCtorCapturingOuter(
op, owner, name, desc, currentClassInternalName)) {
// Discover inner classes.
@@ -894,8 +1094,10 @@ private class InnerClosureFinder(output: Set[Class[_]])
extends ClassVisitor(ASM
if (op == INVOKESPECIAL && name == "<init>" && argTypes.length > 0
&& argTypes(0).toString.startsWith("L") // is it an object?
&& argTypes(0).getInternalName == myName) {
- output += Utils.classForName(owner.replace('/', '.'),
- initialize = false, noSparkClassLoader = true)
+ output += SparkClassUtils.classForName(
+ owner.replace('/', '.'),
+ initialize = false,
+ noSparkClassLoader = true)
}
}
}
diff --git
a/common/utils/src/main/scala/org/apache/spark/util/SparkStreamUtils.scala
b/common/utils/src/main/scala/org/apache/spark/util/SparkStreamUtils.scala
new file mode 100644
index 00000000000..b9148901f1a
--- /dev/null
+++ b/common/utils/src/main/scala/org/apache/spark/util/SparkStreamUtils.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.util
+
+import java.io.{FileInputStream, FileOutputStream, InputStream, OutputStream}
+import java.nio.channels.{FileChannel, WritableByteChannel}
+
+import org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally
+
+private[spark] trait SparkStreamUtils {
+
+ /**
+ * Copy all data from an InputStream to an OutputStream. NIO way of file
stream to file stream
+ * copying is disabled by default unless explicitly set transferToEnabled as
true, the parameter
+ * transferToEnabled should be configured by spark.file.transferTo =
[true|false].
+ */
+ def copyStream(
+ in: InputStream,
+ out: OutputStream,
+ closeStreams: Boolean = false,
+ transferToEnabled: Boolean = false): Long = {
+ tryWithSafeFinally {
+ (in, out) match {
+ case (input: FileInputStream, output: FileOutputStream) if
transferToEnabled =>
+ // When both streams are File stream, use transferTo to improve copy
performance.
+ val inChannel = input.getChannel
+ val outChannel = output.getChannel
+ val size = inChannel.size()
+ copyFileStreamNIO(inChannel, outChannel, 0, size)
+ size
+ case (input, output) =>
+ var count = 0L
+ val buf = new Array[Byte](8192)
+ var n = 0
+ while (n != -1) {
+ n = input.read(buf)
+ if (n != -1) {
+ output.write(buf, 0, n)
+ count += n
+ }
+ }
+ count
+ }
+ } {
+ if (closeStreams) {
+ try {
+ in.close()
+ } finally {
+ out.close()
+ }
+ }
+ }
+ }
+
+ def copyFileStreamNIO(
+ input: FileChannel,
+ output: WritableByteChannel,
+ startPosition: Long,
+ bytesToCopy: Long): Unit = {
+ val outputInitialState = output match {
+ case outputFileChannel: FileChannel =>
+ Some((outputFileChannel.position(), outputFileChannel))
+ case _ => None
+ }
+ var count = 0L
+ // In case transferTo method transferred less data than we have required.
+ while (count < bytesToCopy) {
+ count += input.transferTo(count + startPosition, bytesToCopy - count,
output)
+ }
+ assert(
+ count == bytesToCopy,
+ s"request to copy $bytesToCopy bytes, but actually copied $count bytes.")
+
+ // Check the position after transferTo loop to see if it is in the right
position and
+ // give user information if not.
+ // Position will not be increased to the expected length after calling
transferTo in
+ // kernel version 2.6.32, this issue can be seen in
+ // https://bugs.openjdk.java.net/browse/JDK-7052359
+ // This will lead to stream corruption issue when using sort-based shuffle
(SPARK-3948).
+ outputInitialState.foreach { case (initialPos, outputFileChannel) =>
+ val finalPos = outputFileChannel.position()
+ val expectedPos = initialPos + bytesToCopy
+ assert(
+ finalPos == expectedPos,
+ s"""
+ |Current position $finalPos do not equal to expected position
$expectedPos
+ |after transferTo, please check your kernel version to see if it is
2.6.32,
+ |this is a kernel bug which will lead to unexpected behavior when
using transferTo.
+ |You can set spark.file.transferTo = false to disable this NIO
feature.
+ """.stripMargin)
+ }
+ }
+}
+
+private [spark] object SparkStreamUtils extends SparkStreamUtils
diff --git
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index dcc038eb51d..c4431e9a87f 100644
---
a/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++
b/connector/connect/client/jvm/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -16,6 +16,7 @@
*/
package org.apache.spark.sql.expressions
+import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
@@ -29,7 +30,7 @@ import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, UdfPacket}
import org.apache.spark.sql.types.DataType
-import org.apache.spark.util.{SparkClassUtils, SparkSerDeUtils}
+import org.apache.spark.util.{ClosureCleaner, SparkClassUtils, SparkSerDeUtils}
/**
* A user-defined function. To create one, use the `udf` functions in
`functions`.
@@ -183,6 +184,7 @@ object ScalarUserDefinedFunction {
function: AnyRef,
inputEncoders: Seq[AgnosticEncoder[_]],
outputEncoder: AgnosticEncoder[_]): ScalarUserDefinedFunction = {
+ SparkConnectClosureCleaner.clean(function)
val udfPacketBytes =
SparkSerDeUtils.serialize(UdfPacket(function, inputEncoders,
outputEncoder))
checkDeserializable(udfPacketBytes)
@@ -202,3 +204,9 @@ object ScalarUserDefinedFunction {
outputEncoder = RowEncoder.encoderForDataType(returnType, lenient =
false))
}
}
+
+private object SparkConnectClosureCleaner {
+ def clean(closure: AnyRef): Unit = {
+ ClosureCleaner.clean(closure, cleanTransitively = true, mutable.Map.empty)
+ }
+}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
index 5bb8cbf3543..51e58f9b0bb 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
@@ -362,4 +362,147 @@ class ReplE2ESuite extends RemoteSparkSession with
BeforeAndAfterEach {
val output = runCommandsInShell(input)
assertContains("noException: Boolean = true", output)
}
+
+ test("closure cleaner") {
+ val input =
+ """
+ |class NonSerializable(val id: Int = -1) { }
+ |
+ |{
+ | val x = 100
+ | val y = new NonSerializable
+ |}
+ |
+ |val t = 200
+ |
+ |{
+ | def foo(): Int = { x }
+ | def bar(): Int = { y.id }
+ | val z = new NonSerializable
+ |}
+ |
+ |{
+ | val myLambda = (a: Int) => a + t + foo()
+ | val myUdf = udf(myLambda)
+ |}
+ |
+ |spark.range(0, 10).
+ | withColumn("result", myUdf(col("id"))).
+ | agg(sum("result")).
+ | collect()(0)(0).asInstanceOf[Long]
+ |""".stripMargin
+ val output = runCommandsInShell(input)
+ assertContains(": Long = 3045", output)
+ }
+
+ test("closure cleaner with function") {
+ val input =
+ """
+ |class NonSerializable(val id: Int = -1) { }
+ |
+ |{
+ | val x = 100
+ | val y = new NonSerializable
+ |}
+ |
+ |{
+ | def foo(): Int = { x }
+ | def bar(): Int = { y.id }
+ | val z = new NonSerializable
+ |}
+ |
+ |def example() = {
+ | val myLambda = (a: Int) => a + foo()
+ | val myUdf = udf(myLambda)
+ | spark.range(0, 10).
+ | withColumn("result", myUdf(col("id"))).
+ | agg(sum("result")).
+ | collect()(0)(0).asInstanceOf[Long]
+ |}
+ |
+ |example()
+ |""".stripMargin
+ val output = runCommandsInShell(input)
+ assertContains(": Long = 1045", output)
+ }
+
+ test("closure cleaner nested") {
+ val input =
+ """
+ |class NonSerializable(val id: Int = -1) { }
+ |
+ |{
+ | val x = 100
+ | val y = new NonSerializable
+ |}
+ |
+ |{
+ | def foo(): Int = { x }
+ | def bar(): Int = { y.id }
+ | val z = new NonSerializable
+ |}
+ |
+ |val example = () => {
+ | val nested = () => {
+ | val myLambda = (a: Int) => a + foo()
+ | val myUdf = udf(myLambda)
+ | spark.range(0, 10).
+ | withColumn("result", myUdf(col("id"))).
+ | agg(sum("result")).
+ | collect()(0)(0).asInstanceOf[Long]
+ | }
+ | nested()
+ |}
+ |example()
+ |""".stripMargin
+ val output = runCommandsInShell(input)
+ assertContains(": Long = 1045", output)
+ }
+
+ test("closure cleaner with enclosing lambdas") {
+ val input =
+ """
+ |class NonSerializable(val id: Int = -1) { }
+ |
+ |{
+ | val x = 100
+ | val y = new NonSerializable
+ |}
+ |
+ |val z = new NonSerializable
+ |
+ |spark.range(0, 10).
+ |// for this call UdfUtils will create a new lambda and this lambda
becomes enclosing
+ | map(i => i + x).
+ | agg(sum("value")).
+ | collect()(0)(0).asInstanceOf[Long]
+ |""".stripMargin
+ val output = runCommandsInShell(input)
+ assertContains(": Long = 1045", output)
+ }
+
+ test("closure cleaner cleans capturing class") {
+ val input =
+ """
+ |class NonSerializable(val id: Int = -1) { }
+ |
+ |{
+ | val x = 100
+ | val y = new NonSerializable
+ |}
+ |
+ |class Test extends Serializable {
+ | // capturing class is cmd$Helper$Test
+ | val myUdf = udf((i: Int) => i + x)
+ | val z = new NonSerializable
+ | val res = spark.range(0, 10).
+ | withColumn("result", myUdf(col("id"))).
+ | agg(sum("result")).
+ | collect()(0)(0).asInstanceOf[Long]
+ |}
+ |(new Test()).res
+ |""".stripMargin
+ val output = runCommandsInShell(input)
+ assertContains(": Long = 1045", output)
+ }
}
diff --git
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
index 785e1fa4017..7ddb339b12d 100644
---
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
+++
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/CheckConnectJvmClientCompatibility.scala
@@ -324,6 +324,14 @@ object CheckConnectJvmClientCompatibility {
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.expressions.ScalarUserDefinedFunction$"),
+ // New private API added in the client
+ ProblemFilters
+ .exclude[MissingClassProblem](
+ "org.apache.spark.sql.expressions.SparkConnectClosureCleaner"),
+ ProblemFilters
+ .exclude[MissingClassProblem](
+ "org.apache.spark.sql.expressions.SparkConnectClosureCleaner$"),
+
// Dataset
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.Dataset.plan"
diff --git a/core/pom.xml b/core/pom.xml
index 5ac3d5bb4de..e55283b75fa 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -64,10 +64,6 @@
<artifactId>jnr-posix</artifactId>
<scope>test</scope>
</dependency>
- <dependency>
- <groupId>org.apache.xbean</groupId>
- <artifactId>xbean-asm9-shaded</artifactId>
- </dependency>
<dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-client-api</artifactId>
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala
b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 893895e8fb2..c86f755bbd1 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -2699,7 +2699,7 @@ class SparkContext(config: SparkConf) extends Logging {
* @return the cleaned closure
*/
private[spark] def clean[F <: AnyRef](f: F, checkSerializable: Boolean =
true): F = {
- ClosureCleaner.clean(f, checkSerializable)
+ SparkClosureCleaner.clean(f, checkSerializable)
f
}
diff --git
a/core/src/main/scala/org/apache/spark/util/SparkClosureCleaner.scala
b/core/src/main/scala/org/apache/spark/util/SparkClosureCleaner.scala
new file mode 100644
index 00000000000..44e0efb4494
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/SparkClosureCleaner.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import scala.collection.mutable
+
+import org.apache.spark.{SparkEnv, SparkException}
+
+private[spark] object SparkClosureCleaner {
+ /**
+ * Clean the given closure in place.
+ *
+ * More specifically, this renders the given closure serializable as long as
it does not
+ * explicitly reference unserializable objects.
+ *
+ * @param closure the closure to clean
+ * @param checkSerializable whether to verify that the closure is
serializable after cleaning
+ * @param cleanTransitively whether to clean enclosing closures transitively
+ */
+ def clean(
+ closure: AnyRef,
+ checkSerializable: Boolean = true,
+ cleanTransitively: Boolean = true): Unit = {
+ if (ClosureCleaner.clean(closure, cleanTransitively, mutable.Map.empty)) {
+ try {
+ if (checkSerializable && SparkEnv.get != null) {
+ SparkEnv.get.closureSerializer.newInstance().serialize(closure)
+ }
+ } catch {
+ case ex: Exception => throw new SparkException("Task not
serializable", ex)
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala
b/core/src/main/scala/org/apache/spark/util/Utils.scala
index f8decbcff5f..f22bec5c2be 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -24,7 +24,7 @@ import java.lang.reflect.InvocationTargetException
import java.math.{MathContext, RoundingMode}
import java.net._
import java.nio.ByteBuffer
-import java.nio.channels.{Channels, FileChannel, WritableByteChannel}
+import java.nio.channels.Channels
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.security.SecureRandom
@@ -97,7 +97,8 @@ private[spark] object Utils
with SparkClassUtils
with SparkErrorUtils
with SparkFileUtils
- with SparkSerDeUtils {
+ with SparkSerDeUtils
+ with SparkStreamUtils {
private val sparkUncaughtExceptionHandler = new SparkUncaughtExceptionHandler
@volatile private var cachedLocalDir: String = ""
@@ -244,49 +245,6 @@ private[spark] object Utils
dir
}
- /**
- * Copy all data from an InputStream to an OutputStream. NIO way of file
stream to file stream
- * copying is disabled by default unless explicitly set transferToEnabled as
true,
- * the parameter transferToEnabled should be configured by
spark.file.transferTo = [true|false].
- */
- def copyStream(
- in: InputStream,
- out: OutputStream,
- closeStreams: Boolean = false,
- transferToEnabled: Boolean = false): Long = {
- tryWithSafeFinally {
- (in, out) match {
- case (input: FileInputStream, output: FileOutputStream) if
transferToEnabled =>
- // When both streams are File stream, use transferTo to improve copy
performance.
- val inChannel = input.getChannel
- val outChannel = output.getChannel
- val size = inChannel.size()
- copyFileStreamNIO(inChannel, outChannel, 0, size)
- size
- case (input, output) =>
- var count = 0L
- val buf = new Array[Byte](8192)
- var n = 0
- while (n != -1) {
- n = input.read(buf)
- if (n != -1) {
- output.write(buf, 0, n)
- count += n
- }
- }
- count
- }
- } {
- if (closeStreams) {
- try {
- in.close()
- } finally {
- out.close()
- }
- }
- }
- }
-
/**
* Copy the first `maxSize` bytes of data from the InputStream to an
in-memory
* buffer, primarily to check for corruption.
@@ -331,43 +289,6 @@ private[spark] object Utils
}
}
- def copyFileStreamNIO(
- input: FileChannel,
- output: WritableByteChannel,
- startPosition: Long,
- bytesToCopy: Long): Unit = {
- val outputInitialState = output match {
- case outputFileChannel: FileChannel =>
- Some((outputFileChannel.position(), outputFileChannel))
- case _ => None
- }
- var count = 0L
- // In case transferTo method transferred less data than we have required.
- while (count < bytesToCopy) {
- count += input.transferTo(count + startPosition, bytesToCopy - count,
output)
- }
- assert(count == bytesToCopy,
- s"request to copy $bytesToCopy bytes, but actually copied $count bytes.")
-
- // Check the position after transferTo loop to see if it is in the right
position and
- // give user information if not.
- // Position will not be increased to the expected length after calling
transferTo in
- // kernel version 2.6.32, this issue can be seen in
- // https://bugs.openjdk.java.net/browse/JDK-7052359
- // This will lead to stream corruption issue when using sort-based shuffle
(SPARK-3948).
- outputInitialState.foreach { case (initialPos, outputFileChannel) =>
- val finalPos = outputFileChannel.position()
- val expectedPos = initialPos + bytesToCopy
- assert(finalPos == expectedPos,
- s"""
- |Current position $finalPos do not equal to expected position
$expectedPos
- |after transferTo, please check your kernel version to see if it is
2.6.32,
- |this is a kernel bug which will lead to unexpected behavior when
using transferTo.
- |You can set spark.file.transferTo = false to disable this NIO
feature.
- """.stripMargin)
- }
- }
-
/**
* A file name may contain some invalid URI characters, such as " ". This
method will convert the
* file name to a raw path accepted by `java.net.URI(String)`.
diff --git
a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
index cef0d8c1de0..2f084b2037e 100644
--- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
@@ -373,7 +373,7 @@ class TestCreateNullValue {
println(getX)
}
// scalastyle:on println
- ClosureCleaner.clean(closure)
+ SparkClosureCleaner.clean(closure)
}
nestedClosure()
}
diff --git
a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
index 0635b4a358a..b055dae3994 100644
--- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala
@@ -96,10 +96,10 @@ class ClosureCleanerSuite2 extends SparkFunSuite with
BeforeAndAfterAll with Pri
// If the resulting closure is not serializable even after
// cleaning, we expect ClosureCleaner to throw a SparkException
if (serializableAfter) {
- ClosureCleaner.clean(closure, checkSerializable = true, transitive)
+ SparkClosureCleaner.clean(closure, checkSerializable = true, transitive)
} else {
intercept[SparkException] {
- ClosureCleaner.clean(closure, checkSerializable = true, transitive)
+ SparkClosureCleaner.clean(closure, checkSerializable = true,
transitive)
}
}
assertSerializable(closure, serializableAfter)
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 47fd7881d2f..10864390e3f 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -43,7 +43,9 @@ object MimaExcludes {
// [SPARK-44198][CORE] Support propagation of the log level to the
executors
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages$SparkAppConfig$"),
// [SPARK-45427][CORE] Add RPC SSL settings to SSLOptions and
SparkTransportConf
-
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.SparkTransportConf.fromSparkConf")
+
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.network.netty.SparkTransportConf.fromSparkConf"),
+ // [SPARK-45136][CONNECT] Enhance ClosureCleaner with Ammonite support
+
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.MethodIdentifier$")
)
// Default exclude rules
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 1c77b87dbf1..dc5e22f0571 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -38,7 +38,7 @@ import org.apache.spark.sql.errors.QueryErrorsBase
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
-import org.apache.spark.util.{ClosureCleaner, Utils}
+import org.apache.spark.util.{SparkClosureCleaner, Utils}
case class RepeatedStruct(s: Seq[PrimitiveData])
@@ -689,7 +689,7 @@ class ExpressionEncoderSuite extends
CodegenInterpretedPlanTest with AnalysisTes
val encoder = implicitly[ExpressionEncoder[T]]
// Make sure encoder is serializable.
- ClosureCleaner.clean((s: String) => encoder.getClass.getName)
+ SparkClosureCleaner.clean((s: String) => encoder.getClass.getName)
val row = encoder.createSerializer().apply(input)
val schema = toAttributes(encoder.schema)
diff --git
a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
index dcd698c860d..f04b9da9b45 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala
@@ -22,7 +22,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.{JavaPairRDD, JavaUtils, Optional}
import org.apache.spark.api.java.function.{Function3 => JFunction3, Function4
=> JFunction4}
import org.apache.spark.rdd.RDD
-import org.apache.spark.util.ClosureCleaner
+import org.apache.spark.util.SparkClosureCleaner
/**
* :: Experimental ::
@@ -157,7 +157,7 @@ object StateSpec {
def function[KeyType, ValueType, StateType, MappedType](
mappingFunction: (Time, KeyType, Option[ValueType], State[StateType]) =>
Option[MappedType]
): StateSpec[KeyType, ValueType, StateType, MappedType] = {
- ClosureCleaner.clean(mappingFunction, checkSerializable = true)
+ SparkClosureCleaner.clean(mappingFunction, checkSerializable = true)
new StateSpecImpl(mappingFunction)
}
@@ -175,7 +175,7 @@ object StateSpec {
def function[KeyType, ValueType, StateType, MappedType](
mappingFunction: (KeyType, Option[ValueType], State[StateType]) =>
MappedType
): StateSpec[KeyType, ValueType, StateType, MappedType] = {
- ClosureCleaner.clean(mappingFunction, checkSerializable = true)
+ SparkClosureCleaner.clean(mappingFunction, checkSerializable = true)
val wrappedFunction =
(time: Time, key: KeyType, value: Option[ValueType], state:
State[StateType]) => {
Some(mappingFunction(key, value, state))
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]