This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push:
new 75a38b9024a [SPARK-45616][CORE] Avoid ParVector, which does not
propagate ThreadLocals or SparkSession
75a38b9024a is described below
commit 75a38b9024af3c9cfd85e916c46359f7e7315c87
Author: Ankur Dave <[email protected]>
AuthorDate: Mon Oct 23 10:47:42 2023 +0800
[SPARK-45616][CORE] Avoid ParVector, which does not propagate ThreadLocals
or SparkSession
### What changes were proposed in this pull request?
`CastSuiteBase` and `ExpressionInfoSuite` use `ParVector.foreach()` to run
Spark SQL queries in parallel. They incorrectly assume that each parallel
operation will inherit the main thread’s active SparkSession. This is only true
when these parallel operations run in freshly-created threads. However, when
other code has already run some parallel operations before Spark was started,
then there may be existing threads that do not have an active SparkSession. In
that case, these tests fai [...]
The fix is to use the existing method `ThreadUtils.parmap()`. This method
creates fresh threads that inherit the current active SparkSession, and it
propagates the Spark ThreadLocals.
This PR also adds a scalastyle warning against use of ParVector.
### Why are the changes needed?
This change makes `CastSuiteBase` and `ExpressionInfoSuite` less brittle to
future changes that may run parallel operations during test startup.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Reproduced the test failures by running a ParVector operation before Spark
starts. Verified that this PR fixes the test failures in this condition.
```scala
protected override def beforeAll(): Unit = {
// Run a ParVector operation before initializing the SparkSession. This
starts some Scala
// execution context threads that have no active SparkSession. These
threads will be reused for
// later ParVector operations, reproducing SPARK-45616.
new ParVector((0 until 100).toVector).foreach { _ => }
super.beforeAll()
}
```
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43466 from ankurdave/SPARK-45616.
Authored-by: Ankur Dave <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
(cherry picked from commit 376de8a502fca6b46d7f21560a60024d643144ea)
Signed-off-by: Wenchen Fan <[email protected]>
---
core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala | 2 ++
core/src/main/scala/org/apache/spark/util/ThreadUtils.scala | 4 ++++
scalastyle-config.xml | 12 ++++++++++++
.../spark/sql/catalyst/expressions/CastSuiteBase.scala | 9 ++++++---
.../scala/org/apache/spark/sql/execution/command/ddl.scala | 2 ++
.../apache/spark/sql/expressions/ExpressionInfoSuite.scala | 11 ++++++-----
.../main/scala/org/apache/spark/streaming/DStreamGraph.scala | 4 ++++
.../apache/spark/streaming/util/FileBasedWriteAheadLog.scala | 2 ++
8 files changed, 38 insertions(+), 8 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
index 0a930234437..3c1451a0185 100644
--- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
@@ -76,8 +76,10 @@ class UnionRDD[T: ClassTag](
override def getPartitions: Array[Partition] = {
val parRDDs = if (isPartitionListingParallel) {
+ // scalastyle:off parvector
val parArray = new ParVector(rdds.toVector)
parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
+ // scalastyle:on parvector
parArray
} else {
rdds
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index 16d7de56c39..2d3d6ec89ff 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -363,6 +363,10 @@ private[spark] object ThreadUtils {
* Comparing to the map() method of Scala parallel collections, this method
can be interrupted
* at any time. This is useful on canceling of task execution, for example.
*
+ * Functions are guaranteed to be executed in freshly-created threads that
inherit the calling
+ * thread's Spark thread-local variables. These threads also inherit the
calling thread's active
+ * SparkSession.
+ *
* @param in - the input collection which should be transformed in parallel.
* @param prefix - the prefix assigned to the underlying thread pool.
* @param maxThreads - maximum number of thread can be created during
execution.
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index 74e8480deaf..0ccd937e72e 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -227,6 +227,18 @@ This file is divided into 3 sections:
]]></customMessage>
</check>
+ <check customId="parvector" level="error"
class="org.scalastyle.file.RegexChecker" enabled="true">
+ <parameters><parameter name="regex">new.*ParVector</parameter></parameters>
+ <customMessage><![CDATA[
+ Are you sure you want to create a ParVector? It will not automatically
propagate Spark ThreadLocals or the
+ active SparkSession for the submitted tasks. In most cases, you should
use ThreadUtils.parmap instead.
+ If you must use ParVector, then wrap your creation of the ParVector with
+ // scalastyle:off parvector
+ ...ParVector...
+ // scalastyle:on parvector
+ ]]></customMessage>
+ </check>
+
<check customId="caselocale" level="error"
class="org.scalastyle.file.RegexChecker" enabled="true">
<parameters><parameter
name="regex">(\.toUpperCase|\.toLowerCase)(?!(\(|\(Locale.ROOT\)))</parameter></parameters>
<customMessage><![CDATA[
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
index 0172fd9b3e4..1ce311a5544 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
@@ -22,8 +22,6 @@ import java.time.{Duration, LocalDate, LocalDateTime, Period}
import java.time.temporal.ChronoUnit
import java.util.{Calendar, Locale, TimeZone}
-import scala.collection.parallel.immutable.ParVector
-
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
@@ -42,6 +40,7 @@ import org.apache.spark.sql.types.DayTimeIntervalType.{DAY,
HOUR, MINUTE, SECOND
import org.apache.spark.sql.types.UpCastRule.numericPrecedence
import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.ThreadUtils
/**
* Common test suite for [[Cast]] with ansi mode on and off. It only includes
test cases that work
@@ -126,7 +125,11 @@ abstract class CastSuiteBase extends SparkFunSuite with
ExpressionEvalHelper {
}
test("cast string to timestamp") {
- new ParVector(ALL_TIMEZONES.toVector).foreach { zid =>
+ ThreadUtils.parmap(
+ ALL_TIMEZONES,
+ prefix = "CastSuiteBase-cast-string-to-timestamp",
+ maxThreads = Runtime.getRuntime.availableProcessors
+ ) { zid =>
def checkCastStringToTimestamp(str: String, expected: Timestamp): Unit =
{
checkEvaluation(cast(Literal(str), TimestampType, Option(zid.getId)),
expected)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index a8f7cdb2600..bb8fea71019 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -755,8 +755,10 @@ case class RepairTableCommand(
val statusPar: Seq[FileStatus] =
if (partitionNames.length > 1 && statuses.length > threshold ||
partitionNames.length > 2) {
// parallelize the list of partitions here, then we can have better
parallelism later.
+ // scalastyle:off parvector
val parArray = new ParVector(statuses.toVector)
parArray.tasksupport = evalTaskSupport
+ // scalastyle:on parvector
parArray.seq
} else {
statuses
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
index 4dd93983e87..a02137a56aa 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.expressions
-import scala.collection.parallel.immutable.ParVector
-
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
@@ -26,7 +24,7 @@ import
org.apache.spark.sql.execution.HiveResult.hiveResultString
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.tags.SlowSQLTest
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ThreadUtils, Utils}
@SlowSQLTest
class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
@@ -197,8 +195,11 @@ class ExpressionInfoSuite extends SparkFunSuite with
SharedSparkSession {
// The encrypt expression includes a random initialization vector to its
encrypted result
classOf[AesEncrypt].getName)
- val parFuncs = new
ParVector(spark.sessionState.functionRegistry.listFunction().toVector)
- parFuncs.foreach { funcId =>
+ ThreadUtils.parmap(
+ spark.sessionState.functionRegistry.listFunction(),
+ prefix = "ExpressionInfoSuite-check-outputs-of-expression-examples",
+ maxThreads = Runtime.getRuntime.availableProcessors
+ ) { funcId =>
// Examples can change settings. We clone the session to prevent tests
clashing.
val clonedSpark = spark.cloneSession()
// Coalescing partitions can change result order, so disable it.
diff --git
a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index 43aaa7e1eea..a8f55c8b4d6 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -52,7 +52,9 @@ final private[streaming] class DStreamGraph extends
Serializable with Logging {
outputStreams.foreach(_.validateAtStart())
numReceivers =
inputStreams.count(_.isInstanceOf[ReceiverInputDStream[_]])
inputStreamNameAndID = inputStreams.map(is => (is.name, is.id)).toSeq
+ // scalastyle:off parvector
new ParVector(inputStreams.toVector).foreach(_.start())
+ // scalastyle:on parvector
}
}
@@ -62,7 +64,9 @@ final private[streaming] class DStreamGraph extends
Serializable with Logging {
def stop(): Unit = {
this.synchronized {
+ // scalastyle:off parvector
new ParVector(inputStreams.toVector).foreach(_.stop())
+ // scalastyle:on parvector
}
}
diff --git
a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
index d1f9dfb7913..4e65bc75e43 100644
---
a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
+++
b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
@@ -314,8 +314,10 @@ private[streaming] object FileBasedWriteAheadLog {
val groupSize = taskSupport.parallelismLevel.max(8)
source.grouped(groupSize).flatMap { group =>
+ // scalastyle:off parvector
val parallelCollection = new ParVector(group.toVector)
parallelCollection.tasksupport = taskSupport
+ // scalastyle:on parvector
parallelCollection.map(handler)
}.flatten
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]