This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6b7771d7438 feat(sagemaker): Add SageMakerConditionOperator and
SageMakerFailOperator (#64545)
6b7771d7438 is described below
commit 6b7771d74388af6950c3911eeaa7c7a0387d42e6
Author: Bhavya Sharma <[email protected]>
AuthorDate: Mon Apr 6 11:20:22 2026 -0700
feat(sagemaker): Add SageMakerConditionOperator and SageMakerFailOperator
(#64545)
Co-authored-by: Niko Oliveira <[email protected]>
---
providers/amazon/docs/operators/sagemaker.rst | 32 +++
.../providers/amazon/aws/operators/sagemaker.py | 266 ++++++++++++++++++++-
.../amazon/aws/example_sagemaker_condition.py | 179 ++++++++++++++
.../aws/operators/test_sagemaker_condition.py | 222 +++++++++++++++++
4 files changed, 698 insertions(+), 1 deletion(-)
diff --git a/providers/amazon/docs/operators/sagemaker.rst
b/providers/amazon/docs/operators/sagemaker.rst
index 5103b868455..71f61a32890 100644
--- a/providers/amazon/docs/operators/sagemaker.rst
+++ b/providers/amazon/docs/operators/sagemaker.rst
@@ -385,6 +385,38 @@ you can use
:class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerPro
:start-after: [START howto_sensor_sagemaker_processing]
:end-before: [END howto_sensor_sagemaker_processing]
+.. _howto/operator:SageMakerConditionOperator:
+
+Branch a DAG based on condition evaluation
+==========================================
+
+To branch an Airflow DAG based on upstream task outputs you can use
+:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerConditionOperator`.
+
+Simple usage with flat parameters (single condition):
+
+.. exampleinclude::
/../../amazon/tests/system/amazon/aws/example_sagemaker_condition.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_sagemaker_condition_flat]
+ :end-before: [END howto_operator_sagemaker_condition_flat]
+
+Advanced usage with conditions list (multiple AND-ed conditions):
+
+.. exampleinclude::
/../../amazon/tests/system/amazon/aws/example_sagemaker_condition.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_sagemaker_condition]
+ :end-before: [END howto_operator_sagemaker_condition]
+
+Using Not and Or conditions:
+
+.. exampleinclude::
/../../amazon/tests/system/amazon/aws/example_sagemaker_condition.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_sagemaker_condition_not_or]
+ :end-before: [END howto_operator_sagemaker_condition_not_or]
+
Reference
---------
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py
b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py
index ff5e754913a..fe1597f5b1f 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -41,7 +41,12 @@ from airflow.providers.amazon.aws.utils import
trim_none_values, validate_execut
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
from airflow.providers.amazon.aws.utils.tags import format_tags
-from airflow.providers.common.compat.sdk import AirflowException, conf
+from airflow.providers.common.compat.sdk import (
+ AirflowException,
+ AirflowFailException,
+ BaseBranchOperator,
+ conf,
+)
from airflow.utils.helpers import prune_dict
if TYPE_CHECKING:
@@ -1991,3 +1996,262 @@ class
SageMakerStartNoteBookOperator(AwsBaseOperator[SageMakerHook]):
self.hook.conn.get_waiter("notebook_instance_in_service").wait(
NotebookInstanceName=self.instance_name
)
+
+
+class SageMakerConditionOperator(BaseBranchOperator):
+ """
+ Evaluates a single condition or a list of conditions, and routes tasks
based on the result.
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:SageMakerConditionOperator`
+
+ :param condition_type: Condition type for the simple (flat) interface.
+ Valid types: ``Equals``, ``GreaterThan``, ``GreaterThanOrEqualTo``,
+ ``LessThan``, ``LessThanOrEqualTo``, ``In``.
+ Mutually exclusive with ``conditions``.
+ :param left_value: Left operand for the flat interface. For ``In``
conditions
+ this is the value to check membership of.
+ :param right_value: Right operand for the flat interface. For ``In``
conditions
+ this is the list of allowed values.
+ :param conditions: List of condition dicts to evaluate (AND-ed together).
+ Each dict must have a ``type`` key. Must not be empty.
+ Mutually exclusive with
``condition_type``/``left_value``/``right_value``.
+ :param if_task_ids: Task ID(s) to execute when all conditions are True.
+ :param else_task_ids: Task ID(s) to execute when any condition is False.
+ If omitted, the task fails with ``AirflowFailException`` when
conditions are not met.
+ """
+
+ _VALID_FLAT_TYPES: ClassVar[set[str]] = {
+ "Equals",
+ "GreaterThan",
+ "GreaterThanOrEqualTo",
+ "LessThan",
+ "LessThanOrEqualTo",
+ "In",
+ }
+
+ template_fields: Sequence[str] = (
+ "condition_type",
+ "left_value",
+ "right_value",
+ "conditions",
+ "if_task_ids",
+ "else_task_ids",
+ )
+
+ def __init__(
+ self,
+ *,
+ condition_type: str | None = None,
+ left_value: Any = None,
+ right_value: Any = None,
+ conditions: list[dict] | None = None,
+ if_task_ids: str | list[str],
+ else_task_ids: str | list[str] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(**kwargs)
+ has_flat = condition_type is not None
+ has_list = conditions is not None
+
+ if has_flat and has_list:
+ raise ValueError(
+ "Cannot use 'condition_type' and 'conditions' together. "
+ "Use 'condition_type' with 'left_value'/'right_value' for a
single condition, "
+ "or 'conditions' for multiple/nested conditions."
+ )
+ if not has_flat and not has_list:
+ raise ValueError(
+ "Missing condition: provide 'condition_type' with
'left_value'/'right_value' "
+ "for a single condition, or 'conditions' for multiple/nested
conditions."
+ )
+
+ if has_flat:
+ if condition_type not in self._VALID_FLAT_TYPES:
+ raise ValueError(
+ f"Unknown condition_type '{condition_type}'. "
+ f"Expected one of: {',
'.join(sorted(self._VALID_FLAT_TYPES))}."
+ )
+ self.condition_type: str | None = condition_type
+ self.left_value = left_value
+ self.right_value = right_value
+ if condition_type == "In":
+ self.conditions: list[dict[str, Any]] = [
+ {"type": "In", "value": left_value, "in_values":
right_value}
+ ]
+ else:
+ self.conditions = [
+ {"type": condition_type, "left_value": left_value,
"right_value": right_value}
+ ]
+ else:
+ self.condition_type = None
+ self.left_value = None
+ self.right_value = None
+ self.conditions = conditions # type: ignore[assignment]
+
+ if not self.conditions:
+ raise ValueError("At least 1 condition is required, but got an
empty list.")
+ self.if_task_ids = [if_task_ids] if isinstance(if_task_ids, str) else
if_task_ids
+ self.else_task_ids = [else_task_ids] if isinstance(else_task_ids, str)
else (else_task_ids or [])
+
+ @staticmethod
+ def _cast(value: Any) -> Any:
+ """
+ Cast Jinja-rendered string values to appropriate Python types.
+
+ This is a compatibility shim for environments where
+ ``render_template_as_native_obj=True`` is not available at the DAG or
+ task level (e.g., YAML DAGs). Once task-level native rendering
+ is widely supported, this method can be removed in favor of letting
+ Airflow handle the casting natively.
+
+ - Numeric strings become int or float.
+ - ``"true"``/``"false"`` become booleans.
+ - ``"None"`` becomes ``None`` (common when ``xcom_pull`` returns
nothing).
+ - Other strings are returned unchanged.
+ - Non-string types pass through as-is.
+ """
+ if not isinstance(value, str):
+ return value
+ if value == "None":
+ return None
+ try:
+ return int(value)
+ except (ValueError, TypeError):
+ pass
+ try:
+ return float(value)
+ except (ValueError, TypeError):
+ pass
+ if value.lower() == "true":
+ return True
+ if value.lower() == "false":
+ return False
+ return value
+
+ _COMPARISON_OPERATORS: ClassVar[dict[str, Callable[[Any, Any], bool]]] = {
+ "Equals": lambda left, right: left == right,
+ "GreaterThan": lambda left, right: left > right,
+ "GreaterThanOrEqualTo": lambda left, right: left >= right,
+ "LessThan": lambda left, right: left < right,
+ "LessThanOrEqualTo": lambda left, right: left <= right,
+ }
+
+ def _evaluate(self, condition: dict, depth: int = 0) -> bool:
+ """
+ Recursively evaluate a single condition dict.
+
+ :param condition: A condition dictionary with a ``type`` key.
+ :param depth: Current nesting depth (used for log indentation only).
+ :returns: Boolean result of the condition evaluation.
+ """
+ log_indent = " " * depth
+ try:
+ condition_type = condition["type"]
+ except KeyError:
+ raise ValueError("Condition dict is missing required key 'type'.")
+
+ if condition_type in self._COMPARISON_OPERATORS:
+ try:
+ left = self._cast(condition["left_value"])
+ right = self._cast(condition["right_value"])
+ except KeyError as e:
+ raise ValueError(f"Condition '{condition_type}' missing
required key {e}.") from None
+
+ # None check — likely an XCom that was not pushed
+ if left is None or right is None:
+ raise TypeError(
+ f"Condition '{condition_type}' received None:
left={left!r}, right={right!r}. "
+ "This usually means the upstream task did not run or did
not push a value to XCom."
+ )
+
+ # Type compatibility check
+ left_type = type(left)
+ right_type = type(right)
+ numeric_types = (int, float)
+ left_is_numeric = isinstance(left, numeric_types) and not
isinstance(left, bool)
+ right_is_numeric = isinstance(right, numeric_types) and not
isinstance(right, bool)
+
+ if not (left_is_numeric and right_is_numeric) and left_type is not
right_type:
+ raise TypeError(
+ f"Cannot compare {left_type.__name__} ({left!r}) with
{right_type.__name__} ({right!r}) "
+ f"in condition '{condition_type}'. Both values must be the
same type."
+ )
+
+ comparison_result =
self._COMPARISON_OPERATORS[condition_type](left, right)
+ self.log.info(
+ "%s%s: %r %s %r -> %s",
+ log_indent,
+ condition_type,
+ left,
+ condition_type,
+ right,
+ comparison_result,
+ )
+ return comparison_result
+
+ if condition_type == "In":
+ try:
+ query_value = self._cast(condition["value"])
+ allowed_values = [self._cast(val) for val in
condition["in_values"]]
+ except KeyError as e:
+ raise ValueError(f"Condition '{condition_type}' missing
required key {e}.") from None
+ membership_result = query_value in allowed_values
+ self.log.info("%sIn: %r in %r -> %s", log_indent, query_value,
allowed_values, membership_result)
+ return membership_result
+
+ if condition_type == "Not":
+ try:
+ inner_condition = condition["condition"]
+ except KeyError as e:
+ raise ValueError(f"Condition '{condition_type}' missing
required key {e}.") from None
+ inner_result = self._evaluate(inner_condition, depth + 1)
+ negated_result = not inner_result
+ self.log.info("%sNot: not %s -> %s", log_indent, inner_result,
negated_result)
+ return negated_result
+
+ if condition_type == "Or":
+ try:
+ inner_conditions = condition["conditions"]
+ except KeyError as e:
+ raise ValueError(f"Condition '{condition_type}' missing
required key {e}.") from None
+ inner_results = [self._evaluate(inner_cond, depth + 1) for
inner_cond in inner_conditions]
+ or_result = any(inner_results)
+ self.log.info("%sOr: any(%r) -> %s", log_indent, inner_results,
or_result)
+ return or_result
+
+ raise ValueError(f"Unknown condition type '{condition_type}'.")
+
+ def choose_branch(self, context: Context) -> list[str]:
+ """
+ Evaluate all conditions and return the appropriate branch task IDs.
+
+ :param context: Airflow context dictionary.
+ :returns: ``if_task_ids`` when all conditions are True,
``else_task_ids`` otherwise.
+ """
+ condition_count = len(self.conditions)
+ self.log.info("Evaluating %d condition(s).", condition_count)
+
+ evaluation_results = [self._evaluate(condition) for condition in
self.conditions]
+ all_conditions_met = all(evaluation_results)
+
+ if all_conditions_met:
+ self.log.info(
+ "All %d condition(s) evaluated to True. Routing to
if_task_ids=%r.",
+ condition_count,
+ self.if_task_ids,
+ )
+ return self.if_task_ids
+
+ if not self.else_task_ids:
+ raise AirflowFailException(
+ f"Condition check failed in task '{self.task_id}':
results={evaluation_results}"
+ )
+
+ self.log.info(
+ "Not all conditions are True (results=%r). Routing to
else_task_ids=%r.",
+ evaluation_results,
+ self.else_task_ids,
+ )
+ return self.else_task_ids
diff --git
a/providers/amazon/tests/system/amazon/aws/example_sagemaker_condition.py
b/providers/amazon/tests/system/amazon/aws/example_sagemaker_condition.py
new file mode 100644
index 00000000000..90a392fa0c1
--- /dev/null
+++ b/providers/amazon/tests/system/amazon/aws/example_sagemaker_condition.py
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+System test for SageMakerConditionOperator.
+
+This operator evaluates conditions against XCom values passed from upstream
tasks.
+
+The DAG simulates an ML accuracy-gate workflow:
+
+1. ``produce_metrics`` pushes a dict of metrics to XCom.
+2. ``check_accuracy`` uses SageMakerConditionOperator to branch:
+ - accuracy >= 0.9 AND loss < 0.1 -> ``deploy_model``
+ - otherwise -> ``retrain_model``
+3. Only the correct branch task runs; the other is skipped.
+"""
+
+from __future__ import annotations
+
+from datetime import datetime
+
+from airflow.providers.amazon.aws.operators.sagemaker import
SageMakerConditionOperator
+
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
+
+if AIRFLOW_V_3_0_PLUS:
+ from airflow.sdk import DAG, chain, task
+else:
+ from airflow.decorators import task # type: ignore[attr-defined,no-redef]
+ from airflow.models.baseoperator import chain # type:
ignore[attr-defined,no-redef]
+ from airflow.models.dag import DAG # type:
ignore[attr-defined,no-redef,assignment]
+
+from system.amazon.aws.utils import SystemTestContextBuilder
+
+DAG_ID = "example_sagemaker_condition"
+
+sys_test_context_task = SystemTestContextBuilder().build()
+
+
+with DAG(
+ DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+) as dag:
+ test_context = sys_test_context_task()
+
+ # TEST SETUP: push simulated ML metrics to XCom
+
+ @task
+ def produce_metrics():
+ """Simulate an ML training job that returns accuracy and loss
metrics."""
+ return {"accuracy": 0.95, "loss": 0.04}
+
+ metrics = produce_metrics()
+
+ # TEST BODY
+
+ # [START howto_operator_sagemaker_condition]
+ check_accuracy = SageMakerConditionOperator(
+ task_id="check_accuracy",
+ conditions=[
+ {
+ "type": "GreaterThanOrEqualTo",
+ "left_value": "{{
ti.xcom_pull(task_ids='produce_metrics')['accuracy'] }}",
+ "right_value": 0.9,
+ },
+ {
+ "type": "LessThan",
+ "left_value": "{{
ti.xcom_pull(task_ids='produce_metrics')['loss'] }}",
+ "right_value": 0.1,
+ },
+ ],
+ if_task_ids=["deploy_model"],
+ else_task_ids=["retrain_model"],
+ )
+ # [END howto_operator_sagemaker_condition]
+
+ @task
+ def deploy_model():
+ """Placeholder: model meets quality bar, proceed to deployment."""
+ return "deployed"
+
+ @task
+ def retrain_model():
+ """Placeholder: model does not meet quality bar, retrain."""
+ return "retrained"
+
+ # Scenario 2: condition evaluates to False -> else branch
+
+ @task
+ def produce_bad_metrics():
+ """Simulate a training job with poor accuracy."""
+ return {"accuracy": 0.5, "loss": 0.8}
+
+ bad_metrics = produce_bad_metrics()
+
+ # [START howto_operator_sagemaker_condition_flat]
+ check_bad_accuracy = SageMakerConditionOperator(
+ task_id="check_bad_accuracy",
+ condition_type="GreaterThanOrEqualTo",
+ left_value="{{
ti.xcom_pull(task_ids='produce_bad_metrics')['accuracy'] }}",
+ right_value=0.9,
+ if_task_ids=["should_not_run"],
+ else_task_ids=["should_run"],
+ )
+ # [END howto_operator_sagemaker_condition_flat]
+
+ @task
+ def should_not_run():
+ """This task should be skipped because accuracy < 0.9."""
+ return "error: should not have run"
+
+ @task
+ def should_run():
+ """This task should execute because accuracy < 0.9 -> else branch."""
+ return "correctly routed to else branch"
+
+ # Scenario 3: Or condition + Not condition
+
+ # [START howto_operator_sagemaker_condition_not_or]
+ check_logical = SageMakerConditionOperator(
+ task_id="check_logical",
+ conditions=[
+ {
+ "type": "Or",
+ "conditions": [
+ {"type": "Equals", "left_value": 1, "right_value": 2},
+ {"type": "Equals", "left_value": 3, "right_value": 3},
+ ],
+ },
+ {
+ "type": "Not",
+ "condition": {"type": "Equals", "left_value": "a",
"right_value": "b"},
+ },
+ ],
+ if_task_ids=["logical_pass"],
+ else_task_ids=["logical_fail"],
+ )
+ # [END howto_operator_sagemaker_condition_not_or]
+
+ @task
+ def logical_pass():
+ """Or(1==2, 3==3) -> True AND Not(a==b) -> True -> if branch."""
+ return "logical conditions passed"
+
+ @task
+ def logical_fail():
+ return "error: logical conditions should have passed"
+
+ test_context >> [metrics, bad_metrics, check_logical]
+
+ chain(metrics, check_accuracy, [deploy_model(), retrain_model()])
+ chain(bad_metrics, check_bad_accuracy, [should_not_run(), should_run()])
+ chain(check_logical, [logical_pass(), logical_fail()])
+
+ from tests_common.test_utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests_common.test_utils.system_tests import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git
a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_condition.py
b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_condition.py
new file mode 100644
index 00000000000..56d461e50a3
--- /dev/null
+++
b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_condition.py
@@ -0,0 +1,222 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock
+
+import pytest
+
+from airflow.providers.amazon.aws.operators.sagemaker import
SageMakerConditionOperator
+from airflow.providers.common.compat.sdk import AirflowFailException
+
+from unit.amazon.aws.utils.test_template_fields import validate_template_fields
+
+
+def _choose(conditions, if_ids="if_task", else_ids="else_task"):
+ """Instantiate with conditions list and call choose_branch."""
+ op = SageMakerConditionOperator(
+ task_id="test",
+ conditions=conditions,
+ if_task_ids=if_ids,
+ else_task_ids=else_ids,
+ )
+ return op.choose_branch(context=MagicMock(spec=dict))
+
+
+def _choose_flat(condition_type, left, right, if_ids="if_task",
else_ids="else_task"):
+ """Instantiate with flat params and call choose_branch."""
+ op = SageMakerConditionOperator(
+ task_id="test",
+ condition_type=condition_type,
+ left_value=left,
+ right_value=right,
+ if_task_ids=if_ids,
+ else_task_ids=else_ids,
+ )
+ return op.choose_branch(context=MagicMock(spec=dict))
+
+
+def test_template_fields():
+ op = SageMakerConditionOperator(
+ task_id="test",
+ conditions=[{"type": "Equals", "left_value": 1, "right_value": 1}],
+ if_task_ids=["a"],
+ else_task_ids=["b"],
+ )
+ validate_template_fields(op)
+
+
+class TestConditionTypes:
+ """One true + one false per condition type, plus logical combinators."""
+
+ @pytest.mark.parametrize(
+ ("cond_type", "left", "right", "expected"),
+ [
+ ("Equals", 1, 1, "if"),
+ ("Equals", 1, 2, "else"),
+ ("GreaterThan", 5, 3, "if"),
+ ("GreaterThan", 3, 5, "else"),
+ ("GreaterThanOrEqualTo", 3, 3, "if"),
+ ("GreaterThanOrEqualTo", 2, 3, "else"),
+ ("LessThan", 3, 5, "if"),
+ ("LessThan", 5, 3, "else"),
+ ("LessThanOrEqualTo", 3, 3, "if"),
+ ("LessThanOrEqualTo", 5, 3, "else"),
+ ],
+ )
+ def test_comparison(self, cond_type, left, right, expected):
+ result = _choose([{"type": cond_type, "left_value": left,
"right_value": right}])
+ assert result == (["if_task"] if expected == "if" else ["else_task"])
+
+ def test_in_true(self):
+ assert _choose([{"type": "In", "value": 1, "in_values": [1, 2]}]) ==
["if_task"]
+
+ def test_in_false(self):
+ assert _choose([{"type": "In", "value": 4, "in_values": [1, 2]}]) ==
["else_task"]
+
+ def test_not_negates(self):
+ cond = [{"type": "Not", "condition": {"type": "Equals", "left_value":
1, "right_value": 1}}]
+ assert _choose(cond) == ["else_task"]
+
+ def test_or_any_true(self):
+ cond = [
+ {
+ "type": "Or",
+ "conditions": [
+ {"type": "Equals", "left_value": 1, "right_value": 2},
+ {"type": "Equals", "left_value": 1, "right_value": 1},
+ ],
+ }
+ ]
+ assert _choose(cond) == ["if_task"]
+
+
+class TestAndSemantics:
+ def test_multiple_conditions_and(self):
+ """All true -> if, one false -> else."""
+ assert _choose(
+ [
+ {"type": "Equals", "left_value": 1, "right_value": 1},
+ {"type": "GreaterThan", "left_value": 5, "right_value": 3},
+ ]
+ ) == ["if_task"]
+ assert _choose(
+ [
+ {"type": "Equals", "left_value": 1, "right_value": 1},
+ {"type": "GreaterThan", "left_value": 2, "right_value": 10},
+ ]
+ ) == ["else_task"]
+
+
+class TestValueCasting:
+ def test_cast_numeric_and_passthrough(self):
+ """int string, float string, bool string, non-numeric string,
non-string type."""
+ assert SageMakerConditionOperator._cast("42") == 42
+ assert SageMakerConditionOperator._cast("0.9") == 0.9
+ assert SageMakerConditionOperator._cast("true") is True
+ assert SageMakerConditionOperator._cast("us-east-1") == "us-east-1"
+ assert SageMakerConditionOperator._cast(42) == 42
+
+
+class TestValidation:
+ def test_empty_conditions_raises(self):
+ with pytest.raises(ValueError, match="At least 1 condition is
required"):
+ SageMakerConditionOperator(task_id="t", conditions=[],
if_task_ids="a", else_task_ids="b")
+
+ def test_unknown_type_raises(self):
+ with pytest.raises(ValueError, match="Unknown condition type"):
+ _choose([{"type": "FooBar", "left_value": 1, "right_value": 2}])
+
+ def test_type_mismatch_raises(self):
+ with pytest.raises(TypeError, match="Cannot compare"):
+ _choose([{"type": "GreaterThanOrEqualTo", "left_value": "hello",
"right_value": 0.9}])
+
+ def test_none_operand_raises(self):
+ """Covers both Python None and Jinja-rendered string 'None'."""
+ with pytest.raises(TypeError, match="received None"):
+ _choose([{"type": "Equals", "left_value": None, "right_value": 1}])
+ with pytest.raises(TypeError, match="received None"):
+ _choose([{"type": "GreaterThanOrEqualTo", "left_value": "None",
"right_value": 0.9}])
+
+ @pytest.mark.parametrize(
+ ("condition", "match_pattern"),
+ [
+ ({}, "missing required key 'type'"),
+ ({"type": "Equals", "left_value": 1}, "missing required key"),
+ ({"type": "Not"}, "missing required key"),
+ ({"type": "Or"}, "missing required key"),
+ ],
+ )
+ def test_missing_key_raises(self, condition, match_pattern):
+ with pytest.raises(ValueError, match=match_pattern):
+ _choose([condition])
+
+
+class TestFlatInterface:
+ def test_flat_condition(self):
+ """Flat params work for comparison and In types."""
+ assert _choose_flat("Equals", 1, 1) == ["if_task"]
+ assert _choose_flat("Equals", 1, 2) == ["else_task"]
+ assert _choose_flat("In", 1, [1, 2, 3]) == ["if_task"]
+ assert _choose_flat("In", 99, [1, 2, 3]) == ["else_task"]
+
+ def test_invalid_condition_type_raises(self):
+ with pytest.raises(ValueError, match="Unknown condition_type"):
+ SageMakerConditionOperator(
+ task_id="t",
+ condition_type="NotEquals",
+ left_value=1,
+ right_value=1,
+ if_task_ids="a",
+ else_task_ids="b",
+ )
+
+ def test_mutual_exclusion_raises(self):
+ with pytest.raises(ValueError, match="Cannot use 'condition_type' and
'conditions' together"):
+ SageMakerConditionOperator(
+ task_id="t",
+ condition_type="Equals",
+ left_value=1,
+ right_value=1,
+ conditions=[{"type": "Equals", "left_value": 1, "right_value":
1}],
+ if_task_ids="a",
+ else_task_ids="b",
+ )
+
+ def test_neither_provided_raises(self):
+ with pytest.raises(ValueError, match="Missing condition"):
+ SageMakerConditionOperator(task_id="t", if_task_ids="a",
else_task_ids="b")
+
+ def test_optional_else_task_ids(self):
+ """else_task_ids defaults to empty list when omitted."""
+ op = SageMakerConditionOperator(
+ task_id="t",
+ conditions=[{"type": "Equals", "left_value": 1, "right_value": 1}],
+ if_task_ids=["deploy"],
+ )
+ assert op.else_task_ids == []
+
+ def test_no_else_branch_raises_on_false(self):
+ """When else_task_ids is empty and conditions are false, raises
AirflowFailException."""
+ op = SageMakerConditionOperator(
+ task_id="check",
+ conditions=[{"type": "Equals", "left_value": 1, "right_value": 2}],
+ if_task_ids=["deploy"],
+ )
+ with pytest.raises(AirflowFailException, match="Condition check failed
in task 'check'"):
+ op.choose_branch(context=MagicMock(spec=dict))