This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v3-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v3-0-test by this push:
new 8f99e25e954 Add `@task.kuberenetes_cmd` (#46913)
8f99e25e954 is described below
commit 8f99e25e95461830ec7713874a8906cbca5cebac
Author: Mikhail Dengin <[email protected]>
AuthorDate: Mon Apr 28 15:19:05 2025 +0200
Add `@task.kuberenetes_cmd` (#46913)
closes: #46414
(cherry picked from commit 75140f621961e8e85665ffb058da14fa81791b33)
---
providers/cncf/kubernetes/docs/operators.rst | 34 ++
providers/cncf/kubernetes/provider.yaml | 2 +
.../cncf/kubernetes/decorators/kubernetes_cmd.py | 123 +++++++
.../providers/cncf/kubernetes/get_provider_info.py | 6 +-
.../kubernetes/example_kubernetes_cmd_decorator.py | 76 ++++
.../cncf/kubernetes/decorators/test_kubernetes.py | 366 +++++--------------
.../kubernetes/decorators/test_kubernetes_cmd.py | 390 +++++++++++++++++++++
.../decorators/test_kubernetes_commons.py | 280 +++++++++++++++
.../sdk/definitions/decorators/__init__.pyi | 173 +++++++++
9 files changed, 1165 insertions(+), 285 deletions(-)
diff --git a/providers/cncf/kubernetes/docs/operators.rst
b/providers/cncf/kubernetes/docs/operators.rst
index 94b3875072d..e7bb06883cd 100644
--- a/providers/cncf/kubernetes/docs/operators.rst
+++ b/providers/cncf/kubernetes/docs/operators.rst
@@ -182,6 +182,40 @@ Also for this action you can use operator in the
deferrable mode:
:start-after: [START howto_operator_k8s_write_xcom_async]
:end-before: [END howto_operator_k8s_write_xcom_async]
+
+Run command in KubernetesPodOperator from TaskFlow
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+With the usage of the ``@task.kubernetes_cmd`` decorator, you can run a
command returned by a function
+in a ``KubernetesPodOperator`` simplifying it's connection to the TaskFlow.
+
+Difference between ``@task.kubernetes`` and ``@task.kubernetes_cmd``
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+``@task.kubernetes`` decorator is designed to run a Python function inside a
Kubernetes pod using KPO.
+It does this by serializing the function into a temporary Python script that
is executed inside the container.
+This is well-suited for cases where you want to isolate Python code execution
and manage complex dependencies,
+as described in the :doc:`TaskFlow documentation
<apache-airflow:tutorial/taskflow>`.
+
+In contrast, ``@task.kubernetes_cmd`` decorator allows the decorated function
to return
+a shell command (as a list of strings), which is then passed as cmds or
arguments to
+``KubernetesPodOperator``.
+This enables executing arbitrary commands available inside a Kubernetes pod --
+without needing to wrap it in Python code.
+
+A key benefit here is that Python excels at composing and templating these
commands.
+Shell commands can be dynamically generated using Python's string formatting,
templating,
+extra function calls and logic. This makes it a flexible tool for
orchestrating complex pipelines
+where the task is to invoke CLI-based operations in containers without the
need to leave
+a TaskFlow context.
+
+How does this decorator work?
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+See the following examples on how the decorator works:
+
+.. exampleinclude::
/../tests/system/cncf/kubernetes/example_kubernetes_cmd_decorator.py
+ :language: python
+ :start-after: [START howto_decorator_kubernetes_cmd]
+ :end-before: [END howto_decorator_kubernetes_cmd]
+
Include error message in email alert
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
diff --git a/providers/cncf/kubernetes/provider.yaml
b/providers/cncf/kubernetes/provider.yaml
index 833822c3442..dbdb7dee67e 100644
--- a/providers/cncf/kubernetes/provider.yaml
+++ b/providers/cncf/kubernetes/provider.yaml
@@ -146,6 +146,8 @@ connection-types:
task-decorators:
- class-name:
airflow.providers.cncf.kubernetes.decorators.kubernetes.kubernetes_task
name: kubernetes
+ - class-name:
airflow.providers.cncf.kubernetes.decorators.kubernetes_cmd.kubernetes_cmd_task
+ name: kubernetes_cmd
config:
local_kubernetes_executor:
diff --git
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/decorators/kubernetes_cmd.py
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/decorators/kubernetes_cmd.py
new file mode 100644
index 00000000000..a65efad1ae6
--- /dev/null
+++
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/decorators/kubernetes_cmd.py
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import warnings
+from collections.abc import Sequence
+from typing import TYPE_CHECKING, Callable
+
+from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS
+
+if AIRFLOW_V_3_0_PLUS:
+ from airflow.sdk.bases.decorator import DecoratedOperator, TaskDecorator,
task_decorator_factory
+else:
+ from airflow.decorators.base import ( # type: ignore[no-redef]
+ DecoratedOperator,
+ TaskDecorator,
+ task_decorator_factory,
+ )
+from airflow.providers.cncf.kubernetes.operators.pod import
KubernetesPodOperator
+from airflow.utils.context import context_merge
+from airflow.utils.operator_helpers import determine_kwargs
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+
+class _KubernetesCmdDecoratedOperator(DecoratedOperator,
KubernetesPodOperator):
+ custom_operator_name = "@task.kubernetes_cmd"
+
+ template_fields: Sequence[str] = KubernetesPodOperator.template_fields
+ overwrite_rtif_after_execution: bool = True
+
+ def __init__(self, *, python_callable: Callable, args_only: bool = False,
**kwargs) -> None:
+ self.args_only = args_only
+
+ cmds = kwargs.pop("cmds", None)
+ arguments = kwargs.pop("arguments", None)
+
+ if cmds is not None or arguments is not None:
+ warnings.warn(
+ f"The `cmds` and `arguments` are unused in
{self.custom_operator_name} decorator. "
+ "You should return a list of commands or image entrypoint
arguments with "
+ "args_only=True from the python_callable.",
+ UserWarning,
+ stacklevel=3,
+ )
+
+ # If the name was not provided, we generate operator name from the
python_callable
+ # we also instruct operator to add a random suffix to avoid collisions
by default
+ op_name = kwargs.pop("name",
f"k8s-airflow-pod-{python_callable.__name__}")
+ random_name_suffix = kwargs.pop("random_name_suffix", True)
+
+ super().__init__(
+ python_callable=python_callable,
+ name=op_name,
+ random_name_suffix=random_name_suffix,
+ cmds=None,
+ arguments=None,
+ **kwargs,
+ )
+
+ def execute(self, context: Context):
+ generated = self._generate_cmds(context)
+ if self.args_only:
+ self.cmds = []
+ self.arguments = generated
+ else:
+ self.cmds = generated
+ self.arguments = []
+ context["ti"].render_templates() # type: ignore[attr-defined]
+ return super().execute(context)
+
+ def _generate_cmds(self, context: Context) -> list[str]:
+ context_merge(context, self.op_kwargs)
+ kwargs = determine_kwargs(self.python_callable, self.op_args, context)
+ generated_cmds = self.python_callable(*self.op_args, **kwargs)
+ func_name = self.python_callable.__name__
+ if not isinstance(generated_cmds, list):
+ raise TypeError(
+ f"Expected python_callable to return a list of strings, but
got {type(generated_cmds)}"
+ )
+ if not all(isinstance(cmd, str) for cmd in generated_cmds):
+ raise TypeError(f"Expected {func_name} to return a list of
strings, but got {generated_cmds}")
+ if not generated_cmds:
+ raise ValueError(f"The {func_name} returned an empty list of
commands")
+
+ return generated_cmds
+
+
+def kubernetes_cmd_task(
+ python_callable: Callable | None = None,
+ **kwargs,
+) -> TaskDecorator:
+ """
+ Kubernetes cmd operator decorator.
+
+ This wraps a function which should return command to be executed
+ in K8s using KubernetesPodOperator. The function should return a list of
strings.
+ If args_only is set to True, the function should return a list of
arguments for
+ container default command. Also accepts any argument that
KubernetesPodOperator
+ will via ``kwargs``. Can be reused in a single DAG.
+
+ :param python_callable: Function to decorate
+ """
+ return task_decorator_factory(
+ python_callable=python_callable,
+ decorated_operator_class=_KubernetesCmdDecoratedOperator,
+ **kwargs,
+ )
diff --git
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/get_provider_info.py
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/get_provider_info.py
index 6426b82404e..821f95cf614 100644
---
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/get_provider_info.py
+++
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/get_provider_info.py
@@ -85,7 +85,11 @@ def get_provider_info():
{
"class-name":
"airflow.providers.cncf.kubernetes.decorators.kubernetes.kubernetes_task",
"name": "kubernetes",
- }
+ },
+ {
+ "class-name":
"airflow.providers.cncf.kubernetes.decorators.kubernetes_cmd.kubernetes_cmd_task",
+ "name": "kubernetes_cmd",
+ },
],
"config": {
"local_kubernetes_executor": {
diff --git
a/providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_cmd_decorator.py
b/providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_cmd_decorator.py
new file mode 100644
index 00000000000..3235dffe0f3
--- /dev/null
+++
b/providers/cncf/kubernetes/tests/system/cncf/kubernetes/example_kubernetes_cmd_decorator.py
@@ -0,0 +1,76 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from datetime import datetime
+
+from airflow.sdk import DAG, task
+
+with DAG(
+ dag_id="example_kubernetes_cmd_decorator",
+ schedule=None,
+ start_date=datetime(2021, 1, 1),
+ tags=["example", "cncf", "kubernetes"],
+ catchup=False,
+) as dag:
+ # [START howto_decorator_kubernetes_cmd]
+ @task
+ def foo() -> str:
+ return "foo"
+
+ @task
+ def bar() -> str:
+ return "bar"
+
+ @task.kubernetes_cmd(
+ image="bash:5.2",
+ name="full_cmd",
+ in_cluster=False,
+ )
+ def execute_in_k8s_pod_full_cmd(foo_result: str, bar_result: str) ->
list[str]:
+ return ["echo", "-e", f"With full cmd:\\t{foo_result}\\t{bar_result}"]
+
+ # The args_only parameter is used to indicate that the decorated function
will
+ # return a list of arguments to be passed as arguments to the container
entrypoint:
+ # in this case, the `bash` command
+ @task.kubernetes_cmd(args_only=True, image="bash:5.2", in_cluster=False)
+ def execute_in_k8s_pod_args_only(foo_result: str, bar_result: str) ->
list[str]:
+ return ["-c", f"echo -e 'With args
only:\\t{foo_result}\\t{bar_result}'"]
+
+ # Templating can be used in the returned command and all other templated
fields in
+ # the decorator parameters.
+ @task.kubernetes_cmd(image="bash:5.2", name="my-pod-{{ ti.task_id }}",
in_cluster=False)
+ def apply_templating(message: str) -> list[str]:
+ full_message = "Templated task_id: {{ ti.task_id }}, dag_id: " +
message
+ return ["echo", full_message]
+
+ foo_result = foo()
+ bar_result = bar()
+
+ full_cmd_instance = execute_in_k8s_pod_full_cmd(foo_result, bar_result)
+ args_instance = execute_in_k8s_pod_args_only(foo_result, bar_result)
+
+ [full_cmd_instance, args_instance] >> apply_templating("{{ dag.dag_id }}")
+
+ # [END howto_decorator_kubernetes_cmd]
+
+
+from tests_common.test_utils.system_tests import get_test_run
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes.py
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes.py
index 528bce1780e..b5c4448e193 100644
---
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes.py
+++
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes.py
@@ -18,311 +18,109 @@ from __future__ import annotations
import base64
import pickle
-from unittest import mock
-import pytest
+from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS
-from airflow.decorators import setup, task, teardown
-from airflow.utils import timezone
+if AIRFLOW_V_3_0_PLUS:
+ from airflow.sdk import task
+else:
+ from airflow.decorators import task
+from unit.cncf.kubernetes.decorators.test_kubernetes_commons import
TestKubernetesDecoratorsBase
-pytestmark = pytest.mark.db_test
-
-
-DEFAULT_DATE = timezone.datetime(2021, 9, 1)
-
-KPO_MODULE = "airflow.providers.cncf.kubernetes.operators.pod"
-POD_MANAGER_CLASS =
"airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager"
-HOOK_CLASS = "airflow.providers.cncf.kubernetes.operators.pod.KubernetesHook"
XCOM_IMAGE = "XCOM_IMAGE"
[email protected](autouse=True)
-def mock_create_pod() -> mock.Mock:
- return mock.patch(f"{POD_MANAGER_CLASS}.create_pod").start()
-
-
[email protected](autouse=True)
-def mock_await_pod_start() -> mock.Mock:
- return mock.patch(f"{POD_MANAGER_CLASS}.await_pod_start").start()
-
-
[email protected](autouse=True)
-def await_xcom_sidecar_container_start() -> mock.Mock:
- return
mock.patch(f"{POD_MANAGER_CLASS}.await_xcom_sidecar_container_start").start()
-
-
[email protected](autouse=True)
-def extract_xcom() -> mock.Mock:
- f = mock.patch(f"{POD_MANAGER_CLASS}.extract_xcom").start()
- f.return_value = '{"key1": "value1", "key2": "value2"}'
- return f
-
-
[email protected](autouse=True)
-def mock_await_pod_completion() -> mock.Mock:
- f = mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion").start()
- f.return_value = mock.MagicMock(**{"status.phase": "Succeeded"})
- return f
-
-
[email protected](autouse=True)
-def mock_hook():
- return mock.patch(HOOK_CLASS).start()
-
-
-# Without this patch each time pod manager would try to extract logs from the
pod
-# and log an error about it's inability to get containers for the log
-# {pod_manager.py:572} ERROR - Could not retrieve containers for the pod: ...
[email protected](autouse=True)
-def mock_fetch_logs() -> mock.Mock:
- f =
mock.patch(f"{POD_MANAGER_CLASS}.fetch_requested_container_logs").start()
- f.return_value = "logs"
- return f
-
-
-def test_basic_kubernetes(dag_maker, session, mock_create_pod: mock.Mock,
mock_hook: mock.Mock) -> None:
- with dag_maker(session=session) as dag:
-
- @task.kubernetes(
- image="python:3.10-slim-buster",
- in_cluster=False,
- cluster_context="default",
- config_file="/tmp/fake_file",
- namespace="default",
- )
- def f():
- import random
-
- return [random.random() for _ in range(100)]
-
- f()
-
- dr = dag_maker.create_dagrun()
- (ti,) = dr.task_instances
- session.add(ti)
- session.commit()
- dag.get_task("f").execute(context=ti.get_template_context(session=session))
- mock_hook.assert_called_once_with(
- conn_id="kubernetes_default",
- in_cluster=False,
- cluster_context="default",
- config_file="/tmp/fake_file",
- )
- assert mock_create_pod.call_count == 1
-
- containers = mock_create_pod.call_args.kwargs["pod"].spec.containers
- assert len(containers) == 1
- assert containers[0].command[0] == "bash"
- assert len(containers[0].args) == 0
- assert containers[0].env[0].name == "__PYTHON_SCRIPT"
- assert containers[0].env[0].value
- assert containers[0].env[1].name == "__PYTHON_INPUT"
-
- # Ensure we pass input through a b64 encoded env var
- decoded_input = pickle.loads(base64.b64decode(containers[0].env[1].value))
- assert decoded_input == {"args": [], "kwargs": {}}
-
-
-def test_kubernetes_with_input_output(
- dag_maker, session, mock_create_pod: mock.Mock, mock_hook: mock.Mock
-) -> None:
- with dag_maker(session=session) as dag:
-
- @task.kubernetes(
- image="python:3.10-slim-buster",
- in_cluster=False,
- cluster_context="default",
- config_file="/tmp/fake_file",
- namespace="default",
- )
- def f(arg1, arg2, kwarg1=None, kwarg2=None):
- return {"key1": "value1", "key2": "value2"}
-
- f.override(task_id="my_task_id", do_xcom_push=True)("arg1", "arg2",
kwarg1="kwarg1")
-
- mock_hook.return_value.get_xcom_sidecar_container_image.return_value =
XCOM_IMAGE
- mock_hook.return_value.get_xcom_sidecar_container_resources.return_value =
{
- "requests": {"cpu": "1m", "memory": "10Mi"},
- "limits": {"cpu": "1m", "memory": "50Mi"},
- }
-
- dr = dag_maker.create_dagrun()
- (ti,) = dr.task_instances
- session.add(dr)
- session.commit()
-
dag.get_task("my_task_id").execute(context=ti.get_template_context(session=session))
-
- mock_hook.assert_called_once_with(
- conn_id="kubernetes_default",
- in_cluster=False,
- cluster_context="default",
- config_file="/tmp/fake_file",
- )
- assert mock_create_pod.call_count == 1
- assert mock_hook.return_value.get_xcom_sidecar_container_image.call_count
== 1
- assert
mock_hook.return_value.get_xcom_sidecar_container_resources.call_count == 1
-
- containers = mock_create_pod.call_args.kwargs["pod"].spec.containers
-
- # First container is Python script
- assert len(containers) == 2
- assert containers[0].command[0] == "bash"
- assert len(containers[0].args) == 0
-
- assert containers[0].env[0].name == "__PYTHON_SCRIPT"
- assert containers[0].env[0].value
- assert containers[0].env[1].name == "__PYTHON_INPUT"
- assert containers[0].env[1].value
-
- # Ensure we pass input through a b64 encoded env var
- decoded_input = pickle.loads(base64.b64decode(containers[0].env[1].value))
- assert decoded_input == {"args": ("arg1", "arg2"), "kwargs": {"kwarg1":
"kwarg1"}}
-
- # Second container is xcom image
- assert containers[1].image == XCOM_IMAGE
- assert containers[1].volume_mounts[0].mount_path == "/airflow/xcom"
-
-
-def test_kubernetes_with_marked_as_setup(
- dag_maker, session, mock_create_pod: mock.Mock, mock_hook: mock.Mock
-) -> None:
- with dag_maker(session=session) as dag:
-
- @setup
- @task.kubernetes(
- image="python:3.10-slim-buster",
- in_cluster=False,
- cluster_context="default",
- config_file="/tmp/fake_file",
- )
- def f():
- return {"key1": "value1", "key2": "value2"}
+class TestKubernetesDecorator(TestKubernetesDecoratorsBase):
+ def test_basic_kubernetes(self):
+ """Test basic proper KubernetesPodOperator creation from
@task.kubernetes decorator"""
+ with self.dag:
- f()
+ @task.kubernetes(
+ image="python:3.10-slim-buster",
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ namespace="default",
+ )
+ def f():
+ import random
- assert len(dag.task_group.children) == 1
- setup_task = dag.task_group.children["f"]
- assert setup_task.is_setup
+ return [random.random() for _ in range(100)]
+ k8s_task = f()
-def test_kubernetes_with_marked_as_teardown(
- dag_maker, session, mock_create_pod: mock.Mock, mock_hook: mock.Mock
-) -> None:
- with dag_maker(session=session) as dag:
+ self.execute_task(k8s_task)
- @teardown
- @task.kubernetes(
- image="python:3.10-slim-buster",
+ self.mock_hook.assert_called_once_with(
+ conn_id="kubernetes_default",
in_cluster=False,
cluster_context="default",
config_file="/tmp/fake_file",
)
- def f():
- return {"key1": "value1", "key2": "value2"}
-
- f()
-
- assert len(dag.task_group.children) == 1
- teardown_task = dag.task_group.children["f"]
- assert teardown_task.is_teardown
-
-
[email protected](
- "name",
- ["no_name_in_args", None, "test_task_name"],
- ids=["no_name_in_args", "name_set_to_None", "with_name"],
-)
[email protected](
- "random_name_suffix",
- [True, False],
- ids=["rand_suffix", "no_rand_suffix"],
-)
-def test_pod_naming(
- dag_maker,
- session,
- mock_create_pod: mock.Mock,
- name: str | None,
- random_name_suffix: bool,
-) -> None:
- """
- Idea behind this test is to check naming conventions are respected in
various
- decorator arguments combinations scenarios.
-
- @task.kubernetes differs from KubernetesPodOperator in a way that it
distinguishes
- between no name argument was provided and name was set to None.
- In the first case, the operator name is generated from the python_callable
name,
- in the second case default KubernetesPodOperator behavior is preserved.
- """
- extra_kwargs = {"name": name}
- if name == "no_name_in_args":
- extra_kwargs.pop("name")
-
- with dag_maker(session=session) as dag:
-
- @task.kubernetes(
- image="python:3.10-slim-buster",
+ assert self.mock_create_pod.call_count == 1
+
+ containers =
self.mock_create_pod.call_args.kwargs["pod"].spec.containers
+ assert len(containers) == 1
+ assert containers[0].command[0] == "bash"
+ assert len(containers[0].args) == 0
+ assert containers[0].env[0].name == "__PYTHON_SCRIPT"
+ assert containers[0].env[0].value
+ assert containers[0].env[1].name == "__PYTHON_INPUT"
+
+ # Ensure we pass input through a b64 encoded env var
+ decoded_input =
pickle.loads(base64.b64decode(containers[0].env[1].value))
+ assert decoded_input == {"args": [], "kwargs": {}}
+
+ def test_kubernetes_with_input_output(self):
+ """Verify @task.kubernetes will run XCom container if do_xcom_push is
set."""
+ with self.dag:
+
+ @task.kubernetes(
+ image="python:3.10-slim-buster",
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ namespace="default",
+ )
+ def f(arg1, arg2, kwarg1=None, kwarg2=None):
+ return {"key1": "value1", "key2": "value2"}
+
+ k8s_task = f.override(task_id="my_task_id",
do_xcom_push=True)("arg1", "arg2", kwarg1="kwarg1")
+
+
self.mock_hook.return_value.get_xcom_sidecar_container_image.return_value =
XCOM_IMAGE
+
self.mock_hook.return_value.get_xcom_sidecar_container_resources.return_value =
{
+ "requests": {"cpu": "1m", "memory": "10Mi"},
+ "limits": {"cpu": "1m", "memory": "50Mi"},
+ }
+
+ self.execute_task(k8s_task)
+ assert self.mock_create_pod.call_count == 1
+
+ self.mock_hook.assert_called_once_with(
+ conn_id="kubernetes_default",
in_cluster=False,
cluster_context="default",
config_file="/tmp/fake_file",
- random_name_suffix=random_name_suffix,
- namespace="default",
- **extra_kwargs, # type: ignore
)
- def task_function_name():
- return 42
-
- task_function_name()
-
- dr = dag_maker.create_dagrun()
- (ti,) = dr.task_instances
- session.add(ti)
- session.commit()
-
- task_id = "task_function_name"
- op = dag.get_task(task_id)
- if name is not None:
- assert isinstance(op.name, str)
-
- # If name was explicitly set to None, we expect the operator name to be
None
- if name is None:
- assert op.name is None
- # If name was not provided in decorator, it would be generated:
- # f"k8s-airflow-pod-{python_callable.__name__}"
- elif name == "no_name_in_args":
- assert op.name == f"k8s-airflow-pod-{task_id}"
- # Otherwise, we expect the name to be exactly the same as provided
- else:
- assert op.name == name
-
- op.execute(context=ti.get_template_context(session=session))
- pod_meta = mock_create_pod.call_args.kwargs["pod"].metadata
- assert isinstance(pod_meta.name, str)
-
- # After execution pod names should not contain underscores
- task_id_normalized = task_id.replace("_", "-")
-
- def check_op_name(name_arg: str | None) -> str:
- if name_arg is None:
- assert op.name is None
- return task_id_normalized
+ assert
self.mock_hook.return_value.get_xcom_sidecar_container_image.call_count == 1
+ assert
self.mock_hook.return_value.get_xcom_sidecar_container_resources.call_count == 1
- assert isinstance(op.name, str)
- if name_arg == "no_name_in_args":
- generated_name = f"k8s-airflow-pod-{task_id_normalized}"
- assert op.name == generated_name
- return generated_name
+ containers =
self.mock_create_pod.call_args.kwargs["pod"].spec.containers
- normalized_name = name_arg.replace("_", "-")
- assert op.name == normalized_name
+ # First container is Python script
+ assert len(containers) == 2
+ assert containers[0].command[0] == "bash"
+ assert len(containers[0].args) == 0
- return normalized_name
+ assert containers[0].env[0].name == "__PYTHON_SCRIPT"
+ assert containers[0].env[0].value
+ assert containers[0].env[1].name == "__PYTHON_INPUT"
+ assert containers[0].env[1].value
- def check_pod_name(name_base: str):
- if random_name_suffix:
- assert pod_meta.name.startswith(f"{name_base}")
- assert pod_meta.name != name_base
- else:
- assert pod_meta.name == name_base
+ # Ensure we pass input through a b64 encoded env var
+ decoded_input =
pickle.loads(base64.b64decode(containers[0].env[1].value))
+ assert decoded_input == {"args": ("arg1", "arg2"), "kwargs":
{"kwarg1": "kwarg1"}}
- pod_name = check_op_name(name)
- check_pod_name(pod_name)
+ # Second container is xcom image
+ assert containers[1].image == XCOM_IMAGE
+ assert containers[1].volume_mounts[0].mount_path == "/airflow/xcom"
diff --git
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_cmd.py
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_cmd.py
new file mode 100644
index 00000000000..9236ecf86e5
--- /dev/null
+++
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_cmd.py
@@ -0,0 +1,390 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import contextlib
+
+import pytest
+
+from airflow.exceptions import AirflowSkipException
+from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS
+
+if AIRFLOW_V_3_0_PLUS:
+ from airflow.sdk import task
+else:
+ from airflow.decorators import task
+from unit.cncf.kubernetes.decorators.test_kubernetes_commons import DAG_ID,
TestKubernetesDecoratorsBase
+
+XCOM_IMAGE = "XCOM_IMAGE"
+
+
+class TestKubernetesCmdDecorator(TestKubernetesDecoratorsBase):
+ @pytest.mark.parametrize(
+ "args_only",
+ [True, False],
+ )
+ def test_basic_kubernetes(self, args_only: bool):
+ """Test basic proper KubernetesPodOperator creation from
@task.kubernetes_cmd decorator"""
+ expected = ["echo", "Hello world!"]
+ with self.dag:
+
+ @task.kubernetes_cmd(
+ image="python:3.10-slim-buster",
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ namespace="default",
+ args_only=args_only,
+ )
+ def hello():
+ return expected
+
+ k8s_task = hello()
+
+ self.execute_task(k8s_task)
+
+ self.mock_hook.assert_called_once_with(
+ conn_id="kubernetes_default",
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ )
+ assert self.mock_create_pod.call_count == 1
+
+ containers =
self.mock_create_pod.call_args.kwargs["pod"].spec.containers
+ assert len(containers) == 1
+
+ expected_command = expected
+ expected_args = []
+ if args_only:
+ expected_args = expected_command
+ expected_command = []
+
+ assert containers[0].command == expected_command
+ assert containers[0].args == expected_args
+
+ @pytest.mark.parametrize(
+ "func_return, exception",
+ [
+ pytest.param("string", TypeError, id="iterable_str"),
+ pytest.param(True, TypeError, id="bool"),
+ pytest.param(42, TypeError, id="int"),
+ pytest.param(None, TypeError, id="None"),
+ pytest.param(("a", "b"), TypeError, id="tuple"),
+ pytest.param([], ValueError, id="empty_list"),
+ pytest.param(["echo", 123], TypeError, id="mixed_list"),
+ pytest.param(["echo", "Hello world!"], None, id="valid_list"),
+ ],
+ )
+ def test_kubernetes_cmd_wrong_cmd(
+ self,
+ func_return,
+ exception,
+ ):
+ """
+ Test that @task.kubernetes_cmd raises an error if the python_callable
returns
+ an invalid value.
+ """
+ with self.dag:
+
+ @task.kubernetes_cmd(
+ image="python:3.10-slim-buster",
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ namespace="default",
+ )
+ def hello():
+ return func_return
+
+ k8s_task = hello()
+
+ context_manager = pytest.raises(exception) if exception else
contextlib.nullcontext()
+ with context_manager:
+ self.execute_task(k8s_task)
+
+ def test_kubernetes_cmd_with_input_output(self):
+ """Verify @task.kubernetes_cmd will run XCom container if do_xcom_push
is set."""
+ with self.dag:
+
+ @task.kubernetes_cmd(
+ image="python:3.10-slim-buster",
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ namespace="default",
+ )
+ def f(arg1: str, arg2: str, kwarg1: str | None = None, kwarg2: str
| None = None):
+ return [
+ "echo",
+ f"arg1={arg1}",
+ f"arg2={arg2}",
+ f"kwarg1={kwarg1}",
+ f"kwarg2={kwarg2}",
+ ]
+
+ k8s_task = f.override(task_id="my_task_id",
do_xcom_push=True)("arg1", "arg2", kwarg1="kwarg1")
+
+
self.mock_hook.return_value.get_xcom_sidecar_container_image.return_value =
XCOM_IMAGE
+
self.mock_hook.return_value.get_xcom_sidecar_container_resources.return_value =
{
+ "requests": {"cpu": "1m", "memory": "10Mi"},
+ "limits": {"cpu": "1m", "memory": "50Mi"},
+ }
+ self.execute_task(k8s_task)
+
+ self.mock_hook.assert_called_once_with(
+ conn_id="kubernetes_default",
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ )
+ assert self.mock_create_pod.call_count == 1
+ assert
self.mock_hook.return_value.get_xcom_sidecar_container_image.call_count == 1
+ assert
self.mock_hook.return_value.get_xcom_sidecar_container_resources.call_count == 1
+
+ containers =
self.mock_create_pod.call_args.kwargs["pod"].spec.containers
+
+ # First container is main one with command
+ assert len(containers) == 2
+ assert containers[0].command == ["echo", "arg1=arg1", "arg2=arg2",
"kwarg1=kwarg1", "kwarg2=None"]
+ assert len(containers[0].args) == 0
+
+ # Second container is xcom image
+ assert containers[1].image == XCOM_IMAGE
+ assert containers[1].volume_mounts[0].mount_path == "/airflow/xcom"
+
+ @pytest.mark.parametrize(
+ "cmds",
+ [None, ["ignored_cmd"], "ignored_cmd"],
+ )
+ @pytest.mark.parametrize(
+ "arguments",
+ [None, ["ignored_arg"], "ignored_arg"],
+ )
+ @pytest.mark.parametrize(
+ "args_only",
+ [True, False],
+ )
+ def test_ignored_decorator_parameters(
+ self,
+ cmds: list[str] | None,
+ arguments: list[str] | None,
+ args_only: bool,
+ ) -> None:
+ """
+ Test setting `cmds` or `arguments` from decorator does not affect the
operator.
+ And the warning is shown only if `cmds` or `arguments` are not None.
+ """
+ context_manager = pytest.warns(UserWarning, match="The `cmds` and
`arguments` are unused")
+ # Don't warn if both `cmds` and `arguments` are None
+ if cmds is None and arguments is None:
+ context_manager = contextlib.nullcontext() # type: ignore
+
+ expected = ["func", "return"]
+ with self.dag:
+ # We need to suppress the warning about `cmds` and `arguments`
being unused
+ with context_manager:
+
+ @task.kubernetes_cmd(
+ image="python:3.10-slim-buster",
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ namespace="default",
+ cmds=cmds,
+ arguments=arguments,
+ args_only=args_only,
+ )
+ def hello():
+ return expected
+
+ hello_task = hello()
+
+ assert hello_task.operator.cmds == []
+ assert hello_task.operator.arguments == []
+
+ self.execute_task(hello_task)
+ containers =
self.mock_create_pod.call_args.kwargs["pod"].spec.containers
+ assert len(containers) == 1
+
+ expected_command = expected
+ expected_args = []
+ if args_only:
+ expected_args = expected_command
+ expected_command = []
+ assert containers[0].command == expected_command
+ assert containers[0].args == expected_args
+
+ @pytest.mark.parametrize(
+ argnames=["command", "op_arg", "expected_command"],
+ argvalues=[
+ pytest.param(
+ ["echo", "hello"],
+ "world",
+ ["echo", "hello", "world"],
+ id="not_templated",
+ ),
+ pytest.param(
+ ["echo", "{{ ti.task_id }}"], "{{ ti.dag_id }}", ["echo",
"hello", DAG_ID], id="templated"
+ ),
+ ],
+ )
+ def test_rendering_kubernetes_cmd(
+ self,
+ command: list[str],
+ op_arg: str,
+ expected_command: list[str],
+ ):
+ """Test that templating works in function return value"""
+ with self.dag:
+
+ @task.kubernetes_cmd(
+ image="python:3.10-slim-buster",
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ namespace="default",
+ )
+ def hello(add_to_command: str):
+ return command + [add_to_command]
+
+ hello_task = hello(op_arg)
+
+ self.execute_task(hello_task)
+
+ self.mock_hook.assert_called_once_with(
+ conn_id="kubernetes_default",
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ )
+ containers =
self.mock_create_pod.call_args.kwargs["pod"].spec.containers
+ assert len(containers) == 1
+
+ assert containers[0].command == expected_command
+ assert containers[0].args == []
+
+ def test_basic_context_works(self):
+ """Test that decorator works with context as kwargs unpcacked in
function arguments"""
+ with self.dag:
+
+ @task.kubernetes_cmd(
+ image="python:3.10-slim-buster",
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ namespace="default",
+ )
+ def hello(**context):
+ return ["echo", context["ti"].task_id,
context["dag_run"].dag_id]
+
+ hello_task = hello()
+
+ self.execute_task(hello_task)
+
+ self.mock_hook.assert_called_once_with(
+ conn_id="kubernetes_default",
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ )
+ containers =
self.mock_create_pod.call_args.kwargs["pod"].spec.containers
+ assert len(containers) == 1
+
+ assert containers[0].command == ["echo", "hello", DAG_ID]
+ assert containers[0].args == []
+
+ def test_named_context_variables(self):
+ """Test that decorator works with specific context variable as kwargs
in function arguments"""
+ with self.dag:
+
+ @task.kubernetes_cmd(
+ image="python:3.10-slim-buster",
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ namespace="default",
+ )
+ def hello(ti=None, dag_run=None):
+ return ["echo", ti.task_id, dag_run.dag_id]
+
+ hello_task = hello()
+
+ self.execute_task(hello_task)
+
+ self.mock_hook.assert_called_once_with(
+ conn_id="kubernetes_default",
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ )
+ containers =
self.mock_create_pod.call_args.kwargs["pod"].spec.containers
+ assert len(containers) == 1
+
+ assert containers[0].command == ["echo", "hello", DAG_ID]
+ assert containers[0].args == []
+
+ def test_rendering_kubernetes_cmd_decorator_params(self):
+ """Test that templating works in decorator parameters"""
+ with self.dag:
+
+ @task.kubernetes_cmd(
+ image="python:{{ dag.dag_id }}",
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ namespace="default",
+ kubernetes_conn_id="kubernetes_{{ dag.dag_id }}",
+ )
+ def hello():
+ return ["echo", "Hello world!"]
+
+ hello_task = hello()
+
+ self.execute_task(hello_task)
+
+ self.mock_hook.assert_called_once_with(
+ conn_id="kubernetes_" + DAG_ID,
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ )
+ containers =
self.mock_create_pod.call_args.kwargs["pod"].spec.containers
+ assert len(containers) == 1
+
+ assert containers[0].image == f"python:{DAG_ID}"
+
+ def test_airflow_skip(self):
+ """Test that the operator is skipped if the task is skipped"""
+ with self.dag:
+
+ @task.kubernetes_cmd(
+ image="python:3.10-slim-buster",
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ namespace="default",
+ )
+ def hello():
+ raise AirflowSkipException("This task should be skipped")
+
+ hello_task = hello()
+
+ with pytest.raises(AirflowSkipException):
+ self.execute_task(hello_task)
+ self.mock_hook.assert_not_called()
+ self.mock_create_pod.assert_not_called()
diff --git
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_commons.py
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_commons.py
new file mode 100644
index 00000000000..16db4b120fb
--- /dev/null
+++
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/decorators/test_kubernetes_commons.py
@@ -0,0 +1,280 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Callable
+from unittest import mock
+
+import pytest
+
+from airflow.providers.cncf.kubernetes.version_compat import AIRFLOW_V_3_0_PLUS
+
+if AIRFLOW_V_3_0_PLUS:
+ from airflow.sdk import setup, task, teardown
+else:
+ from airflow.decorators import setup, task, teardown
+
+from airflow.utils import timezone
+
+from tests_common.test_utils.db import clear_db_dags, clear_db_runs,
clear_rendered_ti_fields
+
+TASK_FUNCTION_NAME_ID = "task_function_name"
+DEFAULT_DATE = timezone.datetime(2023, 1, 1)
+DAG_ID = "k8s_deco_test_dag"
+
+
+def _kubernetes_func():
+ return {"key1": "value1", "key2": "value2"}
+
+
+def _kubernetes_cmd_func():
+ return ["echo", "Hello world!"]
+
+
+def _get_decorator_func(decorator_name: str) -> Callable:
+ if decorator_name == "kubernetes":
+ return _kubernetes_func
+ if decorator_name == "kubernetes_cmd":
+ return _kubernetes_cmd_func
+ raise ValueError(f"Unknown decorator {decorator_name}")
+
+
+def _prepare_task(
+ task_decorator: Callable,
+ decorator_name: str,
+ **decorator_kwargs,
+) -> Callable:
+ func_to_use = _get_decorator_func(decorator_name)
+
+ @task_decorator(
+ image="python:3.10-slim-buster",
+ in_cluster=False,
+ cluster_context="default",
+ config_file="/tmp/fake_file",
+ **decorator_kwargs,
+ )
+ def task_function_name():
+ return func_to_use()
+
+ return task_function_name
+
+
+KPO_MODULE = "airflow.providers.cncf.kubernetes.operators.pod"
+POD_MANAGER_CLASS =
"airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager"
+HOOK_CLASS = "airflow.providers.cncf.kubernetes.operators.pod.KubernetesHook"
+
+
[email protected]_test
+class TestKubernetesDecoratorsBase:
+ @pytest.fixture(autouse=True)
+ def setup(self, dag_maker):
+ self.dag_maker = dag_maker
+
+ with dag_maker(dag_id=DAG_ID) as dag:
+ ...
+
+ self.dag = dag
+
+ self.mock_create_pod =
mock.patch(f"{POD_MANAGER_CLASS}.create_pod").start()
+ self.mock_await_pod_start =
mock.patch(f"{POD_MANAGER_CLASS}.await_pod_start").start()
+ self.mock_await_xcom_sidecar_container_start = mock.patch(
+ f"{POD_MANAGER_CLASS}.await_xcom_sidecar_container_start"
+ ).start()
+
+ self.mock_extract_xcom =
mock.patch(f"{POD_MANAGER_CLASS}.extract_xcom").start()
+ self.mock_extract_xcom.return_value = '{"key1": "value1", "key2":
"value2"}'
+
+ self.mock_await_pod_completion =
mock.patch(f"{POD_MANAGER_CLASS}.await_pod_completion").start()
+ self.mock_await_pod_completion.return_value =
mock.MagicMock(**{"status.phase": "Succeeded"})
+ self.mock_hook = mock.patch(HOOK_CLASS).start()
+
+ # Without this patch each time pod manager would try to extract logs
from the pod
+ # and log an error about it's inability to get containers for the log
+ # {pod_manager.py:572} ERROR - Could not retrieve containers for the
pod: ...
+ self.mock_fetch_logs =
mock.patch(f"{POD_MANAGER_CLASS}.fetch_requested_container_logs").start()
+ self.mock_fetch_logs.return_value = "logs"
+
+ def teardown_method(self):
+ clear_db_runs()
+ clear_db_dags()
+ clear_rendered_ti_fields()
+
+ def execute_task(self, task):
+ session = self.dag_maker.session
+ dag_run = self.dag_maker.create_dagrun(
+ run_id=f"k8s_decorator_test_{DEFAULT_DATE.date()}", session=session
+ )
+ ti = dag_run.get_task_instance(task.operator.task_id, session=session)
+ return_val =
task.operator.execute(context=ti.get_template_context(session=session))
+
+ return ti, return_val
+
+
+def parametrize_kubernetes_decorators_commons(cls):
+ for name, method in cls.__dict__.items():
+ if not name.startswith("test_") or not callable(method):
+ continue
+ new_method = pytest.mark.parametrize(
+ "task_decorator,decorator_name",
+ [
+ (task.kubernetes, "kubernetes"),
+ (task.kubernetes_cmd, "kubernetes_cmd"),
+ ],
+ ids=["kubernetes", "kubernetes_cmd"],
+ )(method)
+ setattr(cls, name, new_method)
+
+ return cls
+
+
+@parametrize_kubernetes_decorators_commons
+class TestKubernetesDecoratorsCommons(TestKubernetesDecoratorsBase):
+ def test_k8s_decorator_init(self, task_decorator, decorator_name):
+ """Test the initialization of the @task.kubernetes[_cmd] decorated
task."""
+
+ with self.dag:
+
+ @task_decorator(
+ image="python:3.10-slim-buster",
+ in_cluster=False,
+ cluster_context="default",
+ )
+ def k8s_task_function() -> list[str]:
+ return ["return", "value"]
+
+ k8s_task = k8s_task_function()
+
+ assert k8s_task.operator.task_id == "k8s_task_function"
+ assert k8s_task.operator.image == "python:3.10-slim-buster"
+
+ expected_cmds = ["placeholder-command"] if decorator_name ==
"kubernetes" else []
+ assert k8s_task.operator.cmds == expected_cmds
+ assert k8s_task.operator.random_name_suffix is True
+
+ def test_decorators_with_marked_as_setup(self, task_decorator,
decorator_name):
+ """Test the @task.kubernetes[_cmd] decorated task works with setup
decorator."""
+ with self.dag:
+ task_function_name = setup(_prepare_task(task_decorator,
decorator_name))
+ task_function_name()
+
+ assert len(self.dag.task_group.children) == 1
+ setup_task = self.dag.task_group.children[TASK_FUNCTION_NAME_ID]
+ assert setup_task.is_setup
+
+ def test_decorators_with_marked_as_teardown(self, task_decorator,
decorator_name):
+ """Test the @task.kubernetes[_cmd] decorated task works with teardown
decorator."""
+ with self.dag:
+ task_function_name = teardown(_prepare_task(task_decorator,
decorator_name))
+ task_function_name()
+
+ assert len(self.dag.task_group.children) == 1
+ teardown_task = self.dag.task_group.children[TASK_FUNCTION_NAME_ID]
+ assert teardown_task.is_teardown
+
+ @pytest.mark.parametrize(
+ "name",
+ ["no_name_in_args", None, "test_task_name"],
+ ids=["no_name_in_args", "name_set_to_None", "with_name"],
+ )
+ @pytest.mark.parametrize(
+ "random_name_suffix",
+ [True, False],
+ ids=["rand_suffix", "no_rand_suffix"],
+ )
+ def test_pod_naming(
+ self,
+ task_decorator,
+ decorator_name,
+ name: str | None,
+ random_name_suffix: bool,
+ ) -> None:
+ """
+ Idea behind this test is to check naming conventions are respected in
various
+ decorator arguments combinations scenarios.
+
+ @task.kubernetes[_cmd] differs from KubernetesPodOperator in a way
that it distinguishes
+ between no name argument was provided and name was set to None.
+ In the first case, the operator name is generated from the
python_callable name,
+ in the second case default KubernetesPodOperator behavior is preserved.
+ """
+ extra_kwargs = {"name": name}
+ if name == "no_name_in_args":
+ extra_kwargs.pop("name")
+
+ decorator_kwargs = {
+ "random_name_suffix": random_name_suffix,
+ "namespace": "default",
+ **extra_kwargs,
+ }
+
+ with self.dag:
+ task_function_name = _prepare_task(
+ task_decorator,
+ decorator_name,
+ **decorator_kwargs,
+ )
+
+ k8s_task = task_function_name()
+
+ task_id = TASK_FUNCTION_NAME_ID
+ op = self.dag.get_task(task_id)
+ if name is not None:
+ assert isinstance(op.name, str)
+
+ # If name was explicitly set to None, we expect the operator name to
be None
+ if name is None:
+ assert op.name is None
+ # If name was not provided in decorator, it would be generated:
+ # f"k8s-airflow-pod-{python_callable.__name__}"
+ elif name == "no_name_in_args":
+ assert op.name == f"k8s-airflow-pod-{task_id}"
+ # Otherwise, we expect the name to be exactly the same as provided
+ else:
+ assert op.name == name
+
+ self.execute_task(k8s_task)
+ pod_meta = self.mock_create_pod.call_args.kwargs["pod"].metadata
+ assert isinstance(pod_meta.name, str)
+
+ # After execution pod names should not contain underscores
+ task_id_normalized = task_id.replace("_", "-")
+
+ def check_op_name(name_arg: str | None) -> str:
+ if name_arg is None:
+ assert op.name is None
+ return task_id_normalized
+
+ assert isinstance(op.name, str)
+ if name_arg == "no_name_in_args":
+ generated_name = f"k8s-airflow-pod-{task_id_normalized}"
+ assert op.name == generated_name
+ return generated_name
+
+ normalized_name = name_arg.replace("_", "-")
+ assert op.name == normalized_name
+
+ return normalized_name
+
+ def check_pod_name(name_base: str):
+ if random_name_suffix:
+ assert pod_meta.name.startswith(f"{name_base}")
+ assert pod_meta.name != name_base
+ else:
+ assert pod_meta.name == name_base
+
+ pod_name = check_op_name(name)
+ check_pod_name(pod_name)
diff --git a/task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi
b/task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi
index f3ebe6cecf4..30e921f2f48 100644
--- a/task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi
+++ b/task-sdk/src/airflow/sdk/definitions/decorators/__init__.pyi
@@ -495,6 +495,7 @@ class TaskDecoratorCollection:
or a list of names of labels to set with empty values (e.g.
``["label1", "label2"]``)
"""
# [END decorator_signature]
+ @overload
def kubernetes(
self,
*,
@@ -667,6 +668,178 @@ class TaskDecoratorCollection:
:param progress_callback: Callback function for receiving k8s
container logs.
"""
@overload
+ def kubernetes(self, python_callable: Callable[FParams, FReturn]) ->
Task[FParams, FReturn]: ...
+ @overload
+ def kubernetes_cmd(
+ self,
+ *,
+ args_only: bool = False, # Added by _KubernetesCmdDecoratedOperator.
+ # 'cmds' filled by _KubernetesCmdDecoratedOperator.
+ # 'arguments' filled by _KubernetesCmdDecoratedOperator.
+ kubernetes_conn_id: str | None = ...,
+ namespace: str | None = None,
+ image: str | None = None,
+ name: str | None = None,
+ random_name_suffix: bool = ...,
+ ports: list[k8s.V1ContainerPort] | None = None,
+ volume_mounts: list[k8s.V1VolumeMount] | None = None,
+ volumes: list[k8s.V1Volume] | None = None,
+ env_vars: list[k8s.V1EnvVar] | dict[str, str] | None = None,
+ env_from: list[k8s.V1EnvFromSource] | None = None,
+ secrets: list[Secret] | None = None,
+ in_cluster: bool | None = None,
+ cluster_context: str | None = None,
+ labels: dict | None = None,
+ reattach_on_restart: bool = ...,
+ startup_timeout_seconds: int = ...,
+ startup_check_interval_seconds: int = ...,
+ get_logs: bool = True,
+ container_logs: Iterable[str] | str | Literal[True] = ...,
+ image_pull_policy: str | None = None,
+ annotations: dict | None = None,
+ container_resources: k8s.V1ResourceRequirements | None = None,
+ affinity: k8s.V1Affinity | None = None,
+ config_file: str | None = None,
+ node_selector: dict | None = None,
+ image_pull_secrets: list[k8s.V1LocalObjectReference] | None = None,
+ service_account_name: str | None = None,
+ hostnetwork: bool = False,
+ host_aliases: list[k8s.V1HostAlias] | None = None,
+ tolerations: list[k8s.V1Toleration] | None = None,
+ security_context: k8s.V1PodSecurityContext | dict | None = None,
+ container_security_context: k8s.V1SecurityContext | dict | None = None,
+ dnspolicy: str | None = None,
+ dns_config: k8s.V1PodDNSConfig | None = None,
+ hostname: str | None = None,
+ subdomain: str | None = None,
+ schedulername: str | None = None,
+ full_pod_spec: k8s.V1Pod | None = None,
+ init_containers: list[k8s.V1Container] | None = None,
+ log_events_on_failure: bool = False,
+ do_xcom_push: bool = False,
+ pod_template_file: str | None = None,
+ pod_template_dict: dict | None = None,
+ priority_class_name: str | None = None,
+ pod_runtime_info_envs: list[k8s.V1EnvVar] | None = None,
+ termination_grace_period: int | None = None,
+ configmaps: list[str] | None = None,
+ skip_on_exit_code: int | Container[int] | None = None,
+ base_container_name: str | None = None,
+ base_container_status_polling_interval: float = ...,
+ deferrable: bool = ...,
+ poll_interval: float = ...,
+ log_pod_spec_on_failure: bool = ...,
+ on_finish_action: str = ...,
+ termination_message_policy: str = ...,
+ active_deadline_seconds: int | None = None,
+ progress_callback: Callable[[str], None] | None = None,
+ **kwargs,
+ ) -> TaskDecorator:
+ """Create a decorator to run a command returned by callable in a
Kubernetes pod.
+
+ :param args_only: If True, the decorated function should return a list
arguments
+ to be passed to the entrypoint of the container image. Defaults to
False.
+ :param kubernetes_conn_id: The Kubernetes cluster's
+ :ref:`connection ID <howto/connection:kubernetes>`.
+ :param namespace: Namespace to run within Kubernetes. Defaults to
*default*.
+ :param image: Docker image to launch. Defaults to *hub.docker.com*, but
+ a fully qualified URL will point to a custom repository.
(templated)
+ :param name: Name of the pod to run. This will be used (plus a random
+ suffix if *random_name_suffix* is *True*) to generate a pod ID
+ (DNS-1123 subdomain, containing only ``[a-z0-9.-]``). Defaults to
+ ``k8s-airflow-pod-{python_callable.__name__}``.
+ :param random_name_suffix: If *True*, will generate a random suffix.
+ :param ports: Ports for the launched pod.
+ :param volume_mounts: *volumeMounts* for the launched pod.
+ :param volumes: Volumes for the launched pod. Includes *ConfigMaps* and
+ *PersistentVolumes*.
+ :param env_vars: Environment variables initialized in the container.
+ (templated)
+ :param env_from: List of sources to populate environment variables in
+ the container.
+ :param secrets: Kubernetes secrets to inject in the container. They can
+ be exposed as environment variables or files in a volume.
+ :param in_cluster: Run kubernetes client with *in_cluster*
configuration.
+ :param cluster_context: Context that points to the Kubernetes cluster.
+ Ignored when *in_cluster* is *True*. If *None*, current-context
will
+ be used.
+ :param reattach_on_restart: If the worker dies while the pod is
running,
+ reattach and monitor during the next try. If *False*, always create
+ a new pod for each try.
+ :param labels: Labels to apply to the pod. (templated)
+ :param startup_timeout_seconds: Timeout in seconds to startup the pod.
+ :param startup_check_interval_seconds: interval in seconds to check if
the pod has already started
+ :param get_logs: Get the stdout of the container as logs of the tasks.
+ :param container_logs: list of containers whose logs will be published
to stdout
+ Takes a sequence of containers, a single container name or True.
+ If True, all the containers logs are published. Works in
conjunction with ``get_logs`` param.
+ The default value is the base container.
+ :param image_pull_policy: Specify a policy to cache or always pull an
+ image.
+ :param annotations: Non-identifying metadata you can attach to the pod.
+ Can be a large range of data, and can include characters that are
+ not permitted by labels.
+ :param container_resources: Resources for the launched pod.
+ :param affinity: Affinity scheduling rules for the launched pod.
+ :param config_file: The path to the Kubernetes config file. If not
+ specified, default value is ``~/.kube/config``. (templated)
+ :param node_selector: A dict containing a group of scheduling rules.
+ :param image_pull_secrets: Any image pull secrets to be given to the
+ pod. If more than one secret is required, provide a comma separated
+ list, e.g. ``secret_a,secret_b``.
+ :param service_account_name: Name of the service account.
+ :param hostnetwork: If *True*, enable host networking on the pod.
+ :param host_aliases: A list of host aliases to apply to the containers
in the pod.
+ :param tolerations: A list of Kubernetes tolerations.
+ :param security_context: Security options the pod should run with
+ (PodSecurityContext).
+ :param container_security_context: security options the container
should run with.
+ :param dnspolicy: DNS policy for the pod.
+ :param dns_config: dns configuration (ip addresses, searches, options)
for the pod.
+ :param hostname: hostname for the pod.
+ :param subdomain: subdomain for the pod.
+ :param schedulername: Specify a scheduler name for the pod
+ :param full_pod_spec: The complete podSpec
+ :param init_containers: Init containers for the launched pod.
+ :param log_events_on_failure: Log the pod's events if a failure occurs.
+ :param do_xcom_push: If *True*, the content of
+ ``/airflow/xcom/return.json`` in the container will also be pushed
+ to an XCom when the container completes.
+ :param pod_template_file: Path to pod template file (templated)
+ :param pod_template_dict: pod template dictionary (templated)
+ :param priority_class_name: Priority class name for the launched pod.
+ :param pod_runtime_info_envs: A list of environment variables
+ to be set in the container.
+ :param termination_grace_period: Termination grace period if task
killed
+ in UI, defaults to kubernetes default
+ :param configmaps: A list of names of config maps from which it
collects
+ ConfigMaps to populate the environment variables with. The contents
+ of the target ConfigMap's Data field will represent the key-value
+ pairs as environment variables. Extends env_from.
+ :param skip_on_exit_code: If task exits with this exit code, leave the
task
+ in ``skipped`` state (default: None). If set to ``None``, any
non-zero
+ exit code will be treated as a failure.
+ :param base_container_name: The name of the base container in the pod.
This container's logs
+ will appear as part of this task's logs if get_logs is True.
Defaults to None. If None,
+ will consult the class variable BASE_CONTAINER_NAME (which
defaults to "base") for the base
+ container name to use.
+ :param base_container_status_polling_interval: Polling period in
seconds to check for the pod base
+ container status.
+ :param deferrable: Run operator in the deferrable mode.
+ :param poll_interval: Polling period in seconds to check for the
status. Used only in deferrable mode.
+ :param log_pod_spec_on_failure: Log the pod's specification if a
failure occurs
+ :param on_finish_action: What to do when the pod reaches its final
state, or the execution is interrupted.
+ If "delete_pod", the pod will be deleted regardless its state; if
"delete_succeeded_pod",
+ only succeeded pod will be deleted. You can set to "keep_pod" to
keep the pod.
+ :param termination_message_policy: The termination message policy of
the base container.
+ Default value is "File"
+ :param active_deadline_seconds: The active_deadline_seconds which
matches to active_deadline_seconds
+ in V1PodSpec.
+ :param progress_callback: Callback function for receiving k8s
container logs.
+ """
+ @overload
+ def kubernetes_cmd(self, python_callable: Callable[FParams, FReturn]) ->
Task[FParams, FReturn]: ...
+ @overload
def sensor( # type: ignore[misc]
self,
*,