This is an automated email from the ASF dual-hosted git repository.
dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 6e41ac5d7015 [SPARK-53978][PYTHON] Support logging in driver-side
workers
6e41ac5d7015 is described below
commit 6e41ac5d7015d96a238a990d5b9697e58458d7c8
Author: Takuya Ueshin <[email protected]>
AuthorDate: Fri Oct 31 08:04:45 2025 -0700
[SPARK-53978][PYTHON] Support logging in driver-side workers
### What changes were proposed in this pull request?
Supports logging in driver-side workers.
### Why are the changes needed?
The basic logging infrastructure was introduced in
https://github.com/apache/spark/pull/52689, and the driver-side workers should
also support logging.
Here adding support for driver-side workers.
### Does this PR introduce _any_ user-facing change?
Yes, the logging feature will be available in driver-side workers.
### How was this patch tested?
Added the related tests.
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #52808 from ueshin/issues/SPARK-53978/driverside.
Authored-by: Takuya Ueshin <[email protected]>
Signed-off-by: Dongjoon Hyun <[email protected]>
---
python/pyspark/sql/tests/test_python_datasource.py | 356 +++++++++++++++++++++
python/pyspark/sql/tests/test_udtf.py | 42 +++
python/pyspark/sql/worker/analyze_udtf.py | 12 +-
.../pyspark/sql/worker/commit_data_source_write.py | 14 +-
python/pyspark/sql/worker/create_data_source.py | 89 +++---
.../sql/worker/data_source_pushdown_filters.py | 112 +++----
python/pyspark/sql/worker/plan_data_source_read.py | 75 ++---
.../sql/worker/python_streaming_sink_runner.py | 58 ++--
.../pyspark/sql/worker/write_into_data_source.py | 140 ++++----
.../v2/python/PythonBatchWriterFactory.scala | 6 +-
.../v2/python/PythonMicroBatchStream.scala | 2 +-
.../v2/python/PythonPartitionReaderFactory.scala | 6 +-
.../datasources/v2/python/PythonScan.scala | 9 +-
.../PythonStreamingPartitionReaderFactory.scala | 6 +-
.../v2/python/PythonStreamingWrite.scala | 16 +-
.../datasources/v2/python/PythonWrite.scala | 10 +-
.../v2/python/UserDefinedPythonDataSource.scala | 5 +-
.../sql/execution/python/PythonPlannerRunner.scala | 10 +
18 files changed, 717 insertions(+), 251 deletions(-)
diff --git a/python/pyspark/sql/tests/test_python_datasource.py
b/python/pyspark/sql/tests/test_python_datasource.py
index 28f8bf3c832b..cfedf1cf075b 100644
--- a/python/pyspark/sql/tests/test_python_datasource.py
+++ b/python/pyspark/sql/tests/test_python_datasource.py
@@ -19,6 +19,10 @@ import platform
import sys
import tempfile
import unittest
+import logging
+import json
+import os
+from dataclasses import dataclass
from datetime import datetime
from decimal import Decimal
from typing import Callable, Iterable, List, Union, Iterator, Tuple
@@ -57,6 +61,7 @@ from pyspark.testing.sqlutils import (
have_pyarrow,
pyarrow_requirement_message,
)
+from pyspark.util import is_remote_only
@unittest.skipIf(not have_pyarrow, pyarrow_requirement_message)
@@ -907,6 +912,357 @@ class BasePythonDataSourceTestsMixin:
"test_table"
)
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_data_source_reader_with_logging(self):
+ logger = logging.getLogger("test_data_source_reader")
+
+ class TestJsonReader(DataSourceReader):
+ def __init__(self, options):
+ logger.warning(f"TestJsonReader.__init__: {list(options)}")
+ self.options = options
+
+ def partitions(self):
+ logger.warning("TestJsonReader.partitions")
+ return super().partitions()
+
+ def read(self, partition):
+ logger.warning(f"TestJsonReader.read: {partition}")
+ path = self.options.get("path")
+ if path is None:
+ raise Exception("path is not specified")
+ with open(path, "r") as file:
+ for line in file.readlines():
+ if line.strip():
+ data = json.loads(line)
+ yield data.get("name"), data.get("age")
+
+ class TestJsonDataSource(DataSource):
+ def __init__(self, options):
+ super().__init__(options)
+ logger.warning(f"TestJsonDataSource.__init__: {list(options)}")
+
+ @classmethod
+ def name(cls):
+ logger.warning("TestJsonDataSource.name")
+ return "my-json"
+
+ def schema(self):
+ logger.warning("TestJsonDataSource.schema")
+ return "name STRING, age INT"
+
+ def reader(self, schema) -> "DataSourceReader":
+ logger.warning(f"TestJsonDataSource.reader:
{schema.fieldNames()}")
+ return TestJsonReader(self.options)
+
+ self.spark.dataSource.register(TestJsonDataSource)
+ path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json")
+
+ with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled":
"true"}):
+ assertDataFrameEqual(
+ self.spark.read.format("my-json").load(path1),
+ [
+ Row(name="Michael", age=None),
+ Row(name="Andy", age=30),
+ Row(name="Justin", age=19),
+ ],
+ )
+
+ logs = self.spark.table("system.session.python_worker_logs")
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=msg,
+ context=context,
+ logger="test_data_source_reader",
+ )
+ for msg, context in [
+ (
+ "TestJsonDataSource.__init__: ['path']",
+ {"class_name": "TestJsonDataSource", "func_name":
"__init__"},
+ ),
+ (
+ "TestJsonDataSource.name",
+ {"class_name": "TestJsonDataSource", "func_name":
"name"},
+ ),
+ (
+ "TestJsonDataSource.schema",
+ {"class_name": "TestJsonDataSource", "func_name":
"schema"},
+ ),
+ (
+ "TestJsonDataSource.reader: ['name', 'age']",
+ {"class_name": "TestJsonDataSource", "func_name":
"reader"},
+ ),
+ (
+ "TestJsonReader.__init__: ['path']",
+ {"class_name": "TestJsonDataSource", "func_name":
"reader"},
+ ),
+ (
+ "TestJsonReader.partitions",
+ {"class_name": "TestJsonReader", "func_name":
"partitions"},
+ ),
+ (
+ "TestJsonReader.read: None",
+ {"class_name": "TestJsonReader", "func_name": "read"},
+ ),
+ ]
+ ],
+ )
+
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_data_source_reader_pushdown_with_logging(self):
+ logger = logging.getLogger("test_data_source_reader_pushdown")
+
+ class TestJsonReader(DataSourceReader):
+ def __init__(self, options):
+ logger.warning(f"TestJsonReader.__init__: {list(options)}")
+ self.options = options
+
+ def pushFilters(self, filters):
+ logger.warning(f"TestJsonReader.pushFilters: {filters}")
+ return super().pushFilters(filters)
+
+ def partitions(self):
+ logger.warning("TestJsonReader.partitions")
+ return super().partitions()
+
+ def read(self, partition):
+ logger.warning(f"TestJsonReader.read: {partition}")
+ path = self.options.get("path")
+ if path is None:
+ raise Exception("path is not specified")
+ with open(path, "r") as file:
+ for line in file.readlines():
+ if line.strip():
+ data = json.loads(line)
+ yield data.get("name"), data.get("age")
+
+ class TestJsonDataSource(DataSource):
+ def __init__(self, options):
+ super().__init__(options)
+ logger.warning(f"TestJsonDataSource.__init__: {list(options)}")
+
+ @classmethod
+ def name(cls):
+ logger.warning("TestJsonDataSource.name")
+ return "my-json"
+
+ def schema(self):
+ logger.warning("TestJsonDataSource.schema")
+ return "name STRING, age INT"
+
+ def reader(self, schema) -> "DataSourceReader":
+ logger.warning(f"TestJsonDataSource.reader:
{schema.fieldNames()}")
+ return TestJsonReader(self.options)
+
+ self.spark.dataSource.register(TestJsonDataSource)
+ path1 = os.path.join(SPARK_HOME, "python/test_support/sql/people.json")
+
+ with self.sql_conf(
+ {
+ "spark.sql.python.filterPushdown.enabled": "true",
+ "spark.sql.pyspark.worker.logging.enabled": "true",
+ }
+ ):
+ assertDataFrameEqual(
+ self.spark.read.format("my-json").load(path1).filter("age is
not null"),
+ [
+ Row(name="Andy", age=30),
+ Row(name="Justin", age=19),
+ ],
+ )
+
+ logs = self.spark.table("system.session.python_worker_logs")
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=msg,
+ context=context,
+ logger="test_data_source_reader_pushdown",
+ )
+ for msg, context in [
+ (
+ "TestJsonDataSource.__init__: ['path']",
+ {"class_name": "TestJsonDataSource", "func_name":
"__init__"},
+ ),
+ (
+ "TestJsonDataSource.name",
+ {"class_name": "TestJsonDataSource", "func_name":
"name"},
+ ),
+ (
+ "TestJsonDataSource.schema",
+ {"class_name": "TestJsonDataSource", "func_name":
"schema"},
+ ),
+ (
+ "TestJsonDataSource.reader: ['name', 'age']",
+ {"class_name": "TestJsonDataSource", "func_name":
"reader"},
+ ),
+ (
+ "TestJsonReader.pushFilters:
[IsNotNull(attribute=('age',))]",
+ {"class_name": "TestJsonReader", "func_name":
"pushFilters"},
+ ),
+ (
+ "TestJsonReader.__init__: ['path']",
+ {"class_name": "TestJsonDataSource", "func_name":
"reader"},
+ ),
+ (
+ "TestJsonReader.partitions",
+ {"class_name": "TestJsonReader", "func_name":
"partitions"},
+ ),
+ (
+ "TestJsonReader.read: None",
+ {"class_name": "TestJsonReader", "func_name": "read"},
+ ),
+ ]
+ ],
+ )
+
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_data_source_writer_with_logging(self):
+ logger = logging.getLogger("test_datasource_writer")
+
+ @dataclass
+ class TestCommitMessage(WriterCommitMessage):
+ count: int
+
+ class TestJsonWriter(DataSourceWriter):
+ def __init__(self, options):
+ logger.warning(f"TestJsonWriter.__init__: {list(options)}")
+ self.options = options
+ self.path = self.options.get("path")
+
+ def write(self, iterator):
+ from pyspark import TaskContext
+
+ if self.options.get("abort", None):
+ logger.warning("TestJsonWriter.write: abort test")
+ raise Exception("abort test")
+
+ context = TaskContext.get()
+ output_path = os.path.join(self.path,
f"{context.partitionId()}.json")
+ count = 0
+ rows = []
+ with open(output_path, "w") as file:
+ for row in iterator:
+ count += 1
+ rows.append(row.asDict())
+ file.write(json.dumps(row.asDict()) + "\n")
+
+ logger.warning(f"TestJsonWriter.write: {count}, {rows}")
+
+ return TestCommitMessage(count=count)
+
+ def commit(self, messages):
+ total_count = sum(message.count for message in messages)
+ with open(os.path.join(self.path, "_success.txt"), "w") as
file:
+ file.write(f"count: {total_count}\n")
+
+ logger.warning(f"TestJsonWriter.commit: {total_count}")
+
+ def abort(self, messages):
+ with open(os.path.join(self.path, "_failed.txt"), "w") as file:
+ file.write("failed")
+
+ logger.warning("TestJsonWriter.abort")
+
+ class TestJsonDataSource(DataSource):
+ @classmethod
+ def name(cls):
+ logger.warning("TestJsonDataSource.name")
+ return "my-json"
+
+ def writer(self, schema, overwrite):
+ logger.warning(f"TestJsonDataSource.writer:
{schema.fieldNames(), {overwrite}}")
+ return TestJsonWriter(self.options)
+
+ # Register the data source
+ self.spark.dataSource.register(TestJsonDataSource)
+
+ with
tempfile.TemporaryDirectory(prefix="test_datasource_write_logging") as d:
+ with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled":
"true"}):
+ # Create a simple DataFrame and write it using our custom
datasource
+ df = self.spark.createDataFrame(
+ [("Charlie", 35), ("Diana", 28)], "name STRING, age INT"
+ ).repartitionByRange(2, "age")
+ df.write.format("my-json").mode("overwrite").save(d)
+
+ # Verify the write worked by checking the success file
+ with open(os.path.join(d, "_success.txt"), "r") as file:
+ text = file.read()
+ self.assertEqual(text, "count: 2\n")
+
+ with self.assertRaises(Exception, msg="abort test"):
+ df.write.format("my-json").mode("append").option("abort",
"true").save(d)
+
+ logs = self.spark.table("system.session.python_worker_logs")
+
+ assertDataFrameEqual(
+ logs.select("level", "msg", "context", "logger"),
+ [
+ Row(
+ level="WARNING",
+ msg=msg,
+ context=context,
+ logger="test_datasource_writer",
+ )
+ for msg, context in [
+ (
+ "TestJsonDataSource.name",
+ {"class_name": "TestJsonDataSource", "func_name":
"name"},
+ ),
+ (
+ "TestJsonDataSource.writer: (['name', 'age'], {True})",
+ {"class_name": "TestJsonDataSource", "func_name":
"writer"},
+ ),
+ (
+ "TestJsonWriter.__init__: ['path']",
+ {"class_name": "TestJsonDataSource", "func_name":
"writer"},
+ ),
+ (
+ "TestJsonWriter.write: 1, [{'name': 'Diana', 'age':
28}]",
+ {"class_name": "TestJsonWriter", "func_name": "write"},
+ ),
+ (
+ "TestJsonWriter.write: 1, [{'name': 'Charlie', 'age':
35}]",
+ {"class_name": "TestJsonWriter", "func_name": "write"},
+ ),
+ (
+ "TestJsonWriter.commit: 2",
+ {"class_name": "TestJsonWriter", "func_name":
"commit"},
+ ),
+ (
+ "TestJsonDataSource.name",
+ {"class_name": "TestJsonDataSource", "func_name":
"name"},
+ ),
+ (
+ "TestJsonDataSource.writer: (['name', 'age'],
{False})",
+ {"class_name": "TestJsonDataSource", "func_name":
"writer"},
+ ),
+ (
+ "TestJsonWriter.__init__: ['abort', 'path']",
+ {"class_name": "TestJsonDataSource", "func_name":
"writer"},
+ ),
+ (
+ "TestJsonWriter.write: abort test",
+ {"class_name": "TestJsonWriter", "func_name": "write"},
+ ),
+ (
+ "TestJsonWriter.write: abort test",
+ {"class_name": "TestJsonWriter", "func_name": "write"},
+ ),
+ (
+ "TestJsonWriter.abort",
+ {"class_name": "TestJsonWriter", "func_name": "abort"},
+ ),
+ ]
+ ],
+ )
+
class PythonDataSourceTests(BasePythonDataSourceTestsMixin, ReusedSQLTestCase):
...
diff --git a/python/pyspark/sql/tests/test_udtf.py
b/python/pyspark/sql/tests/test_udtf.py
index 579ef8a68972..b86a2624acd5 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -3093,6 +3093,48 @@ class BaseUDTFTestsMixin:
],
)
+ @unittest.skipIf(is_remote_only(), "Requires JVM access")
+ def test_udtf_analyze_with_logging(self):
+ @udtf
+ class TestUDTFWithLogging:
+ @staticmethod
+ def analyze(x: AnalyzeArgument) -> AnalyzeResult:
+ logger = logging.getLogger("test_udtf")
+ logger.warning(f"udtf analyze: {x.dataType.json()}")
+ return AnalyzeResult(StructType().add("a",
IntegerType()).add("b", IntegerType()))
+
+ def eval(self, x: int):
+ yield x * 2, x + 10
+
+ with self.sql_conf({"spark.sql.pyspark.worker.logging.enabled":
"true"}):
+ assertDataFrameEqual(
+ self.spark.createDataFrame([(5,), (10,)], ["x"]).lateralJoin(
+ TestUDTFWithLogging(col("x").outer())
+ ),
+ [Row(x=x, a=x * 2, b=x + 10) for x in [5, 10]],
+ )
+
+ logs = self.spark.table("system.session.python_worker_logs")
+
+ assertDataFrameEqual(
+ logs.select(
+ "level",
+ "msg",
+ col("context.class_name").alias("context_class_name"),
+ col("context.func_name").alias("context_func_name"),
+ "logger",
+ ).distinct(),
+ [
+ Row(
+ level="WARNING",
+ msg='udtf analyze: "long"',
+ context_class_name="TestUDTFWithLogging",
+ context_func_name="analyze",
+ logger="test_udtf",
+ )
+ ],
+ )
+
class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
@classmethod
diff --git a/python/pyspark/sql/worker/analyze_udtf.py
b/python/pyspark/sql/worker/analyze_udtf.py
index 892130bbae16..665b1297fbc1 100644
--- a/python/pyspark/sql/worker/analyze_udtf.py
+++ b/python/pyspark/sql/worker/analyze_udtf.py
@@ -24,6 +24,7 @@ from typing import Dict, List, IO, Tuple
from pyspark.accumulators import _accumulatorRegistry
from pyspark.errors import PySparkRuntimeError, PySparkValueError
+from pyspark.logger.worker_io import capture_outputs, context_provider as
default_context_provider
from pyspark.serializers import (
read_bool,
read_int,
@@ -154,8 +155,15 @@ def main(infile: IO, outfile: IO) -> None:
)
)
- # Invoke the UDTF's 'analyze' method.
- result = handler.analyze(*args, **kwargs) # type: ignore[attr-defined]
+ # The default context provider can't detect the class name from static
methods.
+ def context_provider() -> dict[str, str]:
+ context = default_context_provider()
+ context["class_name"] = handler.__name__
+ return context
+
+ with capture_outputs(context_provider):
+ # Invoke the UDTF's 'analyze' method.
+ result = handler.analyze(*args, **kwargs) # type:
ignore[attr-defined]
# Check invariants about the 'analyze' method after running it.
if not isinstance(result, AnalyzeResult):
diff --git a/python/pyspark/sql/worker/commit_data_source_write.py
b/python/pyspark/sql/worker/commit_data_source_write.py
index dd080f1feb6c..fb82b65f3122 100644
--- a/python/pyspark/sql/worker/commit_data_source_write.py
+++ b/python/pyspark/sql/worker/commit_data_source_write.py
@@ -21,6 +21,7 @@ from typing import IO
from pyspark.accumulators import _accumulatorRegistry
from pyspark.errors import PySparkAssertionError
+from pyspark.logger.worker_io import capture_outputs
from pyspark.serializers import (
read_bool,
read_int,
@@ -91,12 +92,13 @@ def main(infile: IO, outfile: IO) -> None:
# Receive a boolean to indicate whether to invoke `abort`.
abort = read_bool(infile)
- # Commit or abort the Python data source write.
- # Note the commit messages can be None if there are failed tasks.
- if abort:
- writer.abort(commit_messages)
- else:
- writer.commit(commit_messages)
+ with capture_outputs():
+ # Commit or abort the Python data source write.
+ # Note the commit messages can be None if there are failed tasks.
+ if abort:
+ writer.abort(commit_messages)
+ else:
+ writer.commit(commit_messages)
# Send a status code back to JVM.
write_int(0, outfile)
diff --git a/python/pyspark/sql/worker/create_data_source.py
b/python/pyspark/sql/worker/create_data_source.py
index fc1b8eaffdaa..15e8fdc618e2 100644
--- a/python/pyspark/sql/worker/create_data_source.py
+++ b/python/pyspark/sql/worker/create_data_source.py
@@ -22,6 +22,7 @@ from typing import IO
from pyspark.accumulators import _accumulatorRegistry
from pyspark.errors import PySparkAssertionError, PySparkTypeError
+from pyspark.logger.worker_io import capture_outputs
from pyspark.serializers import (
read_bool,
read_int,
@@ -106,55 +107,57 @@ def main(infile: IO, outfile: IO) -> None:
# Receive the provider name.
provider = utf8_deserializer.loads(infile)
- # Check if the provider name matches the data source's name.
- if provider.lower() != data_source_cls.name().lower():
- raise PySparkAssertionError(
- errorClass="DATA_SOURCE_TYPE_MISMATCH",
- messageParameters={
- "expected": f"provider with name {data_source_cls.name()}",
- "actual": f"'{provider}'",
- },
- )
-
- # Receive the user-specified schema
- user_specified_schema = None
- if read_bool(infile):
- user_specified_schema =
_parse_datatype_json_string(utf8_deserializer.loads(infile))
- if not isinstance(user_specified_schema, StructType):
+ with capture_outputs():
+ # Check if the provider name matches the data source's name.
+ name = data_source_cls.name()
+ if provider.lower() != name.lower():
raise PySparkAssertionError(
errorClass="DATA_SOURCE_TYPE_MISMATCH",
messageParameters={
- "expected": "the user-defined schema to be a
'StructType'",
- "actual": f"'{type(data_source_cls).__name__}'",
+ "expected": f"provider with name {name}",
+ "actual": f"'{provider}'",
},
)
- # Receive the options.
- options = CaseInsensitiveDict()
- num_options = read_int(infile)
- for _ in range(num_options):
- key = utf8_deserializer.loads(infile)
- value = utf8_deserializer.loads(infile)
- options[key] = value
-
- # Instantiate a data source.
- data_source = data_source_cls(options=options) # type: ignore
-
- # Get the schema of the data source.
- # If user_specified_schema is not None, use user_specified_schema.
- # Otherwise, use the schema of the data source.
- # Throw exception if the data source does not implement schema().
- is_ddl_string = False
- if user_specified_schema is None:
- schema = data_source.schema()
- if isinstance(schema, str):
- # Here we cannot use _parse_datatype_string to parse the DDL
string schema.
- # as it requires an active Spark session.
- is_ddl_string = True
- else:
- schema = user_specified_schema # type: ignore
-
- assert schema is not None
+ # Receive the user-specified schema
+ user_specified_schema = None
+ if read_bool(infile):
+ user_specified_schema =
_parse_datatype_json_string(utf8_deserializer.loads(infile))
+ if not isinstance(user_specified_schema, StructType):
+ raise PySparkAssertionError(
+ errorClass="DATA_SOURCE_TYPE_MISMATCH",
+ messageParameters={
+ "expected": "the user-defined schema to be a
'StructType'",
+ "actual": f"'{type(data_source_cls).__name__}'",
+ },
+ )
+
+ # Receive the options.
+ options = CaseInsensitiveDict()
+ num_options = read_int(infile)
+ for _ in range(num_options):
+ key = utf8_deserializer.loads(infile)
+ value = utf8_deserializer.loads(infile)
+ options[key] = value
+
+ # Instantiate a data source.
+ data_source = data_source_cls(options=options) # type: ignore
+
+ # Get the schema of the data source.
+ # If user_specified_schema is not None, use user_specified_schema.
+ # Otherwise, use the schema of the data source.
+ # Throw exception if the data source does not implement schema().
+ is_ddl_string = False
+ if user_specified_schema is None:
+ schema = data_source.schema()
+ if isinstance(schema, str):
+ # Here we cannot use _parse_datatype_string to parse the
DDL string schema.
+ # as it requires an active Spark session.
+ is_ddl_string = True
+ else:
+ schema = user_specified_schema # type: ignore
+
+ assert schema is not None
# Return the pickled data source instance.
pickleSer._write_with_length(data_source, outfile)
diff --git a/python/pyspark/sql/worker/data_source_pushdown_filters.py
b/python/pyspark/sql/worker/data_source_pushdown_filters.py
index b523cab7c49e..8601521bcfb1 100644
--- a/python/pyspark/sql/worker/data_source_pushdown_filters.py
+++ b/python/pyspark/sql/worker/data_source_pushdown_filters.py
@@ -27,6 +27,7 @@ from typing import IO, Type, Union
from pyspark.accumulators import _accumulatorRegistry
from pyspark.errors import PySparkAssertionError, PySparkValueError
from pyspark.errors.exceptions.base import PySparkNotImplementedError
+from pyspark.logger.worker_io import capture_outputs
from pyspark.serializers import SpecialLengths, UTF8Deserializer, read_int,
read_bool, write_int
from pyspark.sql.datasource import (
DataSource,
@@ -187,63 +188,64 @@ def main(infile: IO, outfile: IO) -> None:
},
)
- # Get the reader.
- reader = data_source.reader(schema=schema)
- # Validate the reader.
- if not isinstance(reader, DataSourceReader):
- raise PySparkAssertionError(
- errorClass="DATA_SOURCE_TYPE_MISMATCH",
- messageParameters={
- "expected": "an instance of DataSourceReader",
- "actual": f"'{type(reader).__name__}'",
- },
+ with capture_outputs():
+ # Get the reader.
+ reader = data_source.reader(schema=schema)
+ # Validate the reader.
+ if not isinstance(reader, DataSourceReader):
+ raise PySparkAssertionError(
+ errorClass="DATA_SOURCE_TYPE_MISMATCH",
+ messageParameters={
+ "expected": "an instance of DataSourceReader",
+ "actual": f"'{type(reader).__name__}'",
+ },
+ )
+
+ # Receive the pushdown filters.
+ json_str = utf8_deserializer.loads(infile)
+ filter_dicts = json.loads(json_str)
+ filters = [FilterRef(deserializeFilter(f)) for f in filter_dicts]
+
+ # Push down the filters and get the indices of the unsupported
filters.
+ unsupported_filters = set(
+ FilterRef(f) for f in reader.pushFilters([ref.filter for ref
in filters])
)
-
- # Receive the pushdown filters.
- json_str = utf8_deserializer.loads(infile)
- filter_dicts = json.loads(json_str)
- filters = [FilterRef(deserializeFilter(f)) for f in filter_dicts]
-
- # Push down the filters and get the indices of the unsupported filters.
- unsupported_filters = set(
- FilterRef(f) for f in reader.pushFilters([ref.filter for ref in
filters])
- )
- supported_filter_indices = []
- for i, filter in enumerate(filters):
- if filter in unsupported_filters:
- unsupported_filters.remove(filter)
- else:
- supported_filter_indices.append(i)
-
- # If it returned any filters that are not in the original filters,
raise an error.
- if len(unsupported_filters) > 0:
- raise PySparkValueError(
- errorClass="DATA_SOURCE_EXTRANEOUS_FILTERS",
- messageParameters={
- "type": type(reader).__name__,
- "input": str(list(filters)),
- "extraneous": str(list(unsupported_filters)),
- },
+ supported_filter_indices = []
+ for i, filter in enumerate(filters):
+ if filter in unsupported_filters:
+ unsupported_filters.remove(filter)
+ else:
+ supported_filter_indices.append(i)
+
+ # If it returned any filters that are not in the original filters,
raise an error.
+ if len(unsupported_filters) > 0:
+ raise PySparkValueError(
+ errorClass="DATA_SOURCE_EXTRANEOUS_FILTERS",
+ messageParameters={
+ "type": type(reader).__name__,
+ "input": str(list(filters)),
+ "extraneous": str(list(unsupported_filters)),
+ },
+ )
+
+ # Receive the max arrow batch size.
+ max_arrow_batch_size = read_int(infile)
+ assert max_arrow_batch_size > 0, (
+ "The maximum arrow batch size should be greater than 0, but
got "
+ f"'{max_arrow_batch_size}'"
+ )
+ binary_as_bytes = read_bool(infile)
+
+ # Return the read function and partitions. Doing this in the same
worker
+ # as filter pushdown helps reduce the number of Python worker
calls.
+ write_read_func_and_partitions(
+ outfile,
+ reader=reader,
+ data_source=data_source,
+ schema=schema,
+ max_arrow_batch_size=max_arrow_batch_size,
+ binary_as_bytes=binary_as_bytes,
)
-
- # Receive the max arrow batch size.
- max_arrow_batch_size = read_int(infile)
- assert max_arrow_batch_size > 0, (
- "The maximum arrow batch size should be greater than 0, but got "
- f"'{max_arrow_batch_size}'"
- )
- binary_as_bytes = read_bool(infile)
-
- # Return the read function and partitions. Doing this in the same
worker as filter pushdown
- # helps reduce the number of Python worker calls.
- write_read_func_and_partitions(
- outfile,
- reader=reader,
- data_source=data_source,
- schema=schema,
- max_arrow_batch_size=max_arrow_batch_size,
- binary_as_bytes=binary_as_bytes,
- )
# Return the supported filter indices.
write_int(len(supported_filter_indices), outfile)
diff --git a/python/pyspark/sql/worker/plan_data_source_read.py
b/python/pyspark/sql/worker/plan_data_source_read.py
index 7c3c31095c14..db79e58f2ec4 100644
--- a/python/pyspark/sql/worker/plan_data_source_read.py
+++ b/python/pyspark/sql/worker/plan_data_source_read.py
@@ -25,6 +25,7 @@ from typing import IO, List, Iterator, Iterable, Tuple, Union
from pyspark.accumulators import _accumulatorRegistry
from pyspark.errors import PySparkAssertionError, PySparkRuntimeError
+from pyspark.logger.worker_io import capture_outputs
from pyspark.serializers import (
read_bool,
read_int,
@@ -357,45 +358,47 @@ def main(infile: IO, outfile: IO) -> None:
is_streaming = read_bool(infile)
binary_as_bytes = read_bool(infile)
- # Instantiate data source reader.
- if is_streaming:
- reader: Union[DataSourceReader, DataSourceStreamReader] =
_streamReader(
- data_source, schema
- )
- else:
- reader = data_source.reader(schema=schema)
- # Validate the reader.
- if not isinstance(reader, DataSourceReader):
- raise PySparkAssertionError(
- errorClass="DATA_SOURCE_TYPE_MISMATCH",
- messageParameters={
- "expected": "an instance of DataSourceReader",
- "actual": f"'{type(reader).__name__}'",
- },
+ with capture_outputs():
+ # Instantiate data source reader.
+ if is_streaming:
+ reader: Union[DataSourceReader, DataSourceStreamReader] =
_streamReader(
+ data_source, schema
)
- is_pushdown_implemented = (
- getattr(reader.pushFilters, "__func__", None) is not
DataSourceReader.pushFilters
- )
- if is_pushdown_implemented and not enable_pushdown:
- # Do not silently ignore pushFilters when pushdown is disabled.
- # Raise an error to ask the user to enable pushdown.
- raise PySparkAssertionError(
- errorClass="DATA_SOURCE_PUSHDOWN_DISABLED",
- messageParameters={
- "type": type(reader).__name__,
- "conf": "spark.sql.python.filterPushdown.enabled",
- },
+ else:
+ reader = data_source.reader(schema=schema)
+ # Validate the reader.
+ if not isinstance(reader, DataSourceReader):
+ raise PySparkAssertionError(
+ errorClass="DATA_SOURCE_TYPE_MISMATCH",
+ messageParameters={
+ "expected": "an instance of DataSourceReader",
+ "actual": f"'{type(reader).__name__}'",
+ },
+ )
+ is_pushdown_implemented = (
+ getattr(reader.pushFilters, "__func__", None)
+ is not DataSourceReader.pushFilters
)
+ if is_pushdown_implemented and not enable_pushdown:
+ # Do not silently ignore pushFilters when pushdown is
disabled.
+ # Raise an error to ask the user to enable pushdown.
+ raise PySparkAssertionError(
+ errorClass="DATA_SOURCE_PUSHDOWN_DISABLED",
+ messageParameters={
+ "type": type(reader).__name__,
+ "conf": "spark.sql.python.filterPushdown.enabled",
+ },
+ )
- # Send the read function and partitions to the JVM.
- write_read_func_and_partitions(
- outfile,
- reader=reader,
- data_source=data_source,
- schema=schema,
- max_arrow_batch_size=max_arrow_batch_size,
- binary_as_bytes=binary_as_bytes,
- )
+ # Send the read function and partitions to the JVM.
+ write_read_func_and_partitions(
+ outfile,
+ reader=reader,
+ data_source=data_source,
+ schema=schema,
+ max_arrow_batch_size=max_arrow_batch_size,
+ binary_as_bytes=binary_as_bytes,
+ )
except BaseException as e:
handle_worker_exception(e, outfile)
sys.exit(-1)
diff --git a/python/pyspark/sql/worker/python_streaming_sink_runner.py
b/python/pyspark/sql/worker/python_streaming_sink_runner.py
index 83ba027a0601..ed6907ce5b63 100644
--- a/python/pyspark/sql/worker/python_streaming_sink_runner.py
+++ b/python/pyspark/sql/worker/python_streaming_sink_runner.py
@@ -22,6 +22,7 @@ from typing import IO
from pyspark.accumulators import _accumulatorRegistry
from pyspark.errors import PySparkAssertionError
+from pyspark.logger.worker_io import capture_outputs
from pyspark.serializers import (
read_bool,
read_int,
@@ -101,33 +102,36 @@ def main(infile: IO, outfile: IO) -> None:
)
# Receive the `overwrite` flag.
overwrite = read_bool(infile)
- # Create the data source writer instance.
- writer = data_source.streamWriter(schema=schema, overwrite=overwrite)
- # Receive the commit messages.
- num_messages = read_int(infile)
-
- commit_messages = []
- for _ in range(num_messages):
- message = pickleSer._read_with_length(infile)
- if message is not None and not isinstance(message,
WriterCommitMessage):
- raise PySparkAssertionError(
- errorClass="DATA_SOURCE_TYPE_MISMATCH",
- messageParameters={
- "expected": "an instance of WriterCommitMessage",
- "actual": f"'{type(message).__name__}'",
- },
- )
- commit_messages.append(message)
-
- batch_id = read_long(infile)
- abort = read_bool(infile)
-
- # Commit or abort the Python data source write.
- # Note the commit messages can be None if there are failed tasks.
- if abort:
- writer.abort(commit_messages, batch_id)
- else:
- writer.commit(commit_messages, batch_id)
+
+ with capture_outputs():
+ # Create the data source writer instance.
+ writer = data_source.streamWriter(schema=schema,
overwrite=overwrite)
+ # Receive the commit messages.
+ num_messages = read_int(infile)
+
+ commit_messages = []
+ for _ in range(num_messages):
+ message = pickleSer._read_with_length(infile)
+ if message is not None and not isinstance(message,
WriterCommitMessage):
+ raise PySparkAssertionError(
+ errorClass="DATA_SOURCE_TYPE_MISMATCH",
+ messageParameters={
+ "expected": "an instance of WriterCommitMessage",
+ "actual": f"'{type(message).__name__}'",
+ },
+ )
+ commit_messages.append(message)
+
+ batch_id = read_long(infile)
+ abort = read_bool(infile)
+
+ # Commit or abort the Python data source write.
+ # Note the commit messages can be None if there are failed tasks.
+ if abort:
+ writer.abort(commit_messages, batch_id)
+ else:
+ writer.commit(commit_messages, batch_id)
+
# Send a status code back to JVM.
write_int(0, outfile)
outfile.flush()
diff --git a/python/pyspark/sql/worker/write_into_data_source.py
b/python/pyspark/sql/worker/write_into_data_source.py
index 46f9b168f067..917d0ca8e007 100644
--- a/python/pyspark/sql/worker/write_into_data_source.py
+++ b/python/pyspark/sql/worker/write_into_data_source.py
@@ -23,6 +23,7 @@ from typing import IO, Iterable, Iterator
from pyspark.accumulators import _accumulatorRegistry
from pyspark.sql.conversion import ArrowTableToRowsConversion
from pyspark.errors import PySparkAssertionError, PySparkRuntimeError,
PySparkTypeError
+from pyspark.logger.worker_io import capture_outputs
from pyspark.serializers import (
read_bool,
read_int,
@@ -122,85 +123,88 @@ def main(infile: IO, outfile: IO) -> None:
# Receive the provider name.
provider = utf8_deserializer.loads(infile)
- # Check if the provider name matches the data source's name.
- if provider.lower() != data_source_cls.name().lower():
- raise PySparkAssertionError(
- errorClass="DATA_SOURCE_TYPE_MISMATCH",
- messageParameters={
- "expected": f"provider with name {data_source_cls.name()}",
- "actual": f"'{provider}'",
- },
- )
-
- # Receive the input schema
- schema = _parse_datatype_json_string(utf8_deserializer.loads(infile))
- if not isinstance(schema, StructType):
- raise PySparkAssertionError(
- errorClass="DATA_SOURCE_TYPE_MISMATCH",
- messageParameters={
- "expected": "the schema to be a 'StructType'",
- "actual": f"'{type(data_source_cls).__name__}'",
- },
- )
+ with capture_outputs():
+ # Check if the provider name matches the data source's name.
+ name = data_source_cls.name()
+ if provider.lower() != name.lower():
+ raise PySparkAssertionError(
+ errorClass="DATA_SOURCE_TYPE_MISMATCH",
+ messageParameters={
+ "expected": f"provider with name {name}",
+ "actual": f"'{provider}'",
+ },
+ )
- # Receive the return type
- return_type =
_parse_datatype_json_string(utf8_deserializer.loads(infile))
- if not isinstance(return_type, StructType):
- raise PySparkAssertionError(
- errorClass="DATA_SOURCE_TYPE_MISMATCH",
- messageParameters={
- "expected": "a return type of type 'StructType'",
- "actual": f"'{type(return_type).__name__}'",
- },
- )
- assert len(return_type) == 1 and isinstance(return_type[0].dataType,
BinaryType), (
- "The output schema of Python data source write should contain only
one column of type "
- f"'BinaryType', but got '{return_type}'"
- )
- return_col_name = return_type[0].name
-
- # Receive the options.
- options = CaseInsensitiveDict()
- num_options = read_int(infile)
- for _ in range(num_options):
- key = utf8_deserializer.loads(infile)
- value = utf8_deserializer.loads(infile)
- options[key] = value
-
- # Receive the `overwrite` flag.
- overwrite = read_bool(infile)
-
- is_streaming = read_bool(infile)
- binary_as_bytes = read_bool(infile)
-
- # Instantiate a data source.
- data_source = data_source_cls(options=options) # type: ignore
-
- if is_streaming:
- # Instantiate the streaming data source writer.
- writer = data_source.streamWriter(schema, overwrite)
- if not isinstance(writer, (DataSourceStreamWriter,
DataSourceStreamArrowWriter)):
+ # Receive the input schema
+ schema =
_parse_datatype_json_string(utf8_deserializer.loads(infile))
+ if not isinstance(schema, StructType):
raise PySparkAssertionError(
errorClass="DATA_SOURCE_TYPE_MISMATCH",
messageParameters={
- "expected": (
- "an instance of DataSourceStreamWriter or "
- "DataSourceStreamArrowWriter"
- ),
- "actual": f"'{type(writer).__name__}'",
+ "expected": "the schema to be a 'StructType'",
+ "actual": f"'{type(data_source_cls).__name__}'",
},
)
- else:
- # Instantiate the data source writer.
- writer = data_source.writer(schema, overwrite) # type:
ignore[assignment]
- if not isinstance(writer, DataSourceWriter):
+
+ # Receive the return type
+ return_type =
_parse_datatype_json_string(utf8_deserializer.loads(infile))
+ if not isinstance(return_type, StructType):
raise PySparkAssertionError(
errorClass="DATA_SOURCE_TYPE_MISMATCH",
messageParameters={
- "expected": "an instance of DataSourceWriter",
- "actual": f"'{type(writer).__name__}'",
+ "expected": "a return type of type 'StructType'",
+ "actual": f"'{type(return_type).__name__}'",
},
)
+ assert len(return_type) == 1 and
isinstance(return_type[0].dataType, BinaryType), (
+ "The output schema of Python data source write should contain
only one column "
+ f"of type 'BinaryType', but got '{return_type}'"
+ )
+ return_col_name = return_type[0].name
+
+ # Receive the options.
+ options = CaseInsensitiveDict()
+ num_options = read_int(infile)
+ for _ in range(num_options):
+ key = utf8_deserializer.loads(infile)
+ value = utf8_deserializer.loads(infile)
+ options[key] = value
+
+ # Receive the `overwrite` flag.
+ overwrite = read_bool(infile)
+
+ is_streaming = read_bool(infile)
+ binary_as_bytes = read_bool(infile)
+
+ # Instantiate a data source.
+ data_source = data_source_cls(options=options) # type: ignore
+
+ if is_streaming:
+ # Instantiate the streaming data source writer.
+ writer = data_source.streamWriter(schema, overwrite)
+ if not isinstance(writer, (DataSourceStreamWriter,
DataSourceStreamArrowWriter)):
+ raise PySparkAssertionError(
+ errorClass="DATA_SOURCE_TYPE_MISMATCH",
+ messageParameters={
+ "expected": (
+ "an instance of DataSourceStreamWriter or "
+ "DataSourceStreamArrowWriter"
+ ),
+ "actual": f"'{type(writer).__name__}'",
+ },
+ )
+ else:
+ # Instantiate the data source writer.
+
+ writer = data_source.writer(schema, overwrite) # type:
ignore[assignment]
+ if not isinstance(writer, DataSourceWriter):
+ raise PySparkAssertionError(
+ errorClass="DATA_SOURCE_TYPE_MISMATCH",
+ messageParameters={
+ "expected": "an instance of DataSourceWriter",
+ "actual": f"'{type(writer).__name__}'",
+ },
+ )
# Create a function that can be used in mapInArrow.
import pyarrow as pa
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonBatchWriterFactory.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonBatchWriterFactory.scala
index d5412f1bdd38..aaf8f56f0a9f 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonBatchWriterFactory.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonBatchWriterFactory.scala
@@ -32,7 +32,8 @@ case class PythonBatchWriterFactory(
source: UserDefinedPythonDataSource,
pickledWriteFunc: Array[Byte],
inputSchema: StructType,
- jobArtifactUUID: Option[String]) extends DataWriterFactory {
+ jobArtifactUUID: Option[String],
+ sessionUUID: Option[String]) extends DataWriterFactory {
override def createWriter(partitionId: Int, taskId: Long):
DataWriter[InternalRow] = {
new DataWriter[InternalRow] {
@@ -47,7 +48,8 @@ case class PythonBatchWriterFactory(
inputSchema,
UserDefinedPythonDataSource.writeOutputSchema,
metrics,
- jobArtifactUUID)
+ jobArtifactUUID,
+ sessionUUID)
val outputIter = evaluatorFactory.createEvaluator().eval(partitionId,
records.asScala)
outputIter.foreach { row =>
if (commitMessage == null) {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
index 65c71dd4eeb7..50ea7616061c 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonMicroBatchStream.scala
@@ -95,7 +95,7 @@ class PythonMicroBatchStream(
override def createReaderFactory(): PartitionReaderFactory = {
new PythonStreamingPartitionReaderFactory(
- ds.source, readInfo.func, outputSchema, None)
+ ds.source, readInfo.func, outputSchema, None, None)
}
override def commit(end: Offset): Unit = {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonPartitionReaderFactory.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonPartitionReaderFactory.scala
index 44933779c26a..4496aae65ee9 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonPartitionReaderFactory.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonPartitionReaderFactory.scala
@@ -30,7 +30,8 @@ class PythonPartitionReaderFactory(
source: UserDefinedPythonDataSource,
pickledReadFunc: Array[Byte],
outputSchema: StructType,
- jobArtifactUUID: Option[String])
+ jobArtifactUUID: Option[String],
+ sessionUUID: Option[String])
extends PartitionReaderFactory {
override def createReader(partition: InputPartition):
PartitionReader[InternalRow] = {
@@ -45,7 +46,8 @@ class PythonPartitionReaderFactory(
UserDefinedPythonDataSource.readInputSchema,
outputSchema,
metrics,
- jobArtifactUUID)
+ jobArtifactUUID,
+ sessionUUID)
val part = partition.asInstanceOf[PythonInputPartition]
evaluatorFactory.createEvaluator().eval(
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
index 52af33e7aa99..a133c40cde60 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScan.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.datasources.v2.python
import org.apache.spark.JobArtifactSet
+import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.connector.read._
import org.apache.spark.sql.connector.read.streaming.MicroBatchStream
@@ -61,6 +62,12 @@ class PythonBatch(
outputSchema: StructType,
options: CaseInsensitiveStringMap) extends Batch {
private val jobArtifactUUID =
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+ private val sessionUUID = {
+ SparkSession.getActiveSession.collect {
+ case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
+ session.sessionUUID
+ }
+ }
private lazy val infoInPython: PythonDataSourceReadInfo = {
ds.getOrCreateReadInfo(shortName, options, outputSchema, isStreaming =
false)
@@ -72,6 +79,6 @@ class PythonBatch(
override def createReaderFactory(): PartitionReaderFactory = {
val readerFunc = infoInPython.func
new PythonPartitionReaderFactory(
- ds.source, readerFunc, outputSchema, jobArtifactUUID)
+ ds.source, readerFunc, outputSchema, jobArtifactUUID, sessionUUID)
}
}
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala
index 466ecf609093..2d29a0226397 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingPartitionReaderFactory.scala
@@ -41,7 +41,8 @@ class PythonStreamingPartitionReaderFactory(
source: UserDefinedPythonDataSource,
pickledReadFunc: Array[Byte],
outputSchema: StructType,
- jobArtifactUUID: Option[String])
+ jobArtifactUUID: Option[String],
+ sessionUUID: Option[String])
extends PartitionReaderFactory with Logging {
override def createReader(partition: InputPartition):
PartitionReader[InternalRow] = {
@@ -70,7 +71,8 @@ class PythonStreamingPartitionReaderFactory(
UserDefinedPythonDataSource.readInputSchema,
outputSchema,
metrics,
- jobArtifactUUID)
+ jobArtifactUUID,
+ sessionUUID)
evaluatorFactory.createEvaluator().eval(
part.index, Iterator.single(InternalRow(part.pickedPartition)))
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingWrite.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingWrite.scala
index 4c149437a300..be7a0429e512 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingWrite.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonStreamingWrite.scala
@@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.v2.python
import org.apache.spark.JobArtifactSet
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.connector.write._
import
org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory,
StreamingWrite}
import org.apache.spark.sql.types.StructType
@@ -36,6 +37,12 @@ class PythonStreamingWrite(
private var pythonDataSourceWriter: Array[Byte] = _
private[this] val jobArtifactUUID =
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+ private val sessionUUID = {
+ SparkSession.getActiveSession.collect {
+ case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
+ session.sessionUUID
+ }
+ }
private def createDataSourceFunc =
ds.source.createPythonFunction(
@@ -53,7 +60,8 @@ class PythonStreamingWrite(
pythonDataSourceWriter = writeInfo.writer
- new PythonStreamingWriterFactory(ds.source, writeInfo.func, info.schema(),
jobArtifactUUID)
+ new PythonStreamingWriterFactory(ds.source, writeInfo.func, info.schema(),
+ jobArtifactUUID, sessionUUID)
}
override def commit(epochId: Long, messages: Array[WriterCommitMessage]):
Unit = {
@@ -81,8 +89,10 @@ class PythonStreamingWriterFactory(
source: UserDefinedPythonDataSource,
pickledWriteFunc: Array[Byte],
inputSchema: StructType,
- jobArtifactUUID: Option[String])
- extends PythonBatchWriterFactory(source, pickledWriteFunc, inputSchema,
jobArtifactUUID)
+ jobArtifactUUID: Option[String],
+ sessionUUID: Option[String])
+ extends PythonBatchWriterFactory(source, pickledWriteFunc, inputSchema,
+ jobArtifactUUID, sessionUUID)
with StreamingDataWriterFactory {
override def createWriter(
partitionId: Int,
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala
index 447221715264..156dcb242320 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonWrite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.datasources.v2.python
import org.apache.spark.JobArtifactSet
+import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.connector.write.{BatchWrite, _}
import org.apache.spark.sql.connector.write.streaming.StreamingWrite
@@ -56,6 +57,12 @@ class PythonBatchWrite(
private var pythonDataSourceWriter: Array[Byte] = _
private[this] val jobArtifactUUID =
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+ private val sessionUUID = {
+ SparkSession.getActiveSession.collect {
+ case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
+ session.sessionUUID
+ }
+ }
override def createBatchWriterFactory(physicalInfo: PhysicalWriteInfo):
DataWriterFactory =
{
@@ -68,7 +75,8 @@ class PythonBatchWrite(
pythonDataSourceWriter = writeInfo.writer
- PythonBatchWriterFactory(ds.source, writeInfo.func, info.schema(),
jobArtifactUUID)
+ PythonBatchWriterFactory(ds.source, writeInfo.func, info.schema(),
+ jobArtifactUUID, sessionUUID)
}
override def commit(messages: Array[WriterCommitMessage]): Unit = {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
index 63e7e32c1c7b..c147030037cd 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/UserDefinedPythonDataSource.scala
@@ -143,7 +143,8 @@ case class UserDefinedPythonDataSource(dataSourceCls:
PythonFunction) {
inputSchema: StructType,
outputSchema: StructType,
metrics: Map[String, SQLMetric],
- jobArtifactUUID: Option[String]): MapInBatchEvaluatorFactory = {
+ jobArtifactUUID: Option[String],
+ sessionUUID: Option[String]): MapInBatchEvaluatorFactory = {
val pythonFunc = createPythonFunction(pickledFunc)
val pythonEvalType = PythonEvalType.SQL_MAP_ARROW_ITER_UDF
@@ -171,7 +172,7 @@ case class UserDefinedPythonDataSource(dataSourceCls:
PythonFunction) {
pythonRunnerConf,
metrics,
jobArtifactUUID,
- None) // TODO: Python worker logging
+ sessionUUID)
}
def createPythonMetrics(): Array[CustomMetric] = {
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
index 9ff934a07a16..0f4ac4ddad71 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonPlannerRunner.scala
@@ -32,6 +32,7 @@ import org.apache.spark.api.python.{BasePythonRunner,
PythonFunction, PythonWork
import org.apache.spark.internal.{Logging, LogKeys}
import org.apache.spark.internal.config.BUFFER_SIZE
import org.apache.spark.internal.config.Python._
+import org.apache.spark.sql.classic.SparkSession
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.DirectByteBufferOutputStream
@@ -62,6 +63,12 @@ abstract class PythonPlannerRunner[T](func: PythonFunction)
extends Logging {
val workerMemoryMb = SQLConf.get.pythonPlannerExecMemory
val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+ val sessionUUID = {
+ SparkSession.getActiveSession.collect {
+ case session if session.sessionState.conf.pythonWorkerLoggingEnabled =>
+ session.sessionUUID
+ }
+ }
val envVars = new HashMap[String, String](func.envVars)
val pythonExec = func.pythonExec
@@ -93,6 +100,9 @@ abstract class PythonPlannerRunner[T](func: PythonFunction)
extends Logging {
}
envVars.put("SPARK_JOB_ARTIFACT_UUID",
jobArtifactUUID.getOrElse("default"))
+ sessionUUID.foreach { uuid =>
+ envVars.put("PYSPARK_SPARK_SESSION_UUID", uuid)
+ }
EvaluatePython.registerPicklers()
val pickler = new Pickler(/* useMemo = */ true,
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]