This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new cabd768309 Merge nowait and skip_locked into with_row_locks (#36889)
cabd768309 is described below
commit cabd768309296f5a9c92604d704307f816ff8786
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Sat Jan 20 07:13:05 2024 +0800
Merge nowait and skip_locked into with_row_locks (#36889)
Since the two functions are always used in conjunction with the last, we
can simply handle the two arguments specially in with_row_locks, instead
of doing the same checks over and over again.
The two functions are removed outright since they are not documented and
thus technically not subject to backward compatibility. I highly doubt
anyone is using them directly due to their highly specific nature.
---
airflow/dag_processing/manager.py | 6 +--
airflow/jobs/scheduler_job_runner.py | 21 ++---------
airflow/models/abstractoperator.py | 4 +-
airflow/models/dag.py | 3 +-
airflow/models/dagrun.py | 4 +-
airflow/models/pool.py | 4 +-
airflow/utils/sqlalchemy.py | 72 ++++++++++++++----------------------
tests/utils/test_sqlalchemy.py | 66 ---------------------------------
8 files changed, 40 insertions(+), 140 deletions(-)
diff --git a/airflow/dag_processing/manager.py
b/airflow/dag_processing/manager.py
index b82a26f376..e1fa7a43bd 100644
--- a/airflow/dag_processing/manager.py
+++ b/airflow/dag_processing/manager.py
@@ -63,7 +63,7 @@ from airflow.utils.process_utils import (
)
from airflow.utils.retries import retry_db_transaction
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import prohibit_commit, skip_locked,
with_row_locks
+from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks
if TYPE_CHECKING:
from multiprocessing.connection import Connection as
MultiprocessingConnection
@@ -681,9 +681,7 @@ class DagFileProcessorManager(LoggingMixin):
DbCallbackRequest.processor_subdir ==
self.get_dag_directory(),
)
query =
query.order_by(DbCallbackRequest.priority_weight.asc()).limit(max_callbacks)
- query = with_row_locks(
- query, of=DbCallbackRequest, session=session,
**skip_locked(session=session)
- )
+ query = with_row_locks(query, of=DbCallbackRequest,
session=session, skip_locked=True)
callbacks = session.scalars(query)
for callback in callbacks:
try:
diff --git a/airflow/jobs/scheduler_job_runner.py
b/airflow/jobs/scheduler_job_runner.py
index 85dccbb26a..627e0d1468 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -68,7 +68,6 @@ from airflow.utils.session import NEW_SESSION,
create_session, provide_session
from airflow.utils.sqlalchemy import (
is_lock_not_available_error,
prohibit_commit,
- skip_locked,
tuple_in_condition,
with_row_locks,
)
@@ -399,12 +398,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
timer.start()
try:
- query = with_row_locks(
- query,
- of=TI,
- session=session,
- **skip_locked(session=session),
- )
+ query = with_row_locks(query, of=TI, session=session,
skip_locked=True)
task_instances_to_examine: list[TI] =
session.scalars(query).all()
timer.stop(send=True)
@@ -706,12 +700,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
query =
select(TI).where(filter_for_tis).options(selectinload(TI.dag_model))
# row lock this entire set of taskinstances to make sure the scheduler
doesn't fail when we have
# multi-schedulers
- tis_query: Query = with_row_locks(
- query,
- of=TI,
- session=session,
- **skip_locked(session=session),
- )
+ tis_query: Query = with_row_locks(query, of=TI, session=session,
skip_locked=True)
tis: Iterator[TI] = session.scalars(tis_query)
for ti in tis:
try_number = ti_primary_key_to_try_number_map[ti.key.primary]
@@ -1434,7 +1423,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
select(DagModel).where(DagModel.dag_id ==
dag_run.dag_id).options(joinedload(DagModel.parent_dag))
)
dag_model = session.scalars(
- with_row_locks(query, of=DagModel, session=session,
**skip_locked(session=session))
+ with_row_locks(query, of=DagModel, session=session,
skip_locked=True)
).one_or_none()
if not dag:
@@ -1660,9 +1649,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
)
# Lock these rows, so that another scheduler can't try and
adopt these too
- tis_to_adopt_or_reset = with_row_locks(
- query, of=TI, session=session,
**skip_locked(session=session)
- )
+ tis_to_adopt_or_reset = with_row_locks(query, of=TI,
session=session, skip_locked=True)
tis_to_adopt_or_reset =
session.scalars(tis_to_adopt_or_reset).all()
to_reset =
self.job.executor.try_adopt_task_instances(tis_to_adopt_or_reset)
diff --git a/airflow/models/abstractoperator.py
b/airflow/models/abstractoperator.py
index f5a266f4b1..4ec8335255 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -35,7 +35,7 @@ from airflow.utils.db import exists_query
from airflow.utils.log.secrets_masker import redact
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.setup_teardown import SetupTeardownContext
-from airflow.utils.sqlalchemy import skip_locked, with_row_locks
+from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.task_group import MappedTaskGroup
from airflow.utils.trigger_rule import TriggerRule
@@ -625,7 +625,7 @@ class AbstractOperator(Templater, DAGNode):
TaskInstance.run_id == run_id,
TaskInstance.map_index >= total_expanded_ti_count,
)
- query = with_row_locks(query, of=TaskInstance, session=session,
**skip_locked(session=session))
+ query = with_row_locks(query, of=TaskInstance, session=session,
skip_locked=True)
to_update = session.scalars(query)
for ti in to_update:
ti.state = TaskInstanceState.REMOVED
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 4e70b87817..9ee3409c0d 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -128,7 +128,6 @@ from airflow.utils.sqlalchemy import (
Interval,
UtcDateTime,
lock_rows,
- skip_locked,
tuple_in_condition,
with_row_locks,
)
@@ -3789,7 +3788,7 @@ class DagModel(Base):
)
return (
- session.scalars(with_row_locks(query, of=cls, session=session,
**skip_locked(session=session))),
+ session.scalars(with_row_locks(query, of=cls, session=session,
skip_locked=True)),
dataset_triggered_dag_info,
)
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 6a1e71d4d7..501470fd56 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -65,7 +65,7 @@ from airflow.utils import timezone
from airflow.utils.helpers import chunks, is_container, prune_dict
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, skip_locked,
tuple_in_condition, with_row_locks
+from airflow.utils.sqlalchemy import UtcDateTime, nulls_first,
tuple_in_condition, with_row_locks
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import NOTSET, DagRunType
@@ -365,7 +365,7 @@ class DagRun(Base, LoggingMixin):
query = query.where(DagRun.execution_date <= func.now())
return session.scalars(
- with_row_locks(query.limit(max_number), of=cls, session=session,
**skip_locked(session=session))
+ with_row_locks(query.limit(max_number), of=cls, session=session,
skip_locked=True)
)
@classmethod
diff --git a/airflow/models/pool.py b/airflow/models/pool.py
index 1960c0a867..3ca7293ffe 100644
--- a/airflow/models/pool.py
+++ b/airflow/models/pool.py
@@ -27,7 +27,7 @@ from airflow.ti_deps.dependencies_states import
EXECUTION_STATES
from airflow.typing_compat import TypedDict
from airflow.utils.db import exists_query
from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import nowait, with_row_locks
+from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.state import TaskInstanceState
if TYPE_CHECKING:
@@ -172,7 +172,7 @@ class Pool(Base):
query = select(Pool.pool, Pool.slots, Pool.include_deferred)
if lock_rows:
- query = with_row_locks(query, session=session, **nowait(session))
+ query = with_row_locks(query, session=session, nowait=True)
pool_rows = session.execute(query)
for pool_name, total_slots, include_deferred in pool_rows:
diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index 9d9b248ec7..2dc495811a 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -334,46 +334,6 @@ class Interval(TypeDecorator):
return data
-def skip_locked(session: Session) -> dict[str, Any]:
- """
- Return kargs for passing to `with_for_update()` suitable for the current
DB engine version.
-
- We do this as we document the fact that on DB engines that don't support
this construct, we do not
- support/recommend running HA scheduler. If a user ignores this and tries
anyway everything will still
- work, just slightly slower in some circumstances.
-
- Specifically don't emit SKIP LOCKED for MySQL < 8, or MariaDB, neither of
which support this construct
-
- See https://jira.mariadb.org/browse/MDEV-13115
- """
- dialect = session.bind.dialect
-
- if dialect.name != "mysql" or dialect.supports_for_update_of:
- return {"skip_locked": True}
- else:
- return {}
-
-
-def nowait(session: Session) -> dict[str, Any]:
- """
- Return kwargs for passing to `with_for_update()` suitable for the current
DB engine version.
-
- We do this as we document the fact that on DB engines that don't support
this construct, we do not
- support/recommend running HA scheduler. If a user ignores this and tries
anyway everything will still
- work, just slightly slower in some circumstances.
-
- Specifically don't emit NOWAIT for MySQL < 8, or MariaDB, neither of which
support this construct
-
- See https://jira.mariadb.org/browse/MDEV-13115
- """
- dialect = session.bind.dialect
-
- if dialect.name != "mysql" or dialect.supports_for_update_of:
- return {"nowait": True}
- else:
- return {}
-
-
def nulls_first(col, session: Session) -> dict[str, Any]:
"""Specify *NULLS FIRST* to the column ordering.
@@ -390,22 +350,44 @@ def nulls_first(col, session: Session) -> dict[str, Any]:
USE_ROW_LEVEL_LOCKING: bool = conf.getboolean("scheduler",
"use_row_level_locking", fallback=True)
-def with_row_locks(query: Query, session: Session, **kwargs) -> Query:
+def with_row_locks(
+ query: Query,
+ session: Session,
+ *,
+ nowait: bool = False,
+ skip_locked: bool = False,
+ **kwargs,
+) -> Query:
"""
- Apply with_for_update to an SQLAlchemy query, if row level locking is in
use.
+ Apply with_for_update to the SQLAlchemy query if row level locking is in
use.
+
+ This wrapper is needed so we don't use the syntax on unsupported database
+ engines. In particular, MySQL (prior to 8.0) and MariaDB do not support
+ row locking, where we do not support nor recommend running HA scheduler. If
+ a user ignores this and tries anyway, everything will still work, just
+ slightly slower in some circumstances.
+
+ See https://jira.mariadb.org/browse/MDEV-13115
:param query: An SQLAlchemy Query object
:param session: ORM Session
+ :param nowait: If set to True, will pass NOWAIT to supported database
backends.
+ :param skip_locked: If set to True, will pass SKIP LOCKED to supported
database backends.
:param kwargs: Extra kwargs to pass to with_for_update (of, nowait,
skip_locked, etc)
:return: updated query
"""
dialect = session.bind.dialect
# Don't use row level locks if the MySQL dialect (Mariadb & MySQL < 8)
does not support it.
- if USE_ROW_LEVEL_LOCKING and (dialect.name != "mysql" or
dialect.supports_for_update_of):
- return query.with_for_update(**kwargs)
- else:
+ if not USE_ROW_LEVEL_LOCKING:
+ return query
+ if dialect.name == "mysql" and not dialect.supports_for_update_of:
return query
+ if nowait:
+ kwargs["nowait"] = True
+ if skip_locked:
+ kwargs["skip_locked"] = True
+ return query.with_for_update(**kwargs)
@contextlib.contextmanager
diff --git a/tests/utils/test_sqlalchemy.py b/tests/utils/test_sqlalchemy.py
index e01d0904ad..16ba6b392d 100644
--- a/tests/utils/test_sqlalchemy.py
+++ b/tests/utils/test_sqlalchemy.py
@@ -36,9 +36,7 @@ from airflow.settings import Session
from airflow.utils.sqlalchemy import (
ExecutorConfigType,
ensure_pod_is_valid_after_unpickling,
- nowait,
prohibit_commit,
- skip_locked,
with_row_locks,
)
from airflow.utils.state import State
@@ -117,70 +115,6 @@ class TestSqlAlchemyUtils:
)
dag.clear()
- @pytest.mark.parametrize(
- "dialect, supports_for_update_of, expected_return_value",
- [
- (
- "postgresql",
- True,
- {"skip_locked": True},
- ),
- (
- "mysql",
- False,
- {},
- ),
- (
- "mysql",
- True,
- {"skip_locked": True},
- ),
- (
- "sqlite",
- False,
- {"skip_locked": True},
- ),
- ],
- )
- def test_skip_locked(self, dialect, supports_for_update_of,
expected_return_value):
- session = mock.Mock()
- session.bind.dialect.name = dialect
- session.bind.dialect.supports_for_update_of = supports_for_update_of
- assert skip_locked(session=session) == expected_return_value
-
- @pytest.mark.parametrize(
- "dialect, supports_for_update_of, expected_return_value",
- [
- (
- "postgresql",
- True,
- {"nowait": True},
- ),
- (
- "mysql",
- False,
- {},
- ),
- (
- "mysql",
- True,
- {"nowait": True},
- ),
- (
- "sqlite",
- False,
- {
- "nowait": True,
- },
- ),
- ],
- )
- def test_nowait(self, dialect, supports_for_update_of,
expected_return_value):
- session = mock.Mock()
- session.bind.dialect.name = dialect
- session.bind.dialect.supports_for_update_of = supports_for_update_of
- assert nowait(session=session) == expected_return_value
-
@pytest.mark.parametrize(
"dialect, supports_for_update_of, use_row_level_lock_conf,
expected_use_row_level_lock",
[