This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d0e4b8d959 Refactor DAG.dataset_triggers into the timetable class
(#39321)
d0e4b8d959 is described below
commit d0e4b8d95992936d8c89a5224107514392f918fb
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Thu Jun 27 05:47:56 2024 -0400
Refactor DAG.dataset_triggers into the timetable class (#39321)
---
airflow/datasets/__init__.py | 11 +++--
airflow/models/dag.py | 54 ++++++++++------------
airflow/serialization/schema.json | 4 --
airflow/serialization/serialized_objects.py | 64 +++++++++++++++------------
airflow/timetables/base.py | 63 ++++++++++++++++++++------
airflow/timetables/datasets.py | 20 ++++-----
airflow/timetables/simple.py | 16 +++++++
tests/cli/commands/test_dag_command.py | 2 +-
tests/datasets/test_dataset.py | 40 +++++++++--------
tests/models/test_dag.py | 2 +-
tests/serialization/test_dag_serialization.py | 6 ---
tests/timetables/test_datasets_timetable.py | 12 ++++-
12 files changed, 176 insertions(+), 118 deletions(-)
diff --git a/airflow/datasets/__init__.py b/airflow/datasets/__init__.py
index 4de148de84..7d656e59dd 100644
--- a/airflow/datasets/__init__.py
+++ b/airflow/datasets/__init__.py
@@ -124,12 +124,15 @@ class BaseDataset:
:meta private:
"""
- def __or__(self, other: BaseDataset) -> DatasetAny:
+ def __bool__(self) -> bool:
+ return True
+
+ def __or__(self, other: BaseDataset) -> BaseDataset:
if not isinstance(other, BaseDataset):
return NotImplemented
return DatasetAny(self, other)
- def __and__(self, other: BaseDataset) -> DatasetAll:
+ def __and__(self, other: BaseDataset) -> BaseDataset:
if not isinstance(other, BaseDataset):
return NotImplemented
return DatasetAll(self, other)
@@ -216,7 +219,7 @@ class DatasetAny(_DatasetBooleanCondition):
agg_func = any
- def __or__(self, other: BaseDataset) -> DatasetAny:
+ def __or__(self, other: BaseDataset) -> BaseDataset:
if not isinstance(other, BaseDataset):
return NotImplemented
# Optimization: X | (Y | Z) is equivalent to X | Y | Z.
@@ -238,7 +241,7 @@ class DatasetAll(_DatasetBooleanCondition):
agg_func = all
- def __and__(self, other: BaseDataset) -> DatasetAll:
+ def __and__(self, other: BaseDataset) -> BaseDataset:
if not isinstance(other, BaseDataset):
return NotImplemented
# Optimization: X & (Y & Z) is equivalent to X & Y & Z.
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 97e39d4a93..919267ad52 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -115,7 +115,6 @@ from airflow.security import permissions
from airflow.settings import json
from airflow.stats import Stats
from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction,
Timetable
-from airflow.timetables.datasets import DatasetOrTimeSchedule
from airflow.timetables.interval import CronDataIntervalTimetable,
DeltaDataIntervalTimetable
from airflow.timetables.simple import (
ContinuousTimetable,
@@ -652,35 +651,31 @@ class DAG(LoggingMixin):
stacklevel=2,
)
- self.timetable: Timetable
+ if timetable is not None:
+ schedule = timetable
+ elif schedule_interval is not NOTSET:
+ schedule = schedule_interval
+
+ # Kept for compatibility. Do not use in new code.
self.schedule_interval: ScheduleInterval
- self.dataset_triggers: BaseDataset | None = None
- if isinstance(schedule, BaseDataset):
- self.dataset_triggers = schedule
- elif isinstance(schedule, Collection) and not isinstance(schedule,
str):
- if not all(isinstance(x, Dataset) for x in schedule):
- raise ValueError("All elements in 'schedule' should be
datasets")
- self.dataset_triggers = DatasetAll(*schedule)
- elif isinstance(schedule, Timetable):
- timetable = schedule
- elif schedule is not NOTSET and not isinstance(schedule, BaseDataset):
- schedule_interval = schedule
- if isinstance(schedule, DatasetOrTimeSchedule):
+ if isinstance(schedule, Timetable):
self.timetable = schedule
- self.dataset_triggers = self.timetable.datasets
+ self.schedule_interval = schedule.summary
+ elif isinstance(schedule, BaseDataset):
+ self.timetable = DatasetTriggeredTimetable(schedule)
self.schedule_interval = self.timetable.summary
- elif self.dataset_triggers:
- self.timetable = DatasetTriggeredTimetable()
- self.schedule_interval = self.timetable.summary
- elif timetable:
- self.timetable = timetable
+ elif isinstance(schedule, Collection) and not isinstance(schedule,
str):
+ if not all(isinstance(x, Dataset) for x in schedule):
+ raise ValueError("All elements in 'schedule' should be
datasets")
+ self.timetable = DatasetTriggeredTimetable(DatasetAll(*schedule))
self.schedule_interval = self.timetable.summary
+ elif isinstance(schedule, ArgNotSet):
+ self.timetable = create_timetable(schedule, self.timezone)
+ self.schedule_interval = DEFAULT_SCHEDULE_INTERVAL
else:
- if isinstance(schedule_interval, ArgNotSet):
- schedule_interval = DEFAULT_SCHEDULE_INTERVAL
- self.schedule_interval = schedule_interval
- self.timetable = create_timetable(schedule_interval, self.timezone)
+ self.timetable = create_timetable(schedule, self.timezone)
+ self.schedule_interval = schedule
if isinstance(template_searchpath, str):
template_searchpath = [template_searchpath]
@@ -3250,10 +3245,7 @@ class DAG(LoggingMixin):
)
orm_dag.schedule_interval = dag.schedule_interval
orm_dag.timetable_description = dag.timetable.description
- if (dataset_triggers := dag.dataset_triggers) is None:
- orm_dag.dataset_expression = None
- else:
- orm_dag.dataset_expression = dataset_triggers.as_expression()
+ orm_dag.dataset_expression =
dag.timetable.dataset_condition.as_expression()
orm_dag.processor_subdir = processor_subdir
@@ -3309,11 +3301,11 @@ class DAG(LoggingMixin):
# later we'll persist them to the database.
for dag in dags:
curr_orm_dag = existing_dags.get(dag.dag_id)
- if dag.dataset_triggers is None:
+ if not (dataset_condition := dag.timetable.dataset_condition):
if curr_orm_dag and curr_orm_dag.schedule_dataset_references:
curr_orm_dag.schedule_dataset_references = []
else:
- for _, dataset in dag.dataset_triggers.iter_datasets():
+ for _, dataset in dataset_condition.iter_datasets():
dag_references[dag.dag_id].add(dataset.uri)
input_datasets[DatasetModel.from_public(dataset)] = None
curr_outlet_references = curr_orm_dag and
curr_orm_dag.task_outlet_dataset_references
@@ -3967,7 +3959,7 @@ class DagModel(Base):
for ser_dag in ser_dags:
dag_id = ser_dag.dag_id
statuses = dag_statuses[dag_id]
- if not dag_ready(dag_id, cond=ser_dag.dag.dataset_triggers,
statuses=statuses):
+ if not dag_ready(dag_id,
cond=ser_dag.dag.timetable.dataset_condition, statuses=statuses):
del by_dag[dag_id]
del dag_statuses[dag_id]
del dag_statuses
diff --git a/airflow/serialization/schema.json
b/airflow/serialization/schema.json
index 76ae3e36ba..84b2e2ed4a 100644
--- a/airflow/serialization/schema.json
+++ b/airflow/serialization/schema.json
@@ -148,10 +148,6 @@
{ "$ref": "#/definitions/typed_relativedelta" }
]
},
- "dataset_triggers": {
- "$ref": "#/definitions/typed_dataset_cond"
-
-},
"owner_links": { "type": "object" },
"timetable": {
"type": "object",
diff --git a/airflow/serialization/serialized_objects.py
b/airflow/serialization/serialized_objects.py
index a73684b9a5..95848429e6 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -37,7 +37,7 @@ from pendulum.tz.timezone import FixedTimezone, Timezone
from airflow.compat.functools import cache
from airflow.configuration import conf
-from airflow.datasets import Dataset, DatasetAll, DatasetAny
+from airflow.datasets import BaseDataset, Dataset, DatasetAll, DatasetAny
from airflow.exceptions import AirflowException, RemovedInAirflow3Warning,
SerializationError, TaskDeferred
from airflow.jobs.job import Job
from airflow.models.baseoperator import BaseOperator
@@ -228,6 +228,35 @@ class
_PriorityWeightStrategyNotRegistered(AirflowException):
)
+def encode_dataset_condition(var: BaseDataset) -> dict[str, Any]:
+ """Encode a dataset condition.
+
+ :meta private:
+ """
+ if isinstance(var, Dataset):
+ return {"__type": DAT.DATASET, "uri": var.uri, "extra": var.extra}
+ if isinstance(var, DatasetAll):
+ return {"__type": DAT.DATASET_ALL, "objects":
[encode_dataset_condition(x) for x in var.objects]}
+ if isinstance(var, DatasetAny):
+ return {"__type": DAT.DATASET_ANY, "objects":
[encode_dataset_condition(x) for x in var.objects]}
+ raise ValueError(f"serialization not implemented for
{type(var).__name__!r}")
+
+
+def decode_dataset_condition(var: dict[str, Any]) -> BaseDataset:
+ """Decode a previously serialized dataset condition.
+
+ :meta private:
+ """
+ dat = var["__type"]
+ if dat == DAT.DATASET:
+ return Dataset(var["uri"], extra=var["extra"])
+ if dat == DAT.DATASET_ALL:
+ return DatasetAll(*(decode_dataset_condition(x) for x in
var["objects"]))
+ if dat == DAT.DATASET_ANY:
+ return DatasetAny(*(decode_dataset_condition(x) for x in
var["objects"]))
+ raise ValueError(f"deserialization not implemented for DAT {dat!r}")
+
+
def encode_timetable(var: Timetable) -> dict[str, Any]:
"""
Encode a timetable instance.
@@ -488,8 +517,6 @@ class BaseSerialization:
serialized_object[key] = encode_timetable(value)
elif key == "weight_rule" and value is not None:
serialized_object[key] = encode_priority_weight_strategy(value)
- elif key == "dataset_triggers":
- serialized_object[key] = cls.serialize(value)
else:
value = cls.serialize(value)
if isinstance(value, dict) and Encoding.TYPE in value:
@@ -607,24 +634,9 @@ class BaseSerialization:
return cls._encode(cls._serialize_param(var), type_=DAT.PARAM)
elif isinstance(var, XComArg):
return cls._encode(serialize_xcom_arg(var), type_=DAT.XCOM_REF)
- elif isinstance(var, Dataset):
- return cls._encode({"uri": var.uri, "extra": var.extra},
type_=DAT.DATASET)
- elif isinstance(var, DatasetAll):
- return cls._encode(
- [
- cls.serialize(x, strict=strict,
use_pydantic_models=use_pydantic_models)
- for x in var.objects
- ],
- type_=DAT.DATASET_ALL,
- )
- elif isinstance(var, DatasetAny):
- return cls._encode(
- [
- cls.serialize(x, strict=strict,
use_pydantic_models=use_pydantic_models)
- for x in var.objects
- ],
- type_=DAT.DATASET_ANY,
- )
+ elif isinstance(var, BaseDataset):
+ serialized_dataset = encode_dataset_condition(var)
+ return cls._encode(serialized_dataset,
type_=serialized_dataset.pop("__type"))
elif isinstance(var, SimpleTaskInstance):
return cls._encode(
cls.serialize(var.__dict__, strict=strict,
use_pydantic_models=use_pydantic_models),
@@ -740,9 +752,9 @@ class BaseSerialization:
elif type_ == DAT.DATASET:
return Dataset(**var)
elif type_ == DAT.DATASET_ANY:
- return DatasetAny(*(cls.deserialize(x) for x in var))
+ return DatasetAny(*(decode_dataset_condition(x) for x in
var["objects"]))
elif type_ == DAT.DATASET_ALL:
- return DatasetAll(*(cls.deserialize(x) for x in var))
+ return DatasetAll(*(decode_dataset_condition(x) for x in
var["objects"]))
elif type_ == DAT.SIMPLE_TASK_INSTANCE:
return SimpleTaskInstance(**cls.deserialize(var))
elif type_ == DAT.CONNECTION:
@@ -914,9 +926,7 @@ class DependencyDetector:
"""Detect dependencies set directly on the DAG object."""
if not dag:
return
- if not dag.dataset_triggers:
- return
- for uri, _ in dag.dataset_triggers.iter_datasets():
+ for uri, _ in dag.timetable.dataset_condition.iter_datasets():
yield DagDependency(
source="dataset",
target=dag.dag_id,
@@ -1562,8 +1572,6 @@ class SerializedDAG(DAG, BaseSerialization):
v = cls.deserialize(v)
elif k == "params":
v = cls._deserialize_params_dict(v)
- elif k == "dataset_triggers":
- v = cls.deserialize(v)
# else use v as it is
setattr(dag, k, v)
diff --git a/airflow/timetables/base.py b/airflow/timetables/base.py
index b5e95ef5f4..0b1a4f6de9 100644
--- a/airflow/timetables/base.py
+++ b/airflow/timetables/base.py
@@ -16,17 +16,47 @@
# under the License.
from __future__ import annotations
-from typing import TYPE_CHECKING, Any, NamedTuple, Sequence
+from typing import TYPE_CHECKING, Any, Iterator, NamedTuple, Sequence
from warnings import warn
+from airflow.datasets import BaseDataset
from airflow.typing_compat import Protocol, runtime_checkable
if TYPE_CHECKING:
from pendulum import DateTime
+ from airflow.datasets import Dataset
from airflow.utils.types import DagRunType
+class _NullDataset(BaseDataset):
+ """Sentinel type that represents "no datasets".
+
+ This is only implemented to make typing easier in timetables, and not
+ expected to be used anywhere else.
+
+ :meta private:
+ """
+
+ def __bool__(self) -> bool:
+ return False
+
+ def __or__(self, other: BaseDataset) -> BaseDataset:
+ return NotImplemented
+
+ def __and__(self, other: BaseDataset) -> BaseDataset:
+ return NotImplemented
+
+ def as_expression(self) -> Any:
+ return None
+
+ def evaluate(self, statuses: dict[str, bool]) -> bool:
+ return False
+
+ def iter_datasets(self) -> Iterator[tuple[str, Dataset]]:
+ return iter(())
+
+
class DataInterval(NamedTuple):
"""A data interval for a DagRun to operate over.
@@ -127,6 +157,12 @@ class Timetable(Protocol):
@property
def can_be_scheduled(self):
+ """Whether this timetable can actually schedule runs in an automated
manner.
+
+ This defaults to and should generally be *True* (including non periodic
+ execution types like *@once* and data triggered tables), but
+ ``NullTimetable`` sets this to *False*.
+ """
if hasattr(self, "can_run"):
warn(
'can_run class variable is deprecated. Use "can_be_scheduled"
instead.',
@@ -136,13 +172,6 @@ class Timetable(Protocol):
return self.can_run
return self._can_be_scheduled
- """Whether this timetable can actually schedule runs in an automated
manner.
-
- This defaults to and should generally be *True* (including non periodic
- execution types like *@once* and data triggered tables), but
- ``NullTimetable`` sets this to *False*.
- """
-
run_ordering: Sequence[str] = ("data_interval_end", "execution_date")
"""How runs triggered from this timetable should be ordered in UI.
@@ -150,11 +179,19 @@ class Timetable(Protocol):
"""
active_runs_limit: int | None = None
- """Override the max_active_runs parameter of any DAGs using this timetable.
- This is called during DAG initializing, and will set the max_active_runs if
- it returns a value. In most cases this should return None, but in some
cases
- (for example, the ContinuousTimetable) there are good reasons for limiting
- the DAGRun parallelism.
+ """Maximum active runs that can be active at one time for a DAG.
+
+ This is called during DAG initialization, and the return value is used as
+ the DAG's default ``max_active_runs``. This should generally return *None*,
+ but there are good reasons to limit DAG run parallelism in some cases, such
+ as for :class:`~airflow.timetable.simple.ContinuousTimetable`.
+ """
+
+ dataset_condition: BaseDataset = _NullDataset()
+ """The dataset condition that triggers a DAG using this timetable.
+
+ If this is not *None*, this should be a dataset, or a combination of, that
+ controls the DAG's dataset triggers.
"""
@classmethod
diff --git a/airflow/timetables/datasets.py b/airflow/timetables/datasets.py
index a8f4a7f22f..4c27f39b26 100644
--- a/airflow/timetables/datasets.py
+++ b/airflow/timetables/datasets.py
@@ -44,9 +44,9 @@ class DatasetOrTimeSchedule(DatasetTriggeredSchedule):
) -> None:
self.timetable = timetable
if isinstance(datasets, BaseDataset):
- self.datasets = datasets
+ self.dataset_condition = datasets
else:
- self.datasets = DatasetAll(*datasets)
+ self.dataset_condition = DatasetAll(*datasets)
self.description = f"Triggered by datasets or {timetable.description}"
self.periodic = timetable.periodic
@@ -55,25 +55,25 @@ class DatasetOrTimeSchedule(DatasetTriggeredSchedule):
@classmethod
def deserialize(cls, data: dict[str, typing.Any]) -> Timetable:
- from airflow.serialization.serialized_objects import decode_timetable
+ from airflow.serialization.serialized_objects import
decode_dataset_condition, decode_timetable
return cls(
+ datasets=decode_dataset_condition(data["dataset_condition"]),
timetable=decode_timetable(data["timetable"]),
- # don't need the datasets after deserialization
- # they are already stored on dataset_triggers attr on DAG
- # and this is what scheduler looks at
- datasets=[],
)
def serialize(self) -> dict[str, typing.Any]:
- from airflow.serialization.serialized_objects import encode_timetable
+ from airflow.serialization.serialized_objects import
encode_dataset_condition, encode_timetable
- return {"timetable": encode_timetable(self.timetable)}
+ return {
+ "dataset_condition":
encode_dataset_condition(self.dataset_condition),
+ "timetable": encode_timetable(self.timetable),
+ }
def validate(self) -> None:
if isinstance(self.timetable, DatasetTriggeredSchedule):
raise AirflowTimetableInvalid("cannot nest dataset timetables")
- if not isinstance(self.datasets, BaseDataset):
+ if not isinstance(self.dataset_condition, BaseDataset):
raise AirflowTimetableInvalid("all elements in 'datasets' must be
datasets")
@property
diff --git a/airflow/timetables/simple.py b/airflow/timetables/simple.py
index 6452244262..98bf7835f6 100644
--- a/airflow/timetables/simple.py
+++ b/airflow/timetables/simple.py
@@ -25,6 +25,7 @@ if TYPE_CHECKING:
from pendulum import DateTime
from sqlalchemy import Session
+ from airflow.datasets import BaseDataset
from airflow.models.dataset import DatasetEvent
from airflow.timetables.base import TimeRestriction
from airflow.utils.types import DagRunType
@@ -156,10 +157,25 @@ class DatasetTriggeredTimetable(_TrivialTimetable):
description: str = "Triggered by datasets"
+ def __init__(self, datasets: BaseDataset) -> None:
+ super().__init__()
+ self.dataset_condition = datasets
+
+ @classmethod
+ def deserialize(cls, data: dict[str, Any]) -> Timetable:
+ from airflow.serialization.serialized_objects import
decode_dataset_condition
+
+ return cls(decode_dataset_condition(data["dataset_condition"]))
+
@property
def summary(self) -> str:
return "Dataset"
+ def serialize(self) -> dict[str, Any]:
+ from airflow.serialization.serialized_objects import
encode_dataset_condition
+
+ return {"dataset_condition":
encode_dataset_condition(self.dataset_condition)}
+
def generate_run_id(
self,
*,
diff --git a/tests/cli/commands/test_dag_command.py
b/tests/cli/commands/test_dag_command.py
index de8788ada8..535a197eab 100644
--- a/tests/cli/commands/test_dag_command.py
+++ b/tests/cli/commands/test_dag_command.py
@@ -970,7 +970,7 @@ class TestCliDags:
mock_render_dag.assert_has_calls([mock.call(mock_get_dag.return_value,
tis=[])])
assert "SOURCE" in output
- @mock.patch("workday.AfterWorkdayTimetable")
+ @mock.patch("workday.AfterWorkdayTimetable", side_effect=lambda:
mock.MagicMock(active_runs_limit=None))
@mock.patch("airflow.models.dag._get_or_create_dagrun")
def test_dag_test_with_custom_timetable(self, mock__get_or_create_dagrun,
_):
"""
diff --git a/tests/datasets/test_dataset.py b/tests/datasets/test_dataset.py
index fcd4ae9a1e..4ce9b01ada 100644
--- a/tests/datasets/test_dataset.py
+++ b/tests/datasets/test_dataset.py
@@ -258,22 +258,21 @@ def test_dataset_trigger_setup_and_serialization(session,
dag_maker, create_test
with dag_maker(schedule=DatasetAny(*datasets)) as dag:
EmptyOperator(task_id="hello")
- # Verify dataset triggers are set up correctly
+ # Verify datasets are set up correctly
assert isinstance(
- dag.dataset_triggers, DatasetAny
- ), "DAG dataset triggers should be an instance of DatasetAny"
+ dag.timetable.dataset_condition, DatasetAny
+ ), "DAG datasets should be an instance of DatasetAny"
- # Serialize and deserialize DAG dataset triggers
- serialized_trigger = SerializedDAG.serialize(dag.dataset_triggers)
- deserialized_trigger = SerializedDAG.deserialize(serialized_trigger)
+ # Round-trip the DAG through serialization
+ deserialized_dag =
SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))
# Verify serialization and deserialization integrity
assert isinstance(
- deserialized_trigger, DatasetAny
- ), "Deserialized trigger should maintain type DatasetAny"
+ deserialized_dag.timetable.dataset_condition, DatasetAny
+ ), "Deserialized datasets should maintain type DatasetAny"
assert (
- deserialized_trigger.objects == dag.dataset_triggers.objects
- ), "Deserialized trigger objects should match original"
+ deserialized_dag.timetable.dataset_condition.objects ==
dag.timetable.dataset_condition.objects
+ ), "Deserialized datasets should match original"
@pytest.mark.db_test
@@ -303,12 +302,13 @@ def test_dataset_dag_run_queue_processing(session,
clear_datasets, dag_maker, cr
for (serialized_dag,) in serialized_dags:
dag = SerializedDAG.deserialize(serialized_dag.data)
for dataset_uri, status in dag_statuses[dag.dag_id].items():
- assert dag.dataset_triggers.evaluate({dataset_uri: status}), "DAG
trigger evaluation failed"
+ cond = dag.timetable.dataset_condition
+ assert cond.evaluate({dataset_uri: status}), "DAG trigger
evaluation failed"
@pytest.mark.db_test
@pytest.mark.usefixtures("clear_datasets")
-def test_dag_with_complex_dataset_triggers(session, dag_maker):
+def test_dag_with_complex_dataset_condition(session, dag_maker):
# Create Dataset instances
d1 = Dataset(uri="hello1")
d2 = Dataset(uri="hello2")
@@ -324,13 +324,13 @@ def test_dag_with_complex_dataset_triggers(session,
dag_maker):
EmptyOperator(task_id="hello")
assert isinstance(
- dag.dataset_triggers, DatasetAny
+ dag.timetable.dataset_condition, DatasetAny
), "DAG's dataset trigger should be an instance of DatasetAny"
assert any(
- isinstance(trigger, DatasetAll) for trigger in
dag.dataset_triggers.objects
+ isinstance(trigger, DatasetAll) for trigger in
dag.timetable.dataset_condition.objects
), "DAG's dataset trigger should include DatasetAll"
- serialized_triggers = SerializedDAG.serialize(dag.dataset_triggers)
+ serialized_triggers =
SerializedDAG.serialize(dag.timetable.dataset_condition)
deserialized_triggers = SerializedDAG.deserialize(serialized_triggers)
@@ -341,11 +341,13 @@ def test_dag_with_complex_dataset_triggers(session,
dag_maker):
isinstance(trigger, DatasetAll) for trigger in
deserialized_triggers.objects
), "Deserialized triggers should include DatasetAll"
- serialized_dag_dict = SerializedDAG.to_dict(dag)["dag"]
- assert "dataset_triggers" in serialized_dag_dict, "Serialized DAG should
contain 'dataset_triggers'"
+ serialized_timetable_dict =
SerializedDAG.to_dict(dag)["dag"]["timetable"]["__var"]
+ assert (
+ "dataset_condition" in serialized_timetable_dict
+ ), "Serialized timetable should contain 'dataset_condition'"
assert isinstance(
- serialized_dag_dict["dataset_triggers"], dict
- ), "Serialized 'dataset_triggers' should be a dict"
+ serialized_timetable_dict["dataset_condition"], dict
+ ), "Serialized 'dataset_condition' should be a dict"
def datasets_equal(d1: BaseDataset, d2: BaseDataset) -> bool:
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 0a1247d4ed..52001bef91 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -1968,7 +1968,7 @@ class TestDag:
def test_timetable_and_description_from_dataset(self):
dag = DAG("test_schedule_interval_arg",
schedule=[Dataset(uri="hello")], start_date=TEST_DATE)
- assert dag.timetable == DatasetTriggeredTimetable()
+ assert dag.timetable == DatasetTriggeredTimetable(Dataset(uri="hello"))
assert dag.schedule_interval == "Dataset"
assert dag.timetable.description == "Triggered by datasets"
diff --git a/tests/serialization/test_dag_serialization.py
b/tests/serialization/test_dag_serialization.py
index a1e76a13a8..d95c45ed64 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -591,16 +591,10 @@ class TestStringifiedDAGs:
"params",
"_processor_dags_folder",
}
- compare_serialization_list = {
- "dataset_triggers",
- }
fields_to_check = dag.get_serialized_fields() - exclusion_list
for field in fields_to_check:
actual = getattr(serialized_dag, field)
expected = getattr(dag, field)
- if field in compare_serialization_list:
- actual = BaseSerialization.serialize(actual)
- expected = BaseSerialization.serialize(expected)
assert actual == expected, f"{dag.dag_id}.{field} does not match"
# _processor_dags_folder is only populated at serialization time
# it's only used when relying on serialized dag to determine a dag's
relative path
diff --git a/tests/timetables/test_datasets_timetable.py
b/tests/timetables/test_datasets_timetable.py
index 32a5c89fd3..68928a568d 100644
--- a/tests/timetables/test_datasets_timetable.py
+++ b/tests/timetables/test_datasets_timetable.py
@@ -129,6 +129,10 @@ def test_serialization(dataset_timetable:
DatasetOrTimeSchedule, monkeypatch: An
serialized = dataset_timetable.serialize()
assert serialized == {
"timetable": "mock_serialized_timetable",
+ "dataset_condition": {
+ "__type": "dataset_all",
+ "objects": [{"__type": "dataset", "uri": "test_dataset", "extra":
None}],
+ },
}
@@ -141,7 +145,13 @@ def test_deserialization(monkeypatch: Any) -> None:
monkeypatch.setattr(
"airflow.serialization.serialized_objects.decode_timetable", lambda x:
MockTimetable()
)
- mock_serialized_data = {"timetable": "mock_serialized_timetable",
"datasets": [{"uri": "test_dataset"}]}
+ mock_serialized_data = {
+ "timetable": "mock_serialized_timetable",
+ "dataset_condition": {
+ "__type": "dataset_all",
+ "objects": [{"__type": "dataset", "uri": "test_dataset", "extra":
None}],
+ },
+ }
deserialized = DatasetOrTimeSchedule.deserialize(mock_serialized_data)
assert isinstance(deserialized, DatasetOrTimeSchedule)