This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 953dd94c6050510a39582dd223113a8994c1312b
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Mon Sep 22 16:30:43 2025 +0100

    Add tests for DBManager upgrade and downgrade (#55940)
    
    This commit adds more tests to cover various aspects of DB manager
    upgrade and downgrades.
    
    (cherry picked from commit 166a937d3be980555173918c3c82e54d3289efea)
---
 .../tests/unit/cli/commands/test_db_command.py     | 276 +++++++++++++++++++++
 .../unit/cli/commands/test_db_manager_command.py   | 191 ++++++++++++--
 2 files changed, 448 insertions(+), 19 deletions(-)

diff --git a/airflow-core/tests/unit/cli/commands/test_db_command.py 
b/airflow-core/tests/unit/cli/commands/test_db_command.py
index 8a6ff6dd4ca..cd7d61f5f24 100644
--- a/airflow-core/tests/unit/cli/commands/test_db_command.py
+++ b/airflow-core/tests/unit/cli/commands/test_db_command.py
@@ -47,6 +47,118 @@ class TestCliDb:
         db_command.resetdb(self.parser.parse_args(["db", "reset", "--yes", 
"--skip-init"]))
         mock_resetdb.assert_called_once_with(skip_init=True)
 
+    def test_run_db_migrate_command_success_and_messages(self, capsys):
+        class Args:
+            to_revision = None
+            to_version = None
+            from_revision = None
+            from_version = None
+            show_sql_only = False
+
+        called = {}
+
+        def fake_command(**kwargs):
+            called.update(kwargs)
+
+        heads = {"2.10.0": "22ed7efa9da2"}
+
+        db_command.run_db_migrate_command(Args(), fake_command, heads)
+        out = capsys.readouterr().out
+        assert "Performing upgrade" in out
+        assert "Database migrating done!" in out
+        assert called == {"to_revision": None, "from_revision": None, 
"show_sql_only": False}
+
+    def test_run_db_migrate_command_offline_generation(self, capsys):
+        class Args:
+            to_revision = None
+            to_version = None
+            from_revision = None
+            from_version = None
+            show_sql_only = True
+
+        called = {}
+
+        def fake_command(**kwargs):
+            called.update(kwargs)
+
+        heads = {"2.10.0": "22ed7efa9da2"}
+
+        db_command.run_db_migrate_command(Args(), fake_command, heads)
+        out = capsys.readouterr().out
+        assert "Generating sql for upgrade" in out
+        assert called == {"to_revision": None, "from_revision": None, 
"show_sql_only": True}
+
+    @pytest.mark.parametrize(
+        "args, match",
+        [
+            (
+                {
+                    "to_revision": "abc",
+                    "to_version": "2.10.0",
+                    "from_revision": None,
+                    "from_version": None,
+                    "show_sql_only": False,
+                },
+                "Cannot supply both",
+            ),
+            (
+                {
+                    "to_revision": None,
+                    "to_version": None,
+                    "from_revision": "abc",
+                    "from_version": "2.10.0",
+                    "show_sql_only": True,
+                },
+                "Cannot supply both",
+            ),
+            (
+                {
+                    "to_revision": None,
+                    "to_version": None,
+                    "from_revision": "abc",
+                    "from_version": None,
+                    "show_sql_only": False,
+                },
+                "only .* with `--show-sql-only`",
+            ),
+            (
+                {
+                    "to_revision": None,
+                    "to_version": "abc",
+                    "from_revision": None,
+                    "from_version": None,
+                    "show_sql_only": False,
+                },
+                "Invalid version",
+            ),
+            (
+                {
+                    "to_revision": None,
+                    "to_version": "2.1.25",
+                    "from_revision": None,
+                    "from_version": None,
+                    "show_sql_only": False,
+                },
+                "Unknown version",
+            ),
+        ],
+    )
+    def test_run_db_migrate_command_validation_errors(self, args, match):
+        class Args:
+            to_revision = args["to_revision"]
+            to_version = args["to_version"]
+            from_revision = args["from_revision"]
+            from_version = args["from_version"]
+            show_sql_only = args["show_sql_only"]
+
+        def fake_command(**kwargs):
+            pass
+
+        heads = {"2.10.0": "22ed7efa9da2"}
+
+        with pytest.raises(SystemExit, match=match):
+            db_command.run_db_migrate_command(Args(), fake_command, heads)
+
     @mock.patch("airflow.cli.commands.db_command.db.check_migrations")
     def test_cli_check_migrations(self, mock_wait_for_migrations):
         db_command.check_migrations(self.parser.parse_args(["db", 
"check-migrations"]))
@@ -317,6 +429,170 @@ class TestCliDb:
         with pytest.raises(AirflowException, match=r"Unknown driver: 
invalid\+psycopg"):
             db_command.shell(self.parser.parse_args(["db", "shell"]))
 
+    def test_run_db_downgrade_command_success_and_messages(self, capsys):
+        class Args:
+            to_revision = "abc"
+            to_version = None
+            from_revision = None
+            from_version = None
+            show_sql_only = False
+            yes = True
+
+        called = {}
+
+        def fake_command(**kwargs):
+            called.update(kwargs)
+
+        heads = {"2.10.0": "22ed7efa9da2"}
+
+        db_command.run_db_downgrade_command(Args(), fake_command, heads)
+        out = capsys.readouterr().out
+        assert "Performing downgrade" in out
+        assert "Downgrade complete" in out
+        assert called == {"to_revision": "abc", "from_revision": None, 
"show_sql_only": False}
+
+    def test_run_db_downgrade_command_offline_generation(self, capsys):
+        class Args:
+            to_revision = None
+            to_version = "2.10.0"
+            from_revision = None
+            from_version = None
+            show_sql_only = True
+            yes = False
+
+        called = {}
+
+        def fake_command(**kwargs):
+            called.update(kwargs)
+
+        heads = {"2.10.0": "22ed7efa9da2"}
+
+        db_command.run_db_downgrade_command(Args(), fake_command, heads)
+        out = capsys.readouterr().out
+        assert "Generating sql for downgrade" in out
+        assert called == {"to_revision": "22ed7efa9da2", "from_revision": 
None, "show_sql_only": True}
+
+    @pytest.mark.parametrize(
+        "args, match",
+        [
+            (
+                {
+                    "to_revision": None,
+                    "to_version": None,
+                    "from_revision": None,
+                    "from_version": None,
+                    "show_sql_only": False,
+                    "yes": False,
+                },
+                "Must provide either",
+            ),
+            (
+                {
+                    "to_revision": "abc",
+                    "to_version": "2.10.0",
+                    "from_revision": None,
+                    "from_version": None,
+                    "show_sql_only": False,
+                    "yes": True,
+                },
+                "Cannot supply both",
+            ),
+            (
+                {
+                    "to_revision": "abc",
+                    "to_version": None,
+                    "from_revision": "abc1",
+                    "from_version": "2.10.0",
+                    "show_sql_only": True,
+                    "yes": True,
+                },
+                "may not be combined",
+            ),
+            (
+                {
+                    "to_revision": None,
+                    "to_version": "2.1.25",
+                    "from_revision": None,
+                    "from_version": None,
+                    "show_sql_only": False,
+                    "yes": True,
+                },
+                "not supported",
+            ),
+            (
+                {
+                    "to_revision": None,
+                    "to_version": None,
+                    "from_revision": "abc",
+                    "from_version": None,
+                    "show_sql_only": False,
+                    "yes": True,
+                },
+                "only .* with `--show-sql-only`",
+            ),
+        ],
+    )
+    def test_run_db_downgrade_command_validation_errors(self, args, match):
+        class Args:
+            to_revision = args["to_revision"]
+            to_version = args["to_version"]
+            from_revision = args["from_revision"]
+            from_version = args["from_version"]
+            show_sql_only = args["show_sql_only"]
+            yes = args["yes"]
+
+        def fake_command(**kwargs):
+            pass
+
+        heads = {"2.10.0": "22ed7efa9da2"}
+
+        with pytest.raises(SystemExit, match=match):
+            db_command.run_db_downgrade_command(Args(), fake_command, heads)
+
+    @mock.patch("airflow.cli.commands.db_command.input")
+    def test_run_db_downgrade_command_confirmation_yes_calls_command(self, 
mock_input, capsys):
+        mock_input.return_value = "Y"
+
+        class Args:
+            to_revision = "abc"
+            to_version = None
+            from_revision = None
+            from_version = None
+            show_sql_only = False
+            yes = False
+
+        called = {}
+
+        def fake_command(**kwargs):
+            called.update(kwargs)
+
+        heads = {"2.10.0": "22ed7efa9da2"}
+
+        db_command.run_db_downgrade_command(Args(), fake_command, heads)
+        out = capsys.readouterr().out
+        assert "Performing downgrade" in out
+        assert called == {"to_revision": "abc", "from_revision": None, 
"show_sql_only": False}
+
+    @mock.patch("airflow.cli.commands.db_command.input")
+    def test_run_db_downgrade_command_confirmation_no_cancels(self, 
mock_input):
+        mock_input.return_value = "n"
+
+        class Args:
+            to_revision = "abc"
+            to_version = None
+            from_revision = None
+            from_version = None
+            show_sql_only = False
+            yes = False
+
+        def fake_command(**kwargs):
+            raise AssertionError("Command should not be called when cancelled")
+
+        heads = {"2.10.0": "22ed7efa9da2"}
+
+        with pytest.raises(SystemExit, match="Cancelled"):
+            db_command.run_db_downgrade_command(Args(), fake_command, heads)
+
     @pytest.mark.parametrize(
         "args, match",
         [
diff --git a/airflow-core/tests/unit/cli/commands/test_db_manager_command.py 
b/airflow-core/tests/unit/cli/commands/test_db_manager_command.py
index a5e2ae58d10..df202594205 100644
--- a/airflow-core/tests/unit/cli/commands/test_db_manager_command.py
+++ b/airflow-core/tests/unit/cli/commands/test_db_manager_command.py
@@ -22,48 +22,201 @@ import pytest
 
 from airflow.cli import cli_parser
 from airflow.cli.commands import db_manager_command
+from airflow.utils.db_manager import BaseDBManager
 
 from tests_common.test_utils.config import conf_vars
 
 pytestmark = pytest.mark.db_test
 
 
+class FakeDBManager(BaseDBManager):
+    metadata = mock.MagicMock()
+    migration_dir = "migrations"
+    alembic_file = "alembic.ini"
+    version_table_name = "alembic_version_ext"
+    revision_heads_map = {}
+
+    # Test controls
+    raise_on_init = False
+    instances: list[FakeDBManager] = []
+    last_instance: FakeDBManager | None = None
+
+    def __init__(self, session):
+        if self.raise_on_init:
+            raise AssertionError("Should not instantiate manager when 
cancelled")
+        super().__init__(session)
+        self._resetdb_mock = mock.MagicMock(name="resetdb")
+        self._upgradedb_mock = mock.MagicMock(name="upgradedb")
+        self._downgrade_mock = mock.MagicMock(name="downgrade")
+        FakeDBManager.instances.append(self)
+        FakeDBManager.last_instance = self
+
+    def resetdb(self, skip_init=False):
+        return self._resetdb_mock(skip_init=skip_init)
+
+    def upgradedb(self, to_revision=None, from_revision=None, 
show_sql_only=False):
+        return self._upgradedb_mock(
+            to_revision=to_revision, from_revision=from_revision, 
show_sql_only=show_sql_only
+        )
+
+    def downgrade(self, to_revision, from_revision=None, show_sql_only=False):
+        return self._downgrade_mock(
+            to_revision=to_revision, from_revision=from_revision, 
show_sql_only=show_sql_only
+        )
+
+
[email protected](autouse=True)
+def _reset_fake_db_manager():
+    FakeDBManager.revision_heads_map = {}
+    FakeDBManager.raise_on_init = False
+    FakeDBManager.instances = []
+    FakeDBManager.last_instance = None
+    return None
+
+
 class TestCliDbManager:
     @classmethod
     def setup_class(cls):
         cls.parser = cli_parser.get_parser()
 
+    @mock.patch("airflow.cli.commands.db_manager_command.settings.Session", 
autospec=True)
     @mock.patch("airflow.cli.commands.db_manager_command._get_db_manager")
-    def test_cli_resetdb(self, mock_get_db_manager):
-        manager_name = "path.to.TestDBManager"
-        db_manager_command.resetdb(self.parser.parse_args(["db-manager", 
"reset", manager_name, "--yes"]))
-        mock_get_db_manager.assert_called_once_with("path.to.TestDBManager")
-        mock_get_db_manager.return_value.resetdb.asset_called_once()
+    def test_cli_resetdb_yes_calls_reset(self, mock_get_db_manager, 
mock_session):
+        manager_name = "path.to.FakeDBManager"
+        mock_get_db_manager.return_value = FakeDBManager
+
+        args = self.parser.parse_args(["db-manager", "reset", manager_name, 
"--yes"])
+        db_manager_command.resetdb(args)
+
+        mock_get_db_manager.assert_called_once_with(manager_name)
+        assert len(FakeDBManager.instances) == 1
+        
FakeDBManager.last_instance._resetdb_mock.assert_called_once_with(skip_init=False)
 
+    @mock.patch("airflow.cli.commands.db_manager_command.settings.Session", 
autospec=True)
     @mock.patch("airflow.cli.commands.db_manager_command._get_db_manager")
-    def test_cli_resetdb_skip_init(self, mock_get_db_manager):
-        manager_name = "path.to.TestDBManager"
-        db_manager_command.resetdb(
-            self.parser.parse_args(["db-manager", "reset", manager_name, 
"--yes", "--skip-init"])
-        )
+    def test_cli_resetdb_skip_init(self, mock_get_db_manager, mock_session):
+        manager_name = "path.to.FakeDBManager"
+        mock_get_db_manager.return_value = FakeDBManager
+
+        args = self.parser.parse_args(["db-manager", "reset", manager_name, 
"--yes", "--skip-init"])
+        db_manager_command.resetdb(args)
         mock_get_db_manager.assert_called_once_with(manager_name)
-        
mock_get_db_manager.return_value.resetdb.asset_called_once_with(skip_init=True)
+        assert len(FakeDBManager.instances) == 1
+        
FakeDBManager.last_instance._resetdb_mock.assert_called_once_with(skip_init=True)
+
+    @mock.patch("airflow.cli.commands.db_manager_command.input")
+    @mock.patch("airflow.cli.commands.db_manager_command._get_db_manager")
+    def test_cli_resetdb_prompt_yes(self, mock_get_db_manager, mock_input):
+        mock_input.return_value = "Y"
+        manager_name = "path.to.FakeDBManager"
+        mock_get_db_manager.return_value = FakeDBManager
+        args = self.parser.parse_args(["db-manager", "reset", manager_name])
+        db_manager_command.resetdb(args)
+        assert len(FakeDBManager.instances) == 1
+        
FakeDBManager.last_instance._resetdb_mock.assert_called_once_with(skip_init=False)
 
+    @mock.patch("airflow.cli.commands.db_manager_command.input")
+    @mock.patch("airflow.cli.commands.db_manager_command._get_db_manager")
+    def test_cli_resetdb_prompt_cancel(self, mock_get_db_manager, mock_input):
+        mock_input.return_value = "n"
+        manager_name = "path.to.FakeDBManager"
+        FakeDBManager.raise_on_init = True
+        mock_get_db_manager.return_value = FakeDBManager
+        args = self.parser.parse_args(["db-manager", "reset", manager_name])
+        with pytest.raises(SystemExit, match="Cancelled"):
+            db_manager_command.resetdb(args)
+        assert FakeDBManager.instances == []
+
+    @mock.patch("airflow.cli.commands.db_manager_command.settings.Session", 
autospec=True)
     @mock.patch("airflow.cli.commands.db_manager_command._get_db_manager")
     
@mock.patch("airflow.cli.commands.db_manager_command.run_db_migrate_command")
-    def test_cli_migrate_db(self, mock_run_db_migrate_cmd, 
mock_get_db_manager):
-        manager_name = "path.to.TestDBManager"
-        db_manager_command.migratedb(self.parser.parse_args(["db-manager", 
"migrate", manager_name]))
+    def test_cli_migrate_db(self, mock_run_db_migrate_cmd, 
mock_get_db_manager, mock_session):
+        manager_name = "path.to.FakeDBManager"
+        FakeDBManager.revision_heads_map = {"2.10.0": "22ed7efa9da2"}
+        mock_get_db_manager.return_value = FakeDBManager
+
+        args = self.parser.parse_args(["db-manager", "migrate", manager_name])
+        db_manager_command.migratedb(args)
+
         mock_get_db_manager.assert_called_once_with(manager_name)
-        mock_run_db_migrate_cmd.assert_called_once()
+        assert len(FakeDBManager.instances) == 1
+        # Validate run_db_migrate_command was called with the instance's 
upgradedb and correct heads map
+        called_args, called_kwargs = mock_run_db_migrate_cmd.call_args
+        assert called_args[0] is args
+        # Verify the bound method refers to the instance's upgradedb 
implementation
+        assert called_args[1].__self__ is FakeDBManager.last_instance
+        assert called_args[1].__func__ is FakeDBManager.upgradedb
+        assert called_kwargs["revision_heads_map"] == {"2.10.0": 
"22ed7efa9da2"}
+
+    @mock.patch("airflow.cli.commands.db_manager_command.settings.Session", 
autospec=True)
+    @mock.patch("airflow.cli.commands.db_manager_command._get_db_manager")
+    def test_cli_migrate_db_calls_upgradedb_with_args(self, 
mock_get_db_manager, mock_session):
+        manager_name = "path.to.FakeDBManager"
+        mock_get_db_manager.return_value = FakeDBManager
+
+        args = self.parser.parse_args(
+            [
+                "db-manager",
+                "migrate",
+                manager_name,
+                "--to-revision",
+                "abc",
+                "--from-revision",
+                "def",
+                "--show-sql-only",
+            ]
+        )
+        db_manager_command.migratedb(args)
+
+        assert FakeDBManager.last_instance is not None
+        FakeDBManager.last_instance._upgradedb_mock.assert_called_once_with(
+            to_revision="abc", from_revision="def", show_sql_only=True
+        )
 
+    @mock.patch("airflow.cli.commands.db_manager_command.settings.Session", 
autospec=True)
+    @mock.patch("airflow.cli.commands.db_manager_command._get_db_manager")
+    def test_cli_downgrade_db_calls_downgrade_with_args(self, 
mock_get_db_manager, mock_session):
+        manager_name = "path.to.FakeDBManager"
+        mock_get_db_manager.return_value = FakeDBManager
+
+        args = self.parser.parse_args(
+            [
+                "db-manager",
+                "downgrade",
+                manager_name,
+                "--to-revision",
+                "abc",
+                "--from-revision",
+                "def",
+                "--show-sql-only",
+            ]
+        )
+        db_manager_command.downgrade(args)
+
+        assert FakeDBManager.last_instance is not None
+        FakeDBManager.last_instance._downgrade_mock.assert_called_once_with(
+            to_revision="abc", from_revision="def", show_sql_only=True
+        )
+
+    @mock.patch("airflow.cli.commands.db_manager_command.settings.Session", 
autospec=True)
     @mock.patch("airflow.cli.commands.db_manager_command._get_db_manager")
     
@mock.patch("airflow.cli.commands.db_manager_command.run_db_downgrade_command")
-    def test_cli_downgrade_db(self, mock_run_db_downgrade_cmd, 
mock_get_db_manager):
-        manager_name = "path.to.TestDBManager"
-        db_manager_command.downgrade(self.parser.parse_args(["db-manager", 
"downgrade", manager_name]))
+    def test_cli_downgrade_db(self, mock_run_db_downgrade_cmd, 
mock_get_db_manager, mock_session):
+        manager_name = "path.to.FakeDBManager"
+        FakeDBManager.revision_heads_map = {"2.10.0": "22ed7efa9da2"}
+        mock_get_db_manager.return_value = FakeDBManager
+
+        args = self.parser.parse_args(["db-manager", "downgrade", 
manager_name])
+        db_manager_command.downgrade(args)
+
         mock_get_db_manager.assert_called_once_with(manager_name)
-        mock_run_db_downgrade_cmd.assert_called_once()
+        assert len(FakeDBManager.instances) == 1
+        called_args, called_kwargs = mock_run_db_downgrade_cmd.call_args
+        assert called_args[0] is args
+        # Verify the bound method refers to the instance's downgrade 
implementation
+        assert called_args[1].__self__ is FakeDBManager.last_instance
+        assert called_args[1].__func__ is FakeDBManager.downgrade
+        assert called_kwargs["revision_heads_map"] == {"2.10.0": 
"22ed7efa9da2"}
 
     @conf_vars({("database", "external_db_managers"): 
"path.to.manager.TestDBManager"})
     @mock.patch("airflow.cli.commands.db_manager_command.import_string")

Reply via email to