This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6b7771d7438 feat(sagemaker): Add SageMakerConditionOperator and 
SageMakerFailOperator (#64545)
6b7771d7438 is described below

commit 6b7771d74388af6950c3911eeaa7c7a0387d42e6
Author: Bhavya Sharma <[email protected]>
AuthorDate: Mon Apr 6 11:20:22 2026 -0700

    feat(sagemaker): Add SageMakerConditionOperator and SageMakerFailOperator 
(#64545)
    
    Co-authored-by: Niko Oliveira <[email protected]>
---
 providers/amazon/docs/operators/sagemaker.rst      |  32 +++
 .../providers/amazon/aws/operators/sagemaker.py    | 266 ++++++++++++++++++++-
 .../amazon/aws/example_sagemaker_condition.py      | 179 ++++++++++++++
 .../aws/operators/test_sagemaker_condition.py      | 222 +++++++++++++++++
 4 files changed, 698 insertions(+), 1 deletion(-)

diff --git a/providers/amazon/docs/operators/sagemaker.rst 
b/providers/amazon/docs/operators/sagemaker.rst
index 5103b868455..71f61a32890 100644
--- a/providers/amazon/docs/operators/sagemaker.rst
+++ b/providers/amazon/docs/operators/sagemaker.rst
@@ -385,6 +385,38 @@ you can use 
:class:`~airflow.providers.amazon.aws.sensors.sagemaker.SageMakerPro
     :start-after: [START howto_sensor_sagemaker_processing]
     :end-before: [END howto_sensor_sagemaker_processing]
 
+.. _howto/operator:SageMakerConditionOperator:
+
+Branch a DAG based on condition evaluation
+==========================================
+
+To branch an Airflow DAG based on upstream task outputs you can use
+:class:`~airflow.providers.amazon.aws.operators.sagemaker.SageMakerConditionOperator`.
+
+Simple usage with flat parameters (single condition):
+
+.. exampleinclude:: 
/../../amazon/tests/system/amazon/aws/example_sagemaker_condition.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_sagemaker_condition_flat]
+    :end-before: [END howto_operator_sagemaker_condition_flat]
+
+Advanced usage with conditions list (multiple AND-ed conditions):
+
+.. exampleinclude:: 
/../../amazon/tests/system/amazon/aws/example_sagemaker_condition.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_sagemaker_condition]
+    :end-before: [END howto_operator_sagemaker_condition]
+
+Using Not and Or conditions:
+
+.. exampleinclude:: 
/../../amazon/tests/system/amazon/aws/example_sagemaker_condition.py
+    :language: python
+    :dedent: 4
+    :start-after: [START howto_operator_sagemaker_condition_not_or]
+    :end-before: [END howto_operator_sagemaker_condition_not_or]
+
 Reference
 ---------
 
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py 
b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py
index ff5e754913a..fe1597f5b1f 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py
@@ -41,7 +41,12 @@ from airflow.providers.amazon.aws.utils import 
trim_none_values, validate_execut
 from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 from airflow.providers.amazon.aws.utils.sagemaker import ApprovalStatus
 from airflow.providers.amazon.aws.utils.tags import format_tags
-from airflow.providers.common.compat.sdk import AirflowException, conf
+from airflow.providers.common.compat.sdk import (
+    AirflowException,
+    AirflowFailException,
+    BaseBranchOperator,
+    conf,
+)
 from airflow.utils.helpers import prune_dict
 
 if TYPE_CHECKING:
@@ -1991,3 +1996,262 @@ class 
SageMakerStartNoteBookOperator(AwsBaseOperator[SageMakerHook]):
             self.hook.conn.get_waiter("notebook_instance_in_service").wait(
                 NotebookInstanceName=self.instance_name
             )
+
+
+class SageMakerConditionOperator(BaseBranchOperator):
+    """
+    Evaluates a single condition or a list of conditions, and routes tasks 
based on the result.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:SageMakerConditionOperator`
+
+    :param condition_type: Condition type for the simple (flat) interface.
+        Valid types: ``Equals``, ``GreaterThan``, ``GreaterThanOrEqualTo``,
+        ``LessThan``, ``LessThanOrEqualTo``, ``In``.
+        Mutually exclusive with ``conditions``.
+    :param left_value: Left operand for the flat interface. For ``In`` 
conditions
+        this is the value to check membership of.
+    :param right_value: Right operand for the flat interface. For ``In`` 
conditions
+        this is the list of allowed values.
+    :param conditions: List of condition dicts to evaluate (AND-ed together).
+        Each dict must have a ``type`` key. Must not be empty.
+        Mutually exclusive with 
``condition_type``/``left_value``/``right_value``.
+    :param if_task_ids: Task ID(s) to execute when all conditions are True.
+    :param else_task_ids: Task ID(s) to execute when any condition is False.
+        If omitted, the task fails with ``AirflowFailException`` when 
conditions are not met.
+    """
+
+    _VALID_FLAT_TYPES: ClassVar[set[str]] = {
+        "Equals",
+        "GreaterThan",
+        "GreaterThanOrEqualTo",
+        "LessThan",
+        "LessThanOrEqualTo",
+        "In",
+    }
+
+    template_fields: Sequence[str] = (
+        "condition_type",
+        "left_value",
+        "right_value",
+        "conditions",
+        "if_task_ids",
+        "else_task_ids",
+    )
+
+    def __init__(
+        self,
+        *,
+        condition_type: str | None = None,
+        left_value: Any = None,
+        right_value: Any = None,
+        conditions: list[dict] | None = None,
+        if_task_ids: str | list[str],
+        else_task_ids: str | list[str] | None = None,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        has_flat = condition_type is not None
+        has_list = conditions is not None
+
+        if has_flat and has_list:
+            raise ValueError(
+                "Cannot use 'condition_type' and 'conditions' together. "
+                "Use 'condition_type' with 'left_value'/'right_value' for a 
single condition, "
+                "or 'conditions' for multiple/nested conditions."
+            )
+        if not has_flat and not has_list:
+            raise ValueError(
+                "Missing condition: provide 'condition_type' with 
'left_value'/'right_value' "
+                "for a single condition, or 'conditions' for multiple/nested 
conditions."
+            )
+
+        if has_flat:
+            if condition_type not in self._VALID_FLAT_TYPES:
+                raise ValueError(
+                    f"Unknown condition_type '{condition_type}'. "
+                    f"Expected one of: {', 
'.join(sorted(self._VALID_FLAT_TYPES))}."
+                )
+            self.condition_type: str | None = condition_type
+            self.left_value = left_value
+            self.right_value = right_value
+            if condition_type == "In":
+                self.conditions: list[dict[str, Any]] = [
+                    {"type": "In", "value": left_value, "in_values": 
right_value}
+                ]
+            else:
+                self.conditions = [
+                    {"type": condition_type, "left_value": left_value, 
"right_value": right_value}
+                ]
+        else:
+            self.condition_type = None
+            self.left_value = None
+            self.right_value = None
+            self.conditions = conditions  # type: ignore[assignment]
+
+        if not self.conditions:
+            raise ValueError("At least 1 condition is required, but got an 
empty list.")
+        self.if_task_ids = [if_task_ids] if isinstance(if_task_ids, str) else 
if_task_ids
+        self.else_task_ids = [else_task_ids] if isinstance(else_task_ids, str) 
else (else_task_ids or [])
+
+    @staticmethod
+    def _cast(value: Any) -> Any:
+        """
+        Cast Jinja-rendered string values to appropriate Python types.
+
+        This is a compatibility shim for environments where
+        ``render_template_as_native_obj=True`` is not available at the DAG or
+        task level (e.g., YAML DAGs). Once task-level native rendering
+        is widely supported, this method can be removed in favor of letting
+        Airflow handle the casting natively.
+
+        - Numeric strings become int or float.
+        - ``"true"``/``"false"`` become booleans.
+        - ``"None"`` becomes ``None`` (common when ``xcom_pull`` returns 
nothing).
+        - Other strings are returned unchanged.
+        - Non-string types pass through as-is.
+        """
+        if not isinstance(value, str):
+            return value
+        if value == "None":
+            return None
+        try:
+            return int(value)
+        except (ValueError, TypeError):
+            pass
+        try:
+            return float(value)
+        except (ValueError, TypeError):
+            pass
+        if value.lower() == "true":
+            return True
+        if value.lower() == "false":
+            return False
+        return value
+
+    _COMPARISON_OPERATORS: ClassVar[dict[str, Callable[[Any, Any], bool]]] = {
+        "Equals": lambda left, right: left == right,
+        "GreaterThan": lambda left, right: left > right,
+        "GreaterThanOrEqualTo": lambda left, right: left >= right,
+        "LessThan": lambda left, right: left < right,
+        "LessThanOrEqualTo": lambda left, right: left <= right,
+    }
+
+    def _evaluate(self, condition: dict, depth: int = 0) -> bool:
+        """
+        Recursively evaluate a single condition dict.
+
+        :param condition: A condition dictionary with a ``type`` key.
+        :param depth: Current nesting depth (used for log indentation only).
+        :returns: Boolean result of the condition evaluation.
+        """
+        log_indent = "  " * depth
+        try:
+            condition_type = condition["type"]
+        except KeyError:
+            raise ValueError("Condition dict is missing required key 'type'.")
+
+        if condition_type in self._COMPARISON_OPERATORS:
+            try:
+                left = self._cast(condition["left_value"])
+                right = self._cast(condition["right_value"])
+            except KeyError as e:
+                raise ValueError(f"Condition '{condition_type}' missing 
required key {e}.") from None
+
+            # None check — likely an XCom that was not pushed
+            if left is None or right is None:
+                raise TypeError(
+                    f"Condition '{condition_type}' received None: 
left={left!r}, right={right!r}. "
+                    "This usually means the upstream task did not run or did 
not push a value to XCom."
+                )
+
+            # Type compatibility check
+            left_type = type(left)
+            right_type = type(right)
+            numeric_types = (int, float)
+            left_is_numeric = isinstance(left, numeric_types) and not 
isinstance(left, bool)
+            right_is_numeric = isinstance(right, numeric_types) and not 
isinstance(right, bool)
+
+            if not (left_is_numeric and right_is_numeric) and left_type is not 
right_type:
+                raise TypeError(
+                    f"Cannot compare {left_type.__name__} ({left!r}) with 
{right_type.__name__} ({right!r}) "
+                    f"in condition '{condition_type}'. Both values must be the 
same type."
+                )
+
+            comparison_result = 
self._COMPARISON_OPERATORS[condition_type](left, right)
+            self.log.info(
+                "%s%s: %r %s %r -> %s",
+                log_indent,
+                condition_type,
+                left,
+                condition_type,
+                right,
+                comparison_result,
+            )
+            return comparison_result
+
+        if condition_type == "In":
+            try:
+                query_value = self._cast(condition["value"])
+                allowed_values = [self._cast(val) for val in 
condition["in_values"]]
+            except KeyError as e:
+                raise ValueError(f"Condition '{condition_type}' missing 
required key {e}.") from None
+            membership_result = query_value in allowed_values
+            self.log.info("%sIn: %r in %r -> %s", log_indent, query_value, 
allowed_values, membership_result)
+            return membership_result
+
+        if condition_type == "Not":
+            try:
+                inner_condition = condition["condition"]
+            except KeyError as e:
+                raise ValueError(f"Condition '{condition_type}' missing 
required key {e}.") from None
+            inner_result = self._evaluate(inner_condition, depth + 1)
+            negated_result = not inner_result
+            self.log.info("%sNot: not %s -> %s", log_indent, inner_result, 
negated_result)
+            return negated_result
+
+        if condition_type == "Or":
+            try:
+                inner_conditions = condition["conditions"]
+            except KeyError as e:
+                raise ValueError(f"Condition '{condition_type}' missing 
required key {e}.") from None
+            inner_results = [self._evaluate(inner_cond, depth + 1) for 
inner_cond in inner_conditions]
+            or_result = any(inner_results)
+            self.log.info("%sOr: any(%r) -> %s", log_indent, inner_results, 
or_result)
+            return or_result
+
+        raise ValueError(f"Unknown condition type '{condition_type}'.")
+
+    def choose_branch(self, context: Context) -> list[str]:
+        """
+        Evaluate all conditions and return the appropriate branch task IDs.
+
+        :param context: Airflow context dictionary.
+        :returns: ``if_task_ids`` when all conditions are True, 
``else_task_ids`` otherwise.
+        """
+        condition_count = len(self.conditions)
+        self.log.info("Evaluating %d condition(s).", condition_count)
+
+        evaluation_results = [self._evaluate(condition) for condition in 
self.conditions]
+        all_conditions_met = all(evaluation_results)
+
+        if all_conditions_met:
+            self.log.info(
+                "All %d condition(s) evaluated to True. Routing to 
if_task_ids=%r.",
+                condition_count,
+                self.if_task_ids,
+            )
+            return self.if_task_ids
+
+        if not self.else_task_ids:
+            raise AirflowFailException(
+                f"Condition check failed in task '{self.task_id}': 
results={evaluation_results}"
+            )
+
+        self.log.info(
+            "Not all conditions are True (results=%r). Routing to 
else_task_ids=%r.",
+            evaluation_results,
+            self.else_task_ids,
+        )
+        return self.else_task_ids
diff --git 
a/providers/amazon/tests/system/amazon/aws/example_sagemaker_condition.py 
b/providers/amazon/tests/system/amazon/aws/example_sagemaker_condition.py
new file mode 100644
index 00000000000..90a392fa0c1
--- /dev/null
+++ b/providers/amazon/tests/system/amazon/aws/example_sagemaker_condition.py
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+System test for SageMakerConditionOperator.
+
+This operator evaluates conditions against XCom values passed from upstream 
tasks.
+
+The DAG simulates an ML accuracy-gate workflow:
+
+1. ``produce_metrics`` pushes a dict of metrics to XCom.
+2. ``check_accuracy`` uses SageMakerConditionOperator to branch:
+   - accuracy >= 0.9 AND loss < 0.1 -> ``deploy_model``
+   - otherwise -> ``retrain_model``
+3. Only the correct branch task runs; the other is skipped.
+"""
+
+from __future__ import annotations
+
+from datetime import datetime
+
+from airflow.providers.amazon.aws.operators.sagemaker import 
SageMakerConditionOperator
+
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
+
+if AIRFLOW_V_3_0_PLUS:
+    from airflow.sdk import DAG, chain, task
+else:
+    from airflow.decorators import task  # type: ignore[attr-defined,no-redef]
+    from airflow.models.baseoperator import chain  # type: 
ignore[attr-defined,no-redef]
+    from airflow.models.dag import DAG  # type: 
ignore[attr-defined,no-redef,assignment]
+
+from system.amazon.aws.utils import SystemTestContextBuilder
+
+DAG_ID = "example_sagemaker_condition"
+
+sys_test_context_task = SystemTestContextBuilder().build()
+
+
+with DAG(
+    DAG_ID,
+    schedule="@once",
+    start_date=datetime(2021, 1, 1),
+    catchup=False,
+) as dag:
+    test_context = sys_test_context_task()
+
+    # TEST SETUP: push simulated ML metrics to XCom
+
+    @task
+    def produce_metrics():
+        """Simulate an ML training job that returns accuracy and loss 
metrics."""
+        return {"accuracy": 0.95, "loss": 0.04}
+
+    metrics = produce_metrics()
+
+    # TEST BODY
+
+    # [START howto_operator_sagemaker_condition]
+    check_accuracy = SageMakerConditionOperator(
+        task_id="check_accuracy",
+        conditions=[
+            {
+                "type": "GreaterThanOrEqualTo",
+                "left_value": "{{ 
ti.xcom_pull(task_ids='produce_metrics')['accuracy'] }}",
+                "right_value": 0.9,
+            },
+            {
+                "type": "LessThan",
+                "left_value": "{{ 
ti.xcom_pull(task_ids='produce_metrics')['loss'] }}",
+                "right_value": 0.1,
+            },
+        ],
+        if_task_ids=["deploy_model"],
+        else_task_ids=["retrain_model"],
+    )
+    # [END howto_operator_sagemaker_condition]
+
+    @task
+    def deploy_model():
+        """Placeholder: model meets quality bar, proceed to deployment."""
+        return "deployed"
+
+    @task
+    def retrain_model():
+        """Placeholder: model does not meet quality bar, retrain."""
+        return "retrained"
+
+    # Scenario 2: condition evaluates to False -> else branch
+
+    @task
+    def produce_bad_metrics():
+        """Simulate a training job with poor accuracy."""
+        return {"accuracy": 0.5, "loss": 0.8}
+
+    bad_metrics = produce_bad_metrics()
+
+    # [START howto_operator_sagemaker_condition_flat]
+    check_bad_accuracy = SageMakerConditionOperator(
+        task_id="check_bad_accuracy",
+        condition_type="GreaterThanOrEqualTo",
+        left_value="{{ 
ti.xcom_pull(task_ids='produce_bad_metrics')['accuracy'] }}",
+        right_value=0.9,
+        if_task_ids=["should_not_run"],
+        else_task_ids=["should_run"],
+    )
+    # [END howto_operator_sagemaker_condition_flat]
+
+    @task
+    def should_not_run():
+        """This task should be skipped because accuracy < 0.9."""
+        return "error: should not have run"
+
+    @task
+    def should_run():
+        """This task should execute because accuracy < 0.9 -> else branch."""
+        return "correctly routed to else branch"
+
+    # Scenario 3: Or condition + Not condition
+
+    # [START howto_operator_sagemaker_condition_not_or]
+    check_logical = SageMakerConditionOperator(
+        task_id="check_logical",
+        conditions=[
+            {
+                "type": "Or",
+                "conditions": [
+                    {"type": "Equals", "left_value": 1, "right_value": 2},
+                    {"type": "Equals", "left_value": 3, "right_value": 3},
+                ],
+            },
+            {
+                "type": "Not",
+                "condition": {"type": "Equals", "left_value": "a", 
"right_value": "b"},
+            },
+        ],
+        if_task_ids=["logical_pass"],
+        else_task_ids=["logical_fail"],
+    )
+    # [END howto_operator_sagemaker_condition_not_or]
+
+    @task
+    def logical_pass():
+        """Or(1==2, 3==3) -> True AND Not(a==b) -> True -> if branch."""
+        return "logical conditions passed"
+
+    @task
+    def logical_fail():
+        return "error: logical conditions should have passed"
+
+    test_context >> [metrics, bad_metrics, check_logical]
+
+    chain(metrics, check_accuracy, [deploy_model(), retrain_model()])
+    chain(bad_metrics, check_bad_accuracy, [should_not_run(), should_run()])
+    chain(check_logical, [logical_pass(), logical_fail()])
+
+    from tests_common.test_utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "tearDown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+from tests_common.test_utils.system_tests import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git 
a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_condition.py 
b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_condition.py
new file mode 100644
index 00000000000..56d461e50a3
--- /dev/null
+++ 
b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_condition.py
@@ -0,0 +1,222 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock
+
+import pytest
+
+from airflow.providers.amazon.aws.operators.sagemaker import 
SageMakerConditionOperator
+from airflow.providers.common.compat.sdk import AirflowFailException
+
+from unit.amazon.aws.utils.test_template_fields import validate_template_fields
+
+
+def _choose(conditions, if_ids="if_task", else_ids="else_task"):
+    """Instantiate with conditions list and call choose_branch."""
+    op = SageMakerConditionOperator(
+        task_id="test",
+        conditions=conditions,
+        if_task_ids=if_ids,
+        else_task_ids=else_ids,
+    )
+    return op.choose_branch(context=MagicMock(spec=dict))
+
+
+def _choose_flat(condition_type, left, right, if_ids="if_task", 
else_ids="else_task"):
+    """Instantiate with flat params and call choose_branch."""
+    op = SageMakerConditionOperator(
+        task_id="test",
+        condition_type=condition_type,
+        left_value=left,
+        right_value=right,
+        if_task_ids=if_ids,
+        else_task_ids=else_ids,
+    )
+    return op.choose_branch(context=MagicMock(spec=dict))
+
+
+def test_template_fields():
+    op = SageMakerConditionOperator(
+        task_id="test",
+        conditions=[{"type": "Equals", "left_value": 1, "right_value": 1}],
+        if_task_ids=["a"],
+        else_task_ids=["b"],
+    )
+    validate_template_fields(op)
+
+
+class TestConditionTypes:
+    """One true + one false per condition type, plus logical combinators."""
+
+    @pytest.mark.parametrize(
+        ("cond_type", "left", "right", "expected"),
+        [
+            ("Equals", 1, 1, "if"),
+            ("Equals", 1, 2, "else"),
+            ("GreaterThan", 5, 3, "if"),
+            ("GreaterThan", 3, 5, "else"),
+            ("GreaterThanOrEqualTo", 3, 3, "if"),
+            ("GreaterThanOrEqualTo", 2, 3, "else"),
+            ("LessThan", 3, 5, "if"),
+            ("LessThan", 5, 3, "else"),
+            ("LessThanOrEqualTo", 3, 3, "if"),
+            ("LessThanOrEqualTo", 5, 3, "else"),
+        ],
+    )
+    def test_comparison(self, cond_type, left, right, expected):
+        result = _choose([{"type": cond_type, "left_value": left, 
"right_value": right}])
+        assert result == (["if_task"] if expected == "if" else ["else_task"])
+
+    def test_in_true(self):
+        assert _choose([{"type": "In", "value": 1, "in_values": [1, 2]}]) == 
["if_task"]
+
+    def test_in_false(self):
+        assert _choose([{"type": "In", "value": 4, "in_values": [1, 2]}]) == 
["else_task"]
+
+    def test_not_negates(self):
+        cond = [{"type": "Not", "condition": {"type": "Equals", "left_value": 
1, "right_value": 1}}]
+        assert _choose(cond) == ["else_task"]
+
+    def test_or_any_true(self):
+        cond = [
+            {
+                "type": "Or",
+                "conditions": [
+                    {"type": "Equals", "left_value": 1, "right_value": 2},
+                    {"type": "Equals", "left_value": 1, "right_value": 1},
+                ],
+            }
+        ]
+        assert _choose(cond) == ["if_task"]
+
+
+class TestAndSemantics:
+    def test_multiple_conditions_and(self):
+        """All true -> if, one false -> else."""
+        assert _choose(
+            [
+                {"type": "Equals", "left_value": 1, "right_value": 1},
+                {"type": "GreaterThan", "left_value": 5, "right_value": 3},
+            ]
+        ) == ["if_task"]
+        assert _choose(
+            [
+                {"type": "Equals", "left_value": 1, "right_value": 1},
+                {"type": "GreaterThan", "left_value": 2, "right_value": 10},
+            ]
+        ) == ["else_task"]
+
+
+class TestValueCasting:
+    def test_cast_numeric_and_passthrough(self):
+        """int string, float string, bool string, non-numeric string, 
non-string type."""
+        assert SageMakerConditionOperator._cast("42") == 42
+        assert SageMakerConditionOperator._cast("0.9") == 0.9
+        assert SageMakerConditionOperator._cast("true") is True
+        assert SageMakerConditionOperator._cast("us-east-1") == "us-east-1"
+        assert SageMakerConditionOperator._cast(42) == 42
+
+
+class TestValidation:
+    def test_empty_conditions_raises(self):
+        with pytest.raises(ValueError, match="At least 1 condition is 
required"):
+            SageMakerConditionOperator(task_id="t", conditions=[], 
if_task_ids="a", else_task_ids="b")
+
+    def test_unknown_type_raises(self):
+        with pytest.raises(ValueError, match="Unknown condition type"):
+            _choose([{"type": "FooBar", "left_value": 1, "right_value": 2}])
+
+    def test_type_mismatch_raises(self):
+        with pytest.raises(TypeError, match="Cannot compare"):
+            _choose([{"type": "GreaterThanOrEqualTo", "left_value": "hello", 
"right_value": 0.9}])
+
+    def test_none_operand_raises(self):
+        """Covers both Python None and Jinja-rendered string 'None'."""
+        with pytest.raises(TypeError, match="received None"):
+            _choose([{"type": "Equals", "left_value": None, "right_value": 1}])
+        with pytest.raises(TypeError, match="received None"):
+            _choose([{"type": "GreaterThanOrEqualTo", "left_value": "None", 
"right_value": 0.9}])
+
+    @pytest.mark.parametrize(
+        ("condition", "match_pattern"),
+        [
+            ({}, "missing required key 'type'"),
+            ({"type": "Equals", "left_value": 1}, "missing required key"),
+            ({"type": "Not"}, "missing required key"),
+            ({"type": "Or"}, "missing required key"),
+        ],
+    )
+    def test_missing_key_raises(self, condition, match_pattern):
+        with pytest.raises(ValueError, match=match_pattern):
+            _choose([condition])
+
+
+class TestFlatInterface:
+    def test_flat_condition(self):
+        """Flat params work for comparison and In types."""
+        assert _choose_flat("Equals", 1, 1) == ["if_task"]
+        assert _choose_flat("Equals", 1, 2) == ["else_task"]
+        assert _choose_flat("In", 1, [1, 2, 3]) == ["if_task"]
+        assert _choose_flat("In", 99, [1, 2, 3]) == ["else_task"]
+
+    def test_invalid_condition_type_raises(self):
+        with pytest.raises(ValueError, match="Unknown condition_type"):
+            SageMakerConditionOperator(
+                task_id="t",
+                condition_type="NotEquals",
+                left_value=1,
+                right_value=1,
+                if_task_ids="a",
+                else_task_ids="b",
+            )
+
+    def test_mutual_exclusion_raises(self):
+        with pytest.raises(ValueError, match="Cannot use 'condition_type' and 
'conditions' together"):
+            SageMakerConditionOperator(
+                task_id="t",
+                condition_type="Equals",
+                left_value=1,
+                right_value=1,
+                conditions=[{"type": "Equals", "left_value": 1, "right_value": 
1}],
+                if_task_ids="a",
+                else_task_ids="b",
+            )
+
+    def test_neither_provided_raises(self):
+        with pytest.raises(ValueError, match="Missing condition"):
+            SageMakerConditionOperator(task_id="t", if_task_ids="a", 
else_task_ids="b")
+
+    def test_optional_else_task_ids(self):
+        """else_task_ids defaults to empty list when omitted."""
+        op = SageMakerConditionOperator(
+            task_id="t",
+            conditions=[{"type": "Equals", "left_value": 1, "right_value": 1}],
+            if_task_ids=["deploy"],
+        )
+        assert op.else_task_ids == []
+
+    def test_no_else_branch_raises_on_false(self):
+        """When else_task_ids is empty and conditions are false, raises 
AirflowFailException."""
+        op = SageMakerConditionOperator(
+            task_id="check",
+            conditions=[{"type": "Equals", "left_value": 1, "right_value": 2}],
+            if_task_ids=["deploy"],
+        )
+        with pytest.raises(AirflowFailException, match="Condition check failed 
in task 'check'"):
+            op.choose_branch(context=MagicMock(spec=dict))

Reply via email to