This is an automated email from the ASF dual-hosted git repository.
wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 376de8a502f [SPARK-45616][CORE] Avoid ParVector, which does not
propagate ThreadLocals or SparkSession
376de8a502f is described below
commit 376de8a502fca6b46d7f21560a60024d643144ea
Author: Ankur Dave <[email protected]>
AuthorDate: Mon Oct 23 10:47:42 2023 +0800
[SPARK-45616][CORE] Avoid ParVector, which does not propagate ThreadLocals
or SparkSession
### What changes were proposed in this pull request?
`CastSuiteBase` and `ExpressionInfoSuite` use `ParVector.foreach()` to run
Spark SQL queries in parallel. They incorrectly assume that each parallel
operation will inherit the main thread’s active SparkSession. This is only true
when these parallel operations run in freshly-created threads. However, when
other code has already run some parallel operations before Spark was started,
then there may be existing threads that do not have an active SparkSession. In
that case, these tests fai [...]
The fix is to use the existing method `ThreadUtils.parmap()`. This method
creates fresh threads that inherit the current active SparkSession, and it
propagates the Spark ThreadLocals.
This PR also adds a scalastyle warning against use of ParVector.
### Why are the changes needed?
This change makes `CastSuiteBase` and `ExpressionInfoSuite` less brittle to
future changes that may run parallel operations during test startup.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
Reproduced the test failures by running a ParVector operation before Spark
starts. Verified that this PR fixes the test failures in this condition.
```scala
protected override def beforeAll(): Unit = {
// Run a ParVector operation before initializing the SparkSession. This
starts some Scala
// execution context threads that have no active SparkSession. These
threads will be reused for
// later ParVector operations, reproducing SPARK-45616.
new ParVector((0 until 100).toVector).foreach { _ => }
super.beforeAll()
}
```
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #43466 from ankurdave/SPARK-45616.
Authored-by: Ankur Dave <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
---
core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala | 2 ++
core/src/main/scala/org/apache/spark/util/ThreadUtils.scala | 4 ++++
scalastyle-config.xml | 12 ++++++++++++
.../spark/sql/catalyst/expressions/CastSuiteBase.scala | 9 ++++++---
.../scala/org/apache/spark/sql/execution/command/ddl.scala | 2 ++
.../apache/spark/sql/expressions/ExpressionInfoSuite.scala | 11 ++++++-----
.../main/scala/org/apache/spark/streaming/DStreamGraph.scala | 4 ++++
.../apache/spark/streaming/util/FileBasedWriteAheadLog.scala | 2 ++
8 files changed, 38 insertions(+), 8 deletions(-)
diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
index 0a930234437..3c1451a0185 100644
--- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
@@ -76,8 +76,10 @@ class UnionRDD[T: ClassTag](
override def getPartitions: Array[Partition] = {
val parRDDs = if (isPartitionListingParallel) {
+ // scalastyle:off parvector
val parArray = new ParVector(rdds.toVector)
parArray.tasksupport = UnionRDD.partitionEvalTaskSupport
+ // scalastyle:on parvector
parArray
} else {
rdds
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
index 16d7de56c39..2d3d6ec89ff 100644
--- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala
@@ -363,6 +363,10 @@ private[spark] object ThreadUtils {
* Comparing to the map() method of Scala parallel collections, this method
can be interrupted
* at any time. This is useful on canceling of task execution, for example.
*
+ * Functions are guaranteed to be executed in freshly-created threads that
inherit the calling
+ * thread's Spark thread-local variables. These threads also inherit the
calling thread's active
+ * SparkSession.
+ *
* @param in - the input collection which should be transformed in parallel.
* @param prefix - the prefix assigned to the underlying thread pool.
* @param maxThreads - maximum number of thread can be created during
execution.
diff --git a/scalastyle-config.xml b/scalastyle-config.xml
index 987b4235c19..2077769c71d 100644
--- a/scalastyle-config.xml
+++ b/scalastyle-config.xml
@@ -227,6 +227,18 @@ This file is divided into 3 sections:
]]></customMessage>
</check>
+ <check customId="parvector" level="error"
class="org.scalastyle.file.RegexChecker" enabled="true">
+ <parameters><parameter name="regex">new.*ParVector</parameter></parameters>
+ <customMessage><![CDATA[
+ Are you sure you want to create a ParVector? It will not automatically
propagate Spark ThreadLocals or the
+ active SparkSession for the submitted tasks. In most cases, you should
use ThreadUtils.parmap instead.
+ If you must use ParVector, then wrap your creation of the ParVector with
+ // scalastyle:off parvector
+ ...ParVector...
+ // scalastyle:on parvector
+ ]]></customMessage>
+ </check>
+
<check customId="caselocale" level="error"
class="org.scalastyle.file.RegexChecker" enabled="true">
<parameters><parameter
name="regex">(\.toUpperCase|\.toLowerCase)(?!(\(|\(Locale.ROOT\)))</parameter></parameters>
<customMessage><![CDATA[
diff --git
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
index 0172fd9b3e4..1ce311a5544 100644
---
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
+++
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuiteBase.scala
@@ -22,8 +22,6 @@ import java.time.{Duration, LocalDate, LocalDateTime, Period}
import java.time.temporal.ChronoUnit
import java.util.{Calendar, Locale, TimeZone}
-import scala.collection.parallel.immutable.ParVector
-
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
@@ -42,6 +40,7 @@ import org.apache.spark.sql.types.DayTimeIntervalType.{DAY,
HOUR, MINUTE, SECOND
import org.apache.spark.sql.types.UpCastRule.numericPrecedence
import org.apache.spark.sql.types.YearMonthIntervalType.{MONTH, YEAR}
import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.ThreadUtils
/**
* Common test suite for [[Cast]] with ansi mode on and off. It only includes
test cases that work
@@ -126,7 +125,11 @@ abstract class CastSuiteBase extends SparkFunSuite with
ExpressionEvalHelper {
}
test("cast string to timestamp") {
- new ParVector(ALL_TIMEZONES.toVector).foreach { zid =>
+ ThreadUtils.parmap(
+ ALL_TIMEZONES,
+ prefix = "CastSuiteBase-cast-string-to-timestamp",
+ maxThreads = Runtime.getRuntime.availableProcessors
+ ) { zid =>
def checkCastStringToTimestamp(str: String, expected: Timestamp): Unit =
{
checkEvaluation(cast(Literal(str), TimestampType, Option(zid.getId)),
expected)
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index 1465e32924a..a30734abfa7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -759,8 +759,10 @@ case class RepairTableCommand(
val statusPar: Seq[FileStatus] =
if (partitionNames.length > 1 && statuses.length > threshold ||
partitionNames.length > 2) {
// parallelize the list of partitions here, then we can have better
parallelism later.
+ // scalastyle:off parvector
val parArray = new ParVector(statuses.toVector)
parArray.tasksupport = evalTaskSupport
+ // scalastyle:on parvector
parArray.seq
} else {
statuses
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
index fd6f0adccf7..f8dde124b31 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ExpressionInfoSuite.scala
@@ -17,8 +17,6 @@
package org.apache.spark.sql.expressions
-import scala.collection.parallel.immutable.ParVector
-
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
@@ -26,7 +24,7 @@ import
org.apache.spark.sql.execution.HiveResult.hiveResultString
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.tags.SlowSQLTest
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ThreadUtils, Utils}
@SlowSQLTest
class ExpressionInfoSuite extends SparkFunSuite with SharedSparkSession {
@@ -201,8 +199,11 @@ class ExpressionInfoSuite extends SparkFunSuite with
SharedSparkSession {
// The encrypt expression includes a random initialization vector to its
encrypted result
classOf[AesEncrypt].getName)
- val parFuncs = new
ParVector(spark.sessionState.functionRegistry.listFunction().toVector)
- parFuncs.foreach { funcId =>
+ ThreadUtils.parmap(
+ spark.sessionState.functionRegistry.listFunction(),
+ prefix = "ExpressionInfoSuite-check-outputs-of-expression-examples",
+ maxThreads = Runtime.getRuntime.availableProcessors
+ ) { funcId =>
// Examples can change settings. We clone the session to prevent tests
clashing.
val clonedSpark = spark.cloneSession()
// Coalescing partitions can change result order, so disable it.
diff --git
a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index 43aaa7e1eea..a8f55c8b4d6 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -52,7 +52,9 @@ final private[streaming] class DStreamGraph extends
Serializable with Logging {
outputStreams.foreach(_.validateAtStart())
numReceivers =
inputStreams.count(_.isInstanceOf[ReceiverInputDStream[_]])
inputStreamNameAndID = inputStreams.map(is => (is.name, is.id)).toSeq
+ // scalastyle:off parvector
new ParVector(inputStreams.toVector).foreach(_.start())
+ // scalastyle:on parvector
}
}
@@ -62,7 +64,9 @@ final private[streaming] class DStreamGraph extends
Serializable with Logging {
def stop(): Unit = {
this.synchronized {
+ // scalastyle:off parvector
new ParVector(inputStreams.toVector).foreach(_.stop())
+ // scalastyle:on parvector
}
}
diff --git
a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
index c3f2a04d1f0..908d155908f 100644
---
a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
+++
b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala
@@ -314,8 +314,10 @@ private[streaming] object FileBasedWriteAheadLog {
val groupSize = taskSupport.parallelismLevel.max(8)
source.grouped(groupSize).flatMap { group =>
+ // scalastyle:off parvector
val parallelCollection = new ParVector(group.toVector)
parallelCollection.tasksupport = taskSupport
+ // scalastyle:on parvector
parallelCollection.map(handler)
}.flatten
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]