Copilot commented on code in PR #64571:
URL: https://github.com/apache/airflow/pull/64571#discussion_r3025334878
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -1777,25 +1778,37 @@ def _create_dagruns_for_partitioned_asset_dags(self,
session: Session) -> set[st
self.log.error("Dag '%s' not found in serialized_dag table",
apdr.target_dag_id)
continue
- asset_models = session.scalars(
- select(AssetModel).where(
- exists(
- select(1).where(
- PartitionedAssetKeyLog.asset_id == AssetModel.id,
- PartitionedAssetKeyLog.asset_partition_dag_run_id
== apdr.id,
- PartitionedAssetKeyLog.target_partition_key ==
apdr.partition_key,
+ asset_models = list(
+ session.scalars(
+ select(AssetModel).where(
+ exists(
+ select(1).where(
+ PartitionedAssetKeyLog.asset_id ==
AssetModel.id,
+
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
+ PartitionedAssetKeyLog.target_partition_key ==
apdr.partition_key,
+ )
)
)
)
)
- statuses: dict[SerializedAssetUniqueKey, bool] = {
- SerializedAssetUniqueKey.from_asset(a): True for a in
asset_models
- }
- # todo: AIP-76 so, this basically works when we only require one
partition from each asset to be there
- # but, we ultimately need rollup ability
- # that is, we need to ensure that whenever it is many -> one
partitions, then we need to ensure
- # that all the required keys are there
- # one way to do this would be just to figure out what the count
should be
+ timetable = dag.timetable
+ statuses: dict[SerializedAssetUniqueKey, bool] = {}
+ for asset_model in asset_models:
+ if timetable.partitioned:
+ mapper =
timetable.get_partition_mapper(name=asset_model.name, uri=asset_model.uri)
+ if isinstance(mapper, RollupMapper):
+ expected = mapper.to_upstream(apdr.partition_key)
Review Comment:
`mapper` returned by `timetable.get_partition_mapper(...)` is typically a
Task SDK mapper (e.g. `airflow.sdk.definitions.partition_mappers.*`, as used
throughout existing scheduler tests). Those mappers will not be instances of
core `airflow.partition_mappers.base.RollupMapper`, so this `isinstance(mapper,
RollupMapper)` check will almost always be false and rollup readiness will
never be enforced. Consider switching to a shared interface between core and
SDK (single RollupMapper type), or using duck-typing (e.g. detect/call a
`to_upstream` method) / checking against both core and SDK RollupMapper types.
```suggestion
to_upstream = getattr(mapper, "to_upstream", None)
if isinstance(mapper, RollupMapper) or
callable(to_upstream):
expected = to_upstream(apdr.partition_key)
```
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -1777,25 +1778,37 @@ def _create_dagruns_for_partitioned_asset_dags(self,
session: Session) -> set[st
self.log.error("Dag '%s' not found in serialized_dag table",
apdr.target_dag_id)
continue
- asset_models = session.scalars(
- select(AssetModel).where(
- exists(
- select(1).where(
- PartitionedAssetKeyLog.asset_id == AssetModel.id,
- PartitionedAssetKeyLog.asset_partition_dag_run_id
== apdr.id,
- PartitionedAssetKeyLog.target_partition_key ==
apdr.partition_key,
+ asset_models = list(
+ session.scalars(
+ select(AssetModel).where(
+ exists(
+ select(1).where(
+ PartitionedAssetKeyLog.asset_id ==
AssetModel.id,
+
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
+ PartitionedAssetKeyLog.target_partition_key ==
apdr.partition_key,
+ )
)
)
)
)
- statuses: dict[SerializedAssetUniqueKey, bool] = {
- SerializedAssetUniqueKey.from_asset(a): True for a in
asset_models
- }
- # todo: AIP-76 so, this basically works when we only require one
partition from each asset to be there
- # but, we ultimately need rollup ability
- # that is, we need to ensure that whenever it is many -> one
partitions, then we need to ensure
- # that all the required keys are there
- # one way to do this would be just to figure out what the count
should be
+ timetable = dag.timetable
+ statuses: dict[SerializedAssetUniqueKey, bool] = {}
+ for asset_model in asset_models:
+ if timetable.partitioned:
+ mapper =
timetable.get_partition_mapper(name=asset_model.name, uri=asset_model.uri)
+ if isinstance(mapper, RollupMapper):
+ expected = mapper.to_upstream(apdr.partition_key)
+ actual = set(
+ session.scalars(
+
select(PartitionedAssetKeyLog.source_partition_key).where(
+ PartitionedAssetKeyLog.asset_id ==
asset_model.id,
+
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
Review Comment:
The `actual` query does not filter by
`PartitionedAssetKeyLog.target_partition_key == apdr.partition_key`. Earlier in
this function the existence check includes `target_partition_key`, so omitting
it here can pull in source keys for other target partitions and incorrectly
mark a rollup as satisfied. Filter `actual` by the current `apdr.partition_key`
to keep the readiness check consistent.
```suggestion
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
PartitionedAssetKeyLog.target_partition_key == apdr.partition_key,
```
##########
airflow-core/src/airflow/partition_mappers/temporal.py:
##########
@@ -98,18 +98,23 @@ def normalize(self, dt: datetime) -> datetime:
return dt.replace(hour=0, minute=0, second=0, microsecond=0)
-class StartOfWeekMapper(_BaseTemporalMapper):
- """Map a time-based partition key to week."""
+class StartOfWeekMapper(_BaseTemporalMapper, RollupMapper):
+ """Map a time-based partition key to week (Mon–Sun), requiring all 7 daily
keys."""
default_output_format = "%Y-%m-%d (W%V)"
def normalize(self, dt: datetime) -> datetime:
start = dt - timedelta(days=dt.weekday())
return start.replace(hour=0, minute=0, second=0, microsecond=0)
+ def to_upstream(self, downstream_key: str) -> frozenset[str]:
+ # The output format starts with %Y-%m-%d which is always the Monday of
the week.
+ week_start = datetime.strptime(downstream_key[:10], "%Y-%m-%d")
+ return frozenset((week_start +
timedelta(days=i)).strftime(self.input_format) for i in range(7))
Review Comment:
`to_upstream()` builds a naive `datetime` from the downstream key and
formats it with `self.input_format`. This ignores the mapper timezone
(`self._timezone`) and will also generate incorrect keys when `input_format`
includes timezone info (e.g. `%z`), since formatting a naive datetime won’t
include an offset. Consider making `week_start` timezone-aware in
`self._timezone` (and keeping it aware while adding days) before formatting, so
the generated upstream keys match what `to_downstream()` expects.
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -1777,25 +1778,37 @@ def _create_dagruns_for_partitioned_asset_dags(self,
session: Session) -> set[st
self.log.error("Dag '%s' not found in serialized_dag table",
apdr.target_dag_id)
continue
- asset_models = session.scalars(
- select(AssetModel).where(
- exists(
- select(1).where(
- PartitionedAssetKeyLog.asset_id == AssetModel.id,
- PartitionedAssetKeyLog.asset_partition_dag_run_id
== apdr.id,
- PartitionedAssetKeyLog.target_partition_key ==
apdr.partition_key,
+ asset_models = list(
+ session.scalars(
+ select(AssetModel).where(
+ exists(
+ select(1).where(
+ PartitionedAssetKeyLog.asset_id ==
AssetModel.id,
+
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
+ PartitionedAssetKeyLog.target_partition_key ==
apdr.partition_key,
+ )
)
)
)
)
- statuses: dict[SerializedAssetUniqueKey, bool] = {
- SerializedAssetUniqueKey.from_asset(a): True for a in
asset_models
- }
- # todo: AIP-76 so, this basically works when we only require one
partition from each asset to be there
- # but, we ultimately need rollup ability
- # that is, we need to ensure that whenever it is many -> one
partitions, then we need to ensure
- # that all the required keys are there
- # one way to do this would be just to figure out what the count
should be
+ timetable = dag.timetable
+ statuses: dict[SerializedAssetUniqueKey, bool] = {}
+ for asset_model in asset_models:
+ if timetable.partitioned:
+ mapper =
timetable.get_partition_mapper(name=asset_model.name, uri=asset_model.uri)
+ if isinstance(mapper, RollupMapper):
+ expected = mapper.to_upstream(apdr.partition_key)
+ actual = set(
+ session.scalars(
+
select(PartitionedAssetKeyLog.source_partition_key).where(
+ PartitionedAssetKeyLog.asset_id ==
asset_model.id,
+
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
+ )
+ )
+ )
+
statuses[SerializedAssetUniqueKey.from_asset(asset_model)] =
expected.issubset(actual)
Review Comment:
`get_partition_mapper(...)` and `to_upstream(...)` are called without any
error handling. If a mapper raises (e.g. due to an unexpected partition key
format), this can abort the scheduler loop instead of treating the asset as
not-yet-satisfied and logging an actionable message. Consider wrapping this
block in a broad `try/except Exception` (similar to
`airflow/assets/manager.py`) and setting the status to `False` (or skipping) on
failure.
```suggestion
try:
mapper = timetable.get_partition_mapper(
name=asset_model.name, uri=asset_model.uri
)
if isinstance(mapper, RollupMapper):
expected = mapper.to_upstream(apdr.partition_key)
actual = set(
session.scalars(
select(PartitionedAssetKeyLog.source_partition_key).where(
PartitionedAssetKeyLog.asset_id ==
asset_model.id,
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
)
)
)
statuses[SerializedAssetUniqueKey.from_asset(asset_model)] = expected.issubset(
actual
)
continue
except Exception:
self.log.exception(
"Failed to evaluate partition mapping for asset
%s (uri=%s, partition_key=%s); "
"treating asset as not satisfied.",
asset_model.name,
asset_model.uri,
apdr.partition_key,
)
statuses[SerializedAssetUniqueKey.from_asset(asset_model)] = False
```
##########
airflow-core/src/airflow/partition_mappers/temporal.py:
##########
@@ -98,18 +98,23 @@ def normalize(self, dt: datetime) -> datetime:
return dt.replace(hour=0, minute=0, second=0, microsecond=0)
-class StartOfWeekMapper(_BaseTemporalMapper):
- """Map a time-based partition key to week."""
+class StartOfWeekMapper(_BaseTemporalMapper, RollupMapper):
+ """Map a time-based partition key to week (Mon–Sun), requiring all 7 daily
keys."""
default_output_format = "%Y-%m-%d (W%V)"
def normalize(self, dt: datetime) -> datetime:
start = dt - timedelta(days=dt.weekday())
return start.replace(hour=0, minute=0, second=0, microsecond=0)
+ def to_upstream(self, downstream_key: str) -> frozenset[str]:
+ # The output format starts with %Y-%m-%d which is always the Monday of
the week.
+ week_start = datetime.strptime(downstream_key[:10], "%Y-%m-%d")
+ return frozenset((week_start +
timedelta(days=i)).strftime(self.input_format) for i in range(7))
Review Comment:
`to_upstream()` assumes the downstream key always starts with an ISO date
(`downstream_key[:10]`) and parses it with a hard-coded format. Since
`_BaseTemporalMapper` allows `output_format` overrides, this can silently break
for customized week output formats. Consider either deriving the parse format
from `self.output_format` (or explicitly validating that it starts with
`%Y-%m-%d` and raising a clear error if not) so rollup mapping remains
consistent with the configured output format.
##########
airflow-core/src/airflow/partition_mappers/temporal.py:
##########
@@ -122,6 +127,15 @@ def normalize(self, dt: datetime) -> datetime:
microsecond=0,
)
+ def to_upstream(self, downstream_key: str) -> frozenset[str]:
+ import calendar
+
+ month_start = datetime.strptime(downstream_key, self.output_format)
+ days_in_month = calendar.monthrange(month_start.year,
month_start.month)[1]
Review Comment:
`import calendar` is inside `to_upstream()`. This doesn’t appear to be for
circular-import avoidance or lazy loading, so it should be moved to the module
level per project guidelines to keep imports consistent and discoverable.
##########
airflow-core/src/airflow/partition_mappers/temporal.py:
##########
@@ -122,6 +127,15 @@ def normalize(self, dt: datetime) -> datetime:
microsecond=0,
)
+ def to_upstream(self, downstream_key: str) -> frozenset[str]:
+ import calendar
+
+ month_start = datetime.strptime(downstream_key, self.output_format)
+ days_in_month = calendar.monthrange(month_start.year,
month_start.month)[1]
+ return frozenset(
+ (month_start + timedelta(days=i)).strftime(self.input_format) for
i in range(days_in_month)
+ )
Review Comment:
Similar to `StartOfWeekMapper.to_upstream()`, this uses a naive `datetime`
for `month_start`, so generated upstream keys can be wrong when the mapper
timezone is not UTC or when `input_format` includes `%z`. Consider making
`month_start` timezone-aware in `self._timezone` before formatting upstream
keys so the scheduler waits for the correct keys.
##########
task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py:
##########
@@ -43,17 +45,31 @@ class StartOfDayMapper(_BaseTemporalMapper):
default_output_format = "%Y-%m-%d"
-class StartOfWeekMapper(_BaseTemporalMapper):
- """Map a time-based partition key to week."""
+class StartOfWeekMapper(_BaseTemporalMapper, RollupMapper):
+ """Map a time-based partition key to week (Mon–Sun), requiring all 7 daily
keys."""
default_output_format = "%Y-%m-%d (W%V)"
+ def to_upstream(self, downstream_key: str) -> frozenset[str]:
+ # The output format starts with %Y-%m-%d which is always the Monday of
the week.
+ week_start = datetime.strptime(downstream_key[:10], "%Y-%m-%d")
+ return frozenset((week_start +
timedelta(days=i)).strftime(self.input_format) for i in range(7))
+
-class StartOfMonthMapper(_BaseTemporalMapper):
- """Map a time-based partition key to month."""
+class StartOfMonthMapper(_BaseTemporalMapper, RollupMapper):
+ """Map a time-based partition key to month, requiring all daily keys in
that month."""
default_output_format = "%Y-%m"
+ def to_upstream(self, downstream_key: str) -> frozenset[str]:
+ import calendar
+
+ month_start = datetime.strptime(downstream_key, self.output_format)
+ days_in_month = calendar.monthrange(month_start.year,
month_start.month)[1]
Review Comment:
`import calendar` is inside `to_upstream()`. This doesn’t look necessary for
lazy loading or circular imports, so it should be moved to the module level for
consistency and to satisfy import placement guidelines.
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -1777,25 +1778,37 @@ def _create_dagruns_for_partitioned_asset_dags(self,
session: Session) -> set[st
self.log.error("Dag '%s' not found in serialized_dag table",
apdr.target_dag_id)
continue
- asset_models = session.scalars(
- select(AssetModel).where(
- exists(
- select(1).where(
- PartitionedAssetKeyLog.asset_id == AssetModel.id,
- PartitionedAssetKeyLog.asset_partition_dag_run_id
== apdr.id,
- PartitionedAssetKeyLog.target_partition_key ==
apdr.partition_key,
+ asset_models = list(
+ session.scalars(
+ select(AssetModel).where(
+ exists(
+ select(1).where(
+ PartitionedAssetKeyLog.asset_id ==
AssetModel.id,
+
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
+ PartitionedAssetKeyLog.target_partition_key ==
apdr.partition_key,
+ )
)
)
)
)
- statuses: dict[SerializedAssetUniqueKey, bool] = {
- SerializedAssetUniqueKey.from_asset(a): True for a in
asset_models
- }
- # todo: AIP-76 so, this basically works when we only require one
partition from each asset to be there
- # but, we ultimately need rollup ability
- # that is, we need to ensure that whenever it is many -> one
partitions, then we need to ensure
- # that all the required keys are there
- # one way to do this would be just to figure out what the count
should be
+ timetable = dag.timetable
+ statuses: dict[SerializedAssetUniqueKey, bool] = {}
+ for asset_model in asset_models:
+ if timetable.partitioned:
+ mapper =
timetable.get_partition_mapper(name=asset_model.name, uri=asset_model.uri)
+ if isinstance(mapper, RollupMapper):
+ expected = mapper.to_upstream(apdr.partition_key)
+ actual = set(
+ session.scalars(
+
select(PartitionedAssetKeyLog.source_partition_key).where(
+ PartitionedAssetKeyLog.asset_id ==
asset_model.id,
+
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
+ )
+ )
+ )
Review Comment:
This introduces a per-asset query against `PartitionedAssetKeyLog` inside
the `for asset_model in asset_models` loop (N+1 pattern). Since this runs in
the scheduler loop, it can become a bottleneck for partitioned DAGs with many
upstream assets. Consider fetching all `(asset_id, source_partition_key)` rows
for the `apdr.id` in one query and grouping in Python before the loop.
```suggestion
# Batch-load all source partition keys for this asset partition
dag run to avoid
# issuing a per-asset query (N+1 pattern) inside the loop below.
partition_keys_by_asset_id: dict[int, set] = defaultdict(set)
for asset_id, source_partition_key in session.execute(
select(
PartitionedAssetKeyLog.asset_id,
PartitionedAssetKeyLog.source_partition_key,
).where(
PartitionedAssetKeyLog.asset_partition_dag_run_id ==
apdr.id,
)
):
partition_keys_by_asset_id[asset_id].add(source_partition_key)
for asset_model in asset_models:
if timetable.partitioned:
mapper =
timetable.get_partition_mapper(name=asset_model.name, uri=asset_model.uri)
if isinstance(mapper, RollupMapper):
expected = mapper.to_upstream(apdr.partition_key)
actual =
partition_keys_by_asset_id.get(asset_model.id, set())
```
##########
task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py:
##########
@@ -43,17 +45,31 @@ class StartOfDayMapper(_BaseTemporalMapper):
default_output_format = "%Y-%m-%d"
-class StartOfWeekMapper(_BaseTemporalMapper):
- """Map a time-based partition key to week."""
+class StartOfWeekMapper(_BaseTemporalMapper, RollupMapper):
+ """Map a time-based partition key to week (Mon–Sun), requiring all 7 daily
keys."""
default_output_format = "%Y-%m-%d (W%V)"
+ def to_upstream(self, downstream_key: str) -> frozenset[str]:
+ # The output format starts with %Y-%m-%d which is always the Monday of
the week.
+ week_start = datetime.strptime(downstream_key[:10], "%Y-%m-%d")
+ return frozenset((week_start +
timedelta(days=i)).strftime(self.input_format) for i in range(7))
Review Comment:
`to_upstream()` assumes the downstream week key always begins with
`%Y-%m-%d` and parses `downstream_key[:10]` with a hard-coded format. Since
`_BaseTemporalMapper` allows `output_format` overrides, customized output
formats can break rollup mapping. Consider validating that `self.output_format`
starts with `%Y-%m-%d` (and raising a clear error if not), or otherwise
deriving the date portion consistently from the configured `output_format`.
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -1777,25 +1778,37 @@ def _create_dagruns_for_partitioned_asset_dags(self,
session: Session) -> set[st
self.log.error("Dag '%s' not found in serialized_dag table",
apdr.target_dag_id)
continue
- asset_models = session.scalars(
- select(AssetModel).where(
- exists(
- select(1).where(
- PartitionedAssetKeyLog.asset_id == AssetModel.id,
- PartitionedAssetKeyLog.asset_partition_dag_run_id
== apdr.id,
- PartitionedAssetKeyLog.target_partition_key ==
apdr.partition_key,
+ asset_models = list(
+ session.scalars(
+ select(AssetModel).where(
+ exists(
+ select(1).where(
+ PartitionedAssetKeyLog.asset_id ==
AssetModel.id,
+
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
+ PartitionedAssetKeyLog.target_partition_key ==
apdr.partition_key,
+ )
)
)
)
)
- statuses: dict[SerializedAssetUniqueKey, bool] = {
- SerializedAssetUniqueKey.from_asset(a): True for a in
asset_models
- }
- # todo: AIP-76 so, this basically works when we only require one
partition from each asset to be there
- # but, we ultimately need rollup ability
- # that is, we need to ensure that whenever it is many -> one
partitions, then we need to ensure
- # that all the required keys are there
- # one way to do this would be just to figure out what the count
should be
+ timetable = dag.timetable
+ statuses: dict[SerializedAssetUniqueKey, bool] = {}
+ for asset_model in asset_models:
+ if timetable.partitioned:
+ mapper =
timetable.get_partition_mapper(name=asset_model.name, uri=asset_model.uri)
+ if isinstance(mapper, RollupMapper):
+ expected = mapper.to_upstream(apdr.partition_key)
+ actual = set(
+ session.scalars(
+
select(PartitionedAssetKeyLog.source_partition_key).where(
+ PartitionedAssetKeyLog.asset_id ==
asset_model.id,
+
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
+ )
+ )
+ )
+
statuses[SerializedAssetUniqueKey.from_asset(asset_model)] =
expected.issubset(actual)
+ continue
+ statuses[SerializedAssetUniqueKey.from_asset(asset_model)] =
True
Review Comment:
This adds new rollup readiness behavior (waiting for a complete upstream key
set via `to_upstream()`), but there doesn’t appear to be a regression test
covering the many→one case (e.g. weekly or monthly rollup) in the scheduler’s
partitioned asset DAG run creation tests. Adding a unit test that proves the
scheduler does *not* create a DagRun until all expected upstream keys are
present would help prevent regressions.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]