kaxil commented on code in PR #54383:
URL: https://github.com/apache/airflow/pull/54383#discussion_r2301716045
##########
airflow-core/src/airflow/serialization/serialized_objects.py:
##########
@@ -2163,6 +2240,829 @@ def from_dict(cls, serialized_obj: dict) ->
SerializedDAG:
cls.conversion_v1_to_v2(serialized_obj)
return cls.deserialize_dag(serialized_obj["dag"])
+ @classmethod
+ @provide_session
+ def bulk_write_to_db(
+ cls,
+ bundle_name: str,
+ bundle_version: str | None,
+ dags: Collection[MaybeSerializedDAG],
+ session: Session = NEW_SESSION,
+ ) -> None:
+ """
+ Ensure the DagModel rows for the given dags are up-to-date in the dag
table in the DB.
+
+ :param dags: the DAG objects to save to the DB
+ :return: None
+ """
+ if not dags:
+ return
+
+ from airflow.dag_processing.collection import AssetModelOperation,
DagModelOperation
+
+ log.info("Sync %s DAGs", len(dags))
+ dag_op = DagModelOperation(
+ bundle_name=bundle_name, bundle_version=bundle_version,
dags={d.dag_id: d for d in dags}
+ )
+
+ orm_dags = dag_op.add_dags(session=session)
+ dag_op.update_dags(orm_dags, session=session)
+
+ asset_op = AssetModelOperation.collect(dag_op.dags)
+
+ orm_assets = asset_op.sync_assets(session=session)
+ orm_asset_aliases = asset_op.sync_asset_aliases(session=session)
+ session.flush() # This populates id so we can create fks in later
calls.
+
+ orm_dags = dag_op.find_orm_dags(session=session) # Refetch so
relationship is up to date.
+ asset_op.add_dag_asset_references(orm_dags, orm_assets,
session=session)
+ asset_op.add_dag_asset_alias_references(orm_dags, orm_asset_aliases,
session=session)
+ asset_op.add_dag_asset_name_uri_references(session=session)
+ asset_op.add_task_asset_references(orm_dags, orm_assets,
session=session)
+ asset_op.activate_assets_if_possible(orm_assets.values(),
session=session)
+ session.flush() # Activation is needed when we add trigger references.
+
+ asset_op.add_asset_trigger_references(orm_assets, session=session)
+ dag_op.update_dag_asset_expression(orm_dags=orm_dags,
orm_assets=orm_assets)
+ session.flush()
+
+ @provide_session
+ def sync_to_db(self, session: Session = NEW_SESSION) -> None:
+ """
+ Save attributes about this DAG to the DB.
+
+ :return: None
+ """
+ bundle_name, bundle_version = session.execute(
+ select(DagModel.bundle_name,
DagModel.bundle_version).where(DagModel.dag_id == self.dag_id)
+ ).one()
+ self.bulk_write_to_db(bundle_name, bundle_version, [self],
session=session)
+
+ @cached_property
+ def _time_restriction(self) -> TimeRestriction:
+ start_dates = [t.start_date for t in self.tasks if t.start_date]
+ if self.start_date is not None:
+ start_dates.append(self.start_date)
+ earliest = None
+ if start_dates:
+ earliest = coerce_datetime(min(start_dates))
+ latest = coerce_datetime(self.end_date)
+ end_dates = [t.end_date for t in self.tasks if t.end_date]
+ if len(end_dates) == len(self.tasks): # not exists null end_date
+ if self.end_date is not None:
+ end_dates.append(self.end_date)
+ if end_dates:
+ latest = coerce_datetime(max(end_dates))
+ return TimeRestriction(earliest, latest, self.catchup)
+
+ def infer_automated_data_interval(self, logical_date: datetime.datetime)
-> DataInterval:
+ """
+ Infer a data interval for a run against this DAG.
+
+ This method is used to bridge runs created prior to AIP-39
+ implementation, which do not have an explicit data interval. Therefore,
+ this method only considers ``schedule_interval`` values valid prior to
+ Airflow 2.2.
+
+ DO NOT call this method if there is a known data interval.
+
+ :meta private:
+ """
+ timetable_type = type(self.timetable)
+ if issubclass(timetable_type, (NullTimetable, OnceTimetable,
AssetTriggeredTimetable)):
+ return DataInterval.exact(coerce_datetime(logical_date))
+ start = coerce_datetime(logical_date)
+ if issubclass(timetable_type, CronDataIntervalTimetable):
+ end = cast("CronDataIntervalTimetable",
self.timetable)._get_next(start)
+ elif issubclass(timetable_type, DeltaDataIntervalTimetable):
+ end = cast("DeltaDataIntervalTimetable",
self.timetable)._get_next(start)
+ # Contributors: When the exception below is raised, you might want to
+ # add an 'elif' block here to handle custom timetables. Stop! The bug
+ # you're looking for is instead at when the DAG run (represented by
+ # logical_date) was created. See GH-31969 for an example:
+ # * Wrong fix: GH-32074 (modifies this function).
+ # * Correct fix: GH-32118 (modifies the DAG run creation code).
+ else:
+ raise ValueError(f"Not a valid timetable: {self.timetable!r}")
+ return DataInterval(start, end)
+
+ def get_run_data_interval(self, run: DagRun) -> DataInterval:
+ """
+ Get the data interval of this run.
+
+ For compatibility, this method infers the data interval from the DAG's
+ schedule if the run does not have an explicit one set, which is
possible for
+ runs created prior to AIP-39.
+
+ This function is private to Airflow core and should not be depended on
as a
+ part of the Python API.
+
+ :meta private:
+ """
+ if run.dag_id is not None and run.dag_id != self.dag_id:
+ raise ValueError(f"Arguments refer to different DAGs:
{self.dag_id} != {run.dag_id}")
+ data_interval = _get_model_data_interval(run, "data_interval_start",
"data_interval_end")
+ if data_interval is not None:
+ return data_interval
+ # Compatibility: runs created before AIP-39 implementation don't have
an
+ # explicit data interval. Try to infer from the logical date.
+ return self.infer_automated_data_interval(run.logical_date)
+
+ def get_next_data_interval(self, dag_model: DagModel) -> DataInterval |
None:
+ """
+ Get the data interval of the next scheduled run.
+
+ For compatibility, this method infers the data interval from the DAG's
+ schedule if the run does not have an explicit one set, which is
possible
+ for runs created prior to AIP-39.
+
+ This function is private to Airflow core and should not be depended on
as a
+ part of the Python API.
+
+ :meta private:
+ """
+ if self.dag_id != dag_model.dag_id:
+ raise ValueError(f"Arguments refer to different DAGs:
{self.dag_id} != {dag_model.dag_id}")
+ if dag_model.next_dagrun is None: # Next run not scheduled.
+ return None
+ data_interval = dag_model.next_dagrun_data_interval
+ if data_interval is not None:
+ return data_interval
+
+ # Compatibility: A run was scheduled without an explicit data interval.
+ # This means the run was scheduled before AIP-39 implementation. Try to
+ # infer from the logical date.
+ return self.infer_automated_data_interval(dag_model.next_dagrun)
+
+ def next_dagrun_info(
+ self,
+ last_automated_dagrun: None | DataInterval,
+ *,
+ restricted: bool = True,
+ ) -> DagRunInfo | None:
+ """
+ Get information about the next DagRun of this dag after
``date_last_automated_dagrun``.
+
+ This calculates what time interval the next DagRun should operate on
+ (its logical date) and when it can be scheduled, according to the
+ dag's timetable, start_date, end_date, etc. This doesn't check max
+ active run or any other "max_active_tasks" type limits, but only
+ performs calculations based on the various date and interval fields of
+ this dag and its tasks.
+
+ :param last_automated_dagrun: The ``max(logical_date)`` of
+ existing "automated" DagRuns for this dag (scheduled or backfill,
+ but not manual).
+ :param restricted: If set to *False* (default is *True*), ignore
+ ``start_date``, ``end_date``, and ``catchup`` specified on the DAG
+ or tasks.
+ :return: DagRunInfo of the next dagrun, or None if a dagrun is not
+ going to be scheduled.
+ """
+ if restricted:
+ restriction = self._time_restriction
+ else:
+ restriction = TimeRestriction(earliest=None, latest=None,
catchup=True)
+ try:
+ info = self.timetable.next_dagrun_info(
+ last_automated_data_interval=last_automated_dagrun,
+ restriction=restriction,
+ )
+ except Exception:
+ log.exception(
+ "Failed to fetch run info after data interval %s for DAG %r",
+ last_automated_dagrun,
+ self.dag_id,
+ )
+ info = None
+ return info
+
+ def iter_dagrun_infos_between(
+ self,
+ earliest: datetime.datetime | None,
+ latest: datetime.datetime,
+ *,
+ align: bool = True,
+ ) -> Iterable[DagRunInfo]:
+ """
+ Yield DagRunInfo using this DAG's timetable between given interval.
+
+ DagRunInfo instances yielded if their ``logical_date`` is not earlier
+ than ``earliest``, nor later than ``latest``. The instances are ordered
+ by their ``logical_date`` from earliest to latest.
+
+ If ``align`` is ``False``, the first run will happen immediately on
+ ``earliest``, even if it does not fall on the logical timetable
schedule.
+ The default is ``True``.
+
+ Example: A DAG is scheduled to run every midnight (``0 0 * * *``). If
+ ``earliest`` is ``2021-06-03 23:00:00``, the first DagRunInfo would be
+ ``2021-06-03 23:00:00`` if ``align=False``, and ``2021-06-04 00:00:00``
+ if ``align=True``.
+ """
+ if earliest is None:
+ earliest = self._time_restriction.earliest
+ if earliest is None:
+ raise ValueError("earliest was None and we had no value in
time_restriction to fallback on")
+ earliest = coerce_datetime(earliest)
+ latest = coerce_datetime(latest)
+
+ restriction = TimeRestriction(earliest, latest, catchup=True)
+
+ try:
+ info = self.timetable.next_dagrun_info(
+ last_automated_data_interval=None,
+ restriction=restriction,
+ )
+ except Exception:
+ log.exception(
+ "Failed to fetch run info after data interval %s for DAG %r",
+ None,
+ self.dag_id,
+ )
+ info = None
+
+ if info is None:
+ # No runs to be scheduled between the user-supplied timeframe. But
+ # if align=False, "invent" a data interval for the timeframe
itself.
+ if not align:
+ yield DagRunInfo.interval(earliest, latest)
+ return
+
+ # If align=False and earliest does not fall on the timetable's logical
+ # schedule, "invent" a data interval for it.
+ if not align and info.logical_date != earliest:
+ yield DagRunInfo.interval(earliest, info.data_interval.start)
+
+ # Generate naturally according to schedule.
+ while info is not None:
+ yield info
+ try:
+ info = self.timetable.next_dagrun_info(
+ last_automated_data_interval=info.data_interval,
+ restriction=restriction,
+ )
+ except Exception:
+ log.exception(
+ "Failed to fetch run info after data interval %s for DAG
%r",
+ info.data_interval if info else "<NONE>",
+ self.dag_id,
+ )
+ break
+
+ @provide_session
+ def get_concurrency_reached(self, session=NEW_SESSION) -> bool:
+ """Return a boolean indicating whether the max_active_tasks limit for
this DAG has been reached."""
+ from airflow.models.taskinstance import TaskInstance
+
+ total_tasks = session.scalar(
+ select(func.count(TaskInstance.task_id)).where(
+ TaskInstance.dag_id == self.dag_id,
+ TaskInstance.state == TaskInstanceState.RUNNING,
+ )
+ )
+ return total_tasks >= self.max_active_tasks
+
+ @provide_session
+ def create_dagrun(
+ self,
+ *,
+ run_id: str,
+ logical_date: datetime.datetime | None = None,
+ data_interval: tuple[datetime.datetime, datetime.datetime] | None =
None,
+ run_after: datetime.datetime,
+ conf: dict | None = None,
+ run_type: DagRunType,
+ triggered_by: DagRunTriggeredByType,
+ triggering_user_name: str | None = None,
+ state: DagRunState,
+ start_date: datetime.datetime | None = None,
+ creating_job_id: int | None = None,
+ backfill_id: NonNegativeInt | None = None,
+ session: Session = NEW_SESSION,
+ ) -> DagRun:
+ """
+ Create a run for this DAG to run its tasks.
+
+ :param run_id: ID of the dag_run
+ :param logical_date: date of execution
+ :param run_after: the datetime before which dag won't run
+ :param conf: Dict containing configuration/parameters to pass to the
DAG
+ :param triggered_by: the entity which triggers the dag_run
+ :param triggering_user_name: the user name who triggers the dag_run
+ :param start_date: the date this dag run should be evaluated
+ :param creating_job_id: ID of the job creating this DagRun
+ :param backfill_id: ID of the backfill run if one exists
+ :param session: Unused. Only added in compatibility with database
isolation mode
+ :return: The created DAG run.
+
+ :meta private:
+ """
+ logical_date = coerce_datetime(logical_date)
+ # For manual runs where logical_date is None, ensure no data_interval
is set.
+ if logical_date is None and data_interval is not None:
+ raise ValueError("data_interval must be None when logical_date is
None")
+
+ if data_interval and not isinstance(data_interval, DataInterval):
+ data_interval = DataInterval(*map(coerce_datetime, data_interval))
+
+ if isinstance(run_type, DagRunType):
+ pass
+ elif isinstance(run_type, str): # Ensure the input value is valid.
+ run_type = DagRunType(run_type)
+ else:
+ raise ValueError(f"run_type should be a DagRunType, not
{type(run_type)}")
+
+ if not isinstance(run_id, str):
+ raise ValueError(f"`run_id` should be a str, not {type(run_id)}")
+
+ # This is also done on the DagRun model class, but SQLAlchemy column
+ # validator does not work well for some reason.
+ if not re.match(RUN_ID_REGEX, run_id):
+ regex = airflow_conf.get("scheduler",
"allowed_run_id_pattern").strip()
+ if not regex or not re.match(regex, run_id):
+ raise ValueError(
+ f"The run_id provided '{run_id}' does not match regex
pattern "
+ f"'{regex}' or '{RUN_ID_REGEX}'"
+ )
+
+ # Prevent a manual run from using an ID that looks like a scheduled
run.
+ if run_type == DagRunType.MANUAL:
+ if (inferred_run_type := DagRunType.from_run_id(run_id)) !=
DagRunType.MANUAL:
+ raise ValueError(
+ f"A {run_type.value} DAG run cannot use ID {run_id!r}
since it "
+ f"is reserved for {inferred_run_type.value} runs"
+ )
+
+ # todo: AIP-78 add verification that if run type is backfill then we
have a backfill id
+
+ # create a copy of params before validating
+ copied_params = copy.deepcopy(self.params)
+ if conf:
+ copied_params.update(conf)
+ copied_params.validate()
+ orm_dagrun = _create_orm_dagrun(
+ dag=self,
+ run_id=run_id,
+ logical_date=logical_date,
+ data_interval=data_interval,
+ run_after=coerce_datetime(run_after),
+ start_date=coerce_datetime(start_date),
+ conf=conf,
+ state=state,
+ run_type=run_type,
+ creating_job_id=creating_job_id,
+ backfill_id=backfill_id,
+ triggered_by=triggered_by,
+ triggering_user_name=triggering_user_name,
+ session=session,
+ )
+
+ if self.deadline and isinstance(self.deadline.reference,
DeadlineReference.TYPES.DAGRUN):
+ session.add(
+ Deadline(
+ deadline_time=self.deadline.reference.evaluate_with(
+ session=session,
+ interval=self.deadline.interval,
+ dag_id=self.dag_id,
+ run_id=run_id,
+ ),
+ callback=self.deadline.callback,
+ dagrun_id=orm_dagrun.id,
+ )
+ )
+
+ return orm_dagrun
+
+ @provide_session
+ def set_task_instance_state(
+ self,
+ *,
+ task_id: str,
+ map_indexes: Collection[int] | None = None,
+ run_id: str | None = None,
+ state: TaskInstanceState,
+ upstream: bool = False,
+ downstream: bool = False,
+ future: bool = False,
+ past: bool = False,
+ commit: bool = True,
+ session=NEW_SESSION,
+ ) -> list[TaskInstance]:
+ """
+ Set the state of a TaskInstance and clear downstream tasks in failed
or upstream_failed state.
+
+ :param task_id: Task ID of the TaskInstance
+ :param map_indexes: Only set TaskInstance if its map_index matches.
+ If None (default), all mapped TaskInstances of the task are set.
+ :param run_id: The run_id of the TaskInstance
+ :param state: State to set the TaskInstance to
+ :param upstream: Include all upstream tasks of the given task_id
+ :param downstream: Include all downstream tasks of the given task_id
+ :param future: Include all future TaskInstances of the given task_id
+ :param commit: Commit changes
+ :param past: Include all past TaskInstances of the given task_id
+ """
+ from airflow.api.common.mark_tasks import set_state
+
+ # TODO (GH-52141): get_task in scheduler needs to return scheduler
types
+ # instead, but currently it inherits SDK's DAG.
+ task = cast("SchedulerOperator", self.get_task(task_id))
+ task.dag = self
+
+ tasks_to_set_state: list[SchedulerOperator | tuple[SchedulerOperator,
int]]
+ if map_indexes is None:
+ tasks_to_set_state = [task]
+ else:
+ tasks_to_set_state = [(task, map_index) for map_index in
map_indexes]
+
+ altered = set_state(
+ tasks=tasks_to_set_state,
+ run_id=run_id,
+ upstream=upstream,
+ downstream=downstream,
+ future=future,
+ past=past,
+ state=state,
+ commit=commit,
+ session=session,
+ )
+
+ if not commit:
+ return altered
+
+ # Clear downstream tasks that are in failed/upstream_failed state to
resume them.
+ # Flush the session so that the tasks marked success are reflected in
the db.
+ session.flush()
+ subset = self.partial_subset(
+ task_ids={task_id},
+ include_downstream=True,
+ include_upstream=False,
+ )
+
+ # Raises an error if not found
+ dr_id, logical_date = session.execute(
+ select(DagRun.id, DagRun.logical_date).where(
+ DagRun.run_id == run_id, DagRun.dag_id == self.dag_id
+ )
+ ).one()
+
+ # Now we want to clear downstreams of tasks that had their state set...
+ clear_kwargs = {
+ "only_failed": True,
+ "session": session,
+ # Exclude the task itself from being cleared.
+ "exclude_task_ids": frozenset((task_id,)),
+ }
+ if not future and not past: # Simple case 1: we're only dealing with
exactly one run.
+ clear_kwargs["run_id"] = run_id
+ subset.clear(**clear_kwargs)
+ elif future and past: # Simple case 2: we're clearing ALL runs.
+ subset.clear(**clear_kwargs)
+ else: # Complex cases: we may have more than one run, based on a date
range.
+ # Make 'future' and 'past' make some sense when multiple runs exist
+ # for the same logical date. We order runs by their id and only
+ # clear runs have larger/smaller ids.
+ exclude_run_id_stmt =
select(DagRun.run_id).where(DagRun.logical_date == logical_date)
+ if future:
+ clear_kwargs["start_date"] = logical_date
+ exclude_run_id_stmt = exclude_run_id_stmt.where(DagRun.id >
dr_id)
+ else:
+ clear_kwargs["end_date"] = logical_date
+ exclude_run_id_stmt = exclude_run_id_stmt.where(DagRun.id <
dr_id)
+
subset.clear(exclude_run_ids=frozenset(session.scalars(exclude_run_id_stmt)),
**clear_kwargs)
+ return altered
+
+ def get_task_assets(
+ self,
+ inlets: bool = True,
+ outlets: bool = True,
+ of_type: type[AssetT] = Asset, # type: ignore[assignment]
+ ) -> Generator[tuple[str, AssetT], None, None]:
+ for task in self.task_dict.values():
+ directions = ["inlets"] if inlets else []
+ if outlets:
+ directions.append("outlets")
+ for direction in directions:
+ if not (ports := getattr(task, direction, None)):
+ continue
+ for port in ports:
+ if not isinstance(port, of_type):
+ continue
+ yield task.task_id, port
+
+ @overload
+ def _get_task_instances(
+ self,
+ *,
+ task_ids: Collection[str | tuple[str, int]] | None,
+ start_date: datetime.datetime | None,
+ end_date: datetime.datetime | None,
+ run_id: str | None,
+ state: TaskInstanceState | Sequence[TaskInstanceState],
+ exclude_task_ids: Collection[str | tuple[str, int]] | None,
+ exclude_run_ids: frozenset[str] | None,
+ session: Session,
+ ) -> Iterable[TaskInstance]: ... # pragma: no cover
+
+ @overload
+ def _get_task_instances(
+ self,
Review Comment:
Once we are done with this migration, I'd love to move this code somewhere
else -- like you mentioned maybe SerializedDag should be somewhere else
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]