This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d3b6dd13e9e9 [SPARK-51688][PYTHON] Use Unix Domain Socket between 
Python and JVM communication
d3b6dd13e9e9 is described below

commit d3b6dd13e9e9f1c995f9c1152d8958a29c8ccd54
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Tue Apr 15 10:35:11 2025 +0900

    [SPARK-51688][PYTHON] Use Unix Domain Socket between Python and JVM 
communication
    
    ### What changes were proposed in this pull request?
    
    This PR proposes to use Unix Domain Socket (UDS) in the communication 
between Python process and JVM (except Py4J, which does not support UDS).
    
    It adds a new configuration `spark.python.unix.domain.socket.enabled` that 
is disabled by default. When enabled, it uses UDS. When disabled, we use TPC/IP 
sockets as it is.
    
    When we use UDS, since the data is protected by file permissions, it also 
avoid doing the unnecessary authentication we use for TPC/IP sockets.
    
    ### Why are the changes needed?
    
    1. UDS is known as faster than TPC/IP, see also 
https://www.researchgate.net/figure/Performance-Comparison-of-TCP-vs-Unix-Domain-Sockets-as-a-Function-of-Message-Size_fig3_221461399
    2. It does not require network as it avoids TPC/IP layer so we can avoid 
network overhead.
    
    ### Does this PR introduce _any_ user-facing change?
    
    To the end users, no. This is the implementation level change.
    
    ### How was this patch tested?
    
    Manually ran the tests after enabling this configuration.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #50466 from HyukjinKwon/unix-domain-socket.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 R/pkg/R/context.R                                  |   2 +-
 .../org/apache/spark/api/python/PythonRDD.scala    |  31 +++----
 .../org/apache/spark/api/python/PythonRunner.scala | 100 ++++++++++++---------
 .../spark/api/python/PythonWorkerFactory.scala     |  70 ++++++++++++---
 .../spark/api/python/PythonWorkerUtils.scala       |  12 ++-
 .../spark/api/python/StreamingPythonRunner.scala   |  24 +++--
 .../scala/org/apache/spark/api/r/RAuthHelper.scala |   1 +
 .../main/scala/org/apache/spark/api/r/RRDD.scala   |   6 +-
 .../org/apache/spark/internal/config/Python.scala  |  24 +++++
 .../apache/spark/security/SocketAuthHelper.scala   |  14 ++-
 .../apache/spark/security/SocketAuthServer.scala   |  66 ++++++++++----
 .../apache/spark/api/python/PythonRDDSuite.scala   |  16 ++--
 .../spark/security/SocketAuthHelperSuite.scala     |  12 ++-
 python/pyspark/core/broadcast.py                   |   8 +-
 python/pyspark/core/context.py                     |   2 +-
 python/pyspark/daemon.py                           |  28 +++++-
 .../deepspeed/tests/test_deepspeed_distributor.py  |   4 +-
 .../streaming/worker/foreach_batch_worker.py       |   8 +-
 .../connect/streaming/worker/listener_worker.py    |   8 +-
 .../streaming/python_streaming_source_runner.py    |   8 +-
 .../sql/streaming/stateful_processor_api_client.py |  32 ++++---
 .../transform_with_state_driver_worker.py          |  12 ++-
 python/pyspark/sql/worker/analyze_udtf.py          |   8 +-
 .../pyspark/sql/worker/commit_data_source_write.py |   8 +-
 python/pyspark/sql/worker/create_data_source.py    |   8 +-
 .../sql/worker/data_source_pushdown_filters.py     |   8 +-
 python/pyspark/sql/worker/lookup_data_sources.py   |   8 +-
 python/pyspark/sql/worker/plan_data_source_read.py |   8 +-
 .../sql/worker/python_streaming_sink_runner.py     |   8 +-
 .../pyspark/sql/worker/write_into_data_source.py   |   8 +-
 python/pyspark/taskcontext.py                      |  22 ++---
 python/pyspark/tests/test_appsubmit.py             |   2 +
 python/pyspark/util.py                             |  35 ++++++--
 python/pyspark/worker.py                           |  19 ++--
 python/pyspark/worker_util.py                      |  10 ++-
 python/run-tests.py                                |   2 +
 .../spark/deploy/yarn/YarnClusterSuite.scala       |  14 ++-
 .../spark/sql/api/python/PythonSQLUtils.scala      |   7 +-
 .../streaming/PythonStreamingSourceRunner.scala    |   5 +-
 .../TransformWithStateInPandasPythonRunner.scala   |  57 +++++++++---
 .../TransformWithStateInPandasStateServer.scala    |  16 ++--
 ...ransformWithStateInPandasStateServerSuite.scala |   4 +-
 42 files changed, 515 insertions(+), 230 deletions(-)

diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R
index eea83aa5ab52..0242e7114978 100644
--- a/R/pkg/R/context.R
+++ b/R/pkg/R/context.R
@@ -181,7 +181,7 @@ parallelize <- function(sc, coll, numSlices = 1) {
       parallelism <- as.integer(numSlices)
       jserver <- newJObject("org.apache.spark.api.r.RParallelizeServer", sc, 
parallelism)
       authSecret <- callJMethod(jserver, "secret")
-      port <- callJMethod(jserver, "port")
+      port <- callJMethod(jserver, "connInfo")
       conn <- socketConnection(
         port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)
       doServerAuth(conn, authSecret)
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index d643983ef5df..2152724c4c13 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -19,6 +19,7 @@ package org.apache.spark.api.python
 
 import java.io._
 import java.net._
+import java.nio.channels.{Channels, SocketChannel}
 import java.nio.charset.StandardCharsets
 import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
 
@@ -231,9 +232,9 @@ private[spark] object PythonRDD extends Logging {
    *         server object that can be used to join the JVM serving thread in 
Python.
    */
   def toLocalIteratorAndServe[T](rdd: RDD[T], prefetchPartitions: Boolean = 
false): Array[Any] = {
-    val handleFunc = (sock: Socket) => {
-      val out = new DataOutputStream(sock.getOutputStream)
-      val in = new DataInputStream(sock.getInputStream)
+    val handleFunc = (sock: SocketChannel) => {
+      val out = new DataOutputStream(Channels.newOutputStream(sock))
+      val in = new DataInputStream(Channels.newInputStream(sock))
       Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
         // Collects a partition on each iteration
         val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
@@ -287,7 +288,7 @@ private[spark] object PythonRDD extends Logging {
     }
 
     val server = new SocketFuncServer(authHelper, "serve toLocalIterator", 
handleFunc)
-    Array(server.port, server.secret, server)
+    Array(server.connInfo, server.secret, server)
   }
 
   def readRDDFromFile(
@@ -831,21 +832,21 @@ private[spark] class PythonBroadcast(@transient var path: 
String) extends Serial
 
   def setupEncryptionServer(): Array[Any] = {
     encryptionServer = new SocketAuthServer[Unit]("broadcast-encrypt-server") {
-      override def handleConnection(sock: Socket): Unit = {
+      override def handleConnection(sock: SocketChannel): Unit = {
         val env = SparkEnv.get
-        val in = sock.getInputStream()
+        val in = Channels.newInputStream(sock)
         val abspath = new File(path).getAbsolutePath
         val out = env.serializerManager.wrapForEncryption(new 
FileOutputStream(abspath))
         DechunkedInputStream.dechunkAndCopyToOutput(in, out)
       }
     }
-    Array(encryptionServer.port, encryptionServer.secret)
+    Array(encryptionServer.connInfo, encryptionServer.secret)
   }
 
   def setupDecryptionServer(): Array[Any] = {
     decryptionServer = new 
SocketAuthServer[Unit]("broadcast-decrypt-server-for-driver") {
-      override def handleConnection(sock: Socket): Unit = {
-        val out = new DataOutputStream(new 
BufferedOutputStream(sock.getOutputStream()))
+      override def handleConnection(sock: SocketChannel): Unit = {
+        val out = new DataOutputStream(new 
BufferedOutputStream(Channels.newOutputStream(sock)))
         Utils.tryWithSafeFinally {
           val in = SparkEnv.get.serializerManager.wrapForEncryption(new 
FileInputStream(path))
           Utils.tryWithSafeFinally {
@@ -859,7 +860,7 @@ private[spark] class PythonBroadcast(@transient var path: 
String) extends Serial
         }
       }
     }
-    Array(decryptionServer.port, decryptionServer.secret)
+    Array(decryptionServer.connInfo, decryptionServer.secret)
   }
 
   def waitTillBroadcastDataSent(): Unit = decryptionServer.getResult()
@@ -945,8 +946,8 @@ private[spark] class EncryptedPythonBroadcastServer(
     val idsAndFiles: Seq[(Long, String)])
     extends SocketAuthServer[Unit]("broadcast-decrypt-server") with Logging {
 
-  override def handleConnection(socket: Socket): Unit = {
-    val out = new DataOutputStream(new 
BufferedOutputStream(socket.getOutputStream()))
+  override def handleConnection(socket: SocketChannel): Unit = {
+    val out = new DataOutputStream(new 
BufferedOutputStream(Channels.newOutputStream(socket)))
     var socketIn: InputStream = null
     // send the broadcast id, then the decrypted data.  We don't need to send 
the length, the
     // the python pickle module just needs a stream.
@@ -962,7 +963,7 @@ private[spark] class EncryptedPythonBroadcastServer(
       }
       logTrace("waiting for python to accept broadcast data over socket")
       out.flush()
-      socketIn = socket.getInputStream()
+      socketIn = Channels.newInputStream(socket)
       socketIn.read()
       logTrace("done serving broadcast data")
     } {
@@ -983,8 +984,8 @@ private[spark] class EncryptedPythonBroadcastServer(
 private[spark] abstract class PythonRDDServer
     extends 
SocketAuthServer[JavaRDD[Array[Byte]]]("pyspark-parallelize-server") {
 
-  def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = {
-    val in = sock.getInputStream()
+  def handleConnection(sock: SocketChannel): JavaRDD[Array[Byte]] = {
+    val in = Channels.newInputStream(sock)
     val dechunkedInput: InputStream = new DechunkedInputStream(in)
     streamToRDD(dechunkedInput)
   }
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index 84701ee593c1..c2539ee05f21 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -20,9 +20,9 @@ package org.apache.spark.api.python
 import java.io._
 import java.net._
 import java.nio.ByteBuffer
-import java.nio.channels.SelectionKey
-import java.nio.charset.StandardCharsets.UTF_8
+import java.nio.channels.{AsynchronousCloseException, Channels, SelectionKey, 
ServerSocketChannel, SocketChannel}
 import java.nio.file.{Files => JavaFiles, Path}
+import java.util.UUID
 import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
 import java.util.concurrent.atomic.AtomicBoolean
 
@@ -201,9 +201,10 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
   // Python accumulator is always set in production except in tests. See 
SPARK-27893
   private val maybeAccumulator: Option[PythonAccumulator] = Option(accumulator)
 
-  // Expose a ServerSocket to support method calls via socket from Python 
side. Only relevant for
-  // for tasks that are a part of barrier stage, refer [[BarrierTaskContext]] 
for details.
-  private[spark] var serverSocket: Option[ServerSocket] = None
+  // Expose a ServerSocketChannel to support method calls via socket from 
Python side.
+  // Only relevant for tasks that are a part of barrier stage, refer
+  // `BarrierTaskContext` for details.
+  private[spark] var serverSocketChannel: Option[ServerSocketChannel] = None
 
   // Authentication helper used when serving method calls via socket from 
Python side.
   private lazy val authHelper = new SocketAuthHelper(conf)
@@ -347,6 +348,11 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
     def writeNextInputToStream(dataOut: DataOutputStream): Boolean
 
     def open(dataOut: DataOutputStream): Unit = Utils.logUncaughtExceptions {
+      val isUnixDomainSock = 
authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
+      lazy val sockPath = new File(
+        authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR)
+          .getOrElse(System.getProperty("java.io.tmpdir")),
+        s".${UUID.randomUUID()}.sock")
       try {
         // Partition index
         dataOut.writeInt(partitionIndex)
@@ -356,27 +362,34 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
         // Init a ServerSocket to accept method calls from Python side.
         val isBarrier = context.isInstanceOf[BarrierTaskContext]
         if (isBarrier) {
-          serverSocket = Some(new ServerSocket(/* port */ 0,
-            /* backlog */ 1,
-            InetAddress.getByName("localhost")))
-          // A call to accept() for ServerSocket shall block infinitely.
-          serverSocket.foreach(_.setSoTimeout(0))
+          if (isUnixDomainSock) {
+            serverSocketChannel = 
Some(ServerSocketChannel.open(StandardProtocolFamily.UNIX))
+            sockPath.deleteOnExit()
+            
serverSocketChannel.get.bind(UnixDomainSocketAddress.of(sockPath.getPath))
+          } else {
+            serverSocketChannel = Some(ServerSocketChannel.open())
+            serverSocketChannel.foreach(_.bind(
+              new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 1))
+            // A call to accept() for ServerSocket shall block infinitely.
+            serverSocketChannel.foreach(_.socket().setSoTimeout(0))
+          }
+
           new Thread("accept-connections") {
             setDaemon(true)
 
             override def run(): Unit = {
-              while (!serverSocket.get.isClosed()) {
-                var sock: Socket = null
+              while (serverSocketChannel.get.isOpen()) {
+                var sock: SocketChannel = null
                 try {
-                  sock = serverSocket.get.accept()
+                  sock = serverSocketChannel.get.accept()
                   // Wait for function call from python side.
-                  sock.setSoTimeout(10000)
+                  if (!isUnixDomainSock) sock.socket().setSoTimeout(10000)
                   authHelper.authClient(sock)
-                  val input = new DataInputStream(sock.getInputStream())
+                  val input = new 
DataInputStream(Channels.newInputStream(sock))
                   val requestMethod = input.readInt()
                   // The BarrierTaskContext function may wait infinitely, 
socket shall not timeout
                   // before the function finishes.
-                  sock.setSoTimeout(0)
+                  if (!isUnixDomainSock) sock.socket().setSoTimeout(0)
                   requestMethod match {
                     case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
                       barrierAndServe(requestMethod, sock)
@@ -385,13 +398,14 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
                       barrierAndServe(requestMethod, sock, message)
                     case _ =>
                       val out = new DataOutputStream(new BufferedOutputStream(
-                        sock.getOutputStream))
+                        Channels.newOutputStream(sock)))
                       
writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out)
                   }
                 } catch {
-                  case e: SocketException if e.getMessage.contains("Socket 
closed") =>
-                    // It is possible that the ServerSocket is not closed, but 
the native socket
-                    // has already been closed, we shall catch and silently 
ignore this case.
+                  case _: AsynchronousCloseException =>
+                    // Ignore to make less noisy. These will be closed when 
tasks
+                    // are finished by listeners.
+                    if (isUnixDomainSock) sockPath.delete()
                 } finally {
                   if (sock != null) {
                     sock.close()
@@ -401,33 +415,35 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
             }
           }.start()
         }
-        val secret = if (isBarrier) {
-          authHelper.secret
-        } else {
-          ""
-        }
         if (isBarrier) {
           // Close ServerSocket on task completion.
-          serverSocket.foreach { server =>
-            context.addTaskCompletionListener[Unit](_ => server.close())
+          serverSocketChannel.foreach { server =>
+            context.addTaskCompletionListener[Unit] { _ =>
+              server.close()
+              if (isUnixDomainSock) sockPath.delete()
+            }
           }
-          val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0)
-          if (boundPort == -1) {
-            val message = "ServerSocket failed to bind to Java side."
-            logError(message)
-            throw new SparkException(message)
+          if (isUnixDomainSock) {
+            logDebug(s"Started ServerSocket on with Unix Domain Socket 
$sockPath.")
+            dataOut.writeBoolean(/* isBarrier = */true)
+            dataOut.writeInt(-1)
+            PythonRDD.writeUTF(sockPath.getPath, dataOut)
+          } else {
+            val boundPort: Int = 
serverSocketChannel.map(_.socket().getLocalPort).getOrElse(-1)
+            if (boundPort == -1) {
+              val message = "ServerSocket failed to bind to Java side."
+              logError(message)
+              throw new SparkException(message)
+            }
+            logDebug(s"Started ServerSocket on port $boundPort.")
+            dataOut.writeBoolean(/* isBarrier = */true)
+            dataOut.writeInt(boundPort)
+            PythonRDD.writeUTF(authHelper.secret, dataOut)
           }
-          logDebug(s"Started ServerSocket on port $boundPort.")
-          dataOut.writeBoolean(/* isBarrier = */true)
-          dataOut.writeInt(boundPort)
         } else {
           dataOut.writeBoolean(/* isBarrier = */false)
-          dataOut.writeInt(0)
         }
         // Write out the TaskContextInfo
-        val secretBytes = secret.getBytes(UTF_8)
-        dataOut.writeInt(secretBytes.length)
-        dataOut.write(secretBytes, 0, secretBytes.length)
         dataOut.writeInt(context.stageId())
         dataOut.writeInt(context.partitionId())
         dataOut.writeInt(context.attemptNumber())
@@ -485,12 +501,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
     /**
      * Gateway to call BarrierTaskContext methods.
      */
-    def barrierAndServe(requestMethod: Int, sock: Socket, message: String = 
""): Unit = {
+    def barrierAndServe(requestMethod: Int, sock: SocketChannel, message: 
String = ""): Unit = {
       require(
-        serverSocket.isDefined,
+        serverSocketChannel.isDefined,
         "No available ServerSocket to redirect the BarrierTaskContext method 
call."
       )
-      val out = new DataOutputStream(new 
BufferedOutputStream(sock.getOutputStream))
+      val out = new DataOutputStream(new 
BufferedOutputStream(Channels.newOutputStream(sock)))
       try {
         val messages = requestMethod match {
           case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 19a067076967..64b29585a0d9 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -18,10 +18,11 @@
 package org.apache.spark.api.python
 
 import java.io.{DataInputStream, DataOutputStream, EOFException, File, 
InputStream}
-import java.net.{InetAddress, InetSocketAddress, SocketException}
+import java.net.{InetAddress, InetSocketAddress, SocketException, 
StandardProtocolFamily, UnixDomainSocketAddress}
 import java.net.SocketTimeoutException
 import java.nio.channels._
 import java.util.Arrays
+import java.util.UUID
 import java.util.concurrent.TimeUnit
 import javax.annotation.concurrent.GuardedBy
 
@@ -33,6 +34,7 @@ import org.apache.spark._
 import org.apache.spark.errors.SparkCoreErrors
 import org.apache.spark.internal.{Logging, MDC}
 import org.apache.spark.internal.LogKeys._
+import org.apache.spark.internal.config.Python.{PYTHON_UNIX_DOMAIN_SOCKET_DIR, 
PYTHON_UNIX_DOMAIN_SOCKET_ENABLED}
 import org.apache.spark.security.SocketAuthHelper
 import org.apache.spark.util.{RedirectThread, Utils}
 
@@ -97,6 +99,7 @@ private[spark] class PythonWorkerFactory(
   }
 
   private val authHelper = new SocketAuthHelper(SparkEnv.get.conf)
+  private val isUnixDomainSock = 
authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
 
   @GuardedBy("self")
   private var daemon: Process = null
@@ -106,6 +109,8 @@ private[spark] class PythonWorkerFactory(
   @GuardedBy("self")
   private val daemonWorkers = new mutable.WeakHashMap[PythonWorker, 
ProcessHandle]()
   @GuardedBy("self")
+  private var daemonSockPath: String = _
+  @GuardedBy("self")
   private val idleWorkers = new mutable.Queue[PythonWorker]()
   @GuardedBy("self")
   private var lastActivityNs = 0L
@@ -152,7 +157,11 @@ private[spark] class PythonWorkerFactory(
   private def createThroughDaemon(): (PythonWorker, Option[ProcessHandle]) = {
 
     def createWorker(): (PythonWorker, Option[ProcessHandle]) = {
-      val socketChannel = SocketChannel.open(new InetSocketAddress(daemonHost, 
daemonPort))
+      val socketChannel = if (isUnixDomainSock) {
+        SocketChannel.open(UnixDomainSocketAddress.of(daemonSockPath))
+      } else {
+        SocketChannel.open(new InetSocketAddress(daemonHost, daemonPort))
+      }
       // These calls are blocking.
       val pid = new 
DataInputStream(Channels.newInputStream(socketChannel)).readInt()
       if (pid < 0) {
@@ -161,7 +170,7 @@ private[spark] class PythonWorkerFactory(
       val processHandle = ProcessHandle.of(pid).orElseThrow(
         () => new IllegalStateException("Python daemon failed to launch 
worker.")
       )
-      authHelper.authToServer(socketChannel.socket())
+      authHelper.authToServer(socketChannel)
       socketChannel.configureBlocking(false)
       val worker = PythonWorker(socketChannel)
       daemonWorkers.put(worker, processHandle)
@@ -192,9 +201,19 @@ private[spark] class PythonWorkerFactory(
   private[spark] def createSimpleWorker(
       blockingMode: Boolean): (PythonWorker, Option[ProcessHandle]) = {
     var serverSocketChannel: ServerSocketChannel = null
+    lazy val sockPath = new File(
+      authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR)
+        .getOrElse(System.getProperty("java.io.tmpdir")),
+      s".${UUID.randomUUID()}.sock")
     try {
-      serverSocketChannel = ServerSocketChannel.open()
-      serverSocketChannel.bind(new 
InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 1)
+      if (isUnixDomainSock) {
+        serverSocketChannel = 
ServerSocketChannel.open(StandardProtocolFamily.UNIX)
+        sockPath.deleteOnExit()
+        serverSocketChannel.bind(UnixDomainSocketAddress.of(sockPath.getPath))
+      } else {
+        serverSocketChannel = ServerSocketChannel.open()
+        serverSocketChannel.bind(new 
InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 1)
+      }
 
       // Create and start the worker
       val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", 
workerModule))
@@ -209,9 +228,14 @@ private[spark] class PythonWorkerFactory(
       workerEnv.put("PYTHONPATH", pythonPath)
       // This is equivalent to setting the -u flag; we use it because ipython 
doesn't support -u:
       workerEnv.put("PYTHONUNBUFFERED", "YES")
-      workerEnv.put("PYTHON_WORKER_FACTORY_PORT", 
serverSocketChannel.socket().getLocalPort
-        .toString)
-      workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
+      if (isUnixDomainSock) {
+        workerEnv.put("PYTHON_WORKER_FACTORY_SOCK_PATH", sockPath.getPath)
+        workerEnv.put("PYTHON_UNIX_DOMAIN_ENABLED", "True")
+      } else {
+        workerEnv.put("PYTHON_WORKER_FACTORY_PORT", 
serverSocketChannel.socket().getLocalPort
+          .toString)
+        workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
+      }
       if (Utils.preferIPv6) {
         workerEnv.put("SPARK_PREFER_IPV6", "True")
       }
@@ -233,7 +257,7 @@ private[spark] class PythonWorkerFactory(
             throw new SocketTimeoutException(
               "Timed out while waiting for the Python worker to connect back")
           }
-        authHelper.authClient(socketChannel.socket())
+        authHelper.authClient(socketChannel)
         // TODO: When we drop JDK 8, we can just use workerProcess.pid()
         val pid = new 
DataInputStream(Channels.newInputStream(socketChannel)).readInt()
         if (pid < 0) {
@@ -254,6 +278,7 @@ private[spark] class PythonWorkerFactory(
     } finally {
       if (serverSocketChannel != null) {
         serverSocketChannel.close()
+        if (isUnixDomainSock) sockPath.delete()
       }
     }
   }
@@ -278,7 +303,15 @@ private[spark] class PythonWorkerFactory(
         val workerEnv = pb.environment()
         workerEnv.putAll(envVars.asJava)
         workerEnv.put("PYTHONPATH", pythonPath)
-        workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
+        if (isUnixDomainSock) {
+          workerEnv.put(
+            "PYTHON_WORKER_FACTORY_SOCK_DIR",
+            authHelper.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR)
+              .getOrElse(System.getProperty("java.io.tmpdir")))
+          workerEnv.put("PYTHON_UNIX_DOMAIN_ENABLED", "True")
+        } else {
+          workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret)
+        }
         if (Utils.preferIPv6) {
           workerEnv.put("SPARK_PREFER_IPV6", "True")
         }
@@ -288,7 +321,11 @@ private[spark] class PythonWorkerFactory(
 
         val in = new DataInputStream(daemon.getInputStream)
         try {
-          daemonPort = in.readInt()
+          if (isUnixDomainSock) {
+            daemonSockPath = PythonWorkerUtils.readUTF(in)
+          } else {
+            daemonPort = in.readInt()
+          }
         } catch {
           case _: EOFException if daemon.isAlive =>
             throw SparkCoreErrors.eofExceptionWhileReadPortNumberError(
@@ -301,10 +338,14 @@ private[spark] class PythonWorkerFactory(
         // test that the returned port number is within a valid range.
         // note: this does not cover the case where the port number
         // is arbitrary data but is also coincidentally within range
-        if (daemonPort < 1 || daemonPort > 0xffff) {
+        val isMalformedPort = !isUnixDomainSock && (daemonPort < 1 || 
daemonPort > 0xffff)
+        val isMalformedSockPath = isUnixDomainSock && !new 
File(daemonSockPath).exists()
+        val errorMsg =
+          if (isUnixDomainSock) daemonSockPath else f"$daemonPort 
(0x$daemonPort%08x)"
+        if (isMalformedPort || isMalformedSockPath) {
           val exceptionMessage = f"""
-            |Bad data in $daemonModule's standard output. Invalid port number:
-            |  $daemonPort (0x$daemonPort%08x)
+            |Bad data in $daemonModule's standard output. Invalid port 
number/socket path:
+            |  $errorMsg
             |Python command to execute the daemon was:
             |  ${command.asScala.mkString(" ")}
             |Check that you don't have any unexpected modules or libraries in
@@ -407,6 +448,7 @@ private[spark] class PythonWorkerFactory(
 
         daemon = null
         daemonPort = 0
+        daemonSockPath = null
       } else {
         simpleWorkers.values.foreach(_.destroy())
       }
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
index ae3614445be6..0a6def051a34 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerUtils.scala
@@ -117,9 +117,15 @@ private[spark] object PythonWorkerUtils extends Logging {
         }
       }
       val server = new EncryptedPythonBroadcastServer(env, idsAndFiles)
-      dataOut.writeInt(server.port)
-      logTrace(s"broadcast decryption server setup on ${server.port}")
-      writeUTF(server.secret, dataOut)
+      server.connInfo match {
+        case portNum: Int =>
+          dataOut.writeInt(portNum)
+          writeUTF(server.secret, dataOut)
+        case sockPath: String =>
+          dataOut.writeInt(-1)
+          writeUTF(sockPath, dataOut)
+      }
+      logTrace(s"broadcast decryption server setup on ${server.connInfo}")
       sendBidsToRemove()
       idsAndFiles.foreach { case (id, _) =>
         // send new broadcast
diff --git 
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
index 6f9708def2f2..7eba574751b4 100644
--- 
a/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
+++ 
b/core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.api.python
 
 import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, 
DataOutputStream}
+import java.nio.channels.Channels
 
 import scala.jdk.CollectionConverters._
 
@@ -25,7 +26,7 @@ import org.apache.spark.{SparkEnv, SparkPythonException}
 import org.apache.spark.internal.{Logging, MDC}
 import org.apache.spark.internal.LogKeys.{PYTHON_WORKER_MODULE, 
PYTHON_WORKER_RESPONSE, SESSION_ID}
 import org.apache.spark.internal.config.BUFFER_SIZE
-import org.apache.spark.internal.config.Python.PYTHON_AUTH_SOCKET_TIMEOUT
+import org.apache.spark.internal.config.Python.{PYTHON_AUTH_SOCKET_TIMEOUT, 
PYTHON_UNIX_DOMAIN_SOCKET_ENABLED}
 
 
 private[spark] object StreamingPythonRunner {
@@ -45,6 +46,7 @@ private[spark] class StreamingPythonRunner(
     sessionId: String,
     workerModule: String) extends Logging {
   private val conf = SparkEnv.get.conf
+  private val isUnixDomainSock = conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
   protected val bufferSize: Int = conf.get(BUFFER_SIZE)
   protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
 
@@ -78,14 +80,20 @@ private[spark] class StreamingPythonRunner(
     pythonWorker = Some(worker)
     pythonWorkerFactory = Some(workerFactory)
 
-    val socket = pythonWorker.get.channel.socket()
-    val stream = new BufferedOutputStream(socket.getOutputStream, bufferSize)
-    val dataIn = new DataInputStream(new 
BufferedInputStream(socket.getInputStream, bufferSize))
+    val socketChannel = pythonWorker.get.channel
+    val stream = new 
BufferedOutputStream(Channels.newOutputStream(socketChannel), bufferSize)
+    val dataIn = new DataInputStream(
+      new BufferedInputStream(Channels.newInputStream(socketChannel), 
bufferSize))
     val dataOut = new DataOutputStream(stream)
 
-    val originalTimeout = socket.getSoTimeout()
-    // Set timeout to 5 minute during initialization config transmission
-    socket.setSoTimeout(5 * 60 * 1000)
+    val originalTimeout = if (!isUnixDomainSock) {
+      val timeout = socketChannel.socket().getSoTimeout()
+      // Set timeout to 5 minute during initialization config transmission
+      socketChannel.socket().setSoTimeout(5 * 60 * 1000)
+      Some(timeout)
+    } else {
+      None
+    }
 
     val resFromPython = try {
       PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
@@ -111,7 +119,7 @@ private[spark] class StreamingPythonRunner(
 
     // Set timeout back to the original timeout
     // Should be infinity by default
-    socket.setSoTimeout(originalTimeout)
+    originalTimeout.foreach(v => socketChannel.socket().setSoTimeout(v))
 
     if (resFromPython != 0) {
       val errMessage = PythonWorkerUtils.readUTF(dataIn)
diff --git a/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala 
b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala
index ac6826a9ec77..5c45986a8f9a 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RAuthHelper.scala
@@ -24,6 +24,7 @@ import org.apache.spark.SparkConf
 import org.apache.spark.security.SocketAuthHelper
 
 private[spark] class RAuthHelper(conf: SparkConf) extends 
SocketAuthHelper(conf) {
+  override val isUnixDomainSock = false
 
   override protected def readUtf8(s: Socket): String = {
     SerDe.readString(new DataInputStream(s.getInputStream()))
diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala 
b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
index ff6ed9f86b55..3b309e093970 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.api.r
 
 import java.io.{File, OutputStream}
-import java.net.Socket
+import java.nio.channels.{Channels, SocketChannel}
 import java.util.{Map => JMap}
 
 import scala.jdk.CollectionConverters._
@@ -179,8 +179,8 @@ private[spark] class RParallelizeServer(sc: 
JavaSparkContext, parallelism: Int)
     extends SocketAuthServer[JavaRDD[Array[Byte]]](
       new RAuthHelper(SparkEnv.get.conf), "sparkr-parallelize-server") {
 
-  override def handleConnection(sock: Socket): JavaRDD[Array[Byte]] = {
-    val in = sock.getInputStream()
+  override def handleConnection(sock: SocketChannel): JavaRDD[Array[Byte]] = {
+    val in = Channels.newInputStream(sock)
     JavaRDD.readRDDFromInputStream(sc.sc, in, parallelism)
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/internal/config/Python.scala 
b/core/src/main/scala/org/apache/spark/internal/config/Python.scala
index 1f827e8dc449..7f9921d58dba 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/Python.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/Python.scala
@@ -70,6 +70,30 @@ private[spark] object Python {
     .booleanConf
     .createWithDefault(false)
 
+  val PYTHON_UNIX_DOMAIN_SOCKET_ENABLED = 
ConfigBuilder("spark.python.unix.domain.socket.enabled")
+    .doc("When set to true, the Python driver uses a Unix domain socket for 
operations like " +
+      "creating or collecting a DataFrame from local data, using accumulators, 
and executing " +
+      "Python functions with PySpark such as Python UDFs. This configuration 
only applies " +
+      "to Spark Classic and Spark Connect server.")
+    .version("4.1.0")
+    .booleanConf
+    .createWithDefault(false)
+
+  val PYTHON_UNIX_DOMAIN_SOCKET_DIR = 
ConfigBuilder("spark.python.unix.domain.socket.dir")
+    .doc("When specified, it uses the directory to create Unix domain socket 
files. " +
+      "Otherwise, it uses the default location of the temporary directory set 
in " +
+      s"'java.io.tmpdir' property. This is used when 
${PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key} " +
+      "is enabled.")
+    .internal()
+    .version("4.1.0")
+    .stringConf
+    // UDS requires the length of path lower than 104 characters. We use UUID 
(36 characters)
+    // and additional prefix "." (1), postfix ".sock" (5), and the path 
separator (1).
+    .checkValue(
+      _.length <= (104 - (36 + 1 + 5 + 1)),
+      s"The directory path should be lower than ${(104 - (36 + 1 + 5 + 1))}")
+    .createOptional
+
   private val PYTHON_WORKER_IDLE_TIMEOUT_SECONDS_KEY = 
"spark.python.worker.idleTimeoutSeconds"
   private val PYTHON_WORKER_KILL_ON_IDLE_TIMEOUT_KEY = 
"spark.python.worker.killOnIdleTimeout"
 
diff --git 
a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala 
b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
index f800553c5388..ecebb97ecfc1 100644
--- a/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
+++ b/core/src/main/scala/org/apache/spark/security/SocketAuthHelper.scala
@@ -19,9 +19,11 @@ package org.apache.spark.security
 
 import java.io.{DataInputStream, DataOutputStream}
 import java.net.Socket
+import java.nio.channels.SocketChannel
 import java.nio.charset.StandardCharsets.UTF_8
 
 import org.apache.spark.SparkConf
+import org.apache.spark.internal.config.Python.{PYTHON_UNIX_DOMAIN_SOCKET_DIR, 
PYTHON_UNIX_DOMAIN_SOCKET_ENABLED}
 import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.util.Utils
 
@@ -35,6 +37,9 @@ import org.apache.spark.util.Utils
  * There's no secrecy, so this relies on the sockets being either local or 
somehow encrypted.
  */
 private[spark] class SocketAuthHelper(val conf: SparkConf) {
+  val isUnixDomainSock: Boolean = conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
+  lazy val sockDir: String =
+    
conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR).getOrElse(System.getProperty("java.io.tmpdir"))
 
   val secret = Utils.createSecret(conf)
 
@@ -47,6 +52,11 @@ private[spark] class SocketAuthHelper(val conf: SparkConf) {
    * @param s The client socket.
    * @throws IllegalArgumentException If authentication fails.
    */
+  def authClient(socket: SocketChannel): Unit = {
+    if (isUnixDomainSock) return
+    authClient(socket.socket())
+  }
+
   def authClient(s: Socket): Unit = {
     var shouldClose = true
     try {
@@ -80,7 +90,9 @@ private[spark] class SocketAuthHelper(val conf: SparkConf) {
    * @param s The socket connected to the server.
    * @throws IllegalArgumentException If authentication fails.
    */
-  def authToServer(s: Socket): Unit = {
+  def authToServer(socket: SocketChannel): Unit = {
+    if (isUnixDomainSock) return
+    val s = socket.socket()
     var shouldClose = true
     try {
       writeUtf8(secret, s)
diff --git 
a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala 
b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
index 9efe2af5fcc8..b0446a4f2feb 100644
--- a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
+++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
@@ -17,8 +17,10 @@
 
 package org.apache.spark.security
 
-import java.io.{BufferedOutputStream, OutputStream}
-import java.net.{InetAddress, ServerSocket, Socket}
+import java.io.{BufferedOutputStream, File, OutputStream}
+import java.net.{InetAddress, InetSocketAddress, StandardProtocolFamily, 
UnixDomainSocketAddress}
+import java.nio.channels.{Channels, ServerSocketChannel, SocketChannel}
+import java.util.UUID
 
 import scala.concurrent.Promise
 import scala.concurrent.duration.Duration
@@ -46,44 +48,70 @@ private[spark] abstract class SocketAuthServer[T](
   def this(threadName: String) = this(SparkEnv.get, threadName)
 
   private val promise = Promise[T]()
+  private val isUnixDomainSock: Boolean = authHelper.isUnixDomainSock
 
-  private def startServer(): (Int, String) = {
+  private def startServer(): (Any, String) = {
     logTrace("Creating listening socket")
-    val address = InetAddress.getLoopbackAddress()
-    val serverSocket = new ServerSocket(0, 1, address)
+    lazy val sockPath = new File(authHelper.sockDir, 
s".${UUID.randomUUID()}.sock")
+
+    val (serverSocketChannel, address) = if (isUnixDomainSock) {
+      val address = UnixDomainSocketAddress.of(sockPath.getPath)
+      val serverChannel = ServerSocketChannel.open(StandardProtocolFamily.UNIX)
+      sockPath.deleteOnExit()
+      serverChannel.bind(address)
+      (serverChannel, address)
+    } else {
+      val address = InetAddress.getLoopbackAddress()
+      val serverChannel = ServerSocketChannel.open()
+      serverChannel.bind(new InetSocketAddress(address, 0), 1)
+      (serverChannel, address)
+    }
+
     // Close the socket if no connection in the configured seconds
     val timeout = authHelper.conf.get(PYTHON_AUTH_SOCKET_TIMEOUT).toInt
     logTrace(s"Setting timeout to $timeout sec")
-    serverSocket.setSoTimeout(timeout * 1000)
+    if (!isUnixDomainSock) serverSocketChannel.socket().setSoTimeout(timeout * 
1000)
 
     new Thread(threadName) {
       setDaemon(true)
       override def run(): Unit = {
-        var sock: Socket = null
+        var sock: SocketChannel = null
         try {
-          logTrace(s"Waiting for connection on $address with port 
${serverSocket.getLocalPort}")
-          sock = serverSocket.accept()
-          logTrace(s"Connection accepted from address 
${sock.getRemoteSocketAddress}")
+          if (isUnixDomainSock) {
+            logTrace(s"Waiting for connection on $address.")
+          } else {
+            logTrace(
+              s"Waiting for connection on $address with port " +
+                s"${serverSocketChannel.socket().getLocalPort}")
+          }
+          sock = serverSocketChannel.accept()
+          logTrace(s"Connection accepted from address 
${sock.getRemoteAddress}")
           authHelper.authClient(sock)
           logTrace("Client authenticated")
           promise.complete(Try(handleConnection(sock)))
         } finally {
           logTrace("Closing server")
-          JavaUtils.closeQuietly(serverSocket)
+          JavaUtils.closeQuietly(serverSocketChannel)
           JavaUtils.closeQuietly(sock)
+          if (isUnixDomainSock) sockPath.delete()
         }
       }
     }.start()
-    (serverSocket.getLocalPort, authHelper.secret)
+    if (isUnixDomainSock) {
+      (sockPath.getPath, null)
+    } else {
+      (serverSocketChannel.socket().getLocalPort, authHelper.secret)
+    }
   }
 
-  val (port, secret) = startServer()
+  // connInfo is either a string (for UDS) or a port number (for TCP/IP).
+  val (connInfo, secret) = startServer()
 
   /**
    * Handle a connection which has already been authenticated.  Any error from 
this function
    * will clean up this connection and the entire server, and get propagated 
to [[getResult]].
    */
-  def handleConnection(sock: Socket): T
+  def handleConnection(sock: SocketChannel): T
 
   /**
    * Blocks indefinitely for [[handleConnection]] to finish, and returns that 
result.  If
@@ -108,9 +136,9 @@ private[spark] abstract class SocketAuthServer[T](
 private[spark] class SocketFuncServer(
     authHelper: SocketAuthHelper,
     threadName: String,
-    func: Socket => Unit) extends SocketAuthServer[Unit](authHelper, 
threadName) {
+    func: SocketChannel => Unit) extends SocketAuthServer[Unit](authHelper, 
threadName) {
 
-  override def handleConnection(sock: Socket): Unit = {
+  override def handleConnection(sock: SocketChannel): Unit = {
     func(sock)
   }
 }
@@ -134,8 +162,8 @@ private[spark] object SocketAuthServer {
   def serveToStream(
       threadName: String,
       authHelper: SocketAuthHelper)(writeFunc: OutputStream => Unit): 
Array[Any] = {
-    val handleFunc = (sock: Socket) => {
-      val out = new BufferedOutputStream(sock.getOutputStream())
+    val handleFunc = (sock: SocketChannel) => {
+      val out = new BufferedOutputStream(Channels.newOutputStream(sock))
       Utils.tryWithSafeFinally {
         writeFunc(out)
       } {
@@ -144,6 +172,6 @@ private[spark] object SocketAuthServer {
     }
 
     val server = new SocketFuncServer(authHelper, threadName, handleFunc)
-    Array(server.port, server.secret, server)
+    Array(server.connInfo, server.secret, server)
   }
 }
diff --git 
a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala 
b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
index 88ad5b3a7483..4efd2870cccb 100644
--- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala
@@ -18,7 +18,8 @@
 package org.apache.spark.api.python
 
 import java.io.{ByteArrayOutputStream, DataOutputStream, File}
-import java.net.{InetAddress, Socket}
+import java.net.{InetAddress, InetSocketAddress}
+import java.nio.channels.SocketChannel
 import java.nio.charset.StandardCharsets
 import java.util
 
@@ -33,6 +34,7 @@ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
 
 import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, 
SparkFunSuite}
 import org.apache.spark.api.java.JavaSparkContext
+import 
org.apache.spark.internal.config.Python.PYTHON_UNIX_DOMAIN_SOCKET_ENABLED
 import org.apache.spark.rdd.{HadoopRDD, RDD}
 import org.apache.spark.security.{SocketAuthHelper, SocketAuthServer}
 import org.apache.spark.util.Utils
@@ -76,10 +78,14 @@ class PythonRDDSuite extends SparkFunSuite with 
LocalSparkContext {
   }
 
   test("python server error handling") {
-    val authHelper = new SocketAuthHelper(new SparkConf())
+    val conf = new SparkConf()
+    conf.set(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key, false.toString)
+    val authHelper = new SocketAuthHelper(conf)
     val errorServer = new ExceptionPythonServer(authHelper)
-    val client = new Socket(InetAddress.getLoopbackAddress(), errorServer.port)
-    authHelper.authToServer(client)
+    val socketChannel = SocketChannel.open(
+      new InetSocketAddress(InetAddress.getLoopbackAddress(),
+        errorServer.connInfo.asInstanceOf[Int]))
+    authHelper.authToServer(socketChannel)
     val ex = intercept[Exception] { errorServer.getResult(Duration(1, 
"second")) }
     assert(ex.getCause().getMessage().contains("exception within 
handleConnection"))
   }
@@ -87,7 +93,7 @@ class PythonRDDSuite extends SparkFunSuite with 
LocalSparkContext {
   class ExceptionPythonServer(authHelper: SocketAuthHelper)
       extends SocketAuthServer[Unit](authHelper, "error-server") {
 
-    override def handleConnection(sock: Socket): Unit = {
+    override def handleConnection(sock: SocketChannel): Unit = {
       throw new Exception("exception within handleConnection")
     }
   }
diff --git 
a/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala 
b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala
index e57cb701b628..c5a6199cf4c1 100644
--- a/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala
+++ b/core/src/test/scala/org/apache/spark/security/SocketAuthHelperSuite.scala
@@ -18,14 +18,17 @@ package org.apache.spark.security
 
 import java.io.Closeable
 import java.net._
+import java.nio.channels.SocketChannel
 
 import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.internal.config._
+import 
org.apache.spark.internal.config.Python.PYTHON_UNIX_DOMAIN_SOCKET_ENABLED
 import org.apache.spark.util.Utils
 
 class SocketAuthHelperSuite extends SparkFunSuite {
 
   private val conf = new SparkConf()
+  conf.set(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key, false.toString)
   private val authHelper = new SocketAuthHelper(conf)
 
   test("successful auth") {
@@ -43,7 +46,9 @@ class SocketAuthHelperSuite extends SparkFunSuite {
   test("failed auth") {
     Utils.tryWithResource(new ServerThread()) { server =>
       Utils.tryWithResource(server.createClient()) { client =>
-        val badHelper = new SocketAuthHelper(new 
SparkConf().set(AUTH_SECRET_BIT_LENGTH, 128))
+        val badHelper = new SocketAuthHelper(new SparkConf()
+          .set(AUTH_SECRET_BIT_LENGTH, 128)
+          .set(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key, false.toString))
         intercept[IllegalArgumentException] {
           badHelper.authToServer(client)
         }
@@ -66,8 +71,9 @@ class SocketAuthHelperSuite extends SparkFunSuite {
     setDaemon(true)
     start()
 
-    def createClient(): Socket = {
-      new Socket(InetAddress.getLoopbackAddress(), ss.getLocalPort())
+    def createClient(): SocketChannel = {
+      SocketChannel.open(new InetSocketAddress(
+        InetAddress.getLoopbackAddress(), ss.getLocalPort))
     }
 
     override def run(): Unit = {
diff --git a/python/pyspark/core/broadcast.py b/python/pyspark/core/broadcast.py
index 69d57c35614d..2d5658284be8 100644
--- a/python/pyspark/core/broadcast.py
+++ b/python/pyspark/core/broadcast.py
@@ -125,8 +125,8 @@ class Broadcast(Generic[T]):
             if sc._encryption_enabled:
                 # with encryption, we ask the jvm to do the encryption for us, 
we send it data
                 # over a socket
-                port, auth_secret = 
self._python_broadcast.setupEncryptionServer()
-                (encryption_sock_file, _) = local_connect_and_auth(port, 
auth_secret)
+                conn_info, auth_secret = 
self._python_broadcast.setupEncryptionServer()
+                (encryption_sock_file, _) = local_connect_and_auth(conn_info, 
auth_secret)
                 broadcast_out = ChunkedStream(encryption_sock_file, 8192)
             else:
                 # no encryption, we can just write pickled data directly to 
the file from python
@@ -270,8 +270,8 @@ class Broadcast(Generic[T]):
             # we only need to decrypt it here when encryption is enabled and
             # if its on the driver, since executor decryption is handled 
already
             if self._sc is not None and self._sc._encryption_enabled:
-                port, auth_secret = 
self._python_broadcast.setupDecryptionServer()
-                (decrypted_sock_file, _) = local_connect_and_auth(port, 
auth_secret)
+                conn_info, auth_secret = 
self._python_broadcast.setupDecryptionServer()
+                (decrypted_sock_file, _) = local_connect_and_auth(conn_info, 
auth_secret)
                 self._python_broadcast.waitTillBroadcastDataSent()
                 return self.load(decrypted_sock_file)
             else:
diff --git a/python/pyspark/core/context.py b/python/pyspark/core/context.py
index 5fcd4ffb0921..4d5c03fd1900 100644
--- a/python/pyspark/core/context.py
+++ b/python/pyspark/core/context.py
@@ -880,7 +880,7 @@ class SparkContext:
         if self._encryption_enabled:
             # with encryption, we open a server in java and send the data 
directly
             server = server_func()
-            (sock_file, _) = local_connect_and_auth(server.port(), 
server.secret())
+            (sock_file, _) = local_connect_and_auth(server.connInfo(), 
server.secret())
             chunked_out = ChunkedStream(sock_file, 8192)
             serializer.dump_stream(data, chunked_out)
             chunked_out.close()
diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py
index a23af109ea6d..ca33ce2c39ef 100644
--- a/python/pyspark/daemon.py
+++ b/python/pyspark/daemon.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
+import uuid
 import numbers
 import os
 import signal
@@ -93,8 +93,20 @@ def manager():
     # Create a new process group to corral our children
     os.setpgid(0, 0)
 
+    is_unix_domain_sock = os.environ.get("PYTHON_UNIX_DOMAIN_ENABLED", 
"false").lower() == "true"
+    socket_path = None
+
     # Create a listening socket on the loopback interface
-    if os.environ.get("SPARK_PREFER_IPV6", "false").lower() == "true":
+    if is_unix_domain_sock:
+        assert "PYTHON_WORKER_FACTORY_SOCK_DIR" in os.environ
+        socket_path = os.path.join(
+            os.environ["PYTHON_WORKER_FACTORY_SOCK_DIR"], 
f".{uuid.uuid4()}.sock"
+        )
+        listen_sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+        listen_sock.bind(socket_path)
+        listen_sock.listen(max(1024, SOMAXCONN))
+        listen_port = socket_path
+    elif os.environ.get("SPARK_PREFER_IPV6", "false").lower() == "true":
         listen_sock = socket.socket(AF_INET6, SOCK_STREAM)
         listen_sock.bind(("::1", 0, 0, 0))
         listen_sock.listen(max(1024, SOMAXCONN))
@@ -108,10 +120,15 @@ def manager():
     # re-open stdin/stdout in 'wb' mode
     stdin_bin = os.fdopen(sys.stdin.fileno(), "rb", 4)
     stdout_bin = os.fdopen(sys.stdout.fileno(), "wb", 4)
-    write_int(listen_port, stdout_bin)
+    if is_unix_domain_sock:
+        write_with_length(listen_port.encode("utf-8"), stdout_bin)
+    else:
+        write_int(listen_port, stdout_bin)
     stdout_bin.flush()
 
     def shutdown(code):
+        if socket_path is not None and os.path.exists(socket_path):
+            os.remove(socket_path)
         signal.signal(SIGTERM, SIG_DFL)
         # Send SIGHUP to notify workers of shutdown
         os.kill(0, SIGHUP)
@@ -195,7 +212,10 @@ def manager():
                         write_int(os.getpid(), outfile)
                         outfile.flush()
                         outfile.close()
-                        authenticated = False
+                        authenticated = (
+                            os.environ.get("PYTHON_UNIX_DOMAIN_ENABLED", 
"false").lower() == "true"
+                            or False
+                        )
                         while True:
                             code = worker(sock, authenticated)
                             if code == 0:
diff --git a/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py 
b/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py
index 66a9b553cc75..e614c347faa9 100644
--- a/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py
+++ b/python/pyspark/ml/deepspeed/tests/test_deepspeed_distributor.py
@@ -227,7 +227,7 @@ class 
DeepspeedTorchDistributorDistributedEndToEnd(unittest.TestCase):
             conf = conf.set(k, v)
         conf = conf.set(
             "spark.worker.resource.gpu.discoveryScript", 
cls.gpu_discovery_script_file_name
-        )
+        ).set("spark.python.unix.domain.socket.enabled", "false")
         sc = SparkContext("local-cluster[2,2,512]", cls.__name__, conf=conf)
         cls.spark = SparkSession(sc)
 
@@ -264,7 +264,7 @@ class 
DeepspeedDistributorLocalEndToEndTests(unittest.TestCase):
             conf = conf.set(k, v)
         conf = conf.set(
             "spark.driver.resource.gpu.discoveryScript", 
cls.gpu_discovery_script_file_name
-        )
+        ).set("spark.python.unix.domain.socket.enabled", "false")
         sc = SparkContext("local-cluster[2,2,512]", cls.__name__, conf=conf)
         cls.spark = SparkSession(sc)
 
diff --git 
a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py 
b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
index b471769ad428..b819634adb5a 100644
--- a/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/foreach_batch_worker.py
@@ -91,9 +91,11 @@ def main(infile: IO, outfile: IO) -> None:
 
 if __name__ == "__main__":
     # Read information about how to connect back to the JVM from the 
environment.
-    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
-    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
-    (sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
+    conn_info = os.environ.get(
+        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+    )
+    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+    (sock_file, sock) = local_connect_and_auth(conn_info, auth_secret)
     # There could be a long time between each micro batch.
     sock.settimeout(None)
     write_int(os.getpid(), sock_file)
diff --git a/python/pyspark/sql/connect/streaming/worker/listener_worker.py 
b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
index a7a5066ca0d7..2c6ce8715994 100644
--- a/python/pyspark/sql/connect/streaming/worker/listener_worker.py
+++ b/python/pyspark/sql/connect/streaming/worker/listener_worker.py
@@ -105,9 +105,11 @@ def main(infile: IO, outfile: IO) -> None:
 
 if __name__ == "__main__":
     # Read information about how to connect back to the JVM from the 
environment.
-    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
-    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
-    (sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
+    conn_info = os.environ.get(
+        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+    )
+    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+    (sock_file, sock) = local_connect_and_auth(conn_info, auth_secret)
     # There could be a long time between each listener event.
     sock.settimeout(None)
     write_int(os.getpid(), sock_file)
diff --git a/python/pyspark/sql/streaming/python_streaming_source_runner.py 
b/python/pyspark/sql/streaming/python_streaming_source_runner.py
index 11aa4e15ab1e..ab988eb714cc 100644
--- a/python/pyspark/sql/streaming/python_streaming_source_runner.py
+++ b/python/pyspark/sql/streaming/python_streaming_source_runner.py
@@ -204,9 +204,11 @@ def main(infile: IO, outfile: IO) -> None:
 
 if __name__ == "__main__":
     # Read information about how to connect back to the JVM from the 
environment.
-    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
-    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
-    (sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
+    conn_info = os.environ.get(
+        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+    )
+    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+    (sock_file, sock) = local_connect_and_auth(conn_info, auth_secret)
     # Prevent the socket from timeout error when query trigger interval is 
large.
     sock.settimeout(None)
     write_int(os.getpid(), sock_file)
diff --git a/python/pyspark/sql/streaming/stateful_processor_api_client.py 
b/python/pyspark/sql/streaming/stateful_processor_api_client.py
index 50945198f9c4..e564d7186faa 100644
--- a/python/pyspark/sql/streaming/stateful_processor_api_client.py
+++ b/python/pyspark/sql/streaming/stateful_processor_api_client.py
@@ -49,22 +49,26 @@ class StatefulProcessorHandleState(Enum):
 
 class StatefulProcessorApiClient:
     def __init__(
-        self, state_server_port: int, key_schema: StructType, is_driver: bool 
= False
+        self, state_server_port: Union[int, str], key_schema: StructType, 
is_driver: bool = False
     ) -> None:
         self.key_schema = key_schema
-        self._client_socket = socket.socket()
-        self._client_socket.connect(("localhost", state_server_port))
-
-        # SPARK-51667: We have a pattern of sending messages continuously from 
one side
-        # (Python -> JVM, and vice versa) before getting response from other 
side. Since most
-        # messages we are sending are small, this triggers the bad combination 
of Nagle's algorithm
-        # and delayed ACKs, which can cause a significant delay on the latency.
-        # See SPARK-51667 for more details on how this can be a problem.
-        #
-        # Disabling either would work, but it's more common to disable Nagle's 
algorithm; there is
-        # lot less reference to disabling delayed ACKs, while there are lots 
of resources to
-        # disable Nagle's algorithm.
-        self._client_socket.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 
1)
+        if isinstance(state_server_port, str):
+            self._client_socket = socket.socket(socket.AF_UNIX, 
socket.SOCK_STREAM)
+            self._client_socket.connect(state_server_port)
+        else:
+            self._client_socket = socket.socket()
+            self._client_socket.connect(("localhost", state_server_port))
+
+            # SPARK-51667: We have a pattern of sending messages continuously 
from one side
+            # (Python -> JVM, and vice versa) before getting response from 
other side. Since most
+            # messages we are sending are small, this triggers the bad 
combination of Nagle's
+            # algorithm and delayed ACKs, which can cause a significant delay 
on the latency.
+            # See SPARK-51667 for more details on how this can be a problem.
+            #
+            # Disabling either would work, but it's more common to disable 
Nagle's algorithm; there
+            # is lot less reference to disabling delayed ACKs, while there are 
lots of resources to
+            # disable Nagle's algorithm.
+            self._client_socket.setsockopt(socket.IPPROTO_TCP, 
socket.TCP_NODELAY, 1)
 
         self.sockfile = self._client_socket.makefile(
             "rwb", int(os.environ.get("SPARK_BUFFER_SIZE", 65536))
diff --git a/python/pyspark/sql/streaming/transform_with_state_driver_worker.py 
b/python/pyspark/sql/streaming/transform_with_state_driver_worker.py
index 99d386f07b5b..3fe7f68a99e5 100644
--- a/python/pyspark/sql/streaming/transform_with_state_driver_worker.py
+++ b/python/pyspark/sql/streaming/transform_with_state_driver_worker.py
@@ -72,9 +72,11 @@ def main(infile: IO, outfile: IO) -> None:
         # This driver runner will only be used on the first batch of a query,
         # and the following code block should be only run once for each query 
run
         state_server_port = read_int(infile)
+        if state_server_port == -1:
+            state_server_port = utf8_deserializer.loads(infile)
         key_schema = 
StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))
         print(
-            f"{log_name} received parameters for UDF. State server port: 
{state_server_port}, "
+            f"{log_name} received parameters for UDF. State server port/path: 
{state_server_port}, "
             f"key schema: {key_schema}.\n"
         )
 
@@ -94,9 +96,11 @@ def main(infile: IO, outfile: IO) -> None:
 
 if __name__ == "__main__":
     # Read information about how to connect back to the JVM from the 
environment.
-    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
-    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
-    (sock_file, sock) = local_connect_and_auth(java_port, auth_secret)
+    conn_info = os.environ.get(
+        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+    )
+    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+    (sock_file, sock) = local_connect_and_auth(conn_info, auth_secret)
     write_int(os.getpid(), sock_file)
     sock_file.flush()
     main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/analyze_udtf.py 
b/python/pyspark/sql/worker/analyze_udtf.py
index 9247fde78004..1c926f4980a5 100644
--- a/python/pyspark/sql/worker/analyze_udtf.py
+++ b/python/pyspark/sql/worker/analyze_udtf.py
@@ -273,9 +273,11 @@ def main(infile: IO, outfile: IO) -> None:
 
 if __name__ == "__main__":
     # Read information about how to connect back to the JVM from the 
environment.
-    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
-    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
-    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+    conn_info = os.environ.get(
+        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+    )
+    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+    (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
     # TODO: Remove the following two lines and use `Process.pid()` when we 
drop JDK 8.
     write_int(os.getpid(), sock_file)
     sock_file.flush()
diff --git a/python/pyspark/sql/worker/commit_data_source_write.py 
b/python/pyspark/sql/worker/commit_data_source_write.py
index c891d9f083cb..d08d65974dfb 100644
--- a/python/pyspark/sql/worker/commit_data_source_write.py
+++ b/python/pyspark/sql/worker/commit_data_source_write.py
@@ -119,9 +119,11 @@ def main(infile: IO, outfile: IO) -> None:
 
 if __name__ == "__main__":
     # Read information about how to connect back to the JVM from the 
environment.
-    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
-    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
-    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+    conn_info = os.environ.get(
+        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+    )
+    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+    (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
     write_int(os.getpid(), sock_file)
     sock_file.flush()
     main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/create_data_source.py 
b/python/pyspark/sql/worker/create_data_source.py
index 33957616c483..424f07012723 100644
--- a/python/pyspark/sql/worker/create_data_source.py
+++ b/python/pyspark/sql/worker/create_data_source.py
@@ -184,9 +184,11 @@ def main(infile: IO, outfile: IO) -> None:
 
 if __name__ == "__main__":
     # Read information about how to connect back to the JVM from the 
environment.
-    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
-    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
-    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+    conn_info = os.environ.get(
+        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+    )
+    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+    (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
     write_int(os.getpid(), sock_file)
     sock_file.flush()
     main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/data_source_pushdown_filters.py 
b/python/pyspark/sql/worker/data_source_pushdown_filters.py
index 9edbaf3a9b72..0415f450fe0f 100644
--- a/python/pyspark/sql/worker/data_source_pushdown_filters.py
+++ b/python/pyspark/sql/worker/data_source_pushdown_filters.py
@@ -269,7 +269,9 @@ def main(infile: IO, outfile: IO) -> None:
 
 if __name__ == "__main__":
     # Read information about how to connect back to the JVM from the 
environment.
-    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
-    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
-    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+    conn_info = os.environ.get(
+        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+    )
+    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+    (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
     main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/lookup_data_sources.py 
b/python/pyspark/sql/worker/lookup_data_sources.py
index 18737095fa9c..af138ab68965 100644
--- a/python/pyspark/sql/worker/lookup_data_sources.py
+++ b/python/pyspark/sql/worker/lookup_data_sources.py
@@ -104,9 +104,11 @@ def main(infile: IO, outfile: IO) -> None:
 
 if __name__ == "__main__":
     # Read information about how to connect back to the JVM from the 
environment.
-    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
-    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
-    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+    conn_info = os.environ.get(
+        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+    )
+    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+    (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
     write_int(os.getpid(), sock_file)
     sock_file.flush()
     main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py 
b/python/pyspark/sql/worker/plan_data_source_read.py
index 7f765a377bea..5edc8185adcf 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -409,9 +409,11 @@ def main(infile: IO, outfile: IO) -> None:
 
 if __name__ == "__main__":
     # Read information about how to connect back to the JVM from the 
environment.
-    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
-    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
-    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+    conn_info = os.environ.get(
+        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+    )
+    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+    (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
     write_int(os.getpid(), sock_file)
     sock_file.flush()
     main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/python_streaming_sink_runner.py 
b/python/pyspark/sql/worker/python_streaming_sink_runner.py
index 13b8f4d30786..cf6246b54490 100644
--- a/python/pyspark/sql/worker/python_streaming_sink_runner.py
+++ b/python/pyspark/sql/worker/python_streaming_sink_runner.py
@@ -148,9 +148,11 @@ def main(infile: IO, outfile: IO) -> None:
 
 if __name__ == "__main__":
     # Read information about how to connect back to the JVM from the 
environment.
-    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
-    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
-    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+    conn_info = os.environ.get(
+        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+    )
+    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+    (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
     write_int(os.getpid(), sock_file)
     sock_file.flush()
     main(sock_file, sock_file)
diff --git a/python/pyspark/sql/worker/write_into_data_source.py 
b/python/pyspark/sql/worker/write_into_data_source.py
index 235e5c249f69..d6d055f01e54 100644
--- a/python/pyspark/sql/worker/write_into_data_source.py
+++ b/python/pyspark/sql/worker/write_into_data_source.py
@@ -255,9 +255,11 @@ def main(infile: IO, outfile: IO) -> None:
 
 if __name__ == "__main__":
     # Read information about how to connect back to the JVM from the 
environment.
-    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
-    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
-    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+    conn_info = os.environ.get(
+        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+    )
+    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+    (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
     write_int(os.getpid(), sock_file)
     sock_file.flush()
     main(sock_file, sock_file)
diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py
index 9785664d7a15..957f9d70687b 100644
--- a/python/pyspark/taskcontext.py
+++ b/python/pyspark/taskcontext.py
@@ -262,8 +262,8 @@ ALL_GATHER_FUNCTION = 2
 
 
 def _load_from_socket(
-    port: Optional[Union[str, int]],
-    auth_secret: str,
+    conn_info: Optional[Union[str, int]],
+    auth_secret: Optional[str],
     function: int,
     all_gather_message: Optional[str] = None,
 ) -> List[str]:
@@ -271,7 +271,7 @@ def _load_from_socket(
     Load data from a given socket, this is a blocking method thus only return 
when the socket
     connection has been closed.
     """
-    (sockfile, sock) = local_connect_and_auth(port, auth_secret)
+    (sockfile, sock) = local_connect_and_auth(conn_info, auth_secret)
 
     # The call may block forever, so no timeout
     sock.settimeout(None)
@@ -331,7 +331,7 @@ class BarrierTaskContext(TaskContext):
     [1]
     """
 
-    _port: ClassVar[Optional[Union[str, int]]] = None
+    _conn_info: ClassVar[Optional[Union[str, int]]] = None
     _secret: ClassVar[Optional[str]] = None
 
     @classmethod
@@ -368,13 +368,13 @@ class BarrierTaskContext(TaskContext):
 
     @classmethod
     def _initialize(
-        cls: Type["BarrierTaskContext"], port: Optional[Union[str, int]], 
secret: str
+        cls: Type["BarrierTaskContext"], conn_info: Optional[Union[str, int]], 
secret: Optional[str]
     ) -> None:
         """
         Initialize :class:`BarrierTaskContext`, other methods within 
:class:`BarrierTaskContext`
         can only be called after BarrierTaskContext is initialized.
         """
-        cls._port = port
+        cls._conn_info = conn_info
         cls._secret = secret
 
     def barrier(self) -> None:
@@ -393,7 +393,7 @@ class BarrierTaskContext(TaskContext):
         calls, in all possible code branches. Otherwise, you may get the job 
hanging
         or a `SparkException` after timeout.
         """
-        if self._port is None or self._secret is None:
+        if self._conn_info is None:
             raise PySparkRuntimeError(
                 errorClass="CALL_BEFORE_INITIALIZE",
                 messageParameters={
@@ -402,7 +402,7 @@ class BarrierTaskContext(TaskContext):
                 },
             )
         else:
-            _load_from_socket(self._port, self._secret, BARRIER_FUNCTION)
+            _load_from_socket(self._conn_info, self._secret, BARRIER_FUNCTION)
 
     def allGather(self, message: str = "") -> List[str]:
         """
@@ -422,7 +422,7 @@ class BarrierTaskContext(TaskContext):
         """
         if not isinstance(message, str):
             raise TypeError("Argument `message` must be of type `str`")
-        elif self._port is None or self._secret is None:
+        elif self._conn_info is None:
             raise PySparkRuntimeError(
                 errorClass="CALL_BEFORE_INITIALIZE",
                 messageParameters={
@@ -431,7 +431,7 @@ class BarrierTaskContext(TaskContext):
                 },
             )
         else:
-            return _load_from_socket(self._port, self._secret, 
ALL_GATHER_FUNCTION, message)
+            return _load_from_socket(self._conn_info, self._secret, 
ALL_GATHER_FUNCTION, message)
 
     def getTaskInfos(self) -> List["BarrierTaskInfo"]:
         """
@@ -453,7 +453,7 @@ class BarrierTaskContext(TaskContext):
         >>> barrier_info.address
         '...:...'
         """
-        if self._port is None or self._secret is None:
+        if self._conn_info is None:
             raise PySparkRuntimeError(
                 errorClass="CALL_BEFORE_INITIALIZE",
                 messageParameters={
diff --git a/python/pyspark/tests/test_appsubmit.py 
b/python/pyspark/tests/test_appsubmit.py
index 5f2c8b49d279..909ed0447154 100644
--- a/python/pyspark/tests/test_appsubmit.py
+++ b/python/pyspark/tests/test_appsubmit.py
@@ -36,6 +36,8 @@ class SparkSubmitTests(unittest.TestCase):
             
"spark.driver.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir),
             "--conf",
             
"spark.executor.extraJavaOptions=-Djava.io.tmpdir={0}".format(tmp_dir),
+            "--conf",
+            "spark.python.unix.domain.socket.enabled=false",
         ]
 
     def tearDown(self):
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index 5a5a8d31e77d..cdfc8d2a4a4f 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -652,9 +652,9 @@ def _create_local_socket(sock_info: "JavaArray") -> 
"io.BufferedRWPair":
     """
     sockfile: "io.BufferedRWPair"
     sock: "socket.socket"
-    port: int = sock_info[0]
+    conn_info: int = sock_info[0]
     auth_secret: str = sock_info[1]
-    sockfile, sock = local_connect_and_auth(port, auth_secret)
+    sockfile, sock = local_connect_and_auth(conn_info, auth_secret)
     # The RDD materialization time is unpredictable, if we set a timeout for 
socket reading
     # operation, it will very possibly fail. See SPARK-18281.
     sock.settimeout(None)
@@ -731,7 +731,9 @@ def _local_iterator_from_socket(sock_info: "JavaArray", 
serializer: "Serializer"
     return iter(PyLocalIterable(sock_info, serializer))
 
 
-def local_connect_and_auth(port: Optional[Union[str, int]], auth_secret: str) 
-> Tuple:
+def local_connect_and_auth(
+    conn_info: Optional[Union[str, int]], auth_secret: Optional[str]
+) -> Tuple:
     """
     Connect to local host, authenticate with it, and return a (sockfile,sock) 
for that connection.
     Handles IPV4 & IPV6, does some error handling.
@@ -739,26 +741,49 @@ def local_connect_and_auth(port: Optional[Union[str, 
int]], auth_secret: str) ->
     Parameters
     ----------
     port : str or int, optional
-    auth_secret : str
+    auth_secret : str, optional
 
     Returns
     -------
     tuple
         with (sockfile, sock)
     """
+    is_unix_domain_socket = isinstance(conn_info, str) and auth_secret is None
+    if is_unix_domain_socket:
+        sock_path = conn_info
+        assert isinstance(sock_path, str)
+        sock = None
+        try:
+            sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
+            sock.settimeout(int(os.environ.get("SPARK_AUTH_SOCKET_TIMEOUT", 
15)))
+            sock.connect(sock_path)
+            sockfile = sock.makefile("rwb", 
int(os.environ.get("SPARK_BUFFER_SIZE", 65536)))
+            return (sockfile, sock)
+        except socket.error as e:
+            if sock is not None:
+                sock.close()
+            raise PySparkRuntimeError(
+                errorClass="CANNOT_OPEN_SOCKET",
+                messageParameters={
+                    "errors": "tried to connect to %s, but an error occurred: 
%s"
+                    % (sock_path, str(e)),
+                },
+            )
+
     sock = None
     errors = []
     # Support for both IPv4 and IPv6.
     addr = "127.0.0.1"
     if os.environ.get("SPARK_PREFER_IPV6", "false").lower() == "true":
         addr = "::1"
-    for res in socket.getaddrinfo(addr, port, socket.AF_UNSPEC, 
socket.SOCK_STREAM):
+    for res in socket.getaddrinfo(addr, conn_info, socket.AF_UNSPEC, 
socket.SOCK_STREAM):
         af, socktype, proto, _, sa = res
         try:
             sock = socket.socket(af, socktype, proto)
             sock.settimeout(int(os.environ.get("SPARK_AUTH_SOCKET_TIMEOUT", 
15)))
             sock.connect(sa)
             sockfile = sock.makefile("rwb", 
int(os.environ.get("SPARK_BUFFER_SIZE", 65536)))
+            assert isinstance(auth_secret, str)
             _do_server_auth(sockfile, auth_secret)
             return (sockfile, sock)
         except socket.error as e:
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 44a5d0b91131..0724ad42e566 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -45,7 +45,6 @@ from pyspark.serializers import (
     write_long,
     read_int,
     SpecialLengths,
-    UTF8Deserializer,
     CPickleSerializer,
     BatchedSerializer,
 )
@@ -1548,6 +1547,8 @@ def read_udfs(pickleSer, infile, eval_type):
             or eval_type == 
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF
         ):
             state_server_port = read_int(infile)
+            if state_server_port == -1:
+                state_server_port = utf8_deserializer.loads(infile)
             key_schema = 
StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))
 
         # NOTE: if timezone is set here, that implies respectSessionTimeZone 
is True
@@ -1983,8 +1984,6 @@ def main(infile, outfile):
 
         # read inputs only for a barrier task
         isBarrier = read_bool(infile)
-        boundPort = read_int(infile)
-        secret = UTF8Deserializer().loads(infile)
 
         memory_limit_mb = int(os.environ.get("PYSPARK_EXECUTOR_MEMORY_MB", 
"-1"))
         setup_memory_limits(memory_limit_mb)
@@ -1992,6 +1991,12 @@ def main(infile, outfile):
         # initialize global state
         taskContext = None
         if isBarrier:
+            boundPort = read_int(infile)
+            secret = None
+            if boundPort == -1:
+                boundPort = utf8_deserializer.loads(infile)
+            else:
+                secret = utf8_deserializer.loads(infile)
             taskContext = BarrierTaskContext._getOrCreate()
             BarrierTaskContext._initialize(boundPort, secret)
             # Set the task context instance here, so we can get it by 
TaskContext.get for
@@ -2085,9 +2090,11 @@ def main(infile, outfile):
 
 if __name__ == "__main__":
     # Read information about how to connect back to the JVM from the 
environment.
-    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
-    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
-    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+    conn_info = os.environ.get(
+        "PYTHON_WORKER_FACTORY_SOCK_PATH", 
int(os.environ.get("PYTHON_WORKER_FACTORY_PORT", -1))
+    )
+    auth_secret = os.environ.get("PYTHON_WORKER_FACTORY_SECRET")
+    (sock_file, _) = local_connect_and_auth(conn_info, auth_secret)
     # TODO: Remove the following two lines and use `Process.pid()` when we 
drop JDK 8.
     write_int(os.getpid(), sock_file)
     sock_file.flush()
diff --git a/python/pyspark/worker_util.py b/python/pyspark/worker_util.py
index 5c758d3f83fe..c2f35db8d52d 100644
--- a/python/pyspark/worker_util.py
+++ b/python/pyspark/worker_util.py
@@ -156,9 +156,13 @@ def setup_broadcasts(infile: IO) -> None:
     num_broadcast_variables = read_int(infile)
     if needs_broadcast_decryption_server:
         # read the decrypted data from a server in the jvm
-        port = read_int(infile)
-        auth_secret = utf8_deserializer.loads(infile)
-        (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret)
+        conn_info = read_int(infile)
+        auth_secret = None
+        if conn_info == -1:
+            conn_info = utf8_deserializer.loads(infile)
+        else:
+            auth_secret = utf8_deserializer.loads(infile)
+        (broadcast_sock_file, _) = local_connect_and_auth(conn_info, 
auth_secret)
 
     for _ in range(num_broadcast_variables):
         bid = read_long(infile)
diff --git a/python/run-tests.py b/python/run-tests.py
index 64ac48e210db..8752f264cd75 100755
--- a/python/run-tests.py
+++ b/python/run-tests.py
@@ -111,6 +111,7 @@ def run_individual_python_test(target_dir, test_name, 
pyspark_python, keep_test_
     while os.path.isdir(tmp_dir):
         tmp_dir = os.path.join(target_dir, str(uuid.uuid4()))
     os.mkdir(tmp_dir)
+    sock_dir = os.getenv('TMPDIR') or os.getenv('TEMP') or os.getenv('TMP') or 
'/tmp'
     env["TMPDIR"] = tmp_dir
     metastore_dir = os.path.join(tmp_dir, str(uuid.uuid4()))
     while os.path.isdir(metastore_dir):
@@ -124,6 +125,7 @@ def run_individual_python_test(target_dir, test_name, 
pyspark_python, keep_test_
         "--conf", "spark.driver.extraJavaOptions='{0}'".format(java_options),
         "--conf", "spark.executor.extraJavaOptions='{0}'".format(java_options),
         "--conf", "spark.sql.warehouse.dir='{0}'".format(metastore_dir),
+        "--conf", "spark.python.unix.domain.socket.dir={0}".format(sock_dir),
         "pyspark-shell",
     ]
 
diff --git 
a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
 
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index 4408817b0426..b3a792bbfc73 100644
--- 
a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ 
b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -40,6 +40,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.deploy.yarn.config._
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config._
+import 
org.apache.spark.internal.config.Python.PYTHON_UNIX_DOMAIN_SOCKET_ENABLED
 import org.apache.spark.internal.config.UI._
 import org.apache.spark.launcher._
 import org.apache.spark.scheduler.{SparkListener, 
SparkListenerApplicationStart, SparkListenerExecutorAdded}
@@ -268,11 +269,19 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
   }
 
   test("run Python application in yarn-client mode") {
-    testPySpark(true)
+    testPySpark(
+      true,
+      // User is unknown in this suite.
+      extraConf = Map(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key -> false.toString)
+    )
   }
 
   test("run Python application in yarn-cluster mode") {
-    testPySpark(false)
+    testPySpark(
+      false,
+      // User is unknown in this suite.
+      extraConf = Map(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key -> false.toString)
+    )
   }
 
   test("run Python application with Spark Connect in yarn-client mode") {
@@ -290,6 +299,7 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
     testPySpark(
       clientMode = false,
       extraConf = Map(
+        PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.key -> false.toString,  // User is 
unknown in this suite.
         "spark.yarn.appMasterEnv.PYSPARK_DRIVER_PYTHON"
           -> sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", pythonExecutablePath),
         "spark.yarn.appMasterEnv.PYSPARK_PYTHON"
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
index 374d38db371a..40779c66600f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala
@@ -18,8 +18,7 @@
 package org.apache.spark.sql.api.python
 
 import java.io.InputStream
-import java.net.Socket
-import java.nio.channels.Channels
+import java.nio.channels.{Channels, SocketChannel}
 
 import net.razorvine.pickle.{Pickler, Unpickler}
 
@@ -197,8 +196,8 @@ private[sql] object PythonSQLUtils extends Logging {
 private[spark] class ArrowIteratorServer
   extends 
SocketAuthServer[Iterator[Array[Byte]]]("pyspark-arrow-batches-server") {
 
-  def handleConnection(sock: Socket): Iterator[Array[Byte]] = {
-    val in = sock.getInputStream()
+  def handleConnection(sock: SocketChannel): Iterator[Array[Byte]] = {
+    val in = Channels.newInputStream(sock)
     val dechunkedInput: InputStream = new DechunkedInputStream(in)
     // Create array to consume iterator so that we can safely close the file
     
ArrowConverters.getBatchesFromStream(Channels.newChannel(dechunkedInput)).toArray.iterator
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
index 89273b7bc80f..3979220618ba 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/PythonStreamingSourceRunner.scala
@@ -19,6 +19,7 @@
 package org.apache.spark.sql.execution.python.streaming
 
 import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, 
DataOutputStream}
+import java.nio.channels.Channels
 
 import scala.collection.mutable.ArrayBuffer
 import scala.jdk.CollectionConverters._
@@ -99,7 +100,7 @@ class PythonStreamingSourceRunner(
     pythonWorkerFactory = Some(workerFactory)
 
     val stream = new BufferedOutputStream(
-      pythonWorker.get.channel.socket().getOutputStream, bufferSize)
+      Channels.newOutputStream(pythonWorker.get.channel), bufferSize)
     dataOut = new DataOutputStream(stream)
 
     PythonWorkerUtils.writePythonVersion(pythonVer, dataOut)
@@ -118,7 +119,7 @@ class PythonStreamingSourceRunner(
     dataOut.flush()
 
     dataIn = new DataInputStream(
-      new 
BufferedInputStream(pythonWorker.get.channel.socket().getInputStream, 
bufferSize))
+      new 
BufferedInputStream(Channels.newInputStream(pythonWorker.get.channel), 
bufferSize))
 
     val initStatus = dataIn.readInt()
     if (initStatus == SpecialLengths.PYTHON_EXCEPTION_THROWN) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasPythonRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasPythonRunner.scala
index 9b2a2518a7b2..638b2d48ffc4 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasPythonRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasPythonRunner.scala
@@ -17,17 +17,20 @@
 
 package org.apache.spark.sql.execution.python.streaming
 
-import java.io.{DataInputStream, DataOutputStream}
-import java.net.ServerSocket
+import java.io.{DataInputStream, DataOutputStream, File}
+import java.net.{InetAddress, InetSocketAddress, StandardProtocolFamily, 
UnixDomainSocketAddress}
+import java.nio.channels.ServerSocketChannel
+import java.util.UUID
 
 import scala.concurrent.ExecutionContext
 
 import org.apache.arrow.vector.VectorSchemaRoot
 import org.apache.arrow.vector.ipc.ArrowStreamWriter
 
-import org.apache.spark.{SparkException, TaskContext}
+import org.apache.spark.{SparkEnv, SparkException, TaskContext}
 import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, 
PythonFunction, PythonRDD, PythonWorkerUtils, StreamingPythonRunner}
 import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.Python.{PYTHON_UNIX_DOMAIN_SOCKET_DIR, 
PYTHON_UNIX_DOMAIN_SOCKET_ENABLED}
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.execution.python.{BasicPythonArrowOutput, 
PythonArrowInput, PythonUDFRunner}
@@ -196,8 +199,13 @@ abstract class 
TransformWithStateInPandasPythonBaseRunner[I](
 
   override protected def handleMetadataBeforeExec(stream: DataOutputStream): 
Unit = {
     super.handleMetadataBeforeExec(stream)
-    // Also write the port number for state server
-    stream.writeInt(stateServerSocketPort)
+    // Also write the port/path number for state server
+    if (isUnixDomainSock) {
+      stream.writeInt(-1)
+      PythonWorkerUtils.writeUTF(stateServerSocketPath, stream)
+    } else {
+      stream.writeInt(stateServerSocketPort)
+    }
     PythonRDD.writeUTF(groupingKeySchema.json, stream)
   }
 
@@ -255,14 +263,19 @@ class TransformWithStateInPandasPythonPreInitRunner(
     dataOut = result._1
     dataIn = result._2
 
-    // start state server, update socket port
+    // start state server, update socket port/path
     startStateServer()
     (dataOut, dataIn)
   }
 
   def process(): Unit = {
-    // Also write the port number for state server
-    dataOut.writeInt(stateServerSocketPort)
+    // Also write the port/path number for state server
+    if (isUnixDomainSock) {
+      dataOut.writeInt(-1)
+      PythonWorkerUtils.writeUTF(stateServerSocketPath, dataOut)
+    } else {
+      dataOut.writeInt(stateServerSocketPort)
+    }
     PythonWorkerUtils.writeUTF(groupingKeySchema.json, dataOut)
     dataOut.flush()
 
@@ -307,14 +320,27 @@ class TransformWithStateInPandasPythonPreInitRunner(
  * in a new daemon thread.
  */
 trait TransformWithStateInPandasPythonRunnerUtils extends Logging {
-  protected var stateServerSocketPort: Int = 0
-  protected var stateServerSocket: ServerSocket = null
+  protected val isUnixDomainSock: Boolean = 
SparkEnv.get.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED)
+  protected var stateServerSocketPort: Int = -1
+  protected var stateServerSocketPath: String = null
+  protected var stateServerSocket: ServerSocketChannel = null
   protected def initStateServer(): Unit = {
     var failed = false
     try {
-      stateServerSocket = new ServerSocket(/* port = */ 0,
-        /* backlog = */ 1)
-      stateServerSocketPort = stateServerSocket.getLocalPort
+      if (isUnixDomainSock) {
+        val sockPath = new File(
+          SparkEnv.get.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_DIR)
+            .getOrElse(System.getProperty("java.io.tmpdir")),
+          s".${UUID.randomUUID()}.sock")
+        stateServerSocket = 
ServerSocketChannel.open(StandardProtocolFamily.UNIX)
+        stateServerSocket.bind(UnixDomainSocketAddress.of(sockPath.getPath), 1)
+        sockPath.deleteOnExit()
+        stateServerSocketPath = sockPath.getPath
+      } else {
+        stateServerSocket = ServerSocketChannel.open()
+          .bind(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 1)
+        stateServerSocketPort = stateServerSocket.socket().getLocalPort
+      }
     } catch {
       case e: Throwable =>
         failed = true
@@ -326,10 +352,13 @@ trait TransformWithStateInPandasPythonRunnerUtils extends 
Logging {
     }
   }
 
-  protected def closeServerSocketChannelSilently(stateServerSocket: 
ServerSocket): Unit = {
+  protected def closeServerSocketChannelSilently(stateServerSocket: 
ServerSocketChannel): Unit = {
     try {
       logInfo(log"closing the state server socket")
       stateServerSocket.close()
+      if (stateServerSocketPath != null) {
+        new File(stateServerSocketPath).delete
+      }
     } catch {
       case e: Exception =>
         logError(log"failed to close state server socket", e)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala
index f46b66204383..3749fb6b7c50 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServer.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.execution.python.streaming
 
 import java.io.{BufferedInputStream, BufferedOutputStream, DataInputStream, 
DataOutputStream, EOFException}
-import java.net.ServerSocket
+import java.nio.channels.{Channels, ServerSocketChannel}
 import java.time.Duration
 
 import scala.collection.mutable
@@ -27,7 +27,9 @@ import com.google.protobuf.ByteString
 import org.apache.arrow.vector.VectorSchemaRoot
 import org.apache.arrow.vector.ipc.ArrowStreamWriter
 
+import org.apache.spark.SparkEnv
 import org.apache.spark.internal.{Logging, LogKeys, MDC}
+import 
org.apache.spark.internal.config.Python.PYTHON_UNIX_DOMAIN_SOCKET_ENABLED
 import org.apache.spark.sql.{Encoders, Row}
 import org.apache.spark.sql.api.python.PythonSQLUtils
 import org.apache.spark.sql.catalyst.InternalRow
@@ -52,7 +54,7 @@ import org.apache.spark.util.Utils
  * - Requests for managing state variables (e.g. valueState).
  */
 class TransformWithStateInPandasStateServer(
-    stateServerSocket: ServerSocket,
+    stateServerSocket: ServerSocketChannel,
     statefulProcessorHandle: StatefulProcessorHandleImplBase,
     groupingKeySchema: StructType,
     timeZoneId: String,
@@ -80,6 +82,10 @@ class TransformWithStateInPandasStateServer(
   private var inputStream: DataInputStream = _
   private var outputStream: DataOutputStream = outputStreamForTest
 
+  private val isUnixDomainSock = Option(SparkEnv.get)
+    .map(_.conf.get(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED))
+    .getOrElse(PYTHON_UNIX_DOMAIN_SOCKET_ENABLED.defaultValue.get)
+
   /** State variable related class variables */
   // A map to store the value state name -> (value state, schema, value row 
deserializer) mapping.
   private val valueStates = if (valueStateMapForTest != null) {
@@ -148,12 +154,12 @@ class TransformWithStateInPandasStateServer(
     // Disabling either would work, but it's more common to disable Nagle's 
algorithm; there is
     // lot less reference to disabling delayed ACKs, while there are lots of 
resources to
     // disable Nagle's algorithm.
-    listeningSocket.setTcpNoDelay(true)
+    if (!isUnixDomainSock) listeningSocket.socket().setTcpNoDelay(true)
 
     inputStream = new DataInputStream(
-      new BufferedInputStream(listeningSocket.getInputStream))
+      new BufferedInputStream(Channels.newInputStream(listeningSocket)))
     outputStream = new DataOutputStream(
-      new BufferedOutputStream(listeningSocket.getOutputStream)
+      new BufferedOutputStream(Channels.newOutputStream(listeningSocket))
     )
 
     while (listeningSocket.isConnected &&
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala
index 305a520f6af8..f1e6379a00c8 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/streaming/TransformWithStateInPandasStateServerSuite.scala
@@ -17,7 +17,7 @@
 package org.apache.spark.sql.execution.python.streaming
 
 import java.io.DataOutputStream
-import java.net.ServerSocket
+import java.nio.channels.ServerSocketChannel
 
 import scala.collection.mutable
 
@@ -39,7 +39,7 @@ import org.apache.spark.sql.types.{IntegerType, StructField, 
StructType}
 class TransformWithStateInPandasStateServerSuite extends SparkFunSuite with 
BeforeAndAfterEach {
   val stateName = "test"
   val iteratorId = "testId"
-  val serverSocket: ServerSocket = mock(classOf[ServerSocket])
+  val serverSocket: ServerSocketChannel = mock(classOf[ServerSocketChannel])
   val groupingKeySchema: StructType = StructType(Seq())
   val stateSchema: StructType = StructType(Array(StructField("value", 
IntegerType)))
   // Below byte array is a serialized row with a single integer value 1.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to