This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e4114f67e12 [SPARK-45048][CONNECT] Add additional tests for Python
client and attachable execution
e4114f67e12 is described below
commit e4114f67e12a235b4784fcbfa6f6ba9b44a5e715
Author: Martin Grund <[email protected]>
AuthorDate: Fri Sep 1 22:15:23 2023 +0800
[SPARK-45048][CONNECT] Add additional tests for Python client and
attachable execution
### What changes were proposed in this pull request?
For better test coverage add additional tests of the attachable Spark
Connect Python client.
### Why are the changes needed?
Stability
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New test
### Was this patch authored or co-authored using generative AI tooling?
No
Closes #42769 from grundprinzip/SPARK-45048.
Authored-by: Martin Grund <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../sql/tests/connect/client/test_client.py | 156 ++++++++++++++++++++-
1 file changed, 154 insertions(+), 2 deletions(-)
diff --git a/python/pyspark/sql/tests/connect/client/test_client.py
b/python/pyspark/sql/tests/connect/client/test_client.py
index 2ba42cabf84..70280c1d24a 100644
--- a/python/pyspark/sql/tests/connect/client/test_client.py
+++ b/python/pyspark/sql/tests/connect/client/test_client.py
@@ -17,16 +17,21 @@
import unittest
import uuid
-from typing import Optional
+from collections.abc import Generator
+from typing import Optional, Any
from pyspark.testing.connectutils import should_test_connect,
connect_requirement_message
if should_test_connect:
+ import grpc
import pandas as pd
import pyarrow as pa
from pyspark.sql.connect.client import SparkConnectClient, ChannelBuilder
from pyspark.sql.connect.client.core import Retrying
- from pyspark.sql.connect.client.reattach import RetryException
+ from pyspark.sql.connect.client.reattach import (
+ RetryException,
+ ExecutePlanResponseReattachableIterator,
+ )
import pyspark.sql.connect.proto as proto
@@ -119,6 +124,153 @@ class SparkConnectClientTestCase(unittest.TestCase):
self.assertEqual(client._session_id, chan.session_id)
[email protected](not should_test_connect, connect_requirement_message)
+class SparkConnectClientReattachTestCase(unittest.TestCase):
+ def setUp(self) -> None:
+ self.request = proto.ExecutePlanRequest()
+ self.policy = {
+ "max_retries": 3,
+ "backoff_multiplier": 4.0,
+ "initial_backoff": 10,
+ "max_backoff": 10,
+ "jitter": 10,
+ "min_jitter_threshold": 10,
+ }
+ self.response = proto.ExecutePlanResponse()
+ self.finished = proto.ExecutePlanResponse(
+ result_complete=proto.ExecutePlanResponse.ResultComplete()
+ )
+
+ def _stub_with(self, execute=None, attach=None):
+ return MockSparkConnectStub(
+ execute_ops=ResponseGenerator(execute) if execute is not None else
None,
+ attach_ops=ResponseGenerator(attach) if attach is not None else
None,
+ )
+
+ def test_basic_flow(self):
+ stub = self._stub_with([self.response, self.finished])
+ ite = ExecutePlanResponseReattachableIterator(self.request, stub,
self.policy, [])
+ for b in ite:
+ pass
+
+ self.assertEqual(0, stub.attach_calls)
+ self.assertGreater(1, stub.release_calls)
+ self.assertEqual(1, stub.execute_calls)
+
+ def test_fail_during_execute(self):
+ def fatal():
+ raise TestException("Fatal")
+
+ stub = self._stub_with([self.response, fatal])
+ with self.assertRaises(TestException):
+ ite = ExecutePlanResponseReattachableIterator(self.request, stub,
self.policy, [])
+ for b in ite:
+ pass
+
+ self.assertEqual(0, stub.attach_calls)
+ self.assertEqual(0, stub.release_calls)
+ self.assertEqual(1, stub.execute_calls)
+
+ def test_fail_and_retry_during_execute(self):
+ def non_fatal():
+ raise TestException("Non Fatal", grpc.StatusCode.UNAVAILABLE)
+
+ stub = self._stub_with(
+ [self.response, non_fatal], [self.response, self.response,
self.finished]
+ )
+ ite = ExecutePlanResponseReattachableIterator(self.request, stub,
self.policy, [])
+ for b in ite:
+ pass
+
+ self.assertEqual(1, stub.attach_calls)
+ self.assertEqual(1, stub.release_calls)
+ self.assertEqual(1, stub.execute_calls)
+
+ def test_fail_and_retry_during_reattach(self):
+ count = 0
+
+ def non_fatal():
+ nonlocal count
+ if count < 2:
+ count += 1
+ raise TestException("Non Fatal", grpc.StatusCode.UNAVAILABLE)
+ else:
+ return proto.ExecutePlanResponse()
+
+ stub = self._stub_with(
+ [self.response, non_fatal], [self.response, non_fatal,
self.response, self.finished]
+ )
+ ite = ExecutePlanResponseReattachableIterator(self.request, stub,
self.policy, [])
+ for b in ite:
+ pass
+
+ self.assertEqual(2, stub.attach_calls)
+ self.assertEqual(2, stub.release_calls)
+ self.assertEqual(1, stub.execute_calls)
+
+
+class TestException(grpc.RpcError, grpc.Call):
+ """Exception mock to test retryable exceptions."""
+
+ def __init__(self, msg, code=grpc.StatusCode.INTERNAL):
+ self.msg = msg
+ self._code = code
+
+ def code(self):
+ return self._code
+
+ def __str__(self):
+ return self.msg
+
+ def trailing_metadata(self):
+ return ()
+
+
+class ResponseGenerator(Generator):
+ """This class is used to generate values that are returned by the streaming
+ iterator of the GRPC stub."""
+
+ def __init__(self, funs):
+ self._funs = funs
+ self._iterator = iter(self._funs)
+
+ def send(self, value: Any) -> proto.ExecutePlanResponse:
+ val = next(self._iterator)
+ if callable(val):
+ return val()
+ else:
+ return val
+
+ def throw(self, type: Any = None, value: Any = None, traceback: Any =
None) -> Any:
+ super().throw(type, value, traceback)
+
+ def close(self) -> None:
+ return super().close()
+
+
+class MockSparkConnectStub:
+ """Simple mock class for the GRPC stub used by the re-attachable
execution."""
+
+ def __init__(self, execute_ops=None, attach_ops=None):
+ self._execute_ops = execute_ops
+ self._attach_ops = attach_ops
+ # Call counters
+ self.execute_calls = 0
+ self.release_calls = 0
+ self.attach_calls = 0
+
+ def ExecutePlan(self, *args, **kwargs):
+ self.execute_calls += 1
+ return self._execute_ops
+
+ def ReattachExecute(self, *args, **kwargs):
+ self.attach_calls += 1
+ return self._attach_ops
+
+ def ReleaseExecute(self, *args, **kwargs):
+ self.release_calls += 1
+
+
class MockService:
# Simplest mock of the SparkConnectService.
# If this needs more complex logic, it needs to be replaced with Python
mocking.
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]