This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 57a889de35 Make docker operators always use `DockerHook` for API calls
(#28363)
57a889de35 is described below
commit 57a889de357b269ae104b721e2a4bb78b929cea9
Author: Andrey Anshin <[email protected]>
AuthorDate: Tue Jan 3 13:15:54 2023 +0400
Make docker operators always use `DockerHook` for API calls (#28363)
---
airflow/providers/docker/hooks/docker.py | 185 +++++++----
airflow/providers/docker/operators/docker.py | 83 ++---
airflow/providers/docker/operators/docker_swarm.py | 12 +-
tests/providers/conftest.py | 57 ++++
tests/providers/docker/conftest.py | 64 ++++
tests/providers/docker/hooks/test_docker.py | 362 ++++++++++++++-------
tests/providers/docker/operators/test_docker.py | 300 ++++++++---------
.../docker/operators/test_docker_swarm.py | 76 +++--
8 files changed, 725 insertions(+), 414 deletions(-)
diff --git a/airflow/providers/docker/hooks/docker.py
b/airflow/providers/docker/hooks/docker.py
index 981e4b6a48..f4c3e9c44c 100644
--- a/airflow/providers/docker/hooks/docker.py
+++ b/airflow/providers/docker/hooks/docker.py
@@ -17,23 +17,39 @@
# under the License.
from __future__ import annotations
-from typing import Any
+import json
+from typing import TYPE_CHECKING, Any
-from docker import APIClient # type: ignore[attr-defined]
-from docker.constants import DEFAULT_TIMEOUT_SECONDS # type:
ignore[attr-defined]
-from docker.errors import APIError # type: ignore[attr-defined]
+from docker import APIClient, TLSConfig
+from docker.constants import DEFAULT_TIMEOUT_SECONDS
+from docker.errors import APIError
-from airflow.exceptions import AirflowException
+from airflow.compat.functools import cached_property
+from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.hooks.base import BaseHook
-from airflow.utils.log.logging_mixin import LoggingMixin
+if TYPE_CHECKING:
+ from airflow.models import Connection
-class DockerHook(BaseHook, LoggingMixin):
+
+class DockerHook(BaseHook):
"""
- Interact with a Docker Daemon or Registry.
+ Interact with a Docker Daemon and Container Registry.
+
+ This class provide a thin wrapper around the ``docker.APIClient``.
+
+ .. seealso::
+ - :ref:`Docker Connection <howto/connection:docker>`
+ - `Docker SDK: Low-level API
<https://docker-py.readthedocs.io/en/stable/api.html?low-level-api>`_
- :param docker_conn_id: The :ref:`Docker connection id
<howto/connection:docker>`
- where credentials and extra configuration are stored
+ :param docker_conn_id: :ref:`Docker connection id
<howto/connection:docker>` where stored credentials
+ to Docker Registry. If set to ``None`` or empty then hook does not
login to Container Registry.
+ :param base_url: URL to the Docker server.
+ :param version: The version of the API to use. Use ``auto`` or ``None``
for automatically detect
+ the server's version.
+ :param tls: Is connection required TLS, for enable pass ``True`` for use
with default options,
+ or pass a `docker.tls.TLSConfig` object to use custom configurations.
+ :param timeout: Default timeout for API calls, in seconds.
"""
conn_name_attr = "docker_conn_id"
@@ -41,73 +57,128 @@ class DockerHook(BaseHook, LoggingMixin):
conn_type = "docker"
hook_name = "Docker"
- @staticmethod
- def get_ui_field_behaviour() -> dict[str, Any]:
- """Returns custom field behaviour"""
- return {
- "hidden_fields": ["schema"],
- "relabeling": {
- "host": "Registry URL",
- "login": "Username",
- },
- }
-
def __init__(
self,
docker_conn_id: str | None = default_conn_name,
base_url: str | None = None,
version: str | None = None,
- tls: str | None = None,
+ tls: TLSConfig | bool | None = None,
timeout: int = DEFAULT_TIMEOUT_SECONDS,
) -> None:
super().__init__()
if not base_url:
- raise AirflowException("No Docker base URL provided")
- if not version:
- raise AirflowException("No Docker API version provided")
-
- if not docker_conn_id:
- raise AirflowException("No Docker connection id provided")
-
- conn = self.get_connection(docker_conn_id)
-
- if not conn.host:
- raise AirflowException("No Docker URL provided")
- if not conn.login:
- raise AirflowException("No username provided")
- extra_options = conn.extra_dejson
+ raise AirflowException("URL to the Docker server not provided.")
+ elif tls:
+ if base_url.startswith("tcp://"):
+ base_url = base_url.replace("tcp://", "https://")
+ self.log.debug("Change `base_url` schema from 'tcp://' to
'https://'.")
+ if not base_url.startswith("https://"):
+ self.log.warning("When `tls` specified then `base_url`
expected 'https://' schema.")
+ self.docker_conn_id = docker_conn_id
self.__base_url = base_url
self.__version = version
- self.__tls = tls
+ self.__tls = tls or False
self.__timeout = timeout
- if conn.port:
- self.__registry = f"{conn.host}:{conn.port}"
- else:
- self.__registry = conn.host
- self.__username = conn.login
- self.__password = conn.password
- self.__email = extra_options.get("email")
- self.__reauth = extra_options.get("reauth") != "no"
+ self._client_created = False
- def get_conn(self) -> APIClient:
+ @staticmethod
+ def construct_tls_config(
+ ca_cert: str | None = None,
+ client_cert: str | None = None,
+ client_key: str | None = None,
+ assert_hostname: str | bool | None = None,
+ ssl_version: str | None = None,
+ ) -> TLSConfig | bool:
+ """
+ Construct TLSConfig object from parts.
+
+ :param ca_cert: Path to a PEM-encoded CA (Certificate Authority)
certificate file.
+ :param client_cert: Path to PEM-encoded certificate file.
+ :param client_key: Path to PEM-encoded key file.
+ :param assert_hostname: Hostname to match against the docker server
certificate
+ or ``False`` to disable the check.
+ :param ssl_version: Version of SSL to use when communicating with
docker daemon.
+ """
+ if ca_cert and client_cert and client_key:
+ # Ignore type error on SSL version here.
+ # It is deprecated and type annotation is wrong, and it should be
string.
+ return TLSConfig(
+ ca_cert=ca_cert,
+ client_cert=(client_cert, client_key),
+ verify=True,
+ ssl_version=ssl_version,
+ assert_hostname=assert_hostname,
+ )
+ return False
+
+ @cached_property
+ def api_client(self) -> APIClient:
+ """Create connection to docker host and return ``docker.APIClient``
(cached)."""
client = APIClient(
base_url=self.__base_url, version=self.__version, tls=self.__tls,
timeout=self.__timeout
)
- self.__login(client)
+ if self.docker_conn_id:
+ # Obtain connection and try to login to Container Registry only if
``docker_conn_id`` set.
+ self.__login(client, self.get_connection(self.docker_conn_id))
+
+ self._client_created = True
return client
- def __login(self, client) -> None:
- self.log.debug("Logging into Docker")
+ @property
+ def client_created(self) -> bool:
+ """Is api_client created or not."""
+ return self._client_created
+
+ def get_conn(self) -> APIClient:
+ """Create connection to docker host and return ``docker.APIClient``
(cached)."""
+ return self.api_client
+
+ def __login(self, client, conn: Connection) -> None:
+ if not conn.host:
+ raise AirflowNotFoundException("No Docker Registry URL provided.")
+ if not conn.login:
+ raise AirflowNotFoundException("No Docker Registry username
provided.")
+
+ registry = f"{conn.host}:{conn.port}" if conn.port else conn.host
+
+ # Parse additional optional parameters
+ email = conn.extra_dejson.get("email") or None
+ reauth = conn.extra_dejson.get("reauth", True)
+ if isinstance(reauth, str):
+ reauth = reauth.lower()
+ if reauth in ("y", "yes", "t", "true", "on", "1"):
+ reauth = True
+ elif reauth in ("n", "no", "f", "false", "off", "0"):
+ reauth = False
+ else:
+ raise ValueError(f"Unable parse `reauth` value {reauth!r} to
bool.")
+
try:
+ self.log.info("Login into Docker Registry: %s", registry)
client.login(
- username=self.__username,
- password=self.__password,
- registry=self.__registry,
- email=self.__email,
- reauth=self.__reauth,
+ username=conn.login, password=conn.password,
registry=registry, email=email, reauth=reauth
)
self.log.debug("Login successful")
- except APIError as docker_error:
- self.log.error("Docker login failed: %s", str(docker_error))
- raise AirflowException(f"Docker login failed: {docker_error}")
+ except APIError:
+ self.log.error("Login failed")
+ raise
+
+ @classmethod
+ def get_connection_form_widgets(cls) -> dict[str, Any]:
+ """Returns custom field behaviour"""
+ return {
+ "hidden_fields": ["schema"],
+ "relabeling": {
+ "host": "Registry URL",
+ "login": "Username",
+ },
+ "placeholders": {
+ "extra": json.dumps(
+ {
+ "reauth": False,
+ "email": "[email protected]",
+ }
+ )
+ },
+ }
diff --git a/airflow/providers/docker/operators/docker.py
b/airflow/providers/docker/operators/docker.py
index e02a813024..71fa447217 100644
--- a/airflow/providers/docker/operators/docker.py
+++ b/airflow/providers/docker/operators/docker.py
@@ -26,17 +26,20 @@ from io import BytesIO, StringIO
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Iterable, Sequence
-from docker import APIClient, tls # type: ignore[attr-defined]
-from docker.constants import DEFAULT_TIMEOUT_SECONDS # type:
ignore[attr-defined]
-from docker.errors import APIError # type: ignore[attr-defined]
-from docker.types import DeviceRequest, LogConfig, Mount # type:
ignore[attr-defined]
+from docker.constants import DEFAULT_TIMEOUT_SECONDS
+from docker.errors import APIError
+from docker.types import LogConfig, Mount
from dotenv import dotenv_values
+from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.models import BaseOperator
from airflow.providers.docker.hooks.docker import DockerHook
if TYPE_CHECKING:
+ from docker import APIClient
+ from docker.types import DeviceRequest
+
from airflow.utils.context import Context
@@ -258,8 +261,7 @@ class DockerOperator(BaseOperator):
self.cap_add = cap_add
self.extra_hosts = extra_hosts
- self.cli = None
- self.container = None
+ self.container: dict = None # type: ignore[assignment]
self.retrieve_output = retrieve_output
self.retrieve_output_path = retrieve_output_path
self.timeout = timeout
@@ -268,25 +270,35 @@ class DockerOperator(BaseOperator):
self.log_opts_max_file = log_opts_max_file
self.ipc_mode = ipc_mode
- def get_hook(self) -> DockerHook:
- """
- Retrieves hook for the operator.
-
- :return: The Docker Hook
- """
+ @cached_property
+ def hook(self) -> DockerHook:
+ """Create and return an DockerHook (cached)."""
+ tls_config = DockerHook.construct_tls_config(
+ ca_cert=self.tls_ca_cert,
+ client_cert=self.tls_client_cert,
+ client_key=self.tls_client_key,
+ assert_hostname=self.tls_hostname,
+ ssl_version=self.tls_ssl_version,
+ )
return DockerHook(
docker_conn_id=self.docker_conn_id,
base_url=self.docker_url,
version=self.api_version,
- tls=self.__get_tls_config(),
+ tls=tls_config,
timeout=self.timeout,
)
+ def get_hook(self) -> DockerHook:
+ """Create and return an DockerHook (cached)."""
+ return self.hook
+
+ @property
+ def cli(self) -> APIClient:
+ return self.hook.api_client
+
def _run_image(self) -> list[str] | str | None:
"""Run a Docker container with the provided image"""
self.log.info("Starting docker container from image %s", self.image)
- if not self.cli:
- raise Exception("The 'cli' should be initialized before!")
if self.mount_tmp_dir:
with TemporaryDirectory(prefix="airflowtmp",
dir=self.host_tmp_dir) as host_tmp_dir_generated:
tmp_mount = Mount(self.tmp_dir, host_tmp_dir_generated, "bind")
@@ -310,8 +322,6 @@ class DockerOperator(BaseOperator):
self.environment["AIRFLOW_TMP_DIR"] = self.tmp_dir
else:
self.environment.pop("AIRFLOW_TMP_DIR", None)
- if not self.cli:
- raise Exception("The 'cli' should be initialized before!")
docker_log_config = {}
if self.log_opts_max_size is not None:
docker_log_config["max-size"] = self.log_opts_max_size
@@ -407,16 +417,11 @@ class DockerOperator(BaseOperator):
except APIError:
return None
- def execute(self, context: Context) -> str | None:
- self.cli = self._get_cli()
- if not self.cli:
- raise Exception("The 'cli' should be initialized before!")
-
+ def execute(self, context: Context) -> list[str] | str | None:
# Pull the docker image if `force_pull` is set or image does not exist
locally
-
if self.force_pull or not self.cli.images(name=self.image):
self.log.info("Pulling docker image %s", self.image)
- latest_status = {}
+ latest_status: dict[str, str] = {}
for output in self.cli.pull(self.image, stream=True, decode=True):
if isinstance(output, str):
self.log.info("%s", output)
@@ -433,17 +438,8 @@ class DockerOperator(BaseOperator):
latest_status[output_id] = output_status
return self._run_image()
- def _get_cli(self) -> APIClient:
- if self.docker_conn_id:
- return self.get_hook().get_conn()
- else:
- tls_config = self.__get_tls_config()
- return APIClient(
- base_url=self.docker_url, version=self.api_version,
tls=tls_config, timeout=self.timeout
- )
-
@staticmethod
- def format_command(command: str | list[str]) -> list[str] | str:
+ def format_command(command: list[str] | str | None) -> list[str] | str |
None:
"""
Retrieve command(s). if command string starts with [, it returns the
command list)
@@ -452,32 +448,17 @@ class DockerOperator(BaseOperator):
:return: the command (or commands)
"""
if isinstance(command, str) and command.strip().find("[") == 0:
- return ast.literal_eval(command)
+ command = ast.literal_eval(command)
return command
def on_kill(self) -> None:
- if self.cli is not None:
+ if self.hook.client_created:
self.log.info("Stopping docker container")
if self.container is None:
self.log.info("Not attempting to kill container as it was not
created")
return
self.cli.stop(self.container["Id"])
- def __get_tls_config(self) -> tls.TLSConfig | None:
- tls_config = None
- if self.tls_ca_cert and self.tls_client_cert and self.tls_client_key:
- # Ignore type error on SSL version here - it is deprecated and
type annotation is wrong
- # it should be string
- tls_config = tls.TLSConfig(
- ca_cert=self.tls_ca_cert,
- client_cert=(self.tls_client_cert, self.tls_client_key),
- verify=True,
- ssl_version=self.tls_ssl_version,
- assert_hostname=self.tls_hostname,
- )
- self.docker_url = self.docker_url.replace("tcp://", "https://")
- return tls_config
-
@staticmethod
def unpack_environment_variables(env_str: str) -> dict:
r"""
diff --git a/airflow/providers/docker/operators/docker_swarm.py
b/airflow/providers/docker/operators/docker_swarm.py
index d92b0c036c..262111e104 100644
--- a/airflow/providers/docker/operators/docker_swarm.py
+++ b/airflow/providers/docker/operators/docker_swarm.py
@@ -105,7 +105,6 @@ class DockerSwarmOperator(DockerOperator):
**kwargs,
) -> None:
super().__init__(image=image, **kwargs)
-
self.enable_logging = enable_logging
self.service = None
self.configs = configs
@@ -115,16 +114,11 @@ class DockerSwarmOperator(DockerOperator):
self.placement = placement
def execute(self, context: Context) -> None:
- self.cli = self._get_cli()
-
self.environment["AIRFLOW_TMP_DIR"] = self.tmp_dir
-
return self._run_service()
def _run_service(self) -> None:
self.log.info("Starting docker service from image %s", self.image)
- if not self.cli:
- raise Exception("The 'cli' should be initialized before!")
self.service = self.cli.create_service(
types.TaskTemplate(
container_spec=types.ContainerSpec(
@@ -173,8 +167,6 @@ class DockerSwarmOperator(DockerOperator):
self.cli.remove_service(self.service["ID"])
def _service_status(self) -> str | None:
- if not self.cli:
- raise Exception("The 'cli' should be initialized before!")
if not self.service:
raise Exception("The 'service' should be initialized before!")
return self.cli.tasks(filters={"service":
self.service["ID"]})[0]["Status"]["State"]
@@ -184,8 +176,6 @@ class DockerSwarmOperator(DockerOperator):
return status in ["complete", "failed", "shutdown", "rejected",
"orphaned", "remove"]
def _stream_logs_to_output(self) -> None:
- if not self.cli:
- raise Exception("The 'cli' should be initialized before!")
if not self.service:
raise Exception("The 'service' should be initialized before!")
logs = self.cli.service_logs(
@@ -213,6 +203,6 @@ class DockerSwarmOperator(DockerOperator):
self.log.info(line)
def on_kill(self) -> None:
- if self.cli is not None and self.service is not None:
+ if self.hook.client_created and self.service is not None:
self.log.info("Removing docker service: %s", self.service["ID"])
self.cli.remove_service(self.service["ID"])
diff --git a/tests/providers/conftest.py b/tests/providers/conftest.py
new file mode 100644
index 0000000000..7dd0079ae6
--- /dev/null
+++ b/tests/providers/conftest.py
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.models import Connection
+
+
[email protected]
+def hook_conn(request):
+ """
+ Patch ``BaseHook.get_connection()`` by mock value.
+
+ This fixture optionally parametrized, if ``param`` not set or empty it
just mock method.
+ If param is dictionary or :class:`~airflow.models.Connection` than return
it,
+ If param is exception than add side effect.
+ Otherwise, it raises an error
+ """
+ try:
+ conn = request.param
+ except AttributeError:
+ conn = None
+
+ with mock.patch("airflow.hooks.base.BaseHook.get_connection") as m:
+ if not conn:
+ pass # Don't do anything if param not specified or empty
+ elif isinstance(conn, dict):
+ m.return_value = Connection(**conn)
+ elif not isinstance(conn, Connection):
+ m.return_value = conn
+ elif isinstance(conn, Exception):
+ m.side_effect = conn
+ else:
+ raise TypeError(
+ f"{request.node.name!r}: expected dict, Connection object or
Exception, "
+ f"but got {type(conn).__name__}"
+ )
+
+ yield m
diff --git a/tests/providers/docker/conftest.py
b/tests/providers/docker/conftest.py
new file mode 100644
index 0000000000..c698a23de5
--- /dev/null
+++ b/tests/providers/docker/conftest.py
@@ -0,0 +1,64 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from contextlib import AbstractContextManager, contextmanager
+from unittest import mock
+
+import pytest
+
+
+@contextmanager
+def _mocker_context(o, additional_modules: list | None = None) ->
AbstractContextManager[mock.MagicMock]:
+ """
+ Helper context for mocking multiple reference of same object
+ :param o: Object/Class for mocking.
+ :param additional_modules: additional modules where ``o`` exists.
+ """
+ patched = []
+ object_name = o.__name__
+ mocked_object = mock.MagicMock(name=f"Mocked.{object_name}", spec=o)
+ additional_modules = additional_modules or []
+ try:
+ for mdl in [o.__module__, *additional_modules]:
+ mocker = mock.patch(f"{mdl}.{object_name}", mocked_object)
+ mocker.start()
+ patched.append(mocker)
+
+ yield mocked_object
+ finally:
+ for mocker in reversed(patched):
+ mocker.stop()
+
+
[email protected]
+def docker_api_client_patcher():
+ """Patch ``docker.APIClient`` by mock value."""
+ from airflow.providers.docker.hooks.docker import APIClient
+
+ with _mocker_context(APIClient, ["airflow.providers.docker.hooks.docker"])
as m:
+ yield m
+
+
[email protected]
+def docker_hook_patcher():
+ """Patch DockerHook by mock value."""
+ from airflow.providers.docker.operators.docker import DockerHook
+
+ with _mocker_context(DockerHook,
["airflow.providers.docker.operators.docker"]) as m:
+ yield m
diff --git a/tests/providers/docker/hooks/test_docker.py
b/tests/providers/docker/hooks/test_docker.py
index b7d24aa908..090ea2c521 100644
--- a/tests/providers/docker/hooks/test_docker.py
+++ b/tests/providers/docker/hooks/test_docker.py
@@ -17,134 +17,256 @@
# under the License.
from __future__ import annotations
+import logging
+import ssl
from unittest import mock
import pytest
+from docker import TLSConfig
+from docker.errors import APIError
-from airflow.exceptions import AirflowException
-from airflow.models import Connection
+from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.providers.docker.hooks.docker import DockerHook
-from airflow.utils import db
-
-
[email protected]("airflow.providers.docker.hooks.docker.APIClient", autospec=True)
-class TestDockerHook:
- def setup_method(self):
- db.merge_conn(
- Connection(
- conn_id="docker_default",
- conn_type="docker",
- host="some.docker.registry.com",
- login="some_user",
- password="some_p4$$w0rd",
- )
- )
- db.merge_conn(
- Connection(
- conn_id="docker_with_extras",
- conn_type="docker",
- host="another.docker.registry.com",
- port=9876,
- login="some_user",
- password="some_p4$$w0rd",
- extra='{"email": "[email protected]", "reauth": "no"}',
- )
- )
- def test_init_fails_when_no_base_url_given(self, _):
- with pytest.raises(AirflowException):
- DockerHook(docker_conn_id="docker_default", version="auto",
tls=None)
-
- def test_init_fails_when_no_api_version_given(self, _):
- with pytest.raises(AirflowException):
- DockerHook(docker_conn_id="docker_default",
base_url="unix://var/run/docker.sock", tls=None)
-
- def test_get_conn_override_defaults(self, docker_client_mock):
- hook = DockerHook(
- docker_conn_id="docker_default",
- base_url="https://index.docker.io/v1/",
- version="1.23",
- tls="someconfig",
- timeout=100,
- )
+TEST_CONN_ID = "docker_test_connection"
+TEST_BASE_URL = "unix://var/run/docker.sock"
+TEST_TLS_BASE_URL = "tcp://localhost.foo.bar"
+TEST_HTTPS_BASE_URL = "https://localhost.foo.bar"
+TEST_VERSION = "3.14"
+TEST_CONN = {"host": "some.docker.registry.com", "login": "some_user",
"password": "some_p4$$w0rd"}
+MOCK_CONNECTION_NOT_EXIST_MSG = "Testing connection not exists"
+MOCK_CONNECTION_NOT_EXISTS_EX =
AirflowNotFoundException(MOCK_CONNECTION_NOT_EXIST_MSG)
+HOOK_LOGGER_NAME = "airflow.providers.docker.hooks.docker.DockerHook"
+
+
[email protected]
+def hook_kwargs():
+ """Valid attributes for DockerHook."""
+ return {
+ "base_url": TEST_BASE_URL,
+ "docker_conn_id": "docker_default",
+ "tls": False,
+ "version": TEST_VERSION,
+ "timeout": 42,
+ }
+
+
+def test_no_connection_during_initialisation(hook_conn,
docker_api_client_patcher, hook_kwargs):
+ """Hook shouldn't create client during initialisation and retrieve Airflow
connection."""
+ DockerHook(**hook_kwargs)
+ hook_conn.assert_not_called()
+ docker_api_client_patcher.assert_not_called()
+
+
+def test_init_fails_when_no_base_url_given(hook_kwargs):
+ """Test mandatory `base_url` Hook argument."""
+ hook_kwargs.pop("base_url")
+ with pytest.raises(AirflowException, match=r"URL to the Docker server not
provided\."):
+ DockerHook(**hook_kwargs)
+
+
[email protected]("base_url", ["http://foo.bar", TEST_BASE_URL])
[email protected](
+ "tls_config", [pytest.param(True, id="bool"), pytest.param(TLSConfig(),
id="TLSConfig-object")]
+)
+def test_init_warn_on_non_https_host_with_enabled_tls(base_url, tls_config,
hook_kwargs, caplog):
+ """Test warning if user specified tls but use non-https scheme."""
+ caplog.set_level(logging.WARNING, logger=HOOK_LOGGER_NAME)
+ hook_kwargs["base_url"] = base_url
+ hook_kwargs["tls"] = tls_config
+ DockerHook(**hook_kwargs)
+ assert "When `tls` specified then `base_url` expected 'https://' schema."
in caplog.messages
+
+
[email protected]("hook_attr", ["docker_conn_id", "version", "tls",
"timeout"])
+def test_optional_hook_attributes(hook_attr, hook_kwargs):
+ """Test if not provided optional arguments than Hook init nop failed."""
+ hook_kwargs.pop(hook_attr)
+ DockerHook(**hook_kwargs)
+
+
[email protected](
+ "conn_id, hook_conn",
+ [
+ pytest.param(TEST_CONN_ID, None, id="conn-specified"),
+ pytest.param(None, MOCK_CONNECTION_NOT_EXISTS_EX,
id="conn-not-specified"),
+ ],
+ indirect=["hook_conn"],
+)
+def test_create_api_client(conn_id, hook_conn, docker_api_client_patcher,
caplog):
+ """
+ Test creation ``docker.APIClient`` from hook arguments.
+ Additionally check:
+ - Is tls:// changed to https://
+ - Is ``api_client`` property and ``get_conn`` method cacheable.
+ - If `docker_conn_id` not provided that hook doesn't try access to
Airflow Connections.
+ """
+ caplog.set_level(logging.DEBUG, logger=HOOK_LOGGER_NAME)
+ hook = DockerHook(
+ docker_conn_id=conn_id, base_url=TEST_TLS_BASE_URL,
version=TEST_VERSION, tls=True, timeout=42
+ )
+ assert "Change `base_url` schema from 'tcp://' to 'https://'." in
caplog.messages
+ caplog.clear()
+ assert hook.client_created is False
+ api_client = hook.api_client
+ assert api_client is hook.get_conn(), "Docker API Client not cacheable"
+ docker_api_client_patcher.assert_called_once_with(
+ base_url=TEST_HTTPS_BASE_URL, version=TEST_VERSION, tls=True,
timeout=42
+ )
+ assert hook.client_created is True
+
+
+def test_failed_create_api_client(docker_api_client_patcher):
+ """Test failures during creation ``docker.APIClient`` from hook
arguments."""
+ hook = DockerHook(base_url=TEST_BASE_URL)
+ docker_api_client_patcher.side_effect = Exception("Fake Exception")
+ with pytest.raises(Exception, match="Fake Exception"):
hook.get_conn()
- docker_client_mock.assert_called_once_with(
- base_url="https://index.docker.io/v1/",
- version="1.23",
- tls="someconfig",
- timeout=100,
- )
+ assert hook.client_created is False
- def test_get_conn_with_standard_config(self, _):
- try:
- hook = DockerHook(
- docker_conn_id="docker_default",
base_url="unix://var/run/docker.sock", version="auto"
- )
- client = hook.get_conn()
- assert client is not None
- except Exception:
- self.fail("Could not get connection from Airflow")
-
- def test_get_conn_with_extra_config(self, _):
- try:
- hook = DockerHook(
- docker_conn_id="docker_with_extras",
base_url="unix://var/run/docker.sock", version="auto"
- )
- client = hook.get_conn()
- assert client is not None
- except Exception:
- self.fail("Could not get connection from Airflow")
-
- def test_conn_with_standard_config_passes_parameters(self, _):
- hook = DockerHook(
- docker_conn_id="docker_default",
base_url="unix://var/run/docker.sock", version="auto"
- )
- client = hook.get_conn()
- client.login.assert_called_once_with(
- username="some_user",
- password="some_p4$$w0rd",
- registry="some.docker.registry.com",
- reauth=True,
- email=None,
- )
- def test_conn_with_extra_config_passes_parameters(self, _):
- hook = DockerHook(
- docker_conn_id="docker_with_extras",
base_url="unix://var/run/docker.sock", version="auto"
- )
- client = hook.get_conn()
- client.login.assert_called_once_with(
- username="some_user",
- password="some_p4$$w0rd",
- registry="another.docker.registry.com:9876",
- reauth=False,
- email="[email protected]",
- )
[email protected](
+ "hook_conn, expected",
+ [
+ pytest.param(
+ TEST_CONN,
+ {
+ "username": "some_user",
+ "password": "some_p4$$w0rd",
+ "registry": "some.docker.registry.com",
+ "email": None,
+ "reauth": True,
+ },
+ id="host-login-password",
+ ),
+ pytest.param(
+ {
+ "host": "another.docker.registry.com",
+ "login": "another_user",
+ "password": "insecure_password",
+ "extra": {"email": "[email protected]", "reauth": "no"},
+ },
+ {
+ "username": "another_user",
+ "password": "insecure_password",
+ "registry": "another.docker.registry.com",
+ "email": "[email protected]",
+ "reauth": False,
+ },
+ id="host-login-password-email-noreauth",
+ ),
+ pytest.param(
+ {
+ "host": "localhost",
+ "port": 8080,
+ "login": "user",
+ "password": "pass",
+ "extra": {"email": "", "reauth": "TrUe"},
+ },
+ {
+ "username": "user",
+ "password": "pass",
+ "registry": "localhost:8080",
+ "email": None,
+ "reauth": True,
+ },
+ id="host-port-login-password-reauth",
+ ),
+ ],
+ indirect=["hook_conn"],
+)
+def test_success_login_to_registry(hook_conn, docker_api_client_patcher,
expected: dict):
+ """Test success login to Docker Registry with provided connection."""
+ mock_login = mock.MagicMock()
+ docker_api_client_patcher.return_value.login = mock_login
- def test_conn_with_broken_config_missing_username_fails(self, _):
- db.merge_conn(
- Connection(
- conn_id="docker_without_username",
- conn_type="docker",
- host="some.docker.registry.com",
- password="some_p4$$w0rd",
- extra='{"email": "[email protected]"}',
- )
- )
- with pytest.raises(AirflowException):
- DockerHook(
- docker_conn_id="docker_without_username",
- base_url="unix://var/run/docker.sock",
- version="auto",
- )
-
- def test_conn_with_broken_config_missing_host_fails(self, _):
- db.merge_conn(
- Connection(
- conn_id="docker_without_host", conn_type="docker",
login="some_user", password="some_p4$$w0rd"
- )
+ hook = DockerHook(docker_conn_id=TEST_CONN_ID, base_url=TEST_BASE_URL)
+ hook.get_conn()
+ mock_login.assert_called_once_with(**expected)
+
+
+def test_failed_login_to_registry(hook_conn, docker_api_client_patcher,
caplog):
+ """Test error during Docker Registry login."""
+ caplog.set_level(logging.ERROR, logger=HOOK_LOGGER_NAME)
+ docker_api_client_patcher.return_value.login.side_effect = APIError("Fake
Error")
+
+ hook = DockerHook(docker_conn_id=TEST_CONN_ID, base_url=TEST_BASE_URL)
+ with pytest.raises(APIError, match="Fake Error"):
+ hook.get_conn()
+ assert "Login failed" in caplog.messages
+
+
[email protected](
+ "hook_conn, ex, error_message",
+ [
+ pytest.param(
+ {k: v for k, v in TEST_CONN.items() if k != "login"},
+ AirflowNotFoundException,
+ r"No Docker Registry username provided\.",
+ id="missing-username",
+ ),
+ pytest.param(
+ {k: v for k, v in TEST_CONN.items() if k != "host"},
+ AirflowNotFoundException,
+ r"No Docker Registry URL provided\.",
+ id="missing-registry-host",
+ ),
+ pytest.param(
+ {**TEST_CONN, **{"extra": {"reauth": "enabled"}}},
+ ValueError,
+ r"Unable parse `reauth` value '.*' to bool\.",
+ id="wrong-reauth",
+ ),
+ pytest.param(
+ {**TEST_CONN, **{"extra": {"reauth": "disabled"}}},
+ ValueError,
+ r"Unable parse `reauth` value '.*' to bool\.",
+ id="wrong-noreauth",
+ ),
+ ],
+ indirect=["hook_conn"],
+)
+def test_invalid_conn_parameters(hook_conn, docker_api_client_patcher, ex,
error_message):
+ """Test invalid/missing connection parameters."""
+ hook = DockerHook(docker_conn_id=TEST_CONN_ID, base_url=TEST_BASE_URL)
+ with pytest.raises(ex, match=error_message):
+ hook.get_conn()
+
+
[email protected](
+ "tls_params",
+ [
+ pytest.param({}, id="empty-config"),
+ pytest.param({"client_cert": "foo-bar", "client_key": "spam-egg"},
id="missing-ca-cert"),
+ pytest.param({"ca_cert": "foo-bar", "client_key": "spam-egg"},
id="missing-client-cert"),
+ pytest.param({"ca_cert": "foo-bar", "client_cert": "spam-egg"},
id="missing-client-key"),
+ ],
+)
+def test_construct_tls_config_missing_certs_args(tls_params: dict):
+ """Test that return False on missing cert/keys arguments."""
+ assert DockerHook.construct_tls_config(**tls_params) is False
+
+
[email protected]("assert_hostname", ["foo.bar", None, False])
[email protected](
+ "ssl_version",
+ [
+ pytest.param(ssl.PROTOCOL_TLSv1, id="TLSv1"),
+ pytest.param(ssl.PROTOCOL_TLSv1_2, id="TLSv1_2"),
+ None,
+ ],
+)
+def test_construct_tls_config(assert_hostname, ssl_version):
+ """Test construct ``docker.tls.TLSConfig`` object."""
+ tls_params = {"ca_cert": "test-ca", "client_cert": "foo-bar",
"client_key": "spam-egg"}
+ expected_call_args = {"ca_cert": "test-ca", "client_cert": ("foo-bar",
"spam-egg"), "verify": True}
+ if assert_hostname is not None:
+ tls_params["assert_hostname"] = assert_hostname
+ if ssl_version is not None:
+ tls_params["ssl_version"] = ssl_version
+
+ with mock.patch.object(TLSConfig, "__init__", return_value=None) as
mock_tls_config:
+ DockerHook.construct_tls_config(**tls_params)
+ mock_tls_config.assert_called_once_with(
+ **expected_call_args, assert_hostname=assert_hostname,
ssl_version=ssl_version
)
- with pytest.raises(AirflowException):
- DockerHook(
- docker_conn_id="docker_without_host",
base_url="unix://var/run/docker.sock", version="auto"
- )
diff --git a/tests/providers/docker/operators/test_docker.py
b/tests/providers/docker/operators/test_docker.py
index 02734eede8..e37999e40f 100644
--- a/tests/providers/docker/operators/test_docker.py
+++ b/tests/providers/docker/operators/test_docker.py
@@ -23,19 +23,120 @@ from unittest.mock import call
import pytest
from docker import APIClient
-from docker.constants import DEFAULT_TIMEOUT_SECONDS
from docker.errors import APIError
from docker.types import DeviceRequest, LogConfig, Mount
from airflow.exceptions import AirflowException
-from airflow.providers.docker.hooks.docker import DockerHook
from airflow.providers.docker.operators.docker import DockerOperator
+TEST_CONN_ID = "docker_test_connection"
+TEST_DOCKER_URL = "unix://var/run/docker.test.sock"
+TEST_API_VERSION = "1.19" # Keep it as low version might prevent call
non-mocked docker api
+TEST_IMAGE = "apache/airflow:latest"
+TEST_CONTAINER_HOSTNAME = "test.container.host"
+TEST_HOST_TEMP_DIRECTORY = "/tmp/host/dir"
+TEST_AIRFLOW_TEMP_DIRECTORY = "/tmp/airflow/dir"
+TEST_ENTRYPOINT = '["sh", "-c"]'
+
TEMPDIR_MOCK_RETURN_VALUE = "/mkdtemp"
[email protected]("docker_conn_id", [pytest.param(None,
id="empty-conn-id"), TEST_CONN_ID])
[email protected](
+ "tls_params",
+ [
+ pytest.param({}, id="empty-tls-params"),
+ pytest.param(
+ {
+ "tls_ca_cert": "foo",
+ "tls_client_cert": "bar",
+ "tls_client_key": "spam",
+ "tls_hostname": "egg",
+ "tls_ssl_version": "super-secure",
+ },
+ id="all-tls-params",
+ ),
+ ],
+)
+def test_hook_usage(docker_hook_patcher, docker_conn_id, tls_params: dict):
+ """Test that operator use DockerHook."""
+ docker_hook_patcher.construct_tls_config.return_value = "MOCK-TLS-VALUE"
+ expected_tls_call_args = {
+ "ca_cert": tls_params.get("tls_ca_cert"),
+ "client_cert": tls_params.get("tls_client_cert"),
+ "client_key": tls_params.get("tls_client_key"),
+ "assert_hostname": tls_params.get("tls_hostname"),
+ "ssl_version": tls_params.get("tls_ssl_version"),
+ }
+
+ op = DockerOperator(
+ task_id="test_hook_usage_without_tls",
+ api_version=TEST_API_VERSION,
+ docker_conn_id=docker_conn_id,
+ image=TEST_IMAGE,
+ docker_url=TEST_DOCKER_URL,
+ timeout=42,
+ **tls_params,
+ )
+ hook = op.hook
+ assert hook is op.get_hook()
+
+ docker_hook_patcher.assert_called_once_with(
+ docker_conn_id=docker_conn_id,
+ base_url=TEST_DOCKER_URL,
+ version=TEST_API_VERSION,
+ tls="MOCK-TLS-VALUE",
+ timeout=42,
+ )
+
docker_hook_patcher.construct_tls_config.assert_called_once_with(**expected_tls_call_args)
+
+ # Check that ``DockerOperator.cli`` property return the same object as
``hook.api_client``.
+ assert op.cli is hook.api_client
+
+
[email protected](
+ "env_str, expected",
+ [
+ pytest.param("FOO=BAR\nSPAM=EGG", {"FOO": "BAR", "SPAM": "EGG"},
id="parsable-string"),
+ pytest.param("", {}, id="empty-string"),
+ ],
+)
+def test_unpack_environment_variables(env_str, expected):
+ assert DockerOperator.unpack_environment_variables(env_str) == expected
+
+
[email protected]("container_exists", [True, False])
+def test_on_kill_client_created(docker_api_client_patcher, container_exists):
+ """Test operator on_kill method if APIClient created."""
+ op = DockerOperator(image=TEST_IMAGE, hostname=TEST_DOCKER_URL,
task_id="test_on_kill")
+ op.container = {"Id": "some_id"} if container_exists else None
+
+ op.hook.get_conn() # Try to create APIClient
+ op.on_kill()
+ if container_exists:
+
docker_api_client_patcher.return_value.stop.assert_called_once_with("some_id")
+ else:
+ docker_api_client_patcher.return_value.stop.assert_not_called()
+
+
+def test_on_kill_client_not_created(docker_api_client_patcher):
+ """Test operator on_kill method if APIClient not created in case of
error."""
+ docker_api_client_patcher.side_effect = APIError("Fake Client Error")
+ mock_container = mock.MagicMock()
+
+ op = DockerOperator(image=TEST_IMAGE, hostname=TEST_DOCKER_URL,
task_id="test_on_kill")
+ op.container = mock_container
+
+ with pytest.raises(APIError, match="Fake Client Error"):
+ op.hook.get_conn()
+ op.on_kill()
+ docker_api_client_patcher.return_value.stop.assert_not_called()
+ mock_container.assert_not_called()
+
+
class TestDockerOperator:
- def setup_method(self):
+ @pytest.fixture(autouse=True, scope="function")
+ def setup_patchers(self, docker_api_client_patcher):
self.tempdir_patcher =
mock.patch("airflow.providers.docker.operators.docker.TemporaryDirectory")
self.tempdir_mock = self.tempdir_patcher.start()
self.tempdir_mock.return_value.__enter__.return_value =
TEMPDIR_MOCK_RETURN_VALUE
@@ -56,11 +157,7 @@ class TestDockerOperator:
else iter(self.log_messages)
)
- self.client_class_patcher = mock.patch(
- "airflow.providers.docker.operators.docker.APIClient",
- return_value=self.client_mock,
- )
- self.client_class_mock = self.client_class_patcher.start()
+ docker_api_client_patcher.return_value = self.client_mock
def dotenv_mock_return_value(**kwargs):
env_dict = {}
@@ -74,9 +171,9 @@ class TestDockerOperator:
self.dotenv_mock = self.dotenv_patcher.start()
self.dotenv_mock.side_effect = dotenv_mock_return_value
- def teardown_method(self) -> None:
+ yield
+
self.tempdir_patcher.stop()
- self.client_class_patcher.stop()
self.dotenv_patcher.stop()
def test_execute(self):
@@ -85,55 +182,52 @@ class TestDockerOperator:
stringio_mock.side_effect = lambda *args: args[0]
operator = DockerOperator(
- api_version="1.19",
+ api_version=TEST_API_VERSION,
command="env",
environment={"UNIT": "TEST"},
private_environment={"PRIVATE": "MESSAGE"},
env_file="ENV=FILE\nVAR=VALUE",
- image="ubuntu:latest",
+ image=TEST_IMAGE,
network_mode="bridge",
owner="unittest",
task_id="unittest",
mounts=[Mount(source="/host/path", target="/container/path",
type="bind")],
- entrypoint='["sh", "-c"]',
+ entrypoint=TEST_ENTRYPOINT,
working_dir="/container/path",
shm_size=1000,
- host_tmp_dir="/host/airflow",
+ tmp_dir=TEST_AIRFLOW_TEMP_DIRECTORY,
+ host_tmp_dir=TEST_HOST_TEMP_DIRECTORY,
container_name="test_container",
tty=True,
- hostname="test.container.host",
+ hostname=TEST_CONTAINER_HOSTNAME,
device_requests=[DeviceRequest(count=-1, capabilities=[["gpu"]])],
log_opts_max_file="5",
log_opts_max_size="10m",
)
operator.execute(None)
- self.client_class_mock.assert_called_once_with(
- base_url="unix://var/run/docker.sock", tls=None, version="1.19",
timeout=DEFAULT_TIMEOUT_SECONDS
- )
-
self.client_mock.create_container.assert_called_once_with(
command="env",
name="test_container",
environment={
- "AIRFLOW_TMP_DIR": "/tmp/airflow",
+ "AIRFLOW_TMP_DIR": TEST_AIRFLOW_TEMP_DIRECTORY,
"UNIT": "TEST",
"PRIVATE": "MESSAGE",
"ENV": "FILE",
"VAR": "VALUE",
},
host_config=self.client_mock.create_host_config.return_value,
- image="ubuntu:latest",
+ image=TEST_IMAGE,
user=None,
entrypoint=["sh", "-c"],
working_dir="/container/path",
tty=True,
- hostname="test.container.host",
+ hostname=TEST_CONTAINER_HOSTNAME,
)
self.client_mock.create_host_config.assert_called_once_with(
mounts=[
Mount(source="/host/path", target="/container/path",
type="bind"),
- Mount(source="/mkdtemp", target="/tmp/airflow", type="bind"),
+ Mount(source="/mkdtemp", target=TEST_AIRFLOW_TEMP_DIRECTORY,
type="bind"),
],
network_mode="bridge",
shm_size=1000,
@@ -149,16 +243,14 @@ class TestDockerOperator:
log_config=LogConfig(config={"max-size": "10m", "max-file": "5"}),
ipc_mode=None,
)
- self.tempdir_mock.assert_called_once_with(dir="/host/airflow",
prefix="airflowtmp")
- self.client_mock.images.assert_called_once_with(name="ubuntu:latest")
+
self.tempdir_mock.assert_called_once_with(dir=TEST_HOST_TEMP_DIRECTORY,
prefix="airflowtmp")
+ self.client_mock.images.assert_called_once_with(name=TEST_IMAGE)
self.client_mock.attach.assert_called_once_with(
container="some_id", stdout=True, stderr=True, stream=True
)
- self.client_mock.pull.assert_called_once_with("ubuntu:latest",
stream=True, decode=True)
+ self.client_mock.pull.assert_called_once_with(TEST_IMAGE, stream=True,
decode=True)
self.client_mock.wait.assert_called_once_with("some_id")
- assert (
- operator.cli.pull("ubuntu:latest", stream=True, decode=True) ==
self.client_mock.pull.return_value
- )
+ assert operator.cli.pull(TEST_IMAGE, stream=True, decode=True) ==
self.client_mock.pull.return_value
stringio_mock.assert_called_once_with("ENV=FILE\nVAR=VALUE")
self.dotenv_mock.assert_called_once_with(stream="ENV=FILE\nVAR=VALUE")
stringio_patcher.stop()
@@ -174,37 +266,33 @@ class TestDockerOperator:
environment={"UNIT": "TEST"},
private_environment={"PRIVATE": "MESSAGE"},
env_file="ENV=FILE\nVAR=VALUE",
- image="ubuntu:latest",
+ image=TEST_IMAGE,
network_mode="bridge",
owner="unittest",
task_id="unittest",
mounts=[Mount(source="/host/path", target="/container/path",
type="bind")],
mount_tmp_dir=False,
- entrypoint='["sh", "-c"]',
+ entrypoint=TEST_ENTRYPOINT,
working_dir="/container/path",
shm_size=1000,
- host_tmp_dir="/host/airflow",
+ host_tmp_dir=TEST_HOST_TEMP_DIRECTORY,
container_name="test_container",
- hostname="test.container.host",
+ hostname=TEST_CONTAINER_HOSTNAME,
tty=True,
)
operator.execute(None)
- self.client_class_mock.assert_called_once_with(
- base_url="unix://var/run/docker.sock", tls=None, version="1.19",
timeout=DEFAULT_TIMEOUT_SECONDS
- )
-
self.client_mock.create_container.assert_called_once_with(
command="env",
name="test_container",
environment={"UNIT": "TEST", "PRIVATE": "MESSAGE", "ENV": "FILE",
"VAR": "VALUE"},
host_config=self.client_mock.create_host_config.return_value,
- image="ubuntu:latest",
+ image=TEST_IMAGE,
user=None,
entrypoint=["sh", "-c"],
working_dir="/container/path",
tty=True,
- hostname="test.container.host",
+ hostname=TEST_CONTAINER_HOSTNAME,
)
self.client_mock.create_host_config.assert_called_once_with(
mounts=[
@@ -225,22 +313,20 @@ class TestDockerOperator:
ipc_mode=None,
)
self.tempdir_mock.assert_not_called()
- self.client_mock.images.assert_called_once_with(name="ubuntu:latest")
+ self.client_mock.images.assert_called_once_with(name=TEST_IMAGE)
self.client_mock.attach.assert_called_once_with(
container="some_id", stdout=True, stderr=True, stream=True
)
- self.client_mock.pull.assert_called_once_with("ubuntu:latest",
stream=True, decode=True)
+ self.client_mock.pull.assert_called_once_with(TEST_IMAGE, stream=True,
decode=True)
self.client_mock.wait.assert_called_once_with("some_id")
- assert (
- operator.cli.pull("ubuntu:latest", stream=True, decode=True) ==
self.client_mock.pull.return_value
- )
+ assert operator.cli.pull(TEST_IMAGE, stream=True, decode=True) ==
self.client_mock.pull.return_value
stringio_mock.assert_called_once_with("ENV=FILE\nVAR=VALUE")
self.dotenv_mock.assert_called_once_with(stream="ENV=FILE\nVAR=VALUE")
stringio_patcher.stop()
def test_execute_fallback_temp_dir(self, caplog):
self.client_mock.create_container.side_effect = [
- APIError(message="wrong path: " + TEMPDIR_MOCK_RETURN_VALUE),
+ APIError(message=f"wrong path: {TEMPDIR_MOCK_RETURN_VALUE}"),
{"Id": "some_id"},
]
@@ -254,16 +340,17 @@ class TestDockerOperator:
environment={"UNIT": "TEST"},
private_environment={"PRIVATE": "MESSAGE"},
env_file="ENV=FILE\nVAR=VALUE",
- image="ubuntu:latest",
+ image=TEST_IMAGE,
network_mode="bridge",
owner="unittest",
task_id="unittest",
mounts=[Mount(source="/host/path", target="/container/path",
type="bind")],
mount_tmp_dir=True,
- entrypoint='["sh", "-c"]',
+ entrypoint=TEST_ENTRYPOINT,
working_dir="/container/path",
shm_size=1000,
- host_tmp_dir="/host/airflow",
+ host_tmp_dir=TEST_HOST_TEMP_DIRECTORY,
+ tmp_dir=TEST_AIRFLOW_TEMP_DIRECTORY,
container_name="test_container",
tty=True,
)
@@ -277,23 +364,20 @@ class TestDockerOperator:
)
assert warning_message in caplog.messages
- self.client_class_mock.assert_called_once_with(
- base_url="unix://var/run/docker.sock", tls=None, version="1.19",
timeout=DEFAULT_TIMEOUT_SECONDS
- )
self.client_mock.create_container.assert_has_calls(
[
call(
command="env",
name="test_container",
environment={
- "AIRFLOW_TMP_DIR": "/tmp/airflow",
+ "AIRFLOW_TMP_DIR": TEST_AIRFLOW_TEMP_DIRECTORY,
"UNIT": "TEST",
"PRIVATE": "MESSAGE",
"ENV": "FILE",
"VAR": "VALUE",
},
host_config=self.client_mock.create_host_config.return_value,
- image="ubuntu:latest",
+ image=TEST_IMAGE,
user=None,
entrypoint=["sh", "-c"],
working_dir="/container/path",
@@ -305,7 +389,7 @@ class TestDockerOperator:
name="test_container",
environment={"UNIT": "TEST", "PRIVATE": "MESSAGE", "ENV":
"FILE", "VAR": "VALUE"},
host_config=self.client_mock.create_host_config.return_value,
- image="ubuntu:latest",
+ image=TEST_IMAGE,
user=None,
entrypoint=["sh", "-c"],
working_dir="/container/path",
@@ -319,7 +403,7 @@ class TestDockerOperator:
call(
mounts=[
Mount(source="/host/path", target="/container/path",
type="bind"),
- Mount(source="/mkdtemp", target="/tmp/airflow",
type="bind"),
+ Mount(source="/mkdtemp",
target=TEST_AIRFLOW_TEMP_DIRECTORY, type="bind"),
],
network_mode="bridge",
shm_size=1000,
@@ -355,23 +439,21 @@ class TestDockerOperator:
),
]
)
- self.tempdir_mock.assert_called_once_with(dir="/host/airflow",
prefix="airflowtmp")
- self.client_mock.images.assert_called_once_with(name="ubuntu:latest")
+
self.tempdir_mock.assert_called_once_with(dir=TEST_HOST_TEMP_DIRECTORY,
prefix="airflowtmp")
+ self.client_mock.images.assert_called_once_with(name=TEST_IMAGE)
self.client_mock.attach.assert_called_once_with(
container="some_id", stdout=True, stderr=True, stream=True
)
- self.client_mock.pull.assert_called_once_with("ubuntu:latest",
stream=True, decode=True)
+ self.client_mock.pull.assert_called_once_with(TEST_IMAGE, stream=True,
decode=True)
self.client_mock.wait.assert_called_once_with("some_id")
- assert (
- operator.cli.pull("ubuntu:latest", stream=True, decode=True) ==
self.client_mock.pull.return_value
- )
+ assert operator.cli.pull(TEST_IMAGE, stream=True, decode=True) ==
self.client_mock.pull.return_value
stringio_mock.assert_called_with("ENV=FILE\nVAR=VALUE")
self.dotenv_mock.assert_called_with(stream="ENV=FILE\nVAR=VALUE")
stringio_patcher.stop()
def test_private_environment_is_private(self):
operator = DockerOperator(
- private_environment={"PRIVATE": "MESSAGE"}, image="ubuntu:latest",
task_id="unittest"
+ private_environment={"PRIVATE": "MESSAGE"}, image=TEST_IMAGE,
task_id="unittest"
)
assert operator._private_environment == {
"PRIVATE": "MESSAGE"
@@ -385,11 +467,12 @@ class TestDockerOperator:
environment={"UNIT": "TEST"},
private_environment={"PRIVATE": "MESSAGE"},
env_file="UNIT=FILE\nPRIVATE=FILE\nVAR=VALUE",
- image="ubuntu:latest",
+ image=TEST_IMAGE,
task_id="unittest",
- entrypoint='["sh", "-c"]',
+ entrypoint=TEST_ENTRYPOINT,
working_dir="/container/path",
- host_tmp_dir="/host/airflow",
+ host_tmp_dir=TEST_HOST_TEMP_DIRECTORY,
+ tmp_dir=TEST_AIRFLOW_TEMP_DIRECTORY,
container_name="test_container",
tty=True,
)
@@ -398,13 +481,13 @@ class TestDockerOperator:
command="env",
name="test_container",
environment={
- "AIRFLOW_TMP_DIR": "/tmp/airflow",
+ "AIRFLOW_TMP_DIR": TEST_AIRFLOW_TEMP_DIRECTORY,
"UNIT": "TEST",
"PRIVATE": "MESSAGE",
"VAR": "VALUE",
},
host_config=self.client_mock.create_host_config.return_value,
- image="ubuntu:latest",
+ image=TEST_IMAGE,
user=None,
entrypoint=["sh", "-c"],
working_dir="/container/path",
@@ -414,45 +497,17 @@ class TestDockerOperator:
stringio_mock.assert_called_once_with("UNIT=FILE\nPRIVATE=FILE\nVAR=VALUE")
self.dotenv_mock.assert_called_once_with(stream="UNIT=FILE\nPRIVATE=FILE\nVAR=VALUE")
- @mock.patch("airflow.providers.docker.operators.docker.tls.TLSConfig")
- def test_execute_tls(self, tls_class_mock):
- tls_mock = mock.Mock()
- tls_class_mock.return_value = tls_mock
-
- operator = DockerOperator(
- docker_url="tcp://127.0.0.1:2376",
- image="ubuntu",
- owner="unittest",
- task_id="unittest",
- tls_client_cert="cert.pem",
- tls_ca_cert="ca.pem",
- tls_client_key="key.pem",
- )
- operator.execute(None)
-
- tls_class_mock.assert_called_once_with(
- assert_hostname=None,
- ca_cert="ca.pem",
- client_cert=("cert.pem", "key.pem"),
- ssl_version=None,
- verify=True,
- )
-
- self.client_class_mock.assert_called_once_with(
- base_url="https://127.0.0.1:2376", tls=tls_mock, version=None,
timeout=DEFAULT_TIMEOUT_SECONDS
- )
-
def test_execute_unicode_logs(self):
self.client_mock.attach.return_value = ["unicode container log 😁"]
- originalRaiseExceptions = logging.raiseExceptions
+ original_raise_exceptions = logging.raiseExceptions
logging.raiseExceptions = True
- operator = DockerOperator(image="ubuntu", owner="unittest",
task_id="unittest")
+ operator = DockerOperator(image=TEST_IMAGE, owner="unittest",
task_id="unittest")
with mock.patch("traceback.print_exception") as print_exception_mock:
operator.execute(None)
- logging.raiseExceptions = originalRaiseExceptions
+ logging.raiseExceptions = original_raise_exceptions
print_exception_mock.assert_not_called()
def test_execute_container_fails(self):
@@ -474,60 +529,13 @@ class TestDockerOperator:
def test_auto_remove_container_fails(self):
self.client_mock.wait.return_value = {"StatusCode": 1}
- operator = DockerOperator(image="ubuntu", owner="unittest",
task_id="unittest", auto_remove=True)
+ operator = DockerOperator(image="ubuntu", owner="unittest",
task_id="unittest", auto_remove="success")
operator.container = {"Id": "some_id"}
with pytest.raises(AirflowException):
operator.execute(None)
self.client_mock.remove_container.assert_called_once_with("some_id")
- @staticmethod
- def test_on_kill():
- client_mock = mock.Mock(spec=APIClient)
-
- operator = DockerOperator(image="ubuntu", owner="unittest",
task_id="unittest")
- operator.cli = client_mock
- operator.container = {"Id": "some_id"}
-
- operator.on_kill()
-
- client_mock.stop.assert_called_once_with("some_id")
-
- def test_execute_no_docker_conn_id_no_hook(self):
- # Create the DockerOperator
- operator = DockerOperator(image="publicregistry/someimage",
owner="unittest", task_id="unittest")
-
- # Mock out the DockerHook
- hook_mock = mock.Mock(name="DockerHook mock", spec=DockerHook)
- hook_mock.get_conn.return_value = self.client_mock
- operator.get_hook = mock.Mock(
- name="DockerOperator.get_hook mock", spec=DockerOperator.get_hook,
return_value=hook_mock
- )
-
- operator.execute(None)
- assert operator.get_hook.call_count == 0, "Hook called though no
docker_conn_id configured"
-
- @mock.patch("airflow.providers.docker.operators.docker.DockerHook")
- def test_execute_with_docker_conn_id_use_hook(self, hook_class_mock):
- # Create the DockerOperator
- operator = DockerOperator(
- image="publicregistry/someimage",
- owner="unittest",
- task_id="unittest",
- docker_conn_id="some_conn_id",
- )
-
- # Mock out the DockerHook
- hook_mock = mock.Mock(name="DockerHook mock", spec=DockerHook)
- hook_mock.get_conn.return_value = self.client_mock
- hook_class_mock.return_value = hook_mock
-
- operator.execute(None)
-
- assert self.client_class_mock.call_count == 0, "Client was called on
the operator instead of the hook"
- assert hook_class_mock.call_count == 1, "Hook was not called although
docker_conn_id configured"
- assert self.client_mock.pull.call_count == 1, "Image was not pulled
using operator client"
-
def test_execute_xcom_behavior(self):
self.client_mock.pull.return_value = [b'{"status":"pull log"}']
kwargs = {
diff --git a/tests/providers/docker/operators/test_docker_swarm.py
b/tests/providers/docker/operators/test_docker_swarm.py
index d12be721f3..d84ebde312 100644
--- a/tests/providers/docker/operators/test_docker_swarm.py
+++ b/tests/providers/docker/operators/test_docker_swarm.py
@@ -22,15 +22,15 @@ from unittest import mock
import pytest
from docker import APIClient, types
from docker.constants import DEFAULT_TIMEOUT_SECONDS
+from docker.errors import APIError
from airflow.exceptions import AirflowException
from airflow.providers.docker.operators.docker_swarm import DockerSwarmOperator
class TestDockerSwarmOperator:
- @mock.patch("airflow.providers.docker.operators.docker.APIClient")
@mock.patch("airflow.providers.docker.operators.docker_swarm.types")
- def test_execute(self, types_mock, client_class_mock):
+ def test_execute(self, types_mock, docker_api_client_patcher):
mock_obj = mock.Mock()
@@ -54,7 +54,7 @@ class TestDockerSwarmOperator:
types_mock.RestartPolicy.return_value = mock_obj
types_mock.Resources.return_value = mock_obj
- client_class_mock.return_value = client_mock
+ docker_api_client_patcher.return_value = client_mock
operator = DockerSwarmOperator(
api_version="1.19",
@@ -65,7 +65,7 @@ class TestDockerSwarmOperator:
user="unittest",
task_id="unittest",
mounts=[types.Mount(source="/host/path", target="/container/path",
type="bind")],
- auto_remove=True,
+ auto_remove="success",
tty=True,
configs=[types.ConfigReference(config_id="dummy_cfg_id",
config_name="dummy_cfg_name")],
secrets=[types.SecretReference(secret_id="dummy_secret_id",
secret_name="dummy_secret_name")],
@@ -95,8 +95,8 @@ class TestDockerSwarmOperator:
types_mock.RestartPolicy.assert_called_once_with(condition="none")
types_mock.Resources.assert_called_once_with(mem_limit="128m")
- client_class_mock.assert_called_once_with(
- base_url="unix://var/run/docker.sock", tls=None, version="1.19",
timeout=DEFAULT_TIMEOUT_SECONDS
+ docker_api_client_patcher.assert_called_once_with(
+ base_url="unix://var/run/docker.sock", tls=False, version="1.19",
timeout=DEFAULT_TIMEOUT_SECONDS
)
client_mock.service_logs.assert_called_once_with(
@@ -112,9 +112,8 @@ class TestDockerSwarmOperator:
assert client_mock.tasks.call_count == 5
client_mock.remove_service.assert_called_once_with("some_id")
- @mock.patch("airflow.providers.docker.operators.docker.APIClient")
@mock.patch("airflow.providers.docker.operators.docker_swarm.types")
- def test_auto_remove(self, types_mock, client_class_mock):
+ def test_auto_remove(self, types_mock, docker_api_client_patcher):
mock_obj = mock.Mock()
@@ -128,16 +127,17 @@ class TestDockerSwarmOperator:
types_mock.RestartPolicy.return_value = mock_obj
types_mock.Resources.return_value = mock_obj
- client_class_mock.return_value = client_mock
+ docker_api_client_patcher.return_value = client_mock
- operator = DockerSwarmOperator(image="", auto_remove=True,
task_id="unittest", enable_logging=False)
+ operator = DockerSwarmOperator(
+ image="", auto_remove="success", task_id="unittest",
enable_logging=False
+ )
operator.execute(None)
client_mock.remove_service.assert_called_once_with("some_id")
- @mock.patch("airflow.providers.docker.operators.docker.APIClient")
@mock.patch("airflow.providers.docker.operators.docker_swarm.types")
- def test_no_auto_remove(self, types_mock, client_class_mock):
+ def test_no_auto_remove(self, types_mock, docker_api_client_patcher):
mock_obj = mock.Mock()
@@ -151,19 +151,20 @@ class TestDockerSwarmOperator:
types_mock.RestartPolicy.return_value = mock_obj
types_mock.Resources.return_value = mock_obj
- client_class_mock.return_value = client_mock
+ docker_api_client_patcher.return_value = client_mock
- operator = DockerSwarmOperator(image="", auto_remove=False,
task_id="unittest", enable_logging=False)
+ operator = DockerSwarmOperator(
+ image="", auto_remove="never", task_id="unittest",
enable_logging=False
+ )
operator.execute(None)
assert (
client_mock.remove_service.call_count == 0
- ), "Docker service being removed even when `auto_remove` set to
`False`"
+ ), "Docker service being removed even when `auto_remove` set to
`never`"
@pytest.mark.parametrize("status", ["failed", "shutdown", "rejected",
"orphaned", "remove"])
- @mock.patch("airflow.providers.docker.operators.docker.APIClient")
@mock.patch("airflow.providers.docker.operators.docker_swarm.types")
- def test_non_complete_service_raises_error(self, types_mock,
client_class_mock, status):
+ def test_non_complete_service_raises_error(self, types_mock,
docker_api_client_patcher, status):
mock_obj = mock.Mock()
@@ -177,21 +178,38 @@ class TestDockerSwarmOperator:
types_mock.RestartPolicy.return_value = mock_obj
types_mock.Resources.return_value = mock_obj
- client_class_mock.return_value = client_mock
+ docker_api_client_patcher.return_value = client_mock
- operator = DockerSwarmOperator(image="", auto_remove=False,
task_id="unittest", enable_logging=False)
+ operator = DockerSwarmOperator(
+ image="", auto_remove="never", task_id="unittest",
enable_logging=False
+ )
msg = "Service did not complete: {'ID': 'some_id'}"
with pytest.raises(AirflowException) as ctx:
operator.execute(None)
assert str(ctx.value) == msg
- def test_on_kill(self):
- client_mock = mock.Mock(spec=APIClient)
-
- operator = DockerSwarmOperator(image="", auto_remove=False,
task_id="unittest", enable_logging=False)
- operator.cli = client_mock
- operator.service = {"ID": "some_id"}
-
- operator.on_kill()
-
- client_mock.remove_service.assert_called_once_with("some_id")
+ @pytest.mark.parametrize("service_exists", [True, False])
+ def test_on_kill_client_created(self, docker_api_client_patcher,
service_exists):
+ """Test operator on_kill method if APIClient created."""
+ op = DockerSwarmOperator(image="", task_id="test_on_kill")
+ op.service = {"ID": "some_id"} if service_exists else None
+
+ op.hook.get_conn() # Try to create APIClient
+ op.on_kill()
+ if service_exists:
+
docker_api_client_patcher.return_value.remove_service.assert_called_once_with("some_id")
+ else:
+
docker_api_client_patcher.return_value.remove_service.assert_not_called()
+
+ def test_on_kill_client_not_created(self, docker_api_client_patcher):
+ """Test operator on_kill method if APIClient not created in case of
error."""
+ docker_api_client_patcher.side_effect = APIError("Fake Client Error")
+ op = DockerSwarmOperator(image="", task_id="test_on_kill")
+ mock_service = mock.MagicMock()
+ op.service = mock_service
+
+ with pytest.raises(APIError, match="Fake Client Error"):
+ op.hook.get_conn()
+ op.on_kill()
+
docker_api_client_patcher.return_value.remove_service.assert_not_called()
+ mock_service.assert_not_called()