This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new c08700f05f9 [SPARK-44293][CONNECT] Fix invalid URI for custom JARs in 
Spark Connect
c08700f05f9 is described below

commit c08700f05f96e083dd8dec12fb2ca90a49d16a52
Author: vicennial <[email protected]>
AuthorDate: Wed Jul 5 08:41:47 2023 +0900

    [SPARK-44293][CONNECT] Fix invalid URI for custom JARs in Spark Connect
    
    ### What changes were proposed in this pull request?
    
    This PR fixs the bug where invalid JAR URIs were being generated because 
the URI was stored as `artifactURI + "/" + target.toString` (here, `target` is 
the absolute path of the file) instead of `artifactURI + "/" + 
remoteRelativePath.toString` (here, the `remoteRelativePath` is in the form of 
`jars/...`)
    
    ### Why are the changes needed?
    
    Without this change, Spark Connect users attempting to use a custom JAR 
(such as in UDFs) will hit task failure issue as an exception would be thrown 
during the JAR file fetch operation.
    Example stacktrace:
    ```
    23/07/03 17:00:15 INFO Executor: Fetching 
spark://ip-10-110-22-170.us-west-2.compute.internal:43743/artifacts/d9548b02-ff3b-4278-ab52-aef5d1fc724e//home/venkata.gudesa/spark/artifacts/spark-d6141194-c487-40fd-ba40-444d922808ea/d9548b02-ff3b-4278-ab52-aef5d1fc724e/jars/TestHelloV2.jar
 with timestamp 0
    23/07/03 17:00:15 ERROR Executor: Exception in task 6.0 in stage 4.0 (TID 
55)
    java.lang.RuntimeException: Stream 
'/artifacts/d9548b02-ff3b-4278-ab52-aef5d1fc724e//home/venkata.gudesa/spark/artifacts/spark-d6141194-c487-40fd-ba40-444d922808ea/d9548b02-ff3b-4278-ab52-aef5d1fc724e/jars/TestHelloV2.jar'
 was not found.
            at 
org.apache.spark.network.client.TransportResponseHandler.handle(TransportResponseHandler.java:260)
            at 
org.apache.spark.network.server.TransportChannelHandler.channelRead0(TransportChannelHandler.java:142)
            at 
org.apache.spark.network.server.TransportChannelHandler.channelRead0(TransportChannelHandler.java:53)
            at 
io.netty.channel.SimpleChannelInboundHandler.channelRead(SimpleChannelInboundHandler.java:99)
            at 
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:444)
            at 
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420)
            at 
io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:412)
            at 
io.netty.handler.timeout.IdleStateHandler.channelRead(IdleStateHandler.java:286)
            at 
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:442)
            at 
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420)
            at 
io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:412)
            at 
io.netty.handler.codec.MessageToMessageDecoder.channelRead(MessageToMessageDecoder.java:103)
            at 
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:444)
            at 
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420)
            at 
io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:412)
            at 
org.apache.spark.network.util.TransportFrameDecoder.channelRead(TransportFrameDecoder.java:102)
            at 
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:444)
            at 
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420)
            at 
io.netty.channel.AbstractChannelHandlerContext.fireChannelRead(AbstractChannelHandlerContext.java:412)
            at 
io.netty.channel.DefaultChannelPipeline$HeadContext.channelRead(DefaultChannelPipeline.java:1410)
            at 
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:440)
            at 
io.netty.channel.AbstractChannelHandlerContext.invokeChannelRead(AbstractChannelHandlerContext.java:420)
            at 
io.netty.channel.DefaultChannelPipeline.fireChannelRead(DefaultChannelPipeline.java:919)
            at 
io.netty.channel.nio.AbstractNioByteChannel$NioByteUnsafe.read(AbstractNioByteChannel.java:166)
            at 
io.netty.channel.nio.NioEventLoop.processSelectedKey(NioEventLoop.java:788)
            at 
io.netty.channel.nio.NioEventLoop.processSelectedKeysOptimized(NioEventLoop.java:724)
            at 
io.netty.channel.nio.NioEventLoop.processSelectedKeys(NioEventLoop.java:650)
            at io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:562)
            at 
io.netty.util.concurrent.SingleThreadEventExecutor$4.run(SingleThreadEventExecutor.java:997)
            at 
io.netty.util.internal.ThreadExecutorMap$2.run(ThreadExecutorMap.java:74)
            at 
io.netty.util.concurrent.FastThreadLocalRunnable.run(FastThreadLocalRunnable.java:30)
            at java.lang.Thread.run(Thread.java:748)
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    
    No (the bug-fix is consistent with what users expect)
    
    ### How was this patch tested?
    
    New E2E test in `ReplE2ESuite`.
    
    Closes #41844 from vicennial/SPARK-44293.
    
    Authored-by: vicennial <[email protected]>
    Signed-off-by: Hyukjin Kwon <[email protected]>
---
 .../jvm/src/test/resources/TestHelloV2_2.12.jar    | Bin 0 -> 3784 bytes
 .../jvm/src/test/resources/TestHelloV2_2.13.jar    | Bin 0 -> 4118 bytes
 .../spark/sql/application/ReplE2ESuite.scala       |  44 ++++++++++++++++++++-
 .../artifact/SparkConnectArtifactManager.scala     |   2 +-
 4 files changed, 43 insertions(+), 3 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.12.jar 
b/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.12.jar
new file mode 100644
index 00000000000..d89cf6543a2
Binary files /dev/null and 
b/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.12.jar differ
diff --git 
a/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.13.jar 
b/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.13.jar
new file mode 100644
index 00000000000..6dee8fcd9c9
Binary files /dev/null and 
b/connector/connect/client/jvm/src/test/resources/TestHelloV2_2.13.jar differ
diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
index 61959234c87..720f66680ee 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/application/ReplE2ESuite.scala
@@ -17,12 +17,15 @@
 package org.apache.spark.sql.application
 
 import java.io.{PipedInputStream, PipedOutputStream}
+import java.nio.file.Paths
 import java.util.concurrent.{Executors, Semaphore, TimeUnit}
 
+import scala.util.Properties
+
 import org.apache.commons.io.output.ByteArrayOutputStream
 import org.scalatest.BeforeAndAfterEach
 
-import org.apache.spark.sql.connect.client.util.RemoteSparkSession
+import org.apache.spark.sql.connect.client.util.{IntegrationTestUtils, 
RemoteSparkSession}
 
 class ReplE2ESuite extends RemoteSparkSession with BeforeAndAfterEach {
 
@@ -35,6 +38,11 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
   private var ammoniteIn: PipedInputStream = _
   private val semaphore: Semaphore = new Semaphore(0)
 
+  private val scalaVersion = Properties.versionNumberString
+    .split("\\.")
+    .take(2)
+    .mkString(".")
+
   private def getCleanString(out: ByteArrayOutputStream): String = {
     // Remove ANSI colour codes
     // Regex taken from https://stackoverflow.com/a/25189932
@@ -96,7 +104,10 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
 
   def assertContains(message: String, output: String): Unit = {
     val isContain = output.contains(message)
-    assert(isContain, "Ammonite output did not contain '" + message + "':\n" + 
output)
+    assert(
+      isContain,
+      "Ammonite output did not contain '" + message + "':\n" + output +
+        s"\nError Output: ${getCleanString(errorStream)}")
   }
 
   test("Simple query") {
@@ -151,4 +162,33 @@ class ReplE2ESuite extends RemoteSparkSession with 
BeforeAndAfterEach {
     assertContains("Array[java.lang.Long] = Array(0L, 2L, 4L, 6L, 8L)", output)
   }
 
+  test("Client-side JAR") {
+    // scalastyle:off classforname line.size.limit
+    val sparkHome = IntegrationTestUtils.sparkHome
+    val testJar = Paths
+      .get(
+        
s"$sparkHome/connector/connect/client/jvm/src/test/resources/TestHelloV2_$scalaVersion.jar")
+      .toFile
+
+    assert(testJar.exists(), "Missing TestHelloV2 jar!")
+    val input = s"""
+        |import java.nio.file.Paths
+        |def classLoadingTest(x: Int): Int = {
+        |  val classloader =
+        |    
Option(Thread.currentThread().getContextClassLoader).getOrElse(getClass.getClassLoader)
+        |  val cls = Class.forName("com.example.Hello$$", true, classloader)
+        |  val module = cls.getField("MODULE$$").get(null)
+        |  cls.getMethod("test").invoke(module).asInstanceOf[Int]
+        |}
+        |val classLoaderUdf = udf(classLoadingTest _)
+        |
+        |val jarPath = Paths.get("${testJar.toString}").toUri
+        |spark.addArtifact(jarPath)
+        |
+        |spark.range(5).select(classLoaderUdf(col("id"))).as[Int].collect()
+      """.stripMargin
+    val output = runCommandsInShell(input)
+    assertContains("Array[Int] = Array(2, 2, 2, 2, 2)", output)
+    // scalastyle:on classforname line.size.limit
+  }
 }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
index 0a91c6b9550..9fd8e367e4a 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/artifact/SparkConnectArtifactManager.scala
@@ -133,7 +133,7 @@ class SparkConnectArtifactManager(sessionHolder: 
SessionHolder) extends Logging
       Files.move(serverLocalStagingPath, target)
       if (remoteRelativePath.startsWith(s"jars${File.separator}")) {
         jarsList.add(target)
-        jarsURI.add(artifactURI + "/" + target.toString)
+        jarsURI.add(artifactURI + "/" + remoteRelativePath.toString)
       } else if (remoteRelativePath.startsWith(s"pyfiles${File.separator}")) {
         sessionHolder.session.sparkContext.addFile(target.toString)
         val stringRemotePath = remoteRelativePath.toString


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to