This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 6e41ac5d7015 [SPARK-53978][PYTHON] Support logging in driver-side 
workers
6e41ac5d7015 is described below

commit 6e41ac5d7015d96a238a990d5b9697e58458d7c8
Author: Takuya Ueshin <[email protected]>
AuthorDate: Fri Oct 31 08:04:45 2025 -0700

    [SPARK-53978][PYTHON] Support logging in driver-side workers
    
    ### What changes were proposed in this pull request?
    
    Supports logging in driver-side workers.
    
    ### Why are the changes needed?
    
    The basic logging infrastructure was introduced in 
https://github.com/apache/spark/pull/52689, and the driver-side workers should 
also support logging.
    
    Here adding support for driver-side workers.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, the logging feature will be available in driver-side workers.
    
    ### How was this patch tested?
    
    Added the related tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #52808 from ueshin/issues/SPARK-53978/driverside.
    
    Authored-by: Takuya Ueshin <[email protected]>
    Signed-off-by: Dongjoon Hyun <[email protected]>
---
 python/pyspark/sql/tests/test_python_datasource.py | 356 +++++++++++++++++++++
 python/pyspark/sql/tests/test_udtf.py              |  42 +++
 python/pyspark/sql/worker/analyze_udtf.py          |  12 +-
 .../pyspark/sql/worker/commit_data_source_write.py |  14 +-
 python/pyspark/sql/worker/create_data_source.py    |  89 +++---
 .../sql/worker/data_source_pushdown_filters.py     | 112 +++----
 python/pyspark/sql/worker/plan_data_source_read.py |  75 ++---
 .../sql/worker/python_streaming_sink_runner.py     |  58 ++--
 .../pyspark/sql/worker/write_into_data_source.py   | 140 ++++----
 .../v2/python/PythonBatchWriterFactory.scala       |   6 +-
 .../v2/python/PythonMicroBatchStream.scala         |   2 +-
 .../v2/python/PythonPartitionReaderFactory.scala   |   6 +-
 .../datasources/v2/python/PythonScan.scala         |   9 +-
 .../PythonStreamingPartitionReaderFactory.scala    |   6 +-
 .../v2/python/PythonStreamingWrite.scala           |  16 +-
 .../datasources/v2/python/PythonWrite.scala        |  10 +-
 .../v2/python/UserDefinedPythonDataSource.scala    |   5 +-
 .../sql/execution/python/PythonPlannerRunner.scala |  10 +
 18 files changed, 717 insertions(+), 251 deletions(-)

diff --git a/python/pyspark/sql/tests/test_python_datasource.py 
b/python/pyspark/sql/tests/test_python_datasource.py
index 28f8bf3c832b..cfedf1cf075b 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -19,6 +19,10 @@ import platform
 import sys
 import tempfile
 import unittest
+import logging
+import json
+import os
+from dataclasses import dataclass
 from datetime import datetime
 from decimal import Decimal
 from typing import Callable, Iterable, List, Union, Iterator, Tuple
@@ -57,6 +61,7 @@ from pyspark.testing.sqlutils import (
     have_pyarrow,
     pyarrow_requirement_message,
 )
+from pyspark.util import is_remote_only
 
 
 @unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
@@ -907,6 +912,357 @@ class BasePythonDataSourceTestsMixin:
                             "test_table"
                         )
 
+    @unittest.skipIf(is_remote_only(), "Requires JVM access")
+    def test_data_source_reader_with_logging(self):
+        logger = logging.getLogger("test_data_source_reader")
+
+        class TestJsonReader(DataSourceReader):
+            def __init__(self, options):
+                logger.warning(f"TestJsonReader.__init__: {list(options)}")
+                self.options = options
+
+            def partitions(self):
+                logger.warning("TestJsonReader.partitions")
+                return super().partitions()
+
+            def read(self, partition):
+                logger.warning(f"TestJsonReader.read: {partition}")
+                path = self.options.get("path")
+                if path is None:
+                    raise Exception("path is not specified")
+                with open(path, "r") as file:
+                    for line in file.readlines():
+                        if line.strip():
+                            data = json.loads(line)
+                            yield data.get("name"), data.get("age")
+
+        class TestJsonDataSource(DataSource):
+            def __init__(self, options):
+                super().__init__(options)
+                logger.warning(f"TestJsonDataSource.__init__: {list(options)}")
+
+            @classmethod
+            def name(cls):
+                logger.warning("TestJsonDataSource.name")
+                return "my-json"
+
+            def schema(self):
+                logger.warning("TestJsonDataSource.schema")
+                return "name STRING, age INT"
+
+            def reader(self, schema) -> "DataSourceReader":
+                logger.warning(f"TestJsonDataSource.reader: 
{schema.fieldNames()}")
+                return TestJsonReader(self.options)
+
+        self.spark.dataSource.register(TestJsonDataSource)
+        path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json")
+
+        with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": 
"true"}):
+            assertDataFrameEqual(
+                self.spark.read.format("my-json").load(path1),
+                [
+                    Row(name="Michael", age=None),
+                    Row(name="Andy", age=30),
+                    Row(name="Justin", age=19),
+                ],
+            )
+
+        logs = self.spark.table("system.session.python_worker_logs")
+
+        assertDataFrameEqual(
+            logs.select("level", "msg", "context", "logger"),
+            [
+                Row(
+                    level="WARNING",
+                    msg=msg,
+                    context=context,
+                    logger="test_data_source_reader",
+                )
+                for msg, context in [
+                    (
+                        "TestJsonDataSource.__init__: ['path']",
+                        {"class_name": "TestJsonDataSource", "func_name": 
"__init__"},
+                    ),
+                    (
+                        "TestJsonDataSource.name",
+                        {"class_name": "TestJsonDataSource", "func_name": 
"name"},
+                    ),
+                    (
+                        "TestJsonDataSource.schema",
+                        {"class_name": "TestJsonDataSource", "func_name": 
"schema"},
+                    ),
+                    (
+                        "TestJsonDataSource.reader: ['name', 'age']",
+                        {"class_name": "TestJsonDataSource", "func_name": 
"reader"},
+                    ),
+                    (
+                        "TestJsonReader.__init__: ['path']",
+                        {"class_name": "TestJsonDataSource", "func_name": 
"reader"},
+                    ),
+                    (
+                        "TestJsonReader.partitions",
+                        {"class_name": "TestJsonReader", "func_name": 
"partitions"},
+                    ),
+                    (
+                        "TestJsonReader.read: None",
+                        {"class_name": "TestJsonReader", "func_name": "read"},
+                    ),
+                ]
+            ],
+        )
+
+    @unittest.skipIf(is_remote_only(), "Requires JVM access")
+    def test_data_source_reader_pushdown_with_logging(self):
+        logger = logging.getLogger("test_data_source_reader_pushdown")
+
+        class TestJsonReader(DataSourceReader):
+            def __init__(self, options):
+                logger.warning(f"TestJsonReader.__init__: {list(options)}")
+                self.options = options
+
+            def pushFilters(self, filters):
+                logger.warning(f"TestJsonReader.pushFilters: {filters}")
+                return super().pushFilters(filters)
+
+            def partitions(self):
+                logger.warning("TestJsonReader.partitions")
+                return super().partitions()
+
+            def read(self, partition):
+                logger.warning(f"TestJsonReader.read: {partition}")
+                path = self.options.get("path")
+                if path is None:
+                    raise Exception("path is not specified")
+                with open(path, "r") as file:
+                    for line in file.readlines():
+                        if line.strip():
+                            data = json.loads(line)
+                            yield data.get("name"), data.get("age")
+
+        class TestJsonDataSource(DataSource):
+            def __init__(self, options):
+                super().__init__(options)
+                logger.warning(f"TestJsonDataSource.__init__: {list(options)}")
+
+            @classmethod
+            def name(cls):
+                logger.warning("TestJsonDataSource.name")
+                return "my-json"
+
+            def schema(self):
+                logger.warning("TestJsonDataSource.schema")
+                return "name STRING, age INT"
+
+            def reader(self, schema) -> "DataSourceReader":
+                logger.warning(f"TestJsonDataSource.reader: 
{schema.fieldNames()}")
+                return TestJsonReader(self.options)
+
+        self.spark.dataSource.register(TestJsonDataSource)
+        path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json")
+
+        with self.sql_conf(
+            {
+                "spark.sql.python.filterPushdown.enabled": "true",
+                "spark.sql.pyspark.worker.logging.enabled": "true",
+            }
+        ):
+            assertDataFrameEqual(
+                self.spark.read.format("my-json").load(path1).filter("age is 
not null"),
+                [
+                    Row(name="Andy", age=30),
+                    Row(name="Justin", age=19),
+                ],
+            )
+
+        logs = self.spark.table("system.session.python_worker_logs")
+
+        assertDataFrameEqual(
+            logs.select("level", "msg", "context", "logger"),
+            [
+                Row(
+                    level="WARNING",
+                    msg=msg,
+                    context=context,
+                    logger="test_data_source_reader_pushdown",
+                )
+                for msg, context in [
+                    (
+                        "TestJsonDataSource.__init__: ['path']",
+                        {"class_name": "TestJsonDataSource", "func_name": 
"__init__"},
+                    ),
+                    (
+                        "TestJsonDataSource.name",
+                        {"class_name": "TestJsonDataSource", "func_name": 
"name"},
+                    ),
+                    (
+                        "TestJsonDataSource.schema",
+                        {"class_name": "TestJsonDataSource", "func_name": 
"schema"},
+                    ),
+                    (
+                        "TestJsonDataSource.reader: ['name', 'age']",
+                        {"class_name": "TestJsonDataSource", "func_name": 
"reader"},
+                    ),
+                    (
+                        "TestJsonReader.pushFilters: 
[IsNotNull(attribute=('age',))]",
+                        {"class_name": "TestJsonReader", "func_name": 
"pushFilters"},
+                    ),
+                    (
+                        "TestJsonReader.__init__: ['path']",
+                        {"class_name": "TestJsonDataSource", "func_name": 
"reader"},
+                    ),
+                    (
+                        "TestJsonReader.partitions",
+                        {"class_name": "TestJsonReader", "func_name": 
"partitions"},
+                    ),
+                    (
+                        "TestJsonReader.read: None",
+                        {"class_name": "TestJsonReader", "func_name": "read"},
+                    ),
+                ]
+            ],
+        )
+
+    @unittest.skipIf(is_remote_only(), "Requires JVM access")
+    def test_data_source_writer_with_logging(self):
+        logger = logging.getLogger("test_datasource_writer")
+
+        @dataclass
+        class TestCommitMessage(WriterCommitMessage):
+            count: int
+
+        class TestJsonWriter(DataSourceWriter):
+            def __init__(self, options):
+                logger.warning(f"TestJsonWriter.__init__: {list(options)}")
+                self.options = options
+                self.path = self.options.get("path")
+
+            def write(self, iterator):
+                from pyspark import TaskContext
+
+                if self.options.get("abort", None):
+                    logger.warning("TestJsonWriter.write: abort test")
+                    raise Exception("abort test")
+
+                context = TaskContext.get()
+                output_path = os.path.join(self.path, 
f"{context.partitionId()}.json")
+                count = 0
+                rows = []
+                with open(output_path, "w") as file:
+                    for row in iterator:
+                        count += 1
+                        rows.append(row.asDict())
+                        file.write(json.dumps(row.asDict()) + "\n")
+
+                logger.warning(f"TestJsonWriter.write: {count}, {rows}")
+
+                return TestCommitMessage(count=count)
+
+            def commit(self, messages):
+                total_count = sum(message.count for message in messages)
+                with open(os.path.join(self.path, "_success.txt"), "w") as 
file:
+                    file.write(f"count: {total_count}\n")
+
+                logger.warning(f"TestJsonWriter.commit: {total_count}")
+
+            def abort(self, messages):
+                with open(os.path.join(self.path, "_failed.txt"), "w") as file:
+                    file.write("failed")
+
+                logger.warning("TestJsonWriter.abort")
+
+        class TestJsonDataSource(DataSource):
+            @classmethod
+            def name(cls):
+                logger.warning("TestJsonDataSource.name")
+                return "my-json"
+
+            def writer(self, schema, overwrite):
+                logger.warning(f"TestJsonDataSource.writer: 
{schema.fieldNames(), {overwrite}}")
+                return TestJsonWriter(self.options)
+
+        # Register the data source
+        self.spark.dataSource.register(TestJsonDataSource)
+
+        with 
tempfile.TemporaryDirectory(prefix="test_datasource_write_logging") as d:
+            with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": 
"true"}):
+                # Create a simple DataFrame and write it using our custom 
datasource
+                df = self.spark.createDataFrame(
+                    [("Charlie", 35), ("Diana", 28)], "name STRING, age INT"
+                ).repartitionByRange(2, "age")
+                df.write.format("my-json").mode("overwrite").save(d)
+
+                # Verify the write worked by checking the success file
+                with open(os.path.join(d, "_success.txt"), "r") as file:
+                    text = file.read()
+                self.assertEqual(text, "count: 2\n")
+
+                with self.assertRaises(Exception, msg="abort test"):
+                    df.write.format("my-json").mode("append").option("abort", 
"true").save(d)
+
+        logs = self.spark.table("system.session.python_worker_logs")
+
+        assertDataFrameEqual(
+            logs.select("level", "msg", "context", "logger"),
+            [
+                Row(
+                    level="WARNING",
+                    msg=msg,
+                    context=context,
+                    logger="test_datasource_writer",
+                )
+                for msg, context in [
+                    (
+                        "TestJsonDataSource.name",
+                        {"class_name": "TestJsonDataSource", "func_name": 
"name"},
+                    ),
+                    (
+                        "TestJsonDataSource.writer: (['name', 'age'], {True})",
+                        {"class_name": "TestJsonDataSource", "func_name": 
"writer"},
+                    ),
+                    (
+                        "TestJsonWriter.__init__: ['path']",
+                        {"class_name": "TestJsonDataSource", "func_name": 
"writer"},
+                    ),
+                    (
+                        "TestJsonWriter.write: 1, [{'name': 'Diana', 'age': 
28}]",
+                        {"class_name": "TestJsonWriter", "func_name": "write"},
+                    ),
+                    (
+                        "TestJsonWriter.write: 1, [{'name': 'Charlie', 'age': 
35}]",
+                        {"class_name": "TestJsonWriter", "func_name": "write"},
+                    ),
+                    (
+                        "TestJsonWriter.commit: 2",
+                        {"class_name": "TestJsonWriter", "func_name": 
"commit"},
+                    ),
+                    (
+                        "TestJsonDataSource.name",
+                        {"class_name": "TestJsonDataSource", "func_name": 
"name"},
+                    ),
+                    (
+                        "TestJsonDataSource.writer: (['name', 'age'], 
{False})",
+                        {"class_name": "TestJsonDataSource", "func_name": 
"writer"},
+                    ),
+                    (
+                        "TestJsonWriter.__init__: ['abort', 'path']",
+                        {"class_name": "TestJsonDataSource", "func_name": 
"writer"},
+                    ),
+                    (
+                        "TestJsonWriter.write: abort test",
+                        {"class_name": "TestJsonWriter", "func_name": "write"},
+                    ),
+                    (
+                        "TestJsonWriter.write: abort test",
+                        {"class_name": "TestJsonWriter", "func_name": "write"},
+                    ),
+                    (
+                        "TestJsonWriter.abort",
+                        {"class_name": "TestJsonWriter", "func_name": "abort"},
+                    ),
+                ]
+            ],
+        )
+
 
 class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase):
     ...
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index 579ef8a68972..b86a2624acd5 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -3093,6 +3093,48 @@ class BaseUDTFTestsMixin:
             ],
         )
 
+    @unittest.skipIf(is_remote_only(), "Requires JVM access")
+    def test_udtf_analyze_with_logging(self):
+        @udtf
+        class TestUDTFWithLogging:
+            @staticmethod
+            def analyze(x: AnalyzeArgument) -> AnalyzeResult:
+                logger = logging.getLogger("test_udtf")
+                logger.warning(f"udtf analyze: {x.dataType.json()}")
+                return AnalyzeResult(StructType().add("a", 
IntegerType()).add("b", IntegerType()))
+
+            def eval(self, x: int):
+                yield x * 2, x + 10
+
+        with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled": 
"true"}):
+            assertDataFrameEqual(
+                self.spark.createDataFrame([(5,), (10,)], ["x"]).lateralJoin(
+                    TestUDTFWithLogging(col("x").outer())
+                ),
+                [Row(x=x, a=x * 2, b=x + 10) for x in [5, 10]],
+            )
+
+        logs = self.spark.table("system.session.python_worker_logs")
+
+        assertDataFrameEqual(
+            logs.select(
+                "level",
+                "msg",
+                col("context.class_name").alias("context_class_name"),
+                col("context.func_name").alias("context_func_name"),
+                "logger",
+            ).distinct(),
+            [
+                Row(
+                    level="WARNING",
+                    msg='udtf analyze: "long"',
+                    context_class_name="TestUDTFWithLogging",
+                    context_func_name="analyze",
+                    logger="test_udtf",
+                )
+            ],
+        )
+
 
 class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
     @classmethod
diff --git a/python/pyspark/sql/worker/analyze_udtf.py 
b/python/pyspark/sql/worker/analyze_udtf.py
index 892130bbae16..665b1297fbc1 100644
--- a/python/pyspark/sql/worker/analyze_udtf.py
+++ b/python/pyspark/sql/worker/analyze_udtf.py
@@ -24,6 +24,7 @@ from typing import Dict, List, IO, Tuple
 
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.errors import PySparkRuntimeError, PySparkValueError
+from pyspark.logger.worker_io import capture_outputs, context_provider as 
default_context_provider
 from pyspark.serializers import (
     read_bool,
     read_int,
@@ -154,8 +155,15 @@ def main(infile: IO, outfile: IO) -> None:
                 )
             )
 
-        # Invoke the UDTF's 'analyze' method.
-        result = handler.analyze(*args, **kwargs)  # type: ignore[attr-defined]
+        # The default context provider can't detect the class name from static 
methods.
+        def context_provider() -> dict[str, str]:
+            context = default_context_provider()
+            context["class_name"] = handler.__name__
+            return context
+
+        with capture_outputs(context_provider):
+            # Invoke the UDTF's 'analyze' method.
+            result = handler.analyze(*args, **kwargs)  # type: 
ignore[attr-defined]
 
         # Check invariants about the 'analyze' method after running it.
         if not isinstance(result, AnalyzeResult):
diff --git a/python/pyspark/sql/worker/commit_data_source_write.py 
b/python/pyspark/sql/worker/commit_data_source_write.py
index dd080f1feb6c..fb82b65f3122 100644
--- a/python/pyspark/sql/worker/commit_data_source_write.py
+++ b/python/pyspark/sql/worker/commit_data_source_write.py
@@ -21,6 +21,7 @@ from typing import IO
 
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.errors import PySparkAssertionError
+from pyspark.logger.worker_io import capture_outputs
 from pyspark.serializers import (
     read_bool,
     read_int,
@@ -91,12 +92,13 @@ def main(infile: IO, outfile: IO) -> None:
         # Receive a boolean to indicate whether to invoke `abort`.
         abort = read_bool(infile)
 
-        # Commit or abort the Python data source write.
-        # Note the commit messages can be None if there are failed tasks.
-        if abort:
-            writer.abort(commit_messages)
-        else:
-            writer.commit(commit_messages)
+        with capture_outputs():
+            # Commit or abort the Python data source write.
+            # Note the commit messages can be None if there are failed tasks.
+            if abort:
+                writer.abort(commit_messages)
+            else:
+                writer.commit(commit_messages)
 
         # Send a status code back to JVM.
         write_int(0, outfile)
diff --git a/python/pyspark/sql/worker/create_data_source.py 
b/python/pyspark/sql/worker/create_data_source.py
index fc1b8eaffdaa..15e8fdc618e2 100644
--- a/python/pyspark/sql/worker/create_data_source.py
+++ b/python/pyspark/sql/worker/create_data_source.py
@@ -22,6 +22,7 @@ from typing import IO
 
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.errors import PySparkAssertionError, PySparkTypeError
+from pyspark.logger.worker_io import capture_outputs
 from pyspark.serializers import (
     read_bool,
     read_int,
@@ -106,55 +107,57 @@ def main(infile: IO, outfile: IO) -> None:
         # Receive the provider name.
         provider = utf8_deserializer.loads(infile)
 
-        # Check if the provider name matches the data source's name.
-        if provider.lower() != data_source_cls.name().lower():
-            raise PySparkAssertionError(
-                errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                messageParameters={
-                    "expected": f"provider with name {data_source_cls.name()}",
-                    "actual": f"'{provider}'",
-                },
-            )
-
-        # Receive the user-specified schema
-        user_specified_schema = None
-        if read_bool(infile):
-            user_specified_schema = 
_parse_datatype_json_string(utf8_deserializer.loads(infile))
-            if not isinstance(user_specified_schema, StructType):
+        with capture_outputs():
+            # Check if the provider name matches the data source's name.
+            name = data_source_cls.name()
+            if provider.lower() != name.lower():
                 raise PySparkAssertionError(
                     errorClass="DATA_SOURCE_TYPE_MISMATCH",
                     messageParameters={
-                        "expected": "the user-defined schema to be a 
'StructType'",
-                        "actual": f"'{type(data_source_cls).__name__}'",
+                        "expected": f"provider with name {name}",
+                        "actual": f"'{provider}'",
                     },
                 )
 
-        # Receive the options.
-        options = CaseInsensitiveDict()
-        num_options = read_int(infile)
-        for _ in range(num_options):
-            key = utf8_deserializer.loads(infile)
-            value = utf8_deserializer.loads(infile)
-            options[key] = value
-
-        # Instantiate a data source.
-        data_source = data_source_cls(options=options)  # type: ignore
-
-        # Get the schema of the data source.
-        # If user_specified_schema is not None, use user_specified_schema.
-        # Otherwise, use the schema of the data source.
-        # Throw exception if the data source does not implement schema().
-        is_ddl_string = False
-        if user_specified_schema is None:
-            schema = data_source.schema()
-            if isinstance(schema, str):
-                # Here we cannot use _parse_datatype_string to parse the DDL 
string schema.
-                # as it requires an active Spark session.
-                is_ddl_string = True
-        else:
-            schema = user_specified_schema  # type: ignore
-
-        assert schema is not None
+            # Receive the user-specified schema
+            user_specified_schema = None
+            if read_bool(infile):
+                user_specified_schema = 
_parse_datatype_json_string(utf8_deserializer.loads(infile))
+                if not isinstance(user_specified_schema, StructType):
+                    raise PySparkAssertionError(
+                        errorClass="DATA_SOURCE_TYPE_MISMATCH",
+                        messageParameters={
+                            "expected": "the user-defined schema to be a 
'StructType'",
+                            "actual": f"'{type(data_source_cls).__name__}'",
+                        },
+                    )
+
+            # Receive the options.
+            options = CaseInsensitiveDict()
+            num_options = read_int(infile)
+            for _ in range(num_options):
+                key = utf8_deserializer.loads(infile)
+                value = utf8_deserializer.loads(infile)
+                options[key] = value
+
+            # Instantiate a data source.
+            data_source = data_source_cls(options=options)  # type: ignore
+
+            # Get the schema of the data source.
+            # If user_specified_schema is not None, use user_specified_schema.
+            # Otherwise, use the schema of the data source.
+            # Throw exception if the data source does not implement schema().
+            is_ddl_string = False
+            if user_specified_schema is None:
+                schema = data_source.schema()
+                if isinstance(schema, str):
+                    # Here we cannot use _parse_datatype_string to parse the 
DDL string schema.
+                    # as it requires an active Spark session.
+                    is_ddl_string = True
+            else:
+                schema = user_specified_schema  # type: ignore
+
+            assert schema is not None
 
         # Return the pickled data source instance.
         pickleSer._write_with_length(data_source, outfile)
diff --git a/python/pyspark/sql/worker/data_source_pushdown_filters.py 
b/python/pyspark/sql/worker/data_source_pushdown_filters.py
index b523cab7c49e..8601521bcfb1 100644
--- a/python/pyspark/sql/worker/data_source_pushdown_filters.py
+++ b/python/pyspark/sql/worker/data_source_pushdown_filters.py
@@ -27,6 +27,7 @@ from typing import IO, Type, Union
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.errors import PySparkAssertionError, PySparkValueError
 from pyspark.errors.exceptions.base import PySparkNotImplementedError
+from pyspark.logger.worker_io import capture_outputs
 from pyspark.serializers import SpecialLengths, UTF8Deserializer, read_int, 
read_bool, write_int
 from pyspark.sql.datasource import (
     DataSource,
@@ -187,63 +188,64 @@ def main(infile: IO, outfile: IO) -> None:
                 },
             )
 
-        # Get the reader.
-        reader = data_source.reader(schema=schema)
-        # Validate the reader.
-        if not isinstance(reader, DataSourceReader):
-            raise PySparkAssertionError(
-                errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                messageParameters={
-                    "expected": "an instance of DataSourceReader",
-                    "actual": f"'{type(reader).__name__}'",
-                },
+        with capture_outputs():
+            # Get the reader.
+            reader = data_source.reader(schema=schema)
+            # Validate the reader.
+            if not isinstance(reader, DataSourceReader):
+                raise PySparkAssertionError(
+                    errorClass="DATA_SOURCE_TYPE_MISMATCH",
+                    messageParameters={
+                        "expected": "an instance of DataSourceReader",
+                        "actual": f"'{type(reader).__name__}'",
+                    },
+                )
+
+            # Receive the pushdown filters.
+            json_str = utf8_deserializer.loads(infile)
+            filter_dicts = json.loads(json_str)
+            filters = [FilterRef(deserializeFilter(f)) for f in filter_dicts]
+
+            # Push down the filters and get the indices of the unsupported 
filters.
+            unsupported_filters = set(
+                FilterRef(f) for f in reader.pushFilters([ref.filter for ref 
in filters])
             )
-
-        # Receive the pushdown filters.
-        json_str = utf8_deserializer.loads(infile)
-        filter_dicts = json.loads(json_str)
-        filters = [FilterRef(deserializeFilter(f)) for f in filter_dicts]
-
-        # Push down the filters and get the indices of the unsupported filters.
-        unsupported_filters = set(
-            FilterRef(f) for f in reader.pushFilters([ref.filter for ref in 
filters])
-        )
-        supported_filter_indices = []
-        for i, filter in enumerate(filters):
-            if filter in unsupported_filters:
-                unsupported_filters.remove(filter)
-            else:
-                supported_filter_indices.append(i)
-
-        # If it returned any filters that are not in the original filters, 
raise an error.
-        if len(unsupported_filters) > 0:
-            raise PySparkValueError(
-                errorClass="DATA_SOURCE_EXTRANEOUS_FILTERS",
-                messageParameters={
-                    "type": type(reader).__name__,
-                    "input": str(list(filters)),
-                    "extraneous": str(list(unsupported_filters)),
-                },
+            supported_filter_indices = []
+            for i, filter in enumerate(filters):
+                if filter in unsupported_filters:
+                    unsupported_filters.remove(filter)
+                else:
+                    supported_filter_indices.append(i)
+
+            # If it returned any filters that are not in the original filters, 
raise an error.
+            if len(unsupported_filters) > 0:
+                raise PySparkValueError(
+                    errorClass="DATA_SOURCE_EXTRANEOUS_FILTERS",
+                    messageParameters={
+                        "type": type(reader).__name__,
+                        "input": str(list(filters)),
+                        "extraneous": str(list(unsupported_filters)),
+                    },
+                )
+
+            # Receive the max arrow batch size.
+            max_arrow_batch_size = read_int(infile)
+            assert max_arrow_batch_size > 0, (
+                "The maximum arrow batch size should be greater than 0, but 
got "
+                f"'{max_arrow_batch_size}'"
+            )
+            binary_as_bytes = read_bool(infile)
+
+            # Return the read function and partitions. Doing this in the same 
worker
+            # as filter pushdown helps reduce the number of Python worker 
calls.
+            write_read_func_and_partitions(
+                outfile,
+                reader=reader,
+                data_source=data_source,
+                schema=schema,
+                max_arrow_batch_size=max_arrow_batch_size,
+                binary_as_bytes=binary_as_bytes,
             )
-
-        # Receive the max arrow batch size.
-        max_arrow_batch_size = read_int(infile)
-        assert max_arrow_batch_size > 0, (
-            "The maximum arrow batch size should be greater than 0, but got "
-            f"'{max_arrow_batch_size}'"
-        )
-        binary_as_bytes = read_bool(infile)
-
-        # Return the read function and partitions. Doing this in the same 
worker as filter pushdown
-        # helps reduce the number of Python worker calls.
-        write_read_func_and_partitions(
-            outfile,
-            reader=reader,
-            data_source=data_source,
-            schema=schema,
-            max_arrow_batch_size=max_arrow_batch_size,
-            binary_as_bytes=binary_as_bytes,
-        )
 
         # Return the supported filter indices.
         write_int(len(supported_filter_indices), outfile)
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py 
b/python/pyspark/sql/worker/plan_data_source_read.py
index 7c3c31095c14..db79e58f2ec4 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -25,6 +25,7 @@ from typing import IO, List, Iterator, Iterable, Tuple, Union
 
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
+from pyspark.logger.worker_io import capture_outputs
 from pyspark.serializers import (
     read_bool,
     read_int,
@@ -357,45 +358,47 @@ def main(infile: IO, outfile: IO) -> None:
         is_streaming = read_bool(infile)
         binary_as_bytes = read_bool(infile)
 
-        # Instantiate data source reader.
-        if is_streaming:
-            reader: Union[DataSourceReader, DataSourceStreamReader] = 
_streamReader(
-                data_source, schema
-            )
-        else:
-            reader = data_source.reader(schema=schema)
-            # Validate the reader.
-            if not isinstance(reader, DataSourceReader):
-                raise PySparkAssertionError(
-                    errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                    messageParameters={
-                        "expected": "an instance of DataSourceReader",
-                        "actual": f"'{type(reader).__name__}'",
-                    },
+        with capture_outputs():
+            # Instantiate data source reader.
+            if is_streaming:
+                reader: Union[DataSourceReader, DataSourceStreamReader] = 
_streamReader(
+                    data_source, schema
                 )
-            is_pushdown_implemented = (
-                getattr(reader.pushFilters, "__func__", None) is not 
DataSourceReader.pushFilters
-            )
-            if is_pushdown_implemented and not enable_pushdown:
-                # Do not silently ignore pushFilters when pushdown is disabled.
-                # Raise an error to ask the user to enable pushdown.
-                raise PySparkAssertionError(
-                    errorClass="DATA_SOURCE_PUSHDOWN_DISABLED",
-                    messageParameters={
-                        "type": type(reader).__name__,
-                        "conf": "spark.sql.python.filterPushdown.enabled",
-                    },
+            else:
+                reader = data_source.reader(schema=schema)
+                # Validate the reader.
+                if not isinstance(reader, DataSourceReader):
+                    raise PySparkAssertionError(
+                        errorClass="DATA_SOURCE_TYPE_MISMATCH",
+                        messageParameters={
+                            "expected": "an instance of DataSourceReader",
+                            "actual": f"'{type(reader).__name__}'",
+                        },
+                    )
+                is_pushdown_implemented = (
+                    getattr(reader.pushFilters, "__func__", None)
+                    is not DataSourceReader.pushFilters
                 )
+                if is_pushdown_implemented and not enable_pushdown:
+                    # Do not silently ignore pushFilters when pushdown is 
disabled.
+                    # Raise an error to ask the user to enable pushdown.
+                    raise PySparkAssertionError(
+                        errorClass="DATA_SOURCE_PUSHDOWN_DISABLED",
+                        messageParameters={
+                            "type": type(reader).__name__,
+                            "conf": "spark.sql.python.filterPushdown.enabled",
+                        },
+                    )
 
-        # Send the read function and partitions to the JVM.
-        write_read_func_and_partitions(
-            outfile,
-            reader=reader,
-            data_source=data_source,
-            schema=schema,
-            max_arrow_batch_size=max_arrow_batch_size,
-            binary_as_bytes=binary_as_bytes,
-        )
+            # Send the read function and partitions to the JVM.
+            write_read_func_and_partitions(
+                outfile,
+                reader=reader,
+                data_source=data_source,
+                schema=schema,
+                max_arrow_batch_size=max_arrow_batch_size,
+                binary_as_bytes=binary_as_bytes,
+            )
     except BaseException as e:
         handle_worker_exception(e, outfile)
         sys.exit(-1)
diff --git a/python/pyspark/sql/worker/python_streaming_sink_runner.py 
b/python/pyspark/sql/worker/python_streaming_sink_runner.py
index 83ba027a0601..ed6907ce5b63 100644
--- a/python/pyspark/sql/worker/python_streaming_sink_runner.py
+++ b/python/pyspark/sql/worker/python_streaming_sink_runner.py
@@ -22,6 +22,7 @@ from typing import IO
 
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.errors import PySparkAssertionError
+from pyspark.logger.worker_io import capture_outputs
 from pyspark.serializers import (
     read_bool,
     read_int,
@@ -101,33 +102,36 @@ def main(infile: IO, outfile: IO) -> None:
             )
         # Receive the `overwrite` flag.
         overwrite = read_bool(infile)
-        # Create the data source writer instance.
-        writer = data_source.streamWriter(schema=schema, overwrite=overwrite)
-        # Receive the commit messages.
-        num_messages = read_int(infile)
-
-        commit_messages = []
-        for _ in range(num_messages):
-            message = pickleSer._read_with_length(infile)
-            if message is not None and not isinstance(message, 
WriterCommitMessage):
-                raise PySparkAssertionError(
-                    errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                    messageParameters={
-                        "expected": "an instance of WriterCommitMessage",
-                        "actual": f"'{type(message).__name__}'",
-                    },
-                )
-            commit_messages.append(message)
-
-        batch_id = read_long(infile)
-        abort = read_bool(infile)
-
-        # Commit or abort the Python data source write.
-        # Note the commit messages can be None if there are failed tasks.
-        if abort:
-            writer.abort(commit_messages, batch_id)
-        else:
-            writer.commit(commit_messages, batch_id)
+
+        with capture_outputs():
+            # Create the data source writer instance.
+            writer = data_source.streamWriter(schema=schema, 
overwrite=overwrite)
+            # Receive the commit messages.
+            num_messages = read_int(infile)
+
+            commit_messages = []
+            for _ in range(num_messages):
+                message = pickleSer._read_with_length(infile)
+                if message is not None and not isinstance(message, 
WriterCommitMessage):
+                    raise PySparkAssertionError(
+                        errorClass="DATA_SOURCE_TYPE_MISMATCH",
+                        messageParameters={
+                            "expected": "an instance of WriterCommitMessage",
+                            "actual": f"'{type(message).__name__}'",
+                        },
+                    )
+                commit_messages.append(message)
+
+            batch_id = read_long(infile)
+            abort = read_bool(infile)
+
+            # Commit or abort the Python data source write.
+            # Note the commit messages can be None if there are failed tasks.
+            if abort:
+                writer.abort(commit_messages, batch_id)
+            else:
+                writer.commit(commit_messages, batch_id)
+
         # Send a status code back to JVM.
         write_int(0, outfile)
         outfile.flush()
diff --git a/python/pyspark/sql/worker/write_into_data_source.py 
b/python/pyspark/sql/worker/write_into_data_source.py
index 46f9b168f067..917d0ca8e007 100644
--- a/python/pyspark/sql/worker/write_into_data_source.py
+++ b/python/pyspark/sql/worker/write_into_data_source.py
@@ -23,6 +23,7 @@ from typing import IO, Iterable, Iterator
 from pyspark.accumulators import _accumulatorRegistry
 from pyspark.sql.conversion import ArrowTableToRowsConversion
 from pyspark.errors import PySparkAssertionError, PySparkRuntimeError, 
PySparkTypeError
+from pyspark.logger.worker_io import capture_outputs
 from pyspark.serializers import (
     read_bool,
     read_int,
@@ -122,85 +123,88 @@ def main(infile: IO, outfile: IO) -> None:
         # Receive the provider name.
         provider = utf8_deserializer.loads(infile)
 
-        # Check if the provider name matches the data source's name.
-        if provider.lower() != data_source_cls.name().lower():
-            raise PySparkAssertionError(
-                errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                messageParameters={
-                    "expected": f"provider with name {data_source_cls.name()}",
-                    "actual": f"'{provider}'",
-                },
-            )
-
-        # Receive the input schema
-        schema = _parse_datatype_json_string(utf8_deserializer.loads(infile))
-        if not isinstance(schema, StructType):
-            raise PySparkAssertionError(
-                errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                messageParameters={
-                    "expected": "the schema to be a 'StructType'",
-                    "actual": f"'{type(data_source_cls).__name__}'",
-                },
-            )
+        with capture_outputs():
+            # Check if the provider name matches the data source's name.
+            name = data_source_cls.name()
+            if provider.lower() != name.lower():
+                raise PySparkAssertionError(
+                    errorClass="DATA_SOURCE_TYPE_MISMATCH",
+                    messageParameters={
+                        "expected": f"provider with name {name}",
+                        "actual": f"'{provider}'",
+                    },
+                )
 
-        # Receive the return type
-        return_type = 
_parse_datatype_json_string(utf8_deserializer.loads(infile))
-        if not isinstance(return_type, StructType):
-            raise PySparkAssertionError(
-                errorClass="DATA_SOURCE_TYPE_MISMATCH",
-                messageParameters={
-                    "expected": "a return type of type 'StructType'",
-                    "actual": f"'{type(return_type).__name__}'",
-                },
-            )
-        assert len(return_type) == 1 and isinstance(return_type[0].dataType, 
BinaryType), (
-            "The output schema of Python data source write should contain only 
one column of type "
-            f"'BinaryType', but got '{return_type}'"
-        )
-        return_col_name = return_type[0].name
-
-        # Receive the options.
-        options = CaseInsensitiveDict()
-        num_options = read_int(infile)
-        for _ in range(num_options):
-            key = utf8_deserializer.loads(infile)
-            value = utf8_deserializer.loads(infile)
-            options[key] = value
-
-        # Receive the `overwrite` flag.
-        overwrite = read_bool(infile)
-
-        is_streaming = read_bool(infile)
-        binary_as_bytes = read_bool(infile)
-
-        # Instantiate a data source.
-        data_source = data_source_cls(options=options)  # type: ignore
-
-        if is_streaming:
-            # Instantiate the streaming data source writer.
-            writer = data_source.streamWriter(schema, overwrite)
-            if not isinstance(writer, (DataSourceStreamWriter, 
DataSourceStreamArrowWriter)):
+            # Receive the input schema
+            schema = 
_parse_datatype_json_string(utf8_deserializer.loads(infile))
+            if not isinstance(schema, StructType):
                 raise PySparkAssertionError(
                     errorClass="DATA_SOURCE_TYPE_MISMATCH",
                     messageParameters={
-                        "expected": (
-                            "an instance of DataSourceStreamWriter or "
-                            "DataSourceStreamArrowWriter"
-                        ),
-                        "actual": f"'{type(writer).__name__}'",
+                        "expected": "the schema to be a 'StructType'",
+                        "actual": f"'{type(data_source_cls).__name__}'",
                     },
                 )
-        else:
-            # Instantiate the data source writer.
-            writer = data_source.writer(schema, overwrite)  # type: 
ignore[assignment]
-            if not isinstance(writer, DataSourceWriter):
+
+            # Receive the return type
+            return_type = 
_parse_datatype_json_string(utf8_deserializer.loads(infile))
+            if not isinstance(return_type, StructType):
                 raise PySparkAssertionError(
                     errorClass="DATA_SOURCE_TYPE_MISMATCH",
                     messageParameters={
-                        "expected": "an instance of DataSourceWriter",
-                        "actual": f"'{type(writer).__name__}'",
+                        "expected": "a return type of type 'StructType'",
+                        "actual": f"'{type(return_type).__name__}'",
                     },
                 )
+            assert len(return_type) == 1 and 
isinstance(return_type[0].dataType, BinaryType), (
+                "The output schema of Python data source write should contain 
only one column "
+                f"of type 'BinaryType', but got '{return_type}'"
+            )
+            return_col_name = return_type[0].name
+
+            # Receive the options.
+            options = CaseInsensitiveDict()
+            num_options = read_int(infile)
+            for _ in range(num_options):
+                key = utf8_deserializer.loads(infile)
+                value = utf8_deserializer.loads(infile)
+                options[key] = value
+
+            # Receive the `overwrite` flag.
+            overwrite = read_bool(infile)
+
+            is_streaming = read_bool(infile)
+            binary_as_bytes = read_bool(infile)
+
+            # Instantiate a data source.
+            data_source = data_source_cls(options=options)  # type: ignore
+
+            if is_streaming:
+                # Instantiate the streaming data source writer.
+                writer = data_source.streamWriter(schema, overwrite)
+                if not isinstance(writer, (DataSourceStreamWriter, 
DataSourceStreamArrowWriter)):
+                    raise PySparkAssertionError(
+                        errorClass="DATA_SOURCE_TYPE_MISMATCH",
+                        messageParameters={
+                            "expected": (
+                                "an instance of DataSourceStreamWriter or "
+                                "DataSourceStreamArrowWriter"
+                            ),
+                            "actual": f"'{type(writer).__name__}'",
+                        },
+                    )
+            else:
+                # Instantiate the data source writer.
+
+                writer = data_source.writer(schema, overwrite)  # type: 
ignore[assignment]
+                if not isinstance(writer, DataSourceWriter):
+                    raise PySparkAssertionError(
+                        errorClass="DATA_SOURCE_TYPE_MISMATCH",
+                        messageParameters={
+                            "expected": "an instance of DataSourceWriter",
+                            "actual": f"'{type(writer).__name__}'",
+                        },
+                    )
 
         # Create a function that can be used in mapInArrow.
         import pyarrow as pa
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonBatchWriterFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonBatchWriterFactory.scala
index d5412f1bdd38..aaf8f56f0a9f 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonBatchWriterFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonBatchWriterFactory.scala
@@ -32,7 +32,8 @@ case class PythonBatchWriterFactory(
     source: UserDefinedPythonDataSource,
     pickledWriteFunc: Array[Byte],
     inputSchema: StructType,
-    jobArtifactUUID: Option[String]) extends DataWriterFactory {
+    jobArtifactUUID: Option[String],
+    sessionUUID: Option[String]) extends DataWriterFactory {
   override def createWriter(partitionId: Int, taskId: Long): 
DataWriter[InternalRow] = {
     new DataWriter[InternalRow] {
 
@@ -47,7 +48,8 @@ case class PythonBatchWriterFactory(
           inputSchema,
           UserDefinedPythonDataSource.writeOutputSchema,
           metrics,
-          jobArtifactUUID)
+          jobArtifactUUID,
+          sessionUUID)
         val outputIter = evaluatorFactory.createEvaluator().eval(partitionId, 
records.asScala)
         outputIter.foreach { row =>
           if (commitMessage == null) {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
index 65c71dd4eeb7..50ea7616061c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
@@ -95,7 +95,7 @@ class PythonMicroBatchStream(
 
   override def createReaderFactory(): PartitionReaderFactory = {
     new PythonStreamingPartitionReaderFactory(
-      ds.source, readInfo.func, outputSchema, None)
+      ds.source, readInfo.func, outputSchema, None, None)
   }
 
   override def commit(end: Offset): Unit = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonPartitionReaderFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonPartitionReaderFactory.scala
index 44933779c26a..4496aae65ee9 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonPartitionReaderFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonPartitionReaderFactory.scala
@@ -30,7 +30,8 @@ class PythonPartitionReaderFactory(
     source: UserDefinedPythonDataSource,
     pickledReadFunc: Array[Byte],
     outputSchema: StructType,
-    jobArtifactUUID: Option[String])
+    jobArtifactUUID: Option[String],
+    sessionUUID: Option[String])
   extends PartitionReaderFactory {
 
   override def createReader(partition: InputPartition): 
PartitionReader[InternalRow] = {
@@ -45,7 +46,8 @@ class PythonPartitionReaderFactory(
           UserDefinedPythonDataSource.readInputSchema,
           outputSchema,
           metrics,
-          jobArtifactUUID)
+          jobArtifactUUID,
+          sessionUUID)
 
         val part = partition.asInstanceOf[PythonInputPartition]
         evaluatorFactory.createEvaluator().eval(
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
index 52af33e7aa99..a133c40cde60 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
@@ -17,6 +17,7 @@
 package org.apache.spark.sql.execution.datasources.v2.python
 
 import org.apache.spark.JobArtifactSet
+import org.apache.spark.sql.classic.SparkSession
 import org.apache.spark.sql.connector.metric.CustomMetric
 import org.apache.spark.sql.connector.read._
 import org.apache.spark.sql.connector.read.streaming.MicroBatchStream
@@ -61,6 +62,12 @@ class PythonBatch(
     outputSchema: StructType,
     options: CaseInsensitiveStringMap) extends Batch {
   private val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+  private val sessionUUID = {
+    SparkSession.getActiveSession.collect {
+      case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
+        session.sessionUUID
+    }
+  }
 
   private lazy val infoInPython: PythonDataSourceReadInfo = {
     ds.getOrCreateReadInfo(shortName, options, outputSchema, isStreaming = 
false)
@@ -72,6 +79,6 @@ class PythonBatch(
   override def createReaderFactory(): PartitionReaderFactory = {
     val readerFunc = infoInPython.func
     new PythonPartitionReaderFactory(
-      ds.source, readerFunc, outputSchema, jobArtifactUUID)
+      ds.source, readerFunc, outputSchema, jobArtifactUUID, sessionUUID)
   }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala
index 466ecf609093..2d29a0226397 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala
@@ -41,7 +41,8 @@ class PythonStreamingPartitionReaderFactory(
     source: UserDefinedPythonDataSource,
     pickledReadFunc: Array[Byte],
     outputSchema: StructType,
-    jobArtifactUUID: Option[String])
+    jobArtifactUUID: Option[String],
+    sessionUUID: Option[String])
   extends PartitionReaderFactory with Logging {
 
   override def createReader(partition: InputPartition): 
PartitionReader[InternalRow] = {
@@ -70,7 +71,8 @@ class PythonStreamingPartitionReaderFactory(
           UserDefinedPythonDataSource.readInputSchema,
           outputSchema,
           metrics,
-          jobArtifactUUID)
+          jobArtifactUUID,
+          sessionUUID)
 
         evaluatorFactory.createEvaluator().eval(
           part.index, Iterator.single(InternalRow(part.pickedPartition)))
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingWrite.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingWrite.scala
index 4c149437a300..be7a0429e512 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingWrite.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingWrite.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.python
 
 import org.apache.spark.JobArtifactSet
 import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.classic.SparkSession
 import org.apache.spark.sql.connector.write._
 import 
org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, 
StreamingWrite}
 import org.apache.spark.sql.types.StructType
@@ -36,6 +37,12 @@ class PythonStreamingWrite(
   private var pythonDataSourceWriter: Array[Byte] = _
 
   private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+  private val sessionUUID = {
+    SparkSession.getActiveSession.collect {
+      case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
+        session.sessionUUID
+    }
+  }
 
   private def createDataSourceFunc =
     ds.source.createPythonFunction(
@@ -53,7 +60,8 @@ class PythonStreamingWrite(
 
     pythonDataSourceWriter = writeInfo.writer
 
-    new PythonStreamingWriterFactory(ds.source, writeInfo.func, info.schema(), 
jobArtifactUUID)
+    new PythonStreamingWriterFactory(ds.source, writeInfo.func, info.schema(),
+      jobArtifactUUID, sessionUUID)
   }
 
   override def commit(epochId: Long, messages: Array[WriterCommitMessage]): 
Unit = {
@@ -81,8 +89,10 @@ class PythonStreamingWriterFactory(
     source: UserDefinedPythonDataSource,
     pickledWriteFunc: Array[Byte],
     inputSchema: StructType,
-    jobArtifactUUID: Option[String])
-  extends PythonBatchWriterFactory(source, pickledWriteFunc, inputSchema, 
jobArtifactUUID)
+    jobArtifactUUID: Option[String],
+    sessionUUID: Option[String])
+  extends PythonBatchWriterFactory(source, pickledWriteFunc, inputSchema,
+      jobArtifactUUID, sessionUUID)
     with StreamingDataWriterFactory {
   override def createWriter(
       partitionId: Int,
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala
index 447221715264..156dcb242320 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala
@@ -17,6 +17,7 @@
 package org.apache.spark.sql.execution.datasources.v2.python
 
 import org.apache.spark.JobArtifactSet
+import org.apache.spark.sql.classic.SparkSession
 import org.apache.spark.sql.connector.metric.CustomMetric
 import org.apache.spark.sql.connector.write.{BatchWrite, _}
 import org.apache.spark.sql.connector.write.streaming.StreamingWrite
@@ -56,6 +57,12 @@ class PythonBatchWrite(
   private var pythonDataSourceWriter: Array[Byte] = _
 
   private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+  private val sessionUUID = {
+    SparkSession.getActiveSession.collect {
+      case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
+        session.sessionUUID
+    }
+  }
 
   override def createBatchWriterFactory(physicalInfo: PhysicalWriteInfo): 
DataWriterFactory =
   {
@@ -68,7 +75,8 @@ class PythonBatchWrite(
 
     pythonDataSourceWriter = writeInfo.writer
 
-    PythonBatchWriterFactory(ds.source, writeInfo.func, info.schema(), 
jobArtifactUUID)
+    PythonBatchWriterFactory(ds.source, writeInfo.func, info.schema(),
+      jobArtifactUUID, sessionUUID)
   }
 
   override def commit(messages: Array[WriterCommitMessage]): Unit = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
index 63e7e32c1c7b..c147030037cd 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
@@ -143,7 +143,8 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
       inputSchema: StructType,
       outputSchema: StructType,
       metrics: Map[String, SQLMetric],
-      jobArtifactUUID: Option[String]): MapInBatchEvaluatorFactory = {
+      jobArtifactUUID: Option[String],
+      sessionUUID: Option[String]): MapInBatchEvaluatorFactory = {
     val pythonFunc = createPythonFunction(pickledFunc)
 
     val pythonEvalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF
@@ -171,7 +172,7 @@ case class UserDefinedPythonDataSource(dataSourceCls: 
PythonFunction) {
       pythonRunnerConf,
       metrics,
       jobArtifactUUID,
-      None) // TODO: Python worker logging
+      sessionUUID)
   }
 
   def createPythonMetrics(): Array[CustomMetric] = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
index 9ff934a07a16..0f4ac4ddad71 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
@@ -32,6 +32,7 @@ import org.apache.spark.api.python.{BasePythonRunner, 
PythonFunction, PythonWork
 import org.apache.spark.internal.{Logging, LogKeys}
 import org.apache.spark.internal.config.BUFFER_SIZE
 import org.apache.spark.internal.config.Python._
+import org.apache.spark.sql.classic.SparkSession
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.util.DirectByteBufferOutputStream
 
@@ -62,6 +63,12 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) 
extends Logging {
     val workerMemoryMb = SQLConf.get.pythonPlannerExecMemory
 
     val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+    val sessionUUID = {
+      SparkSession.getActiveSession.collect {
+        case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
+          session.sessionUUID
+      }
+    }
 
     val envVars = new HashMap[String, String](func.envVars)
     val pythonExec = func.pythonExec
@@ -93,6 +100,9 @@ abstract class PythonPlannerRunner[T](func: PythonFunction) 
extends Logging {
     }
 
     envVars.put("SPARK_JOB_ARTIFACT_UUID", 
jobArtifactUUID.getOrElse("default"))
+    sessionUUID.foreach { uuid =>
+      envVars.put("PYSPARK_SPARK_SESSION_UUID", uuid)
+    }
 
     EvaluatePython.registerPicklers()
     val pickler = new Pickler(/* useMemo = */ true,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to