This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d0e4b8d959 Refactor DAG.dataset_triggers into the timetable class 
(#39321)
d0e4b8d959 is described below

commit d0e4b8d95992936d8c89a5224107514392f918fb
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Thu Jun 27 05:47:56 2024 -0400

    Refactor DAG.dataset_triggers into the timetable class (#39321)
---
 airflow/datasets/__init__.py                  | 11 +++--
 airflow/models/dag.py                         | 54 ++++++++++------------
 airflow/serialization/schema.json             |  4 --
 airflow/serialization/serialized_objects.py   | 64 +++++++++++++++------------
 airflow/timetables/base.py                    | 63 ++++++++++++++++++++------
 airflow/timetables/datasets.py                | 20 ++++-----
 airflow/timetables/simple.py                  | 16 +++++++
 tests/cli/commands/test_dag_command.py        |  2 +-
 tests/datasets/test_dataset.py                | 40 +++++++++--------
 tests/models/test_dag.py                      |  2 +-
 tests/serialization/test_dag_serialization.py |  6 ---
 tests/timetables/test_datasets_timetable.py   | 12 ++++-
 12 files changed, 176 insertions(+), 118 deletions(-)

diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py
index 4de148de84..7d656e59dd 100644
--- a/airflow/datasets/__init__.py
+++ b/airflow/datasets/__init__.py
@@ -124,12 +124,15 @@ class BaseDataset:
     :meta private:
     """
 
-    def __or__(self, other: BaseDataset) -> DatasetAny:
+    def __bool__(self) -> bool:
+        return True
+
+    def __or__(self, other: BaseDataset) -> BaseDataset:
         if not isinstance(other, BaseDataset):
             return NotImplemented
         return DatasetAny(self, other)
 
-    def __and__(self, other: BaseDataset) -> DatasetAll:
+    def __and__(self, other: BaseDataset) -> BaseDataset:
         if not isinstance(other, BaseDataset):
             return NotImplemented
         return DatasetAll(self, other)
@@ -216,7 +219,7 @@ class DatasetAny(_DatasetBooleanCondition):
 
     agg_func = any
 
-    def __or__(self, other: BaseDataset) -> DatasetAny:
+    def __or__(self, other: BaseDataset) -> BaseDataset:
         if not isinstance(other, BaseDataset):
             return NotImplemented
         # Optimization: X | (Y | Z) is equivalent to X | Y | Z.
@@ -238,7 +241,7 @@ class DatasetAll(_DatasetBooleanCondition):
 
     agg_func = all
 
-    def __and__(self, other: BaseDataset) -> DatasetAll:
+    def __and__(self, other: BaseDataset) -> BaseDataset:
         if not isinstance(other, BaseDataset):
             return NotImplemented
         # Optimization: X & (Y & Z) is equivalent to X & Y & Z.
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 97e39d4a93..919267ad52 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -115,7 +115,6 @@ from airflow.security import permissions
 from airflow.settings import json
 from airflow.stats import Stats
 from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, 
Timetable
-from airflow.timetables.datasets import DatasetOrTimeSchedule
 from airflow.timetables.interval import CronDataIntervalTimetable, 
DeltaDataIntervalTimetable
 from airflow.timetables.simple import (
     ContinuousTimetable,
@@ -652,35 +651,31 @@ class DAG(LoggingMixin):
                 stacklevel=2,
             )
 
-        self.timetable: Timetable
+        if timetable is not None:
+            schedule = timetable
+        elif schedule_interval is not NOTSET:
+            schedule = schedule_interval
+
+        # Kept for compatibility. Do not use in new code.
         self.schedule_interval: ScheduleInterval
-        self.dataset_triggers: BaseDataset | None = None
-        if isinstance(schedule, BaseDataset):
-            self.dataset_triggers = schedule
-        elif isinstance(schedule, Collection) and not isinstance(schedule, 
str):
-            if not all(isinstance(x, Dataset) for x in schedule):
-                raise ValueError("All elements in 'schedule' should be 
datasets")
-            self.dataset_triggers = DatasetAll(*schedule)
-        elif isinstance(schedule, Timetable):
-            timetable = schedule
-        elif schedule is not NOTSET and not isinstance(schedule, BaseDataset):
-            schedule_interval = schedule
 
-        if isinstance(schedule, DatasetOrTimeSchedule):
+        if isinstance(schedule, Timetable):
             self.timetable = schedule
-            self.dataset_triggers = self.timetable.datasets
+            self.schedule_interval = schedule.summary
+        elif isinstance(schedule, BaseDataset):
+            self.timetable = DatasetTriggeredTimetable(schedule)
             self.schedule_interval = self.timetable.summary
-        elif self.dataset_triggers:
-            self.timetable = DatasetTriggeredTimetable()
-            self.schedule_interval = self.timetable.summary
-        elif timetable:
-            self.timetable = timetable
+        elif isinstance(schedule, Collection) and not isinstance(schedule, 
str):
+            if not all(isinstance(x, Dataset) for x in schedule):
+                raise ValueError("All elements in 'schedule' should be 
datasets")
+            self.timetable = DatasetTriggeredTimetable(DatasetAll(*schedule))
             self.schedule_interval = self.timetable.summary
+        elif isinstance(schedule, ArgNotSet):
+            self.timetable = create_timetable(schedule, self.timezone)
+            self.schedule_interval = DEFAULT_SCHEDULE_INTERVAL
         else:
-            if isinstance(schedule_interval, ArgNotSet):
-                schedule_interval = DEFAULT_SCHEDULE_INTERVAL
-            self.schedule_interval = schedule_interval
-            self.timetable = create_timetable(schedule_interval, self.timezone)
+            self.timetable = create_timetable(schedule, self.timezone)
+            self.schedule_interval = schedule
 
         if isinstance(template_searchpath, str):
             template_searchpath = [template_searchpath]
@@ -3250,10 +3245,7 @@ class DAG(LoggingMixin):
             )
             orm_dag.schedule_interval = dag.schedule_interval
             orm_dag.timetable_description = dag.timetable.description
-            if (dataset_triggers := dag.dataset_triggers) is None:
-                orm_dag.dataset_expression = None
-            else:
-                orm_dag.dataset_expression = dataset_triggers.as_expression()
+            orm_dag.dataset_expression = 
dag.timetable.dataset_condition.as_expression()
 
             orm_dag.processor_subdir = processor_subdir
 
@@ -3309,11 +3301,11 @@ class DAG(LoggingMixin):
         # later we'll persist them to the database.
         for dag in dags:
             curr_orm_dag = existing_dags.get(dag.dag_id)
-            if dag.dataset_triggers is None:
+            if not (dataset_condition := dag.timetable.dataset_condition):
                 if curr_orm_dag and curr_orm_dag.schedule_dataset_references:
                     curr_orm_dag.schedule_dataset_references = []
             else:
-                for _, dataset in dag.dataset_triggers.iter_datasets():
+                for _, dataset in dataset_condition.iter_datasets():
                     dag_references[dag.dag_id].add(dataset.uri)
                     input_datasets[DatasetModel.from_public(dataset)] = None
             curr_outlet_references = curr_orm_dag and 
curr_orm_dag.task_outlet_dataset_references
@@ -3967,7 +3959,7 @@ class DagModel(Base):
         for ser_dag in ser_dags:
             dag_id = ser_dag.dag_id
             statuses = dag_statuses[dag_id]
-            if not dag_ready(dag_id, cond=ser_dag.dag.dataset_triggers, 
statuses=statuses):
+            if not dag_ready(dag_id, 
cond=ser_dag.dag.timetable.dataset_condition, statuses=statuses):
                 del by_dag[dag_id]
                 del dag_statuses[dag_id]
         del dag_statuses
diff --git a/airflow/serialization/schema.json 
b/airflow/serialization/schema.json
index 76ae3e36ba..84b2e2ed4a 100644
--- a/airflow/serialization/schema.json
+++ b/airflow/serialization/schema.json
@@ -148,10 +148,6 @@
             { "$ref": "#/definitions/typed_relativedelta" }
           ]
         },
-        "dataset_triggers": {
-        "$ref": "#/definitions/typed_dataset_cond"
-
-},
         "owner_links": { "type": "object" },
         "timetable": {
           "type": "object",
diff --git a/airflow/serialization/serialized_objects.py 
b/airflow/serialization/serialized_objects.py
index a73684b9a5..95848429e6 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -37,7 +37,7 @@ from pendulum.tz.timezone import FixedTimezone, Timezone
 
 from airflow.compat.functools import cache
 from airflow.configuration import conf
-from airflow.datasets import Dataset, DatasetAll, DatasetAny
+from airflow.datasets import BaseDataset, Dataset, DatasetAll, DatasetAny
 from airflow.exceptions import AirflowException, RemovedInAirflow3Warning, 
SerializationError, TaskDeferred
 from airflow.jobs.job import Job
 from airflow.models.baseoperator import BaseOperator
@@ -228,6 +228,35 @@ class 
_PriorityWeightStrategyNotRegistered(AirflowException):
         )
 
 
+def encode_dataset_condition(var: BaseDataset) -> dict[str, Any]:
+    """Encode a dataset condition.
+
+    :meta private:
+    """
+    if isinstance(var, Dataset):
+        return {"__type": DAT.DATASET, "uri": var.uri, "extra": var.extra}
+    if isinstance(var, DatasetAll):
+        return {"__type": DAT.DATASET_ALL, "objects": 
[encode_dataset_condition(x) for x in var.objects]}
+    if isinstance(var, DatasetAny):
+        return {"__type": DAT.DATASET_ANY, "objects": 
[encode_dataset_condition(x) for x in var.objects]}
+    raise ValueError(f"serialization not implemented for 
{type(var).__name__!r}")
+
+
+def decode_dataset_condition(var: dict[str, Any]) -> BaseDataset:
+    """Decode a previously serialized dataset condition.
+
+    :meta private:
+    """
+    dat = var["__type"]
+    if dat == DAT.DATASET:
+        return Dataset(var["uri"], extra=var["extra"])
+    if dat == DAT.DATASET_ALL:
+        return DatasetAll(*(decode_dataset_condition(x) for x in 
var["objects"]))
+    if dat == DAT.DATASET_ANY:
+        return DatasetAny(*(decode_dataset_condition(x) for x in 
var["objects"]))
+    raise ValueError(f"deserialization not implemented for DAT {dat!r}")
+
+
 def encode_timetable(var: Timetable) -> dict[str, Any]:
     """
     Encode a timetable instance.
@@ -488,8 +517,6 @@ class BaseSerialization:
                 serialized_object[key] = encode_timetable(value)
             elif key == "weight_rule" and value is not None:
                 serialized_object[key] = encode_priority_weight_strategy(value)
-            elif key == "dataset_triggers":
-                serialized_object[key] = cls.serialize(value)
             else:
                 value = cls.serialize(value)
                 if isinstance(value, dict) and Encoding.TYPE in value:
@@ -607,24 +634,9 @@ class BaseSerialization:
             return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
         elif isinstance(var, XComArg):
             return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)
-        elif isinstance(var, Dataset):
-            return cls._encode({"uri": var.uri, "extra": var.extra}, 
type_=DAT.DATASET)
-        elif isinstance(var, DatasetAll):
-            return cls._encode(
-                [
-                    cls.serialize(x, strict=strict, 
use_pydantic_models=use_pydantic_models)
-                    for x in var.objects
-                ],
-                type_=DAT.DATASET_ALL,
-            )
-        elif isinstance(var, DatasetAny):
-            return cls._encode(
-                [
-                    cls.serialize(x, strict=strict, 
use_pydantic_models=use_pydantic_models)
-                    for x in var.objects
-                ],
-                type_=DAT.DATASET_ANY,
-            )
+        elif isinstance(var, BaseDataset):
+            serialized_dataset = encode_dataset_condition(var)
+            return cls._encode(serialized_dataset, 
type_=serialized_dataset.pop("__type"))
         elif isinstance(var, SimpleTaskInstance):
             return cls._encode(
                 cls.serialize(var.__dict__, strict=strict, 
use_pydantic_models=use_pydantic_models),
@@ -740,9 +752,9 @@ class BaseSerialization:
         elif type_ == DAT.DATASET:
             return Dataset(**var)
         elif type_ == DAT.DATASET_ANY:
-            return DatasetAny(*(cls.deserialize(x) for x in var))
+            return DatasetAny(*(decode_dataset_condition(x) for x in 
var["objects"]))
         elif type_ == DAT.DATASET_ALL:
-            return DatasetAll(*(cls.deserialize(x) for x in var))
+            return DatasetAll(*(decode_dataset_condition(x) for x in 
var["objects"]))
         elif type_ == DAT.SIMPLE_TASK_INSTANCE:
             return SimpleTaskInstance(**cls.deserialize(var))
         elif type_ == DAT.CONNECTION:
@@ -914,9 +926,7 @@ class DependencyDetector:
         """Detect dependencies set directly on the DAG object."""
         if not dag:
             return
-        if not dag.dataset_triggers:
-            return
-        for uri, _ in dag.dataset_triggers.iter_datasets():
+        for uri, _ in dag.timetable.dataset_condition.iter_datasets():
             yield DagDependency(
                 source="dataset",
                 target=dag.dag_id,
@@ -1562,8 +1572,6 @@ class SerializedDAG(DAG, BaseSerialization):
                 v = cls.deserialize(v)
             elif k == "params":
                 v = cls._deserialize_params_dict(v)
-            elif k == "dataset_triggers":
-                v = cls.deserialize(v)
             # else use v as it is
 
             setattr(dag, k, v)
diff --git a/airflow/timetables/base.py b/airflow/timetables/base.py
index b5e95ef5f4..0b1a4f6de9 100644
--- a/airflow/timetables/base.py
+++ b/airflow/timetables/base.py
@@ -16,17 +16,47 @@
 # under the License.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING, Any, NamedTuple, Sequence
+from typing import TYPE_CHECKING, Any, Iterator, NamedTuple, Sequence
 from warnings import warn
 
+from airflow.datasets import BaseDataset
 from airflow.typing_compat import Protocol, runtime_checkable
 
 if TYPE_CHECKING:
     from pendulum import DateTime
 
+    from airflow.datasets import Dataset
     from airflow.utils.types import DagRunType
 
 
+class _NullDataset(BaseDataset):
+    """Sentinel type that represents "no datasets".
+
+    This is only implemented to make typing easier in timetables, and not
+    expected to be used anywhere else.
+
+    :meta private:
+    """
+
+    def __bool__(self) -> bool:
+        return False
+
+    def __or__(self, other: BaseDataset) -> BaseDataset:
+        return NotImplemented
+
+    def __and__(self, other: BaseDataset) -> BaseDataset:
+        return NotImplemented
+
+    def as_expression(self) -> Any:
+        return None
+
+    def evaluate(self, statuses: dict[str, bool]) -> bool:
+        return False
+
+    def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
+        return iter(())
+
+
 class DataInterval(NamedTuple):
     """A data interval for a DagRun to operate over.
 
@@ -127,6 +157,12 @@ class Timetable(Protocol):
 
     @property
     def can_be_scheduled(self):
+        """Whether this timetable can actually schedule runs in an automated 
manner.
+
+        This defaults to and should generally be *True* (including non periodic
+        execution types like *@once* and data triggered tables), but
+        ``NullTimetable`` sets this to *False*.
+        """
         if hasattr(self, "can_run"):
             warn(
                 'can_run class variable is deprecated. Use "can_be_scheduled" 
instead.',
@@ -136,13 +172,6 @@ class Timetable(Protocol):
             return self.can_run
         return self._can_be_scheduled
 
-    """Whether this timetable can actually schedule runs in an automated 
manner.
-
-    This defaults to and should generally be *True* (including non periodic
-    execution types like *@once* and data triggered tables), but
-    ``NullTimetable`` sets this to *False*.
-    """
-
     run_ordering: Sequence[str] = ("data_interval_end", "execution_date")
     """How runs triggered from this timetable should be ordered in UI.
 
@@ -150,11 +179,19 @@ class Timetable(Protocol):
     """
 
     active_runs_limit: int | None = None
-    """Override the max_active_runs parameter of any DAGs using this timetable.
-    This is called during DAG initializing, and will set the max_active_runs if
-    it returns a value. In most cases this should return None, but in some 
cases
-    (for example, the ContinuousTimetable) there are good reasons for limiting
-    the DAGRun parallelism.
+    """Maximum active runs that can be active at one time for a DAG.
+
+    This is called during DAG initialization, and the return value is used as
+    the DAG's default ``max_active_runs``. This should generally return *None*,
+    but there are good reasons to limit DAG run parallelism in some cases, such
+    as for :class:`~airflow.timetable.simple.ContinuousTimetable`.
+    """
+
+    dataset_condition: BaseDataset = _NullDataset()
+    """The dataset condition that triggers a DAG using this timetable.
+
+    If this is not *None*, this should be a dataset, or a combination of, that
+    controls the DAG's dataset triggers.
     """
 
     @classmethod
diff --git a/airflow/timetables/datasets.py b/airflow/timetables/datasets.py
index a8f4a7f22f..4c27f39b26 100644
--- a/airflow/timetables/datasets.py
+++ b/airflow/timetables/datasets.py
@@ -44,9 +44,9 @@ class DatasetOrTimeSchedule(DatasetTriggeredSchedule):
     ) -> None:
         self.timetable = timetable
         if isinstance(datasets, BaseDataset):
-            self.datasets = datasets
+            self.dataset_condition = datasets
         else:
-            self.datasets = DatasetAll(*datasets)
+            self.dataset_condition = DatasetAll(*datasets)
 
         self.description = f"Triggered by datasets or {timetable.description}"
         self.periodic = timetable.periodic
@@ -55,25 +55,25 @@ class DatasetOrTimeSchedule(DatasetTriggeredSchedule):
 
     @classmethod
     def deserialize(cls, data: dict[str, typing.Any]) -> Timetable:
-        from airflow.serialization.serialized_objects import decode_timetable
+        from airflow.serialization.serialized_objects import 
decode_dataset_condition, decode_timetable
 
         return cls(
+            datasets=decode_dataset_condition(data["dataset_condition"]),
             timetable=decode_timetable(data["timetable"]),
-            # don't need the datasets after deserialization
-            # they are already stored on dataset_triggers attr on DAG
-            # and this is what scheduler looks at
-            datasets=[],
         )
 
     def serialize(self) -> dict[str, typing.Any]:
-        from airflow.serialization.serialized_objects import encode_timetable
+        from airflow.serialization.serialized_objects import 
encode_dataset_condition, encode_timetable
 
-        return {"timetable": encode_timetable(self.timetable)}
+        return {
+            "dataset_condition": 
encode_dataset_condition(self.dataset_condition),
+            "timetable": encode_timetable(self.timetable),
+        }
 
     def validate(self) -> None:
         if isinstance(self.timetable, DatasetTriggeredSchedule):
             raise AirflowTimetableInvalid("cannot nest dataset timetables")
-        if not isinstance(self.datasets, BaseDataset):
+        if not isinstance(self.dataset_condition, BaseDataset):
             raise AirflowTimetableInvalid("all elements in 'datasets' must be 
datasets")
 
     @property
diff --git a/airflow/timetables/simple.py b/airflow/timetables/simple.py
index 6452244262..98bf7835f6 100644
--- a/airflow/timetables/simple.py
+++ b/airflow/timetables/simple.py
@@ -25,6 +25,7 @@ if TYPE_CHECKING:
     from pendulum import DateTime
     from sqlalchemy import Session
 
+    from airflow.datasets import BaseDataset
     from airflow.models.dataset import DatasetEvent
     from airflow.timetables.base import TimeRestriction
     from airflow.utils.types import DagRunType
@@ -156,10 +157,25 @@ class DatasetTriggeredTimetable(_TrivialTimetable):
 
     description: str = "Triggered by datasets"
 
+    def __init__(self, datasets: BaseDataset) -> None:
+        super().__init__()
+        self.dataset_condition = datasets
+
+    @classmethod
+    def deserialize(cls, data: dict[str, Any]) -> Timetable:
+        from airflow.serialization.serialized_objects import 
decode_dataset_condition
+
+        return cls(decode_dataset_condition(data["dataset_condition"]))
+
     @property
     def summary(self) -> str:
         return "Dataset"
 
+    def serialize(self) -> dict[str, Any]:
+        from airflow.serialization.serialized_objects import 
encode_dataset_condition
+
+        return {"dataset_condition": 
encode_dataset_condition(self.dataset_condition)}
+
     def generate_run_id(
         self,
         *,
diff --git a/tests/cli/commands/test_dag_command.py 
b/tests/cli/commands/test_dag_command.py
index de8788ada8..535a197eab 100644
--- a/tests/cli/commands/test_dag_command.py
+++ b/tests/cli/commands/test_dag_command.py
@@ -970,7 +970,7 @@ class TestCliDags:
         mock_render_dag.assert_has_calls([mock.call(mock_get_dag.return_value, 
tis=[])])
         assert "SOURCE" in output
 
-    @mock.patch("workday.AfterWorkdayTimetable")
+    @mock.patch("workday.AfterWorkdayTimetable", side_effect=lambda: 
mock.MagicMock(active_runs_limit=None))
     @mock.patch("airflow.models.dag._get_or_create_dagrun")
     def test_dag_test_with_custom_timetable(self, mock__get_or_create_dagrun, 
_):
         """
diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py
index fcd4ae9a1e..4ce9b01ada 100644
--- a/tests/datasets/test_dataset.py
+++ b/tests/datasets/test_dataset.py
@@ -258,22 +258,21 @@ def test_dataset_trigger_setup_and_serialization(session, 
dag_maker, create_test
     with dag_maker(schedule=DatasetAny(*datasets)) as dag:
         EmptyOperator(task_id="hello")
 
-    # Verify dataset triggers are set up correctly
+    # Verify datasets are set up correctly
     assert isinstance(
-        dag.dataset_triggers, DatasetAny
-    ), "DAG dataset triggers should be an instance of DatasetAny"
+        dag.timetable.dataset_condition, DatasetAny
+    ), "DAG datasets should be an instance of DatasetAny"
 
-    # Serialize and deserialize DAG dataset triggers
-    serialized_trigger = SerializedDAG.serialize(dag.dataset_triggers)
-    deserialized_trigger = SerializedDAG.deserialize(serialized_trigger)
+    # Round-trip the DAG through serialization
+    deserialized_dag = 
SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))
 
     # Verify serialization and deserialization integrity
     assert isinstance(
-        deserialized_trigger, DatasetAny
-    ), "Deserialized trigger should maintain type DatasetAny"
+        deserialized_dag.timetable.dataset_condition, DatasetAny
+    ), "Deserialized datasets should maintain type DatasetAny"
     assert (
-        deserialized_trigger.objects == dag.dataset_triggers.objects
-    ), "Deserialized trigger objects should match original"
+        deserialized_dag.timetable.dataset_condition.objects == 
dag.timetable.dataset_condition.objects
+    ), "Deserialized datasets should match original"
 
 
 @pytest.mark.db_test
@@ -303,12 +302,13 @@ def test_dataset_dag_run_queue_processing(session, 
clear_datasets, dag_maker, cr
     for (serialized_dag,) in serialized_dags:
         dag = SerializedDAG.deserialize(serialized_dag.data)
         for dataset_uri, status in dag_statuses[dag.dag_id].items():
-            assert dag.dataset_triggers.evaluate({dataset_uri: status}), "DAG 
trigger evaluation failed"
+            cond = dag.timetable.dataset_condition
+            assert cond.evaluate({dataset_uri: status}), "DAG trigger 
evaluation failed"
 
 
 @pytest.mark.db_test
 @pytest.mark.usefixtures("clear_datasets")
-def test_dag_with_complex_dataset_triggers(session, dag_maker):
+def test_dag_with_complex_dataset_condition(session, dag_maker):
     # Create Dataset instances
     d1 = Dataset(uri="hello1")
     d2 = Dataset(uri="hello2")
@@ -324,13 +324,13 @@ def test_dag_with_complex_dataset_triggers(session, 
dag_maker):
         EmptyOperator(task_id="hello")
 
     assert isinstance(
-        dag.dataset_triggers, DatasetAny
+        dag.timetable.dataset_condition, DatasetAny
     ), "DAG's dataset trigger should be an instance of DatasetAny"
     assert any(
-        isinstance(trigger, DatasetAll) for trigger in 
dag.dataset_triggers.objects
+        isinstance(trigger, DatasetAll) for trigger in 
dag.timetable.dataset_condition.objects
     ), "DAG's dataset trigger should include DatasetAll"
 
-    serialized_triggers = SerializedDAG.serialize(dag.dataset_triggers)
+    serialized_triggers = 
SerializedDAG.serialize(dag.timetable.dataset_condition)
 
     deserialized_triggers = SerializedDAG.deserialize(serialized_triggers)
 
@@ -341,11 +341,13 @@ def test_dag_with_complex_dataset_triggers(session, 
dag_maker):
         isinstance(trigger, DatasetAll) for trigger in 
deserialized_triggers.objects
     ), "Deserialized triggers should include DatasetAll"
 
-    serialized_dag_dict = SerializedDAG.to_dict(dag)["dag"]
-    assert "dataset_triggers" in serialized_dag_dict, "Serialized DAG should 
contain 'dataset_triggers'"
+    serialized_timetable_dict = 
SerializedDAG.to_dict(dag)["dag"]["timetable"]["__var"]
+    assert (
+        "dataset_condition" in serialized_timetable_dict
+    ), "Serialized timetable should contain 'dataset_condition'"
     assert isinstance(
-        serialized_dag_dict["dataset_triggers"], dict
-    ), "Serialized 'dataset_triggers' should be a dict"
+        serialized_timetable_dict["dataset_condition"], dict
+    ), "Serialized 'dataset_condition' should be a dict"
 
 
 def datasets_equal(d1: BaseDataset, d2: BaseDataset) -> bool:
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 0a1247d4ed..52001bef91 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -1968,7 +1968,7 @@ class TestDag:
 
     def test_timetable_and_description_from_dataset(self):
         dag = DAG("test_schedule_interval_arg", 
schedule=[Dataset(uri="hello")], start_date=TEST_DATE)
-        assert dag.timetable == DatasetTriggeredTimetable()
+        assert dag.timetable == DatasetTriggeredTimetable(Dataset(uri="hello"))
         assert dag.schedule_interval == "Dataset"
         assert dag.timetable.description == "Triggered by datasets"
 
diff --git a/tests/serialization/test_dag_serialization.py 
b/tests/serialization/test_dag_serialization.py
index a1e76a13a8..d95c45ed64 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -591,16 +591,10 @@ class TestStringifiedDAGs:
             "params",
             "_processor_dags_folder",
         }
-        compare_serialization_list = {
-            "dataset_triggers",
-        }
         fields_to_check = dag.get_serialized_fields() - exclusion_list
         for field in fields_to_check:
             actual = getattr(serialized_dag, field)
             expected = getattr(dag, field)
-            if field in compare_serialization_list:
-                actual = BaseSerialization.serialize(actual)
-                expected = BaseSerialization.serialize(expected)
             assert actual == expected, f"{dag.dag_id}.{field} does not match"
         # _processor_dags_folder is only populated at serialization time
         # it's only used when relying on serialized dag to determine a dag's 
relative path
diff --git a/tests/timetables/test_datasets_timetable.py 
b/tests/timetables/test_datasets_timetable.py
index 32a5c89fd3..68928a568d 100644
--- a/tests/timetables/test_datasets_timetable.py
+++ b/tests/timetables/test_datasets_timetable.py
@@ -129,6 +129,10 @@ def test_serialization(dataset_timetable: 
DatasetOrTimeSchedule, monkeypatch: An
     serialized = dataset_timetable.serialize()
     assert serialized == {
         "timetable": "mock_serialized_timetable",
+        "dataset_condition": {
+            "__type": "dataset_all",
+            "objects": [{"__type": "dataset", "uri": "test_dataset", "extra": 
None}],
+        },
     }
 
 
@@ -141,7 +145,13 @@ def test_deserialization(monkeypatch: Any) -> None:
     monkeypatch.setattr(
         "airflow.serialization.serialized_objects.decode_timetable", lambda x: 
MockTimetable()
     )
-    mock_serialized_data = {"timetable": "mock_serialized_timetable", 
"datasets": [{"uri": "test_dataset"}]}
+    mock_serialized_data = {
+        "timetable": "mock_serialized_timetable",
+        "dataset_condition": {
+            "__type": "dataset_all",
+            "objects": [{"__type": "dataset", "uri": "test_dataset", "extra": 
None}],
+        },
+    }
     deserialized = DatasetOrTimeSchedule.deserialize(mock_serialized_data)
     assert isinstance(deserialized, DatasetOrTimeSchedule)
 

Reply via email to