potiuk commented on code in PR #54792:
URL: https://github.com/apache/airflow/pull/54792#discussion_r2291798799
##########
task-sdk/tests/task_sdk/execution_time/test_supervisor.py:
##########
@@ -1276,773 +1278,789 @@ def test_max_wait_time_calculation_edge_cases(
assert actual_timeout >= expected_min_timeout
-class TestHandleRequest:
- @pytest.fixture
- def watched_subprocess(self, mocker):
- read_end, write_end = socket.socketpair()
+@dataclass
+class ClientMock:
+ """Configuration for mocking client method calls."""
- subprocess = ActivitySubprocess(
- process_log=mocker.MagicMock(),
- id=TI_ID,
- pid=12345,
- stdin=write_end,
- client=mocker.Mock(),
- process=mocker.Mock(),
- )
+ method_path: str
+ """Path to the client method to mock (e.g., 'connections.get',
'variables.set')."""
- return subprocess, read_end
+ args: tuple = field(default_factory=tuple)
+ """Positional arguments the client method should be called with."""
- @patch("airflow.sdk.execution_time.supervisor.mask_secret")
- @pytest.mark.parametrize(
- [
- "message",
- "expected_body",
- "client_attr_path",
- "method_arg",
- "method_kwarg",
- "mock_response",
- "mask_secret_args",
- ],
- [
- pytest.param(
- GetConnection(conn_id="test_conn"),
- {"conn_id": "test_conn", "conn_type": "mysql", "type":
"ConnectionResult"},
- "connections.get",
- ("test_conn",),
- {},
- ConnectionResult(conn_id="test_conn", conn_type="mysql"),
- None,
- id="get_connection",
- ),
- pytest.param(
- GetConnection(conn_id="test_conn"),
- {
- "conn_id": "test_conn",
- "conn_type": "mysql",
- "password": "password",
- "type": "ConnectionResult",
- },
- "connections.get",
- ("test_conn",),
- {},
- ConnectionResult(conn_id="test_conn", conn_type="mysql",
password="password"),
- ["password"],
- id="get_connection_with_password",
- ),
- pytest.param(
- GetConnection(conn_id="test_conn"),
- {"conn_id": "test_conn", "conn_type": "mysql", "schema":
"mysql", "type": "ConnectionResult"},
- "connections.get",
- ("test_conn",),
- {},
- ConnectionResult(conn_id="test_conn", conn_type="mysql",
schema="mysql"), # type: ignore[call-arg]
- None,
- id="get_connection_with_alias",
- ),
- pytest.param(
- GetVariable(key="test_key"),
- {"key": "test_key", "value": "test_value", "type":
"VariableResult"},
- "variables.get",
- ("test_key",),
- {},
- VariableResult(key="test_key", value="test_value"),
- ["test_value", "test_key"],
- id="get_variable",
- ),
- pytest.param(
- PutVariable(key="test_key", value="test_value",
description="test_description"),
- None,
- "variables.set",
- ("test_key", "test_value", "test_description"),
- {},
- OKResponse(ok=True),
- None,
- id="set_variable",
- ),
- pytest.param(
- DeleteVariable(key="test_key"),
- {"ok": True, "type": "OKResponse"},
- "variables.delete",
- ("test_key",),
- {},
- OKResponse(ok=True),
- None,
- id="delete_variable",
- ),
- pytest.param(
- DeferTask(next_method="execute_callback",
classpath="my-classpath"),
- None,
- "task_instances.defer",
- (TI_ID, DeferTask(next_method="execute_callback",
classpath="my-classpath")),
- {},
- "",
- None,
- id="patch_task_instance_to_deferred",
- ),
- pytest.param(
+ kwargs: dict = field(default_factory=dict)
+ """Keyword arguments the client method should be called with."""
+
+ response: Any = None
+ """What the mocked client method should return when called."""
+
+
+@dataclass
+class RequestTestCase:
+ """Test case data for request handling tests in `TestHandleRequest`
class."""
+
+ message: Any
+ """The request message to send to the supervisor (e.g., GetConnection,
SetXCom)."""
+
+ test_id: str
+ """Unique identifier for this test case, used in pytest
parameterization."""
+
+ client_mock: ClientMock | None = None
+ """Client method mocking configuration. None for messages that don't
require client calls."""
+
+ expected_body: dict | None = None
+ """Expected response body from supervisor. None if no response body
expected."""
+
+ mask_secret_args: tuple[str] | None = None
+ """Arguments that should be passed to the secret masker for redaction."""
+
+
+# Test cases for request handling
+REQUEST_TEST_CASES = [
+ RequestTestCase(
+ message=GetConnection(conn_id="test_conn"),
+ test_id="get_connection",
+ client_mock=ClientMock(
+ method_path="connections.get",
+ args=("test_conn",),
+ response=ConnectionResult(conn_id="test_conn", conn_type="mysql"),
+ ),
+ expected_body={"conn_id": "test_conn", "conn_type": "mysql", "type":
"ConnectionResult"},
+ ),
+ RequestTestCase(
+ message=GetConnection(conn_id="test_conn"),
+ test_id="get_connection_with_password",
+ client_mock=ClientMock(
+ method_path="connections.get",
+ args=("test_conn",),
+ response=ConnectionResult(conn_id="test_conn", conn_type="mysql",
password="password"),
+ ),
+ expected_body={
+ "conn_id": "test_conn",
+ "conn_type": "mysql",
+ "password": "password",
+ "type": "ConnectionResult",
+ },
+ mask_secret_args=("password",),
+ ),
+ RequestTestCase(
+ message=GetConnection(conn_id="test_conn"),
+ test_id="get_connection_with_alias",
+ client_mock=ClientMock(
+ method_path="connections.get",
+ args=("test_conn",),
+ response=ConnectionResult(conn_id="test_conn", conn_type="mysql",
schema="mysql"), # type: ignore[call-arg]
+ ),
+ expected_body={
+ "conn_id": "test_conn",
+ "conn_type": "mysql",
+ "schema": "mysql",
+ "type": "ConnectionResult",
+ },
+ ),
+ RequestTestCase(
+ message=GetVariable(key="test_key"),
+ test_id="get_variable",
+ client_mock=ClientMock(
+ method_path="variables.get",
+ args=("test_key",),
+ response=VariableResult(key="test_key", value="test_value"),
+ ),
+ expected_body={"key": "test_key", "value": "test_value", "type":
"VariableResult"},
+ mask_secret_args=("test_value", "test_key"),
+ ),
+ RequestTestCase(
+ message=PutVariable(key="test_key", value="test_value",
description="test_description"),
+ test_id="set_variable",
+ client_mock=ClientMock(
+ method_path="variables.set",
+ args=("test_key", "test_value", "test_description"),
+ response=OKResponse(ok=True),
+ ),
+ ),
+ RequestTestCase(
+ message=DeleteVariable(key="test_key"),
+ test_id="delete_variable",
+ client_mock=ClientMock(
+ method_path="variables.delete",
+ args=("test_key",),
+ response=OKResponse(ok=True),
+ ),
+ expected_body={"ok": True, "type": "OKResponse"},
+ ),
+ RequestTestCase(
+ message=DeferTask(next_method="execute_callback",
classpath="my-classpath"),
+ test_id="patch_task_instance_to_deferred",
+ client_mock=ClientMock(
+ method_path="task_instances.defer",
+ args=(TI_ID, DeferTask(next_method="execute_callback",
classpath="my-classpath")),
+ ),
+ ),
+ RequestTestCase(
+ message=RescheduleTask(
+ reschedule_date=timezone.parse("2024-10-31T12:00:00Z"),
+ end_date=timezone.parse("2024-10-31T12:00:00Z"),
+ ),
+ test_id="patch_task_instance_to_up_for_reschedule",
+ client_mock=ClientMock(
+ method_path="task_instances.reschedule",
+ args=(
+ TI_ID,
RescheduleTask(
reschedule_date=timezone.parse("2024-10-31T12:00:00Z"),
end_date=timezone.parse("2024-10-31T12:00:00Z"),
),
- None,
- "task_instances.reschedule",
- (
- TI_ID,
- RescheduleTask(
- reschedule_date=timezone.parse("2024-10-31T12:00:00Z"),
- end_date=timezone.parse("2024-10-31T12:00:00Z"),
- ),
- ),
- {},
- "",
- None,
- id="patch_task_instance_to_up_for_reschedule",
- ),
- pytest.param(
- GetXCom(dag_id="test_dag", run_id="test_run",
task_id="test_task", key="test_key"),
- {"key": "test_key", "value": "test_value", "type":
"XComResult"},
- "xcoms.get",
- ("test_dag", "test_run", "test_task", "test_key", None, False),
- {},
- XComResult(key="test_key", value="test_value"),
- None,
- id="get_xcom",
- ),
- pytest.param(
- GetXCom(
- dag_id="test_dag", run_id="test_run", task_id="test_task",
key="test_key", map_index=2
- ),
- {"key": "test_key", "value": "test_value", "type":
"XComResult"},
- "xcoms.get",
- ("test_dag", "test_run", "test_task", "test_key", 2, False),
- {},
- XComResult(key="test_key", value="test_value"),
- None,
- id="get_xcom_map_index",
- ),
- pytest.param(
- GetXCom(dag_id="test_dag", run_id="test_run",
task_id="test_task", key="test_key"),
- {"key": "test_key", "value": None, "type": "XComResult"},
- "xcoms.get",
- ("test_dag", "test_run", "test_task", "test_key", None, False),
- {},
- XComResult(key="test_key", value=None, type="XComResult"),
- None,
- id="get_xcom_not_found",
- ),
- pytest.param(
- GetXCom(
- dag_id="test_dag",
- run_id="test_run",
- task_id="test_task",
- key="test_key",
- include_prior_dates=True,
- ),
- {"key": "test_key", "value": None, "type": "XComResult"},
- "xcoms.get",
- ("test_dag", "test_run", "test_task", "test_key", None, True),
- {},
- XComResult(key="test_key", value=None, type="XComResult"),
- None,
- id="get_xcom_include_prior_dates",
- ),
- pytest.param(
- SetXCom(
- dag_id="test_dag",
- run_id="test_run",
- task_id="test_task",
- key="test_key",
- value='{"key": "test_key", "value": {"key2": "value2"}}',
- ),
- None,
- "xcoms.set",
- (
- "test_dag",
- "test_run",
- "test_task",
- "test_key",
- '{"key": "test_key", "value": {"key2": "value2"}}',
- None,
- None,
- ),
- {},
- OKResponse(ok=True),
- None,
- id="set_xcom",
- ),
- pytest.param(
- SetXCom(
- dag_id="test_dag",
- run_id="test_run",
- task_id="test_task",
- key="test_key",
- value='{"key": "test_key", "value": {"key2": "value2"}}',
- map_index=2,
- ),
- None,
- "xcoms.set",
- (
- "test_dag",
- "test_run",
- "test_task",
- "test_key",
- '{"key": "test_key", "value": {"key2": "value2"}}',
- 2,
- None,
- ),
- {},
- OKResponse(ok=True),
- None,
- id="set_xcom_with_map_index",
- ),
- pytest.param(
- SetXCom(
- dag_id="test_dag",
- run_id="test_run",
- task_id="test_task",
- key="test_key",
- value='{"key": "test_key", "value": {"key2": "value2"}}',
- map_index=2,
- mapped_length=3,
- ),
- None,
- "xcoms.set",
- (
- "test_dag",
- "test_run",
- "test_task",
- "test_key",
- '{"key": "test_key", "value": {"key2": "value2"}}',
- 2,
- 3,
- ),
- {},
- OKResponse(ok=True),
- None,
- id="set_xcom_with_map_index_and_mapped_length",
- ),
- pytest.param(
- DeleteXCom(
- dag_id="test_dag",
- run_id="test_run",
- task_id="test_task",
- key="test_key",
- map_index=2,
- ),
- None,
- "xcoms.delete",
- ("test_dag", "test_run", "test_task", "test_key", 2),
- {},
- OKResponse(ok=True),
- None,
- id="delete_xcom",
- ),
- # we aren't adding all states under TaskInstanceState here,
because this test's scope is only to check
- # if it can handle TaskState message
- pytest.param(
- TaskState(state=TaskInstanceState.SKIPPED,
end_date=timezone.parse("2024-10-31T12:00:00Z")),
- None,
- "",
- (),
- {},
- "",
- None,
- id="patch_task_instance_to_skipped",
- ),
- pytest.param(
- RetryTask(
- end_date=timezone.parse("2024-10-31T12:00:00Z"),
rendered_map_index="test retry task"
- ),
- None,
- "task_instances.retry",
- (),
- {
- "id": TI_ID,
- "end_date": timezone.parse("2024-10-31T12:00:00Z"),
- "rendered_map_index": "test retry task",
- },
- "",
- None,
- id="up_for_retry",
- ),
- pytest.param(
- SetRenderedFields(rendered_fields={"field1":
"rendered_value1", "field2": "rendered_value2"}),
- None,
- "task_instances.set_rtif",
- (TI_ID, {"field1": "rendered_value1", "field2":
"rendered_value2"}),
- {},
- OKResponse(ok=True),
- None,
- id="set_rtif",
- ),
- pytest.param(
- GetAssetByName(name="asset"),
- {"name": "asset", "uri": "s3://bucket/obj", "group": "asset",
"type": "AssetResult"},
- "assets.get",
- [],
- {"name": "asset"},
- AssetResult(name="asset", uri="s3://bucket/obj",
group="asset"),
- None,
- id="get_asset_by_name",
- ),
- pytest.param(
- GetAssetByUri(uri="s3://bucket/obj"),
- {"name": "asset", "uri": "s3://bucket/obj", "group": "asset",
"type": "AssetResult"},
- "assets.get",
- [],
- {"uri": "s3://bucket/obj"},
- AssetResult(name="asset", uri="s3://bucket/obj",
group="asset"),
- None,
- id="get_asset_by_uri",
- ),
- pytest.param(
- GetAssetEventByAsset(uri="s3://bucket/obj", name="test"),
- {
- "asset_events": [
- {
- "id": 1,
- "timestamp":
timezone.parse("2024-10-31T12:00:00Z"),
- "asset": {"name": "asset", "uri":
"s3://bucket/obj", "group": "asset"},
- "created_dagruns": [],
- }
- ],
- "type": "AssetEventsResult",
- },
- "asset_events.get",
- [],
- {"uri": "s3://bucket/obj", "name": "test"},
- AssetEventsResult(
- asset_events=[
- AssetEventResponse(
- id=1,
- asset=AssetResponse(name="asset",
uri="s3://bucket/obj", group="asset"),
- created_dagruns=[],
- timestamp=timezone.parse("2024-10-31T12:00:00Z"),
- )
- ]
- ),
- None,
- id="get_asset_events_by_uri_and_name",
- ),
- pytest.param(
- GetAssetEventByAsset(uri="s3://bucket/obj", name=None),
- {
- "asset_events": [
- {
- "id": 1,
- "timestamp":
timezone.parse("2024-10-31T12:00:00Z"),
- "asset": {"name": "asset", "uri":
"s3://bucket/obj", "group": "asset"},
- "created_dagruns": [],
- }
- ],
- "type": "AssetEventsResult",
- },
- "asset_events.get",
- [],
- {"uri": "s3://bucket/obj", "name": None},
- AssetEventsResult(
- asset_events=[
- AssetEventResponse(
- id=1,
- asset=AssetResponse(name="asset",
uri="s3://bucket/obj", group="asset"),
- created_dagruns=[],
- timestamp=timezone.parse("2024-10-31T12:00:00Z"),
- )
- ]
- ),
- None,
- id="get_asset_events_by_uri",
- ),
- pytest.param(
- GetAssetEventByAsset(uri=None, name="test"),
- {
- "asset_events": [
- {
- "id": 1,
- "timestamp":
timezone.parse("2024-10-31T12:00:00Z"),
- "asset": {"name": "asset", "uri":
"s3://bucket/obj", "group": "asset"},
- "created_dagruns": [],
- }
- ],
- "type": "AssetEventsResult",
- },
- "asset_events.get",
- [],
- {"uri": None, "name": "test"},
- AssetEventsResult(
- asset_events=[
- AssetEventResponse(
- id=1,
- asset=AssetResponse(name="asset",
uri="s3://bucket/obj", group="asset"),
- created_dagruns=[],
- timestamp=timezone.parse("2024-10-31T12:00:00Z"),
- )
- ]
- ),
- None,
- id="get_asset_events_by_name",
- ),
- pytest.param(
- GetAssetEventByAssetAlias(alias_name="test_alias"),
- {
- "asset_events": [
- {
- "id": 1,
- "timestamp":
timezone.parse("2024-10-31T12:00:00Z"),
- "asset": {"name": "asset", "uri":
"s3://bucket/obj", "group": "asset"},
- "created_dagruns": [],
- }
- ],
- "type": "AssetEventsResult",
- },
- "asset_events.get",
- [],
- {"alias_name": "test_alias"},
- AssetEventsResult(
- asset_events=[
- AssetEventResponse(
- id=1,
- asset=AssetResponse(name="asset",
uri="s3://bucket/obj", group="asset"),
- created_dagruns=[],
- timestamp=timezone.parse("2024-10-31T12:00:00Z"),
- )
- ]
- ),
- None,
- id="get_asset_events_by_asset_alias",
- ),
- pytest.param(
- ValidateInletsAndOutlets(ti_id=TI_ID),
- {
- "inactive_assets": [{"name": "asset_name", "uri":
"asset_uri", "type": "asset"}],
- "type": "InactiveAssetsResult",
- },
- "task_instances.validate_inlets_and_outlets",
- (TI_ID,),
- {},
- InactiveAssetsResult(
- inactive_assets=[AssetProfile(name="asset_name",
uri="asset_uri", type="asset")]
- ),
- None,
- id="validate_inlets_and_outlets",
- ),
- pytest.param(
- SucceedTask(
- end_date=timezone.parse("2024-10-31T12:00:00Z"),
rendered_map_index="test success task"
- ),
- None,
- "task_instances.succeed",
- (),
- {
- "id": TI_ID,
- "outlet_events": None,
- "task_outlets": None,
- "when": timezone.parse("2024-10-31T12:00:00Z"),
- "rendered_map_index": "test success task",
- },
- "",
- None,
- id="succeed_task",
- ),
- pytest.param(
- GetPrevSuccessfulDagRun(ti_id=TI_ID),
- {
- "data_interval_start":
timezone.parse("2025-01-10T12:00:00Z"),
- "data_interval_end":
timezone.parse("2025-01-10T14:00:00Z"),
- "start_date": timezone.parse("2025-01-10T12:00:00Z"),
- "end_date": timezone.parse("2025-01-10T14:00:00Z"),
- "type": "PrevSuccessfulDagRunResult",
- },
- "task_instances.get_previous_successful_dagrun",
- (TI_ID,),
- {},
- PrevSuccessfulDagRunResult(
- start_date=timezone.parse("2025-01-10T12:00:00Z"),
- end_date=timezone.parse("2025-01-10T14:00:00Z"),
- data_interval_start=timezone.parse("2025-01-10T12:00:00Z"),
- data_interval_end=timezone.parse("2025-01-10T14:00:00Z"),
- ),
- None,
- id="get_prev_successful_dagrun",
),
- pytest.param(
- TriggerDagRun(
- dag_id="test_dag",
- run_id="test_run",
- conf={"key": "value"},
- logical_date=timezone.datetime(2025, 1, 1),
- reset_dag_run=True,
- ),
- {"ok": True, "type": "OKResponse"},
- "dag_runs.trigger",
- ("test_dag", "test_run", {"key": "value"},
timezone.datetime(2025, 1, 1), True),
- {},
- OKResponse(ok=True),
- None,
- id="dag_run_trigger",
- ),
- pytest.param(
- # TODO: This should be raise an exception, not returning an
ErrorResponse. Fix this before PR
- TriggerDagRun(dag_id="test_dag", run_id="test_run"),
- {"error": "DAGRUN_ALREADY_EXISTS", "detail": None, "type":
"ErrorResponse"},
- "dag_runs.trigger",
- ("test_dag", "test_run", None, None, False),
- {},
- ErrorResponse(error=ErrorType.DAGRUN_ALREADY_EXISTS),
+ ),
+ ),
+ RequestTestCase(
+ message=GetXCom(dag_id="test_dag", run_id="test_run",
task_id="test_task", key="test_key"),
+ test_id="get_xcom",
+ client_mock=ClientMock(
+ method_path="xcoms.get",
+ args=("test_dag", "test_run", "test_task", "test_key", None,
False),
+ response=XComResult(key="test_key", value="test_value"),
+ ),
+ expected_body={"key": "test_key", "value": "test_value", "type":
"XComResult"},
+ ),
+ RequestTestCase(
+ message=GetXCom(
+ dag_id="test_dag", run_id="test_run", task_id="test_task",
key="test_key", map_index=2
+ ),
+ test_id="get_xcom_map_index",
+ client_mock=ClientMock(
+ method_path="xcoms.get",
+ args=("test_dag", "test_run", "test_task", "test_key", 2, False),
+ response=XComResult(key="test_key", value="test_value"),
+ ),
+ expected_body={"key": "test_key", "value": "test_value", "type":
"XComResult"},
+ ),
+ RequestTestCase(
+ message=GetXCom(dag_id="test_dag", run_id="test_run",
task_id="test_task", key="test_key"),
+ test_id="get_xcom_not_found",
+ client_mock=ClientMock(
+ method_path="xcoms.get",
+ args=("test_dag", "test_run", "test_task", "test_key", None,
False),
+ response=XComResult(key="test_key", value=None, type="XComResult"),
+ ),
+ expected_body={"key": "test_key", "value": None, "type": "XComResult"},
+ ),
+ RequestTestCase(
+ message=GetXCom(
+ dag_id="test_dag",
+ run_id="test_run",
+ task_id="test_task",
+ key="test_key",
+ include_prior_dates=True,
+ ),
+ test_id="get_xcom_include_prior_dates",
+ client_mock=ClientMock(
+ method_path="xcoms.get",
+ args=("test_dag", "test_run", "test_task", "test_key", None, True),
+ response=XComResult(key="test_key", value=None, type="XComResult"),
+ ),
+ expected_body={"key": "test_key", "value": None, "type": "XComResult"},
+ ),
+ RequestTestCase(
+ message=SetXCom(
+ dag_id="test_dag",
+ run_id="test_run",
+ task_id="test_task",
+ key="test_key",
+ value='{"key": "test_key", "value": {"key2": "value2"}}',
+ ),
+ client_mock=ClientMock(
+ method_path="xcoms.set",
+ args=(
+ "test_dag",
+ "test_run",
+ "test_task",
+ "test_key",
+ '{"key": "test_key", "value": {"key2": "value2"}}',
None,
- id="dag_run_trigger_already_exists",
- ),
- pytest.param(
- GetDagRunState(dag_id="test_dag", run_id="test_run"),
- {"state": "running", "type": "DagRunStateResult"},
- "dag_runs.get_state",
- ("test_dag", "test_run"),
- {},
- DagRunStateResult(state=DagRunState.RUNNING),
None,
- id="get_dag_run_state",
),
- pytest.param(
- GetTaskRescheduleStartDate(ti_id=TI_ID),
- {"start_date": timezone.parse("2024-10-31T12:00:00Z"), "type":
"TaskRescheduleStartDate"},
- "task_instances.get_reschedule_start_date",
- (TI_ID, 1),
- {},
-
TaskRescheduleStartDate(start_date=timezone.parse("2024-10-31T12:00:00Z")),
+ response=OKResponse(ok=True),
+ ),
+ test_id="set_xcom",
+ ),
+ RequestTestCase(
+ message=SetXCom(
+ dag_id="test_dag",
+ run_id="test_run",
+ task_id="test_task",
+ key="test_key",
+ value='{"key": "test_key", "value": {"key2": "value2"}}',
+ map_index=2,
+ ),
+ client_mock=ClientMock(
+ method_path="xcoms.set",
+ args=(
+ "test_dag",
+ "test_run",
+ "test_task",
+ "test_key",
+ '{"key": "test_key", "value": {"key2": "value2"}}',
+ 2,
None,
- id="get_task_reschedule_start_date",
),
- pytest.param(
- GetTICount(dag_id="test_dag", task_ids=["task1", "task2"]),
- {"count": 2, "type": "TICount"},
- "task_instances.get_count",
- (),
- {
- "dag_id": "test_dag",
- "map_index": None,
- "logical_dates": None,
- "run_ids": None,
- "states": None,
- "task_group_id": None,
- "task_ids": ["task1", "task2"],
- },
- TICount(count=2),
- None,
- id="get_ti_count",
+ response=OKResponse(ok=True),
+ ),
+ test_id="set_xcom_with_map_index",
+ ),
+ RequestTestCase(
+ message=SetXCom(
+ dag_id="test_dag",
+ run_id="test_run",
+ task_id="test_task",
+ key="test_key",
+ value='{"key": "test_key", "value": {"key2": "value2"}}',
+ map_index=2,
+ mapped_length=3,
+ ),
+ client_mock=ClientMock(
+ method_path="xcoms.set",
+ args=(
+ "test_dag",
+ "test_run",
+ "test_task",
+ "test_key",
+ '{"key": "test_key", "value": {"key2": "value2"}}',
+ 2,
+ 3,
),
- pytest.param(
- GetDRCount(dag_id="test_dag", states=["success", "failed"]),
- {"count": 2, "type": "DRCount"},
- "dag_runs.get_count",
- (),
+ response=OKResponse(ok=True),
+ ),
+ test_id="set_xcom_with_map_index_and_mapped_length",
+ ),
+ RequestTestCase(
+ message=DeleteXCom(
+ dag_id="test_dag",
+ run_id="test_run",
+ task_id="test_task",
+ key="test_key",
+ map_index=2,
+ ),
+ client_mock=ClientMock(
+ method_path="xcoms.delete",
+ args=("test_dag", "test_run", "test_task", "test_key", 2),
+ response=OKResponse(ok=True),
+ ),
+ test_id="delete_xcom",
+ ),
+ RequestTestCase(
+ message=RetryTask(
+ end_date=timezone.parse("2024-10-31T12:00:00Z"),
rendered_map_index="test retry task"
+ ),
+ client_mock=ClientMock(
+ method_path="task_instances.retry",
+ kwargs={
+ "id": TI_ID,
+ "end_date": timezone.parse("2024-10-31T12:00:00Z"),
+ "rendered_map_index": "test retry task",
+ },
+ response=OKResponse(ok=True),
+ ),
+ test_id="up_for_retry",
+ ),
+ RequestTestCase(
+ message=SetRenderedFields(rendered_fields={"field1":
"rendered_value1", "field2": "rendered_value2"}),
+ client_mock=ClientMock(
+ method_path="task_instances.set_rtif",
+ args=(TI_ID, {"field1": "rendered_value1", "field2":
"rendered_value2"}),
+ response=OKResponse(ok=True),
+ ),
+ test_id="set_rtif",
+ ),
+ RequestTestCase(
+ message=SucceedTask(
+ end_date=timezone.parse("2024-10-31T12:00:00Z"),
rendered_map_index="test success task"
+ ),
+ client_mock=ClientMock(
+ method_path="task_instances.succeed",
+ kwargs={
+ "id": TI_ID,
+ "outlet_events": None,
+ "task_outlets": None,
+ "when": timezone.parse("2024-10-31T12:00:00Z"),
+ "rendered_map_index": "test success task",
+ },
+ ),
+ test_id="succeed_task",
+ ),
+ RequestTestCase(
+ message=GetAssetByName(name="asset"),
+ expected_body={"name": "asset", "uri": "s3://bucket/obj", "group":
"asset", "type": "AssetResult"},
+ client_mock=ClientMock(
+ method_path="assets.get",
+ kwargs={"name": "asset"},
+ response=AssetResult(name="asset", uri="s3://bucket/obj",
group="asset"),
+ ),
+ test_id="get_asset_by_name",
+ ),
+ RequestTestCase(
+ message=GetAssetByUri(uri="s3://bucket/obj"),
+ expected_body={"name": "asset", "uri": "s3://bucket/obj", "group":
"asset", "type": "AssetResult"},
+ client_mock=ClientMock(
+ method_path="assets.get",
+ kwargs={"uri": "s3://bucket/obj"},
+ response=AssetResult(name="asset", uri="s3://bucket/obj",
group="asset"),
+ ),
+ test_id="get_asset_by_uri",
+ ),
+ RequestTestCase(
+ message=GetAssetEventByAsset(uri="s3://bucket/obj", name="test"),
+ expected_body={
+ "asset_events": [
{
- "dag_id": "test_dag",
- "logical_dates": None,
- "run_ids": None,
- "states": ["success", "failed"],
- },
- DRCount(count=2),
- None,
- id="get_dr_count",
+ "id": 1,
+ "timestamp": timezone.parse("2024-10-31T12:00:00Z"),
+ "asset": {"name": "asset", "uri": "s3://bucket/obj",
"group": "asset"},
+ "created_dagruns": [],
+ }
+ ],
+ "type": "AssetEventsResult",
+ },
+ client_mock=ClientMock(
+ method_path="asset_events.get",
+ kwargs={"uri": "s3://bucket/obj", "name": "test"},
+ response=AssetEventsResult(
+ asset_events=[
+ AssetEventResponse(
+ id=1,
+ asset=AssetResponse(name="asset",
uri="s3://bucket/obj", group="asset"),
+ created_dagruns=[],
+ timestamp=timezone.parse("2024-10-31T12:00:00Z"),
+ ),
+ ],
),
- pytest.param(
- GetPreviousDagRun(
- dag_id="test_dag",
- logical_date=timezone.parse("2024-01-15T12:00:00Z"),
- ),
- {
- "dag_run": {
- "dag_id": "test_dag",
- "run_id": "prev_run",
- "logical_date": timezone.parse("2024-01-14T12:00:00Z"),
- "run_type": "scheduled",
- "start_date": timezone.parse("2024-01-15T12:00:00Z"),
- "run_after": timezone.parse("2024-01-15T12:00:00Z"),
- "consumed_asset_events": [],
- "state": "success",
- "data_interval_start": None,
- "data_interval_end": None,
- "end_date": None,
- "clear_number": 0,
- "conf": None,
- },
- "type": "PreviousDagRunResult",
- },
- "dag_runs.get_previous",
- (),
+ ),
+ test_id="get_asset_events_by_uri_and_name",
+ ),
+ RequestTestCase(
+ message=GetAssetEventByAsset(uri="s3://bucket/obj", name=None),
+ expected_body={
+ "asset_events": [
{
- "dag_id": "test_dag",
- "logical_date": timezone.parse("2024-01-15T12:00:00Z"),
- "state": None,
- },
- PreviousDagRunResult(
- dag_run=DagRun(
- dag_id="test_dag",
- run_id="prev_run",
- logical_date=timezone.parse("2024-01-14T12:00:00Z"),
- run_type=DagRunType.SCHEDULED,
- start_date=timezone.parse("2024-01-15T12:00:00Z"),
- run_after=timezone.parse("2024-01-15T12:00:00Z"),
- consumed_asset_events=[],
- state=DagRunState.SUCCESS,
+ "id": 1,
+ "timestamp": timezone.parse("2024-10-31T12:00:00Z"),
+ "asset": {"name": "asset", "uri": "s3://bucket/obj",
"group": "asset"},
+ "created_dagruns": [],
+ }
+ ],
+ "type": "AssetEventsResult",
+ },
+ client_mock=ClientMock(
+ method_path="asset_events.get",
+ kwargs={"uri": "s3://bucket/obj", "name": None},
+ response=AssetEventsResult(
+ asset_events=[
+ AssetEventResponse(
+ id=1,
+ asset=AssetResponse(name="asset",
uri="s3://bucket/obj", group="asset"),
+ created_dagruns=[],
+ timestamp=timezone.parse("2024-10-31T12:00:00Z"),
)
- ),
- None,
- id="get_previous_dagrun",
+ ],
),
- pytest.param(
- GetPreviousDagRun(
- dag_id="test_dag",
- logical_date=timezone.parse("2024-01-15T12:00:00Z"),
- state="success",
- ),
- {
- "dag_run": None,
- "type": "PreviousDagRunResult",
- },
- "dag_runs.get_previous",
- (),
+ ),
+ test_id="get_asset_events_by_uri",
+ ),
+ RequestTestCase(
+ message=GetAssetEventByAsset(uri=None, name="test"),
+ expected_body={
+ "asset_events": [
{
- "dag_id": "test_dag",
- "logical_date": timezone.parse("2024-01-15T12:00:00Z"),
- "state": "success",
- },
- PreviousDagRunResult(dag_run=None),
- None,
- id="get_previous_dagrun_with_state",
+ "id": 1,
+ "timestamp": timezone.parse("2024-10-31T12:00:00Z"),
+ "asset": {"name": "asset", "uri": "s3://bucket/obj",
"group": "asset"},
+ "created_dagruns": [],
+ }
+ ],
+ "type": "AssetEventsResult",
+ },
+ client_mock=ClientMock(
+ method_path="asset_events.get",
+ kwargs={"uri": None, "name": "test"},
+ response=AssetEventsResult(
+ asset_events=[
+ AssetEventResponse(
+ id=1,
+ asset=AssetResponse(name="asset",
uri="s3://bucket/obj", group="asset"),
+ created_dagruns=[],
+ timestamp=timezone.parse("2024-10-31T12:00:00Z"),
+ )
+ ]
),
- pytest.param(
- GetTaskStates(dag_id="test_dag", task_group_id="test_group"),
- {
- "task_states": {"run_id": {"task1": "success", "task2":
"failed"}},
- "type": "TaskStatesResult",
- },
- "task_instances.get_task_states",
- (),
+ ),
+ test_id="get_asset_events_by_name",
+ ),
+ RequestTestCase(
+ message=GetAssetEventByAssetAlias(alias_name="test_alias"),
+ expected_body={
+ "asset_events": [
{
- "dag_id": "test_dag",
- "map_index": None,
- "task_ids": None,
- "logical_dates": None,
- "run_ids": None,
- "task_group_id": "test_group",
- },
- TaskStatesResult(task_states={"run_id": {"task1": "success",
"task2": "failed"}}),
- None,
- id="get_task_states",
+ "id": 1,
+ "timestamp": timezone.parse("2024-10-31T12:00:00Z"),
+ "asset": {"name": "asset", "uri": "s3://bucket/obj",
"group": "asset"},
+ "created_dagruns": [],
+ }
+ ],
+ "type": "AssetEventsResult",
+ },
+ client_mock=ClientMock(
+ method_path="asset_events.get",
+ kwargs={"alias_name": "test_alias"},
+ response=AssetEventsResult(
+ asset_events=[
+ AssetEventResponse(
+ id=1,
+ asset=AssetResponse(name="asset",
uri="s3://bucket/obj", group="asset"),
+ created_dagruns=[],
+ timestamp=timezone.parse("2024-10-31T12:00:00Z"),
+ )
+ ]
),
- pytest.param(
- GetXComSequenceItem(
- key="test_key",
- dag_id="test_dag",
- run_id="test_run",
- task_id="test_task",
- offset=0,
- ),
- {"root": "test_value", "type": "XComSequenceIndexResult"},
- "xcoms.get_sequence_item",
- ("test_dag", "test_run", "test_task", "test_key", 0),
- {},
- XComSequenceIndexResult(root="test_value"),
- None,
- id="get_xcom_seq_item",
+ ),
+ test_id="get_asset_events_by_asset_alias",
+ ),
+ RequestTestCase(
+ message=ValidateInletsAndOutlets(ti_id=TI_ID),
+ expected_body={
+ "inactive_assets": [{"name": "asset_name", "uri": "asset_uri",
"type": "asset"}],
+ "type": "InactiveAssetsResult",
+ },
+ client_mock=ClientMock(
+ method_path="task_instances.validate_inlets_and_outlets",
+ args=(TI_ID,),
+ response=InactiveAssetsResult(
+ inactive_assets=[AssetProfile(name="asset_name",
uri="asset_uri", type="asset")]
),
- pytest.param(
- # TODO: This should be raise an exception, not returning an
ErrorResponse. Fix this before PR
- GetXComSequenceItem(
- key="test_key",
- dag_id="test_dag",
- run_id="test_run",
- task_id="test_task",
- offset=2,
- ),
- {"error": "XCOM_NOT_FOUND", "detail": None, "type":
"ErrorResponse"},
- "xcoms.get_sequence_item",
- ("test_dag", "test_run", "test_task", "test_key", 2),
- {},
- ErrorResponse(error=ErrorType.XCOM_NOT_FOUND),
- None,
- id="get_xcom_seq_item_not_found",
+ ),
+ test_id="validate_inlets_and_outlets",
+ ),
+ RequestTestCase(
+ message=GetPrevSuccessfulDagRun(ti_id=TI_ID),
+ expected_body={
+ "data_interval_start": timezone.parse("2025-01-10T12:00:00Z"),
+ "data_interval_end": timezone.parse("2025-01-10T14:00:00Z"),
+ "start_date": timezone.parse("2025-01-10T12:00:00Z"),
+ "end_date": timezone.parse("2025-01-10T14:00:00Z"),
+ "type": "PrevSuccessfulDagRunResult",
+ },
+ client_mock=ClientMock(
+ method_path="task_instances.get_previous_successful_dagrun",
+ args=(TI_ID,),
+ response=PrevSuccessfulDagRunResult(
+ start_date=timezone.parse("2025-01-10T12:00:00Z"),
+ end_date=timezone.parse("2025-01-10T14:00:00Z"),
+ data_interval_start=timezone.parse("2025-01-10T12:00:00Z"),
+ data_interval_end=timezone.parse("2025-01-10T14:00:00Z"),
),
- pytest.param(
- GetXComSequenceSlice(
- key="test_key",
+ ),
+ test_id="get_prev_successful_dagrun",
+ ),
+ RequestTestCase(
+ message=TriggerDagRun(
+ dag_id="test_dag",
+ run_id="test_run",
+ conf={"key": "value"},
+ logical_date=timezone.datetime(2025, 1, 1),
+ reset_dag_run=True,
+ ),
+ expected_body={"ok": True, "type": "OKResponse"},
+ client_mock=ClientMock(
+ method_path="dag_runs.trigger",
+ args=("test_dag", "test_run", {"key": "value"},
timezone.datetime(2025, 1, 1), True),
+ response=OKResponse(ok=True),
+ ),
+ test_id="dag_run_trigger",
+ ),
+ RequestTestCase(
+ message=TriggerDagRun(dag_id="test_dag", run_id="test_run"),
+ expected_body={"error": "DAGRUN_ALREADY_EXISTS", "detail": None,
"type": "ErrorResponse"},
+ client_mock=ClientMock(
+ method_path="dag_runs.trigger",
+ args=("test_dag", "test_run", None, None, False),
+ response=ErrorResponse(error=ErrorType.DAGRUN_ALREADY_EXISTS),
+ ),
+ test_id="dag_run_trigger_already_exists",
+ ),
+ RequestTestCase(
+ message=GetDagRunState(dag_id="test_dag", run_id="test_run"),
+ expected_body={"state": "running", "type": "DagRunStateResult"},
+ client_mock=ClientMock(
+ method_path="dag_runs.get_state",
+ args=("test_dag", "test_run"),
+ response=DagRunStateResult(state=DagRunState.RUNNING),
+ ),
+ test_id="get_dag_run_state",
+ ),
+ RequestTestCase(
+ message=GetPreviousDagRun(
+ dag_id="test_dag",
+ logical_date=timezone.parse("2024-01-15T12:00:00Z"),
+ ),
+ expected_body={
+ "dag_run": {
+ "dag_id": "test_dag",
+ "run_id": "prev_run",
+ "logical_date": timezone.parse("2024-01-14T12:00:00Z"),
+ "run_type": "scheduled",
+ "start_date": timezone.parse("2024-01-15T12:00:00Z"),
+ "run_after": timezone.parse("2024-01-15T12:00:00Z"),
+ "consumed_asset_events": [],
+ "state": "success",
+ "data_interval_start": None,
+ "data_interval_end": None,
+ "end_date": None,
+ "clear_number": 0,
+ "conf": None,
+ },
+ "type": "PreviousDagRunResult",
+ },
+ client_mock=ClientMock(
+ method_path="dag_runs.get_previous",
+ kwargs={
+ "dag_id": "test_dag",
+ "logical_date": timezone.parse("2024-01-15T12:00:00Z"),
+ "state": None,
+ },
+ response=PreviousDagRunResult(
+ dag_run=DagRun(
dag_id="test_dag",
- run_id="test_run",
- task_id="test_task",
- start=None,
- stop=None,
- step=None,
- include_prior_dates=False,
- ),
- {"root": ["foo", "bar"], "type": "XComSequenceSliceResult"},
- "xcoms.get_sequence_slice",
- ("test_dag", "test_run", "test_task", "test_key", None, None,
None, False),
- {},
- XComSequenceSliceResult(root=["foo", "bar"]),
- None,
- id="get_xcom_seq_slice",
+ run_id="prev_run",
+ logical_date=timezone.parse("2024-01-14T12:00:00Z"),
+ run_type=DagRunType.SCHEDULED,
+ start_date=timezone.parse("2024-01-15T12:00:00Z"),
+ run_after=timezone.parse("2024-01-15T12:00:00Z"),
+ consumed_asset_events=[],
+ state=DagRunState.SUCCESS,
+ )
),
- pytest.param(
- CreateHITLDetailPayload(
- ti_id=TI_ID,
- options=["Approve", "Reject"],
- subject="This is subject",
- body="This is body",
- defaults=["Approve"],
- multiple=False,
- params={},
- ),
- {
- "ti_id": str(TI_ID),
- "options": ["Approve", "Reject"],
- "subject": "This is subject",
- "body": "This is body",
- "defaults": ["Approve"],
- "multiple": False,
- "params": {},
- "respondents": None,
- "type": "HITLDetailRequestResult",
- },
- "hitl.add_response",
- (),
- {
- "body": "This is body",
- "defaults": ["Approve"],
- "multiple": False,
- "options": ["Approve", "Reject"],
- "params": {},
- "respondents": None,
- "subject": "This is subject",
- "ti_id": TI_ID,
- },
- HITLDetailRequestResult(
- ti_id=TI_ID,
- options=["Approve", "Reject"],
- subject="This is subject",
- body="This is body",
- defaults=["Approve"],
- multiple=False,
- params={},
- ),
- None,
- id="create_hitl_detail_payload",
+ ),
+ test_id="get_previous_dagrun",
+ ),
+ RequestTestCase(
+ message=GetPreviousDagRun(
+ dag_id="test_dag",
+ logical_date=timezone.parse("2024-01-15T12:00:00Z"),
+ state="success",
+ ),
+ expected_body={
+ "dag_run": None,
+ "type": "PreviousDagRunResult",
+ },
+ client_mock=ClientMock(
+ method_path="dag_runs.get_previous",
+ kwargs={
+ "dag_id": "test_dag",
+ "logical_date": timezone.parse("2024-01-15T12:00:00Z"),
+ "state": "success",
+ },
+ response=PreviousDagRunResult(dag_run=None),
+ ),
+ test_id="get_previous_dagrun_with_state",
+ ),
+ RequestTestCase(
+ message=GetTaskRescheduleStartDate(ti_id=TI_ID),
+ expected_body={
+ "start_date": timezone.parse("2024-10-31T12:00:00Z"),
+ "type": "TaskRescheduleStartDate",
+ },
+ client_mock=ClientMock(
+ method_path="task_instances.get_reschedule_start_date",
+ args=(TI_ID, 1),
+
response=TaskRescheduleStartDate(start_date=timezone.parse("2024-10-31T12:00:00Z")),
+ ),
+ test_id="get_task_reschedule_start_date",
+ ),
+ RequestTestCase(
+ message=GetTICount(dag_id="test_dag", task_ids=["task1", "task2"]),
+ expected_body={"count": 2, "type": "TICount"},
+ client_mock=ClientMock(
+ method_path="task_instances.get_count",
+ kwargs={
+ "dag_id": "test_dag",
+ "map_index": None,
+ "logical_dates": None,
+ "run_ids": None,
+ "states": None,
+ "task_group_id": None,
+ "task_ids": ["task1", "task2"],
+ },
+ response=TICount(count=2),
+ ),
+ test_id="get_ti_count",
+ ),
+ RequestTestCase(
+ message=GetDRCount(dag_id="test_dag", states=["success", "failed"]),
+ expected_body={"count": 2, "type": "DRCount"},
+ client_mock=ClientMock(
+ method_path="dag_runs.get_count",
+ kwargs={
+ "dag_id": "test_dag",
+ "logical_dates": None,
+ "run_ids": None,
+ "states": ["success", "failed"],
+ },
+ response=DRCount(count=2),
+ ),
+ test_id="get_dr_count",
+ ),
+ RequestTestCase(
+ message=GetTaskStates(dag_id="test_dag", task_group_id="test_group"),
+ expected_body={
+ "task_states": {"run_id": {"task1": "success", "task2": "failed"}},
+ "type": "TaskStatesResult",
+ },
+ client_mock=ClientMock(
+ method_path="task_instances.get_task_states",
+ kwargs={
+ "dag_id": "test_dag",
+ "map_index": None,
+ "task_ids": None,
+ "logical_dates": None,
+ "run_ids": None,
+ "task_group_id": "test_group",
+ },
+ response=TaskStatesResult(task_states={"run_id": {"task1":
"success", "task2": "failed"}}),
+ ),
+ test_id="get_task_states",
+ ),
+ RequestTestCase(
+ message=GetXComSequenceItem(
+ key="test_key",
+ dag_id="test_dag",
+ run_id="test_run",
+ task_id="test_task",
+ offset=0,
+ ),
+ expected_body={"root": "test_value", "type":
"XComSequenceIndexResult"},
+ client_mock=ClientMock(
+ method_path="xcoms.get_sequence_item",
+ args=("test_dag", "test_run", "test_task", "test_key", 0),
+ response=XComSequenceIndexResult(root="test_value"),
+ ),
+ test_id="get_xcom_seq_item",
+ ),
+ RequestTestCase(
+ message=GetXComSequenceItem(
+ key="test_key",
+ dag_id="test_dag",
+ run_id="test_run",
+ task_id="test_task",
+ offset=2,
+ ),
+ expected_body={"error": "XCOM_NOT_FOUND", "detail": None, "type":
"ErrorResponse"},
+ client_mock=ClientMock(
+ method_path="xcoms.get_sequence_item",
+ args=("test_dag", "test_run", "test_task", "test_key", 2),
+ response=ErrorResponse(error=ErrorType.XCOM_NOT_FOUND),
+ ),
+ test_id="get_xcom_seq_item_not_found",
+ ),
+ RequestTestCase(
+ message=GetXComSequenceSlice(
+ key="test_key",
+ dag_id="test_dag",
+ run_id="test_run",
+ task_id="test_task",
+ start=None,
+ stop=None,
+ step=None,
+ include_prior_dates=False,
+ ),
+ expected_body={"root": ["foo", "bar"], "type":
"XComSequenceSliceResult"},
+ client_mock=ClientMock(
+ method_path="xcoms.get_sequence_slice",
+ args=("test_dag", "test_run", "test_task", "test_key", None, None,
None, False),
+ response=XComSequenceSliceResult(root=["foo", "bar"]),
+ ),
+ test_id="get_xcom_seq_slice",
+ ),
+ RequestTestCase(
+ message=TaskState(state=TaskInstanceState.SKIPPED,
end_date=timezone.parse("2024-10-31T12:00:00Z")),
+ test_id="patch_task_instance_to_skipped",
+ ),
+ RequestTestCase(
+ message=CreateHITLDetailPayload(
+ ti_id=TI_ID,
+ options=["Approve", "Reject"],
+ subject="This is subject",
+ body="This is body",
+ defaults=["Approve"],
+ multiple=False,
+ params={},
+ ),
+ expected_body={
+ "ti_id": str(TI_ID),
+ "options": ["Approve", "Reject"],
+ "subject": "This is subject",
+ "body": "This is body",
+ "defaults": ["Approve"],
+ "multiple": False,
+ "params": {},
+ "respondents": None,
+ "type": "HITLDetailRequestResult",
+ },
+ client_mock=ClientMock(
+ method_path="hitl.add_response",
+ kwargs={
+ "body": "This is body",
+ "defaults": ["Approve"],
+ "multiple": False,
+ "options": ["Approve", "Reject"],
+ "params": {},
+ "respondents": None,
+ "subject": "This is subject",
+ "ti_id": TI_ID,
+ },
+ response=HITLDetailRequestResult(
+ ti_id=TI_ID,
+ options=["Approve", "Reject"],
+ subject="This is subject",
+ body="This is body",
+ defaults=["Approve"],
+ multiple=False,
+ params={},
),
- ],
- )
+ ),
+ test_id="create_hitl_detail_payload",
+ ),
+ RequestTestCase(
+ message=MaskSecret(value=["iter1", "iter2", {"key": "value"}],
name="test_secret"),
+ mask_secret_args=(["iter1", "iter2", {"key": "value"}], "test_secret"),
+ test_id="mask_secret_list",
+ ),
+]
Review Comment:
nice!
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]