Copilot commented on code in PR #64571:
URL: https://github.com/apache/airflow/pull/64571#discussion_r3025334878


##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -1777,25 +1778,37 @@ def _create_dagruns_for_partitioned_asset_dags(self, 
session: Session) -> set[st
                 self.log.error("Dag '%s' not found in serialized_dag table", 
apdr.target_dag_id)
                 continue
 
-            asset_models = session.scalars(
-                select(AssetModel).where(
-                    exists(
-                        select(1).where(
-                            PartitionedAssetKeyLog.asset_id == AssetModel.id,
-                            PartitionedAssetKeyLog.asset_partition_dag_run_id 
== apdr.id,
-                            PartitionedAssetKeyLog.target_partition_key == 
apdr.partition_key,
+            asset_models = list(
+                session.scalars(
+                    select(AssetModel).where(
+                        exists(
+                            select(1).where(
+                                PartitionedAssetKeyLog.asset_id == 
AssetModel.id,
+                                
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
+                                PartitionedAssetKeyLog.target_partition_key == 
apdr.partition_key,
+                            )
                         )
                     )
                 )
             )
-            statuses: dict[SerializedAssetUniqueKey, bool] = {
-                SerializedAssetUniqueKey.from_asset(a): True for a in 
asset_models
-            }
-            # todo: AIP-76 so, this basically works when we only require one 
partition from each asset to be there
-            #  but, we ultimately need rollup ability
-            #  that is, we need to ensure that whenever it is many -> one 
partitions, then we need to ensure
-            #  that all the required keys are there
-            #  one way to do this would be just to figure out what the count 
should be
+            timetable = dag.timetable
+            statuses: dict[SerializedAssetUniqueKey, bool] = {}
+            for asset_model in asset_models:
+                if timetable.partitioned:
+                    mapper = 
timetable.get_partition_mapper(name=asset_model.name, uri=asset_model.uri)
+                    if isinstance(mapper, RollupMapper):
+                        expected = mapper.to_upstream(apdr.partition_key)

Review Comment:
   `mapper` returned by `timetable.get_partition_mapper(...)` is typically a 
Task SDK mapper (e.g. `airflow.sdk.definitions.partition_mappers.*`, as used 
throughout existing scheduler tests). Those mappers will not be instances of 
core `airflow.partition_mappers.base.RollupMapper`, so this `isinstance(mapper, 
RollupMapper)` check will almost always be false and rollup readiness will 
never be enforced. Consider switching to a shared interface between core and 
SDK (single RollupMapper type), or using duck-typing (e.g. detect/call a 
`to_upstream` method) / checking against both core and SDK RollupMapper types.
   ```suggestion
                       to_upstream = getattr(mapper, "to_upstream", None)
                       if isinstance(mapper, RollupMapper) or 
callable(to_upstream):
                           expected = to_upstream(apdr.partition_key)
   ```



##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -1777,25 +1778,37 @@ def _create_dagruns_for_partitioned_asset_dags(self, 
session: Session) -> set[st
                 self.log.error("Dag '%s' not found in serialized_dag table", 
apdr.target_dag_id)
                 continue
 
-            asset_models = session.scalars(
-                select(AssetModel).where(
-                    exists(
-                        select(1).where(
-                            PartitionedAssetKeyLog.asset_id == AssetModel.id,
-                            PartitionedAssetKeyLog.asset_partition_dag_run_id 
== apdr.id,
-                            PartitionedAssetKeyLog.target_partition_key == 
apdr.partition_key,
+            asset_models = list(
+                session.scalars(
+                    select(AssetModel).where(
+                        exists(
+                            select(1).where(
+                                PartitionedAssetKeyLog.asset_id == 
AssetModel.id,
+                                
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
+                                PartitionedAssetKeyLog.target_partition_key == 
apdr.partition_key,
+                            )
                         )
                     )
                 )
             )
-            statuses: dict[SerializedAssetUniqueKey, bool] = {
-                SerializedAssetUniqueKey.from_asset(a): True for a in 
asset_models
-            }
-            # todo: AIP-76 so, this basically works when we only require one 
partition from each asset to be there
-            #  but, we ultimately need rollup ability
-            #  that is, we need to ensure that whenever it is many -> one 
partitions, then we need to ensure
-            #  that all the required keys are there
-            #  one way to do this would be just to figure out what the count 
should be
+            timetable = dag.timetable
+            statuses: dict[SerializedAssetUniqueKey, bool] = {}
+            for asset_model in asset_models:
+                if timetable.partitioned:
+                    mapper = 
timetable.get_partition_mapper(name=asset_model.name, uri=asset_model.uri)
+                    if isinstance(mapper, RollupMapper):
+                        expected = mapper.to_upstream(apdr.partition_key)
+                        actual = set(
+                            session.scalars(
+                                
select(PartitionedAssetKeyLog.source_partition_key).where(
+                                    PartitionedAssetKeyLog.asset_id == 
asset_model.id,
+                                    
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,

Review Comment:
   The `actual` query does not filter by 
`PartitionedAssetKeyLog.target_partition_key == apdr.partition_key`. Earlier in 
this function the existence check includes `target_partition_key`, so omitting 
it here can pull in source keys for other target partitions and incorrectly 
mark a rollup as satisfied. Filter `actual` by the current `apdr.partition_key` 
to keep the readiness check consistent.
   ```suggestion
                                       
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
                                       
PartitionedAssetKeyLog.target_partition_key == apdr.partition_key,
   ```



##########
airflow-core/src/airflow/partition_mappers/temporal.py:
##########
@@ -98,18 +98,23 @@ def normalize(self, dt: datetime) -> datetime:
         return dt.replace(hour=0, minute=0, second=0, microsecond=0)
 
 
-class StartOfWeekMapper(_BaseTemporalMapper):
-    """Map a time-based partition key to week."""
+class StartOfWeekMapper(_BaseTemporalMapper, RollupMapper):
+    """Map a time-based partition key to week (Mon–Sun), requiring all 7 daily 
keys."""
 
     default_output_format = "%Y-%m-%d (W%V)"
 
     def normalize(self, dt: datetime) -> datetime:
         start = dt - timedelta(days=dt.weekday())
         return start.replace(hour=0, minute=0, second=0, microsecond=0)
 
+    def to_upstream(self, downstream_key: str) -> frozenset[str]:
+        # The output format starts with %Y-%m-%d which is always the Monday of 
the week.
+        week_start = datetime.strptime(downstream_key[:10], "%Y-%m-%d")
+        return frozenset((week_start + 
timedelta(days=i)).strftime(self.input_format) for i in range(7))

Review Comment:
   `to_upstream()` builds a naive `datetime` from the downstream key and 
formats it with `self.input_format`. This ignores the mapper timezone 
(`self._timezone`) and will also generate incorrect keys when `input_format` 
includes timezone info (e.g. `%z`), since formatting a naive datetime won’t 
include an offset. Consider making `week_start` timezone-aware in 
`self._timezone` (and keeping it aware while adding days) before formatting, so 
the generated upstream keys match what `to_downstream()` expects.



##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -1777,25 +1778,37 @@ def _create_dagruns_for_partitioned_asset_dags(self, 
session: Session) -> set[st
                 self.log.error("Dag '%s' not found in serialized_dag table", 
apdr.target_dag_id)
                 continue
 
-            asset_models = session.scalars(
-                select(AssetModel).where(
-                    exists(
-                        select(1).where(
-                            PartitionedAssetKeyLog.asset_id == AssetModel.id,
-                            PartitionedAssetKeyLog.asset_partition_dag_run_id 
== apdr.id,
-                            PartitionedAssetKeyLog.target_partition_key == 
apdr.partition_key,
+            asset_models = list(
+                session.scalars(
+                    select(AssetModel).where(
+                        exists(
+                            select(1).where(
+                                PartitionedAssetKeyLog.asset_id == 
AssetModel.id,
+                                
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
+                                PartitionedAssetKeyLog.target_partition_key == 
apdr.partition_key,
+                            )
                         )
                     )
                 )
             )
-            statuses: dict[SerializedAssetUniqueKey, bool] = {
-                SerializedAssetUniqueKey.from_asset(a): True for a in 
asset_models
-            }
-            # todo: AIP-76 so, this basically works when we only require one 
partition from each asset to be there
-            #  but, we ultimately need rollup ability
-            #  that is, we need to ensure that whenever it is many -> one 
partitions, then we need to ensure
-            #  that all the required keys are there
-            #  one way to do this would be just to figure out what the count 
should be
+            timetable = dag.timetable
+            statuses: dict[SerializedAssetUniqueKey, bool] = {}
+            for asset_model in asset_models:
+                if timetable.partitioned:
+                    mapper = 
timetable.get_partition_mapper(name=asset_model.name, uri=asset_model.uri)
+                    if isinstance(mapper, RollupMapper):
+                        expected = mapper.to_upstream(apdr.partition_key)
+                        actual = set(
+                            session.scalars(
+                                
select(PartitionedAssetKeyLog.source_partition_key).where(
+                                    PartitionedAssetKeyLog.asset_id == 
asset_model.id,
+                                    
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
+                                )
+                            )
+                        )
+                        
statuses[SerializedAssetUniqueKey.from_asset(asset_model)] = 
expected.issubset(actual)

Review Comment:
   `get_partition_mapper(...)` and `to_upstream(...)` are called without any 
error handling. If a mapper raises (e.g. due to an unexpected partition key 
format), this can abort the scheduler loop instead of treating the asset as 
not-yet-satisfied and logging an actionable message. Consider wrapping this 
block in a broad `try/except Exception` (similar to 
`airflow/assets/manager.py`) and setting the status to `False` (or skipping) on 
failure.
   ```suggestion
                       try:
                           mapper = timetable.get_partition_mapper(
                               name=asset_model.name, uri=asset_model.uri
                           )
                           if isinstance(mapper, RollupMapper):
                               expected = mapper.to_upstream(apdr.partition_key)
                               actual = set(
                                   session.scalars(
                                       
select(PartitionedAssetKeyLog.source_partition_key).where(
                                           PartitionedAssetKeyLog.asset_id == 
asset_model.id,
                                           
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
                                       )
                                   )
                               )
                               
statuses[SerializedAssetUniqueKey.from_asset(asset_model)] = expected.issubset(
                                   actual
                               )
                               continue
                       except Exception:
                           self.log.exception(
                               "Failed to evaluate partition mapping for asset 
%s (uri=%s, partition_key=%s); "
                               "treating asset as not satisfied.",
                               asset_model.name,
                               asset_model.uri,
                               apdr.partition_key,
                           )
                           
statuses[SerializedAssetUniqueKey.from_asset(asset_model)] = False
   ```



##########
airflow-core/src/airflow/partition_mappers/temporal.py:
##########
@@ -98,18 +98,23 @@ def normalize(self, dt: datetime) -> datetime:
         return dt.replace(hour=0, minute=0, second=0, microsecond=0)
 
 
-class StartOfWeekMapper(_BaseTemporalMapper):
-    """Map a time-based partition key to week."""
+class StartOfWeekMapper(_BaseTemporalMapper, RollupMapper):
+    """Map a time-based partition key to week (Mon–Sun), requiring all 7 daily 
keys."""
 
     default_output_format = "%Y-%m-%d (W%V)"
 
     def normalize(self, dt: datetime) -> datetime:
         start = dt - timedelta(days=dt.weekday())
         return start.replace(hour=0, minute=0, second=0, microsecond=0)
 
+    def to_upstream(self, downstream_key: str) -> frozenset[str]:
+        # The output format starts with %Y-%m-%d which is always the Monday of 
the week.
+        week_start = datetime.strptime(downstream_key[:10], "%Y-%m-%d")
+        return frozenset((week_start + 
timedelta(days=i)).strftime(self.input_format) for i in range(7))

Review Comment:
   `to_upstream()` assumes the downstream key always starts with an ISO date 
(`downstream_key[:10]`) and parses it with a hard-coded format. Since 
`_BaseTemporalMapper` allows `output_format` overrides, this can silently break 
for customized week output formats. Consider either deriving the parse format 
from `self.output_format` (or explicitly validating that it starts with 
`%Y-%m-%d` and raising a clear error if not) so rollup mapping remains 
consistent with the configured output format.



##########
airflow-core/src/airflow/partition_mappers/temporal.py:
##########
@@ -122,6 +127,15 @@ def normalize(self, dt: datetime) -> datetime:
             microsecond=0,
         )
 
+    def to_upstream(self, downstream_key: str) -> frozenset[str]:
+        import calendar
+
+        month_start = datetime.strptime(downstream_key, self.output_format)
+        days_in_month = calendar.monthrange(month_start.year, 
month_start.month)[1]

Review Comment:
   `import calendar` is inside `to_upstream()`. This doesn’t appear to be for 
circular-import avoidance or lazy loading, so it should be moved to the module 
level per project guidelines to keep imports consistent and discoverable.



##########
airflow-core/src/airflow/partition_mappers/temporal.py:
##########
@@ -122,6 +127,15 @@ def normalize(self, dt: datetime) -> datetime:
             microsecond=0,
         )
 
+    def to_upstream(self, downstream_key: str) -> frozenset[str]:
+        import calendar
+
+        month_start = datetime.strptime(downstream_key, self.output_format)
+        days_in_month = calendar.monthrange(month_start.year, 
month_start.month)[1]
+        return frozenset(
+            (month_start + timedelta(days=i)).strftime(self.input_format) for 
i in range(days_in_month)
+        )

Review Comment:
   Similar to `StartOfWeekMapper.to_upstream()`, this uses a naive `datetime` 
for `month_start`, so generated upstream keys can be wrong when the mapper 
timezone is not UTC or when `input_format` includes `%z`. Consider making 
`month_start` timezone-aware in `self._timezone` before formatting upstream 
keys so the scheduler waits for the correct keys.



##########
task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py:
##########
@@ -43,17 +45,31 @@ class StartOfDayMapper(_BaseTemporalMapper):
     default_output_format = "%Y-%m-%d"
 
 
-class StartOfWeekMapper(_BaseTemporalMapper):
-    """Map a time-based partition key to week."""
+class StartOfWeekMapper(_BaseTemporalMapper, RollupMapper):
+    """Map a time-based partition key to week (Mon–Sun), requiring all 7 daily 
keys."""
 
     default_output_format = "%Y-%m-%d (W%V)"
 
+    def to_upstream(self, downstream_key: str) -> frozenset[str]:
+        # The output format starts with %Y-%m-%d which is always the Monday of 
the week.
+        week_start = datetime.strptime(downstream_key[:10], "%Y-%m-%d")
+        return frozenset((week_start + 
timedelta(days=i)).strftime(self.input_format) for i in range(7))
+
 
-class StartOfMonthMapper(_BaseTemporalMapper):
-    """Map a time-based partition key to month."""
+class StartOfMonthMapper(_BaseTemporalMapper, RollupMapper):
+    """Map a time-based partition key to month, requiring all daily keys in 
that month."""
 
     default_output_format = "%Y-%m"
 
+    def to_upstream(self, downstream_key: str) -> frozenset[str]:
+        import calendar
+
+        month_start = datetime.strptime(downstream_key, self.output_format)
+        days_in_month = calendar.monthrange(month_start.year, 
month_start.month)[1]

Review Comment:
   `import calendar` is inside `to_upstream()`. This doesn’t look necessary for 
lazy loading or circular imports, so it should be moved to the module level for 
consistency and to satisfy import placement guidelines.



##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -1777,25 +1778,37 @@ def _create_dagruns_for_partitioned_asset_dags(self, 
session: Session) -> set[st
                 self.log.error("Dag '%s' not found in serialized_dag table", 
apdr.target_dag_id)
                 continue
 
-            asset_models = session.scalars(
-                select(AssetModel).where(
-                    exists(
-                        select(1).where(
-                            PartitionedAssetKeyLog.asset_id == AssetModel.id,
-                            PartitionedAssetKeyLog.asset_partition_dag_run_id 
== apdr.id,
-                            PartitionedAssetKeyLog.target_partition_key == 
apdr.partition_key,
+            asset_models = list(
+                session.scalars(
+                    select(AssetModel).where(
+                        exists(
+                            select(1).where(
+                                PartitionedAssetKeyLog.asset_id == 
AssetModel.id,
+                                
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
+                                PartitionedAssetKeyLog.target_partition_key == 
apdr.partition_key,
+                            )
                         )
                     )
                 )
             )
-            statuses: dict[SerializedAssetUniqueKey, bool] = {
-                SerializedAssetUniqueKey.from_asset(a): True for a in 
asset_models
-            }
-            # todo: AIP-76 so, this basically works when we only require one 
partition from each asset to be there
-            #  but, we ultimately need rollup ability
-            #  that is, we need to ensure that whenever it is many -> one 
partitions, then we need to ensure
-            #  that all the required keys are there
-            #  one way to do this would be just to figure out what the count 
should be
+            timetable = dag.timetable
+            statuses: dict[SerializedAssetUniqueKey, bool] = {}
+            for asset_model in asset_models:
+                if timetable.partitioned:
+                    mapper = 
timetable.get_partition_mapper(name=asset_model.name, uri=asset_model.uri)
+                    if isinstance(mapper, RollupMapper):
+                        expected = mapper.to_upstream(apdr.partition_key)
+                        actual = set(
+                            session.scalars(
+                                
select(PartitionedAssetKeyLog.source_partition_key).where(
+                                    PartitionedAssetKeyLog.asset_id == 
asset_model.id,
+                                    
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
+                                )
+                            )
+                        )

Review Comment:
   This introduces a per-asset query against `PartitionedAssetKeyLog` inside 
the `for asset_model in asset_models` loop (N+1 pattern). Since this runs in 
the scheduler loop, it can become a bottleneck for partitioned DAGs with many 
upstream assets. Consider fetching all `(asset_id, source_partition_key)` rows 
for the `apdr.id` in one query and grouping in Python before the loop.
   ```suggestion
               # Batch-load all source partition keys for this asset partition 
dag run to avoid
               # issuing a per-asset query (N+1 pattern) inside the loop below.
               partition_keys_by_asset_id: dict[int, set] = defaultdict(set)
               for asset_id, source_partition_key in session.execute(
                   select(
                       PartitionedAssetKeyLog.asset_id,
                       PartitionedAssetKeyLog.source_partition_key,
                   ).where(
                       PartitionedAssetKeyLog.asset_partition_dag_run_id == 
apdr.id,
                   )
               ):
                   
partition_keys_by_asset_id[asset_id].add(source_partition_key)
               for asset_model in asset_models:
                   if timetable.partitioned:
                       mapper = 
timetable.get_partition_mapper(name=asset_model.name, uri=asset_model.uri)
                       if isinstance(mapper, RollupMapper):
                           expected = mapper.to_upstream(apdr.partition_key)
                           actual = 
partition_keys_by_asset_id.get(asset_model.id, set())
   ```



##########
task-sdk/src/airflow/sdk/definitions/partition_mappers/temporal.py:
##########
@@ -43,17 +45,31 @@ class StartOfDayMapper(_BaseTemporalMapper):
     default_output_format = "%Y-%m-%d"
 
 
-class StartOfWeekMapper(_BaseTemporalMapper):
-    """Map a time-based partition key to week."""
+class StartOfWeekMapper(_BaseTemporalMapper, RollupMapper):
+    """Map a time-based partition key to week (Mon–Sun), requiring all 7 daily 
keys."""
 
     default_output_format = "%Y-%m-%d (W%V)"
 
+    def to_upstream(self, downstream_key: str) -> frozenset[str]:
+        # The output format starts with %Y-%m-%d which is always the Monday of 
the week.
+        week_start = datetime.strptime(downstream_key[:10], "%Y-%m-%d")
+        return frozenset((week_start + 
timedelta(days=i)).strftime(self.input_format) for i in range(7))

Review Comment:
   `to_upstream()` assumes the downstream week key always begins with 
`%Y-%m-%d` and parses `downstream_key[:10]` with a hard-coded format. Since 
`_BaseTemporalMapper` allows `output_format` overrides, customized output 
formats can break rollup mapping. Consider validating that `self.output_format` 
starts with `%Y-%m-%d` (and raising a clear error if not), or otherwise 
deriving the date portion consistently from the configured `output_format`.



##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -1777,25 +1778,37 @@ def _create_dagruns_for_partitioned_asset_dags(self, 
session: Session) -> set[st
                 self.log.error("Dag '%s' not found in serialized_dag table", 
apdr.target_dag_id)
                 continue
 
-            asset_models = session.scalars(
-                select(AssetModel).where(
-                    exists(
-                        select(1).where(
-                            PartitionedAssetKeyLog.asset_id == AssetModel.id,
-                            PartitionedAssetKeyLog.asset_partition_dag_run_id 
== apdr.id,
-                            PartitionedAssetKeyLog.target_partition_key == 
apdr.partition_key,
+            asset_models = list(
+                session.scalars(
+                    select(AssetModel).where(
+                        exists(
+                            select(1).where(
+                                PartitionedAssetKeyLog.asset_id == 
AssetModel.id,
+                                
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
+                                PartitionedAssetKeyLog.target_partition_key == 
apdr.partition_key,
+                            )
                         )
                     )
                 )
             )
-            statuses: dict[SerializedAssetUniqueKey, bool] = {
-                SerializedAssetUniqueKey.from_asset(a): True for a in 
asset_models
-            }
-            # todo: AIP-76 so, this basically works when we only require one 
partition from each asset to be there
-            #  but, we ultimately need rollup ability
-            #  that is, we need to ensure that whenever it is many -> one 
partitions, then we need to ensure
-            #  that all the required keys are there
-            #  one way to do this would be just to figure out what the count 
should be
+            timetable = dag.timetable
+            statuses: dict[SerializedAssetUniqueKey, bool] = {}
+            for asset_model in asset_models:
+                if timetable.partitioned:
+                    mapper = 
timetable.get_partition_mapper(name=asset_model.name, uri=asset_model.uri)
+                    if isinstance(mapper, RollupMapper):
+                        expected = mapper.to_upstream(apdr.partition_key)
+                        actual = set(
+                            session.scalars(
+                                
select(PartitionedAssetKeyLog.source_partition_key).where(
+                                    PartitionedAssetKeyLog.asset_id == 
asset_model.id,
+                                    
PartitionedAssetKeyLog.asset_partition_dag_run_id == apdr.id,
+                                )
+                            )
+                        )
+                        
statuses[SerializedAssetUniqueKey.from_asset(asset_model)] = 
expected.issubset(actual)
+                        continue
+                statuses[SerializedAssetUniqueKey.from_asset(asset_model)] = 
True

Review Comment:
   This adds new rollup readiness behavior (waiting for a complete upstream key 
set via `to_upstream()`), but there doesn’t appear to be a regression test 
covering the many→one case (e.g. weekly or monthly rollup) in the scheduler’s 
partitioned asset DAG run creation tests. Adding a unit test that proves the 
scheduler does *not* create a DagRun until all expected upstream keys are 
present would help prevent regressions.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to