This is an automated email from the ASF dual-hosted git repository.
ueshin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 36983443112 [SPARK-45810][PYTHON] Create Python UDTF API to stop
consuming rows from the input table
36983443112 is described below
commit 36983443112799dc2ee4462828e7c0552a63a229
Author: Daniel Tenedorio <[email protected]>
AuthorDate: Wed Nov 15 13:47:04 2023 -0800
[SPARK-45810][PYTHON] Create Python UDTF API to stop consuming rows from
the input table
### What changes were proposed in this pull request?
This PR creates a Python UDTF API to stop consuming rows from the input
table.
If the UDTF raises a `SkipRestOfInputTableException` exception in the
`eval` method, then the UDTF stops consuming rows from the input table for that
input partition, and finally calls the `terminate` method (if any) to represent
a successful UDTF call.
For example:
```
udtf(returnType="total: int")
class TestUDTF:
def __init__(self):
self._total = 0
def eval(self, _: Row):
self._total += 1
if self._total >= 3:
raise SkipRestOfInputTableException("Stop at self._total >= 3")
def terminate(self):
yield self._total,
```
### Why are the changes needed?
This is useful when the UDTF logic knows that we don't have to scan the
input table anymore, and skip the rest of the I/O for that case.
### Does this PR introduce _any_ user-facing change?
Yes, see above.
### How was this patch tested?
This PR adds test coverage.
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #43682 from dtenedor/udtf-api-stop-consuming-input-rows.
Authored-by: Daniel Tenedorio <[email protected]>
Signed-off-by: Takuya UESHIN <[email protected]>
---
python/docs/source/user_guide/sql/python_udtf.rst | 38 ++++++++++++-----
python/pyspark/sql/functions.py | 1 +
python/pyspark/sql/tests/test_udtf.py | 51 +++++++++++++++++++++++
python/pyspark/sql/udtf.py | 19 ++++++++-
python/pyspark/worker.py | 30 ++++++++++---
5 files changed, 123 insertions(+), 16 deletions(-)
diff --git a/python/docs/source/user_guide/sql/python_udtf.rst
b/python/docs/source/user_guide/sql/python_udtf.rst
index 0e0c6e28578..3e3c7634438 100644
--- a/python/docs/source/user_guide/sql/python_udtf.rst
+++ b/python/docs/source/user_guide/sql/python_udtf.rst
@@ -65,8 +65,8 @@ To implement a Python UDTF, you first need to define a class
implementing the me
def analyze(self, *args: Any) -> AnalyzeResult:
"""
- Computes the output schema of a particular call to this function
in response to the
- arguments provided.
+ Static method to compute the output schema of a particular call to
this function in
+ response to the arguments provided.
This method is optional and only needed if the registration of the
UDTF did not provide
a static output schema to be use for all calls to the function. In
this context,
@@ -101,12 +101,20 @@ To implement a Python UDTF, you first need to define a
class implementing the me
partitionBy: Sequence[PartitioningColumn] =
field(default_factory=tuple)
orderBy: Sequence[OrderingColumn] =
field(default_factory=tuple)
+ Notes
+ -----
+ - It is possible for the `analyze` method to accept the exact
arguments expected,
+ mapping 1:1 with the arguments provided to the UDTF call.
+ - The `analyze` method can instead choose to accept positional
arguments if desired
+ (using `*args`) or keyword arguments (using `**kwargs`).
+
Examples
--------
- analyze implementation that returns one output column for each
word in the input string
- argument.
+ This is an `analyze` implementation that returns one output column
for each word in the
+ input string argument.
- >>> def analyze(self, text: str) -> AnalyzeResult:
+ >>> @staticmethod
+ ... def analyze(text: str) -> AnalyzeResult:
... schema = StructType()
... for index, word in enumerate(text.split(" ")):
... schema = schema.add(f"word_{index}")
@@ -114,7 +122,8 @@ To implement a Python UDTF, you first need to define a
class implementing the me
Same as above, but using *args to accept the arguments.
- >>> def analyze(self, *args) -> AnalyzeResult:
+ >>> @staticmethod
+ ... def analyze(*args) -> AnalyzeResult:
... assert len(args) == 1, "This function accepts one argument
only"
... assert args[0].dataType == StringType(), "Only string
arguments are supported"
... text = args[0]
@@ -125,7 +134,8 @@ To implement a Python UDTF, you first need to define a
class implementing the me
Same as above, but using **kwargs to accept the arguments.
- >>> def analyze(self, **kwargs) -> AnalyzeResult:
+ >>> @staticmethod
+ ... def analyze(**kwargs) -> AnalyzeResult:
... assert len(kwargs) == 1, "This function accepts one
argument only"
... assert "text" in kwargs, "An argument named 'text' is
required"
... assert kwargs["text"].dataType == StringType(), "Only
strings are supported"
@@ -135,10 +145,11 @@ To implement a Python UDTF, you first need to define a
class implementing the me
... schema = schema.add(f"word_{index}")
... return AnalyzeResult(schema=schema)
- analyze implementation that returns a constant output schema, but
add custom information
- in the result metadata to be consumed by future __init__ method
calls:
+ An `analyze` implementation that returns a constant output schema,
but add custom
+ information in the result metadata to be consumed by future
__init__ method calls:
- >>> def analyze(self, text: str) -> AnalyzeResult:
+ >>> @staticmethod
+ ... def analyze(text: str) -> AnalyzeResult:
... @dataclass
... class AnalyzeResultWithOtherMetadata(AnalyzeResult):
... num_words: int
@@ -190,6 +201,13 @@ To implement a Python UDTF, you first need to define a
class implementing the me
- It is also possible for UDTFs to accept the exact arguments
expected, along with
their types.
- UDTFs can instead accept keyword arguments during the function
call if needed.
+ - The `eval` method can raise a `SkipRestOfInputTableException` to
indicate that the
+ UDTF wants to skip consuming all remaining rows from the current
partition of the
+ input table. This will cause the UDTF to proceed directly to the
`terminate` method.
+ - The `eval` method can raise any other exception to indicate that
the UDTF should be
+ aborted entirely. This will cause the UDTF to skip the
`terminate` method and proceed
+ directly to the `cleanup` method, and then the exception will be
propagated to the
+ query processor causing the invoking query to fail.
Examples
--------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index ae0f1e70be6..e3b8e4965e4 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -51,6 +51,7 @@ from pyspark.sql.types import ArrayType, DataType,
StringType, StructType, _from
from pyspark.sql.udf import UserDefinedFunction, _create_py_udf # noqa: F401
from pyspark.sql.udtf import AnalyzeArgument, AnalyzeResult # noqa: F401
from pyspark.sql.udtf import OrderingColumn, PartitioningColumn # noqa: F401
+from pyspark.sql.udtf import SkipRestOfInputTableException # noqa: F401
from pyspark.sql.udtf import UserDefinedTableFunction, _create_py_udtf
# Keep pandas_udf and PandasUDFType import for backwards compatible import;
moved in SPARK-28264
diff --git a/python/pyspark/sql/tests/test_udtf.py
b/python/pyspark/sql/tests/test_udtf.py
index 3beb916de66..2794b51eb70 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -44,6 +44,7 @@ from pyspark.sql.functions import (
AnalyzeResult,
OrderingColumn,
PartitioningColumn,
+ SkipRestOfInputTableException,
)
from pyspark.sql.types import (
ArrayType,
@@ -2467,6 +2468,56 @@ class BaseUDTFTestsMixin:
[Row(count=20, buffer="abc")],
)
+ def test_udtf_with_skip_rest_of_input_table_exception(self):
+ @udtf(returnType="current: int, total: int")
+ class TestUDTF:
+ def __init__(self):
+ self._current = 0
+ self._total = 0
+
+ def eval(self, input: Row):
+ self._current = input["id"]
+ self._total += 1
+ if self._total >= 4:
+ raise SkipRestOfInputTableException("Stop at self._total
>= 4")
+
+ def terminate(self):
+ yield self._current, self._total
+
+ self.spark.udtf.register("test_udtf", TestUDTF)
+
+ # Run a test case including WITH SINGLE PARTITION on the UDTF call. The
+ # SkipRestOfInputTableException stops scanning rows after the fourth
input row is consumed.
+ assertDataFrameEqual(
+ self.spark.sql(
+ """
+ WITH t AS (
+ SELECT id FROM range(1, 21)
+ )
+ SELECT current, total
+ FROM test_udtf(TABLE(t) WITH SINGLE PARTITION ORDER BY id)
+ """
+ ),
+ [Row(current=4, total=4)],
+ )
+
+ # Run a test case including WITH SINGLE PARTITION on the UDTF call. The
+ # SkipRestOfInputTableException stops scanning rows for each of the
two partitions
+ # separately.
+ assertDataFrameEqual(
+ self.spark.sql(
+ """
+ WITH t AS (
+ SELECT id FROM range(1, 21)
+ )
+ SELECT current, total
+ FROM test_udtf(TABLE(t) PARTITION BY floor(id / 10) ORDER BY
id)
+ ORDER BY ALL
+ """
+ ),
+ [Row(current=4, total=4), Row(current=13, total=4),
Row(current=20, total=1)],
+ )
+
class UDTFTests(BaseUDTFTestsMixin, ReusedSQLTestCase):
@classmethod
diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py
index aac212ffde9..ab330141514 100644
--- a/python/pyspark/sql/udtf.py
+++ b/python/pyspark/sql/udtf.py
@@ -38,7 +38,14 @@ if TYPE_CHECKING:
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.session import SparkSession
-__all__ = ["AnalyzeArgument", "AnalyzeResult", "UDTFRegistration"]
+__all__ = [
+ "AnalyzeArgument",
+ "AnalyzeResult",
+ "PartitioningColumn",
+ "OrderingColumn",
+ "SkipRestOfInputTableException",
+ "UDTFRegistration",
+]
@dataclass(frozen=True)
@@ -118,6 +125,16 @@ class AnalyzeResult:
orderBy: Sequence[OrderingColumn] = field(default_factory=tuple)
+class SkipRestOfInputTableException(Exception):
+ """
+ This represents an exception that the 'eval' method may raise to indicate
that it is done
+ consuming rows from the current partition of the input table. Then the
UDTF's 'terminate'
+ method runs (if any).
+ """
+
+ pass
+
+
def _create_udtf(
cls: Type,
returnType: Optional[Union[StructType, str]],
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index f6208032d9a..195c989c410 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -43,6 +43,7 @@ from pyspark.serializers import (
CPickleSerializer,
BatchedSerializer,
)
+from pyspark.sql.functions import SkipRestOfInputTableException
from pyspark.sql.pandas.serializers import (
ArrowStreamPandasUDFSerializer,
ArrowStreamPandasUDTFSerializer,
@@ -763,6 +764,7 @@ def read_udtf(pickleSer, infile, eval_type):
self._udtf = create_udtf()
self._prev_arguments: list = list()
self._partition_child_indexes: list = partition_child_indexes
+ self._eval_raised_skip_rest_of_input_table: bool = False
def eval(self, *args, **kwargs) -> Iterator:
changed_partitions = self._check_partition_boundaries(
@@ -775,16 +777,24 @@ def read_udtf(pickleSer, infile, eval_type):
for row in result:
yield row
self._udtf = self._create_udtf()
- if self._udtf.eval is not None:
+ self._eval_raised_skip_rest_of_input_table = False
+ if self._udtf.eval is not None and not
self._eval_raised_skip_rest_of_input_table:
# Filter the arguments to exclude projected PARTITION BY
values added by Catalyst.
filtered_args = [self._remove_partition_by_exprs(arg) for arg
in args]
filtered_kwargs = {
key: self._remove_partition_by_exprs(value) for (key,
value) in kwargs.items()
}
- result = self._udtf.eval(*filtered_args, **filtered_kwargs)
- if result is not None:
- for row in result:
- yield row
+ try:
+ result = self._udtf.eval(*filtered_args, **filtered_kwargs)
+ if result is not None:
+ for row in result:
+ yield row
+ except SkipRestOfInputTableException:
+ # If the 'eval' method raised this exception, then we
should skip the rest of
+ # the rows in the current partition. Set this field to
True here and then for
+ # each subsequent row in the partition, we will skip
calling the 'eval' method
+ # until we see a change in the partition boundaries.
+ self._eval_raised_skip_rest_of_input_table = True
def terminate(self) -> Iterator:
if self._udtf.terminate is not None:
@@ -995,6 +1005,8 @@ def read_udtf(pickleSer, infile, eval_type):
def func(*a: Any) -> Any:
try:
return f(*a)
+ except SkipRestOfInputTableException:
+ raise
except Exception as e:
raise PySparkRuntimeError(
error_class="UDTF_EXEC_ERROR",
@@ -1057,6 +1069,9 @@ def read_udtf(pickleSer, infile, eval_type):
yield from eval(*[a[o] for o in args_kwargs_offsets])
if terminate is not None:
yield from terminate()
+ except SkipRestOfInputTableException:
+ if terminate is not None:
+ yield from terminate()
finally:
if cleanup is not None:
cleanup()
@@ -1098,6 +1113,8 @@ def read_udtf(pickleSer, infile, eval_type):
def evaluate(*a) -> tuple:
try:
res = f(*a)
+ except SkipRestOfInputTableException:
+ raise
except Exception as e:
raise PySparkRuntimeError(
error_class="UDTF_EXEC_ERROR",
@@ -1144,6 +1161,9 @@ def read_udtf(pickleSer, infile, eval_type):
yield eval(*[a[o] for o in args_kwargs_offsets])
if terminate is not None:
yield terminate()
+ except SkipRestOfInputTableException:
+ if terminate is not None:
+ yield terminate()
finally:
if cleanup is not None:
cleanup()
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]