This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch generate-task-jwt-tokens in repository https://gitbox.apache.org/repos/asf/airflow.git
commit fee57a190d08faa8a18a4db3bcf6962f0b140743 Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Thu Feb 20 15:33:12 2025 +0000 Add JWT validation and generation machinery for the Task Execution API to use --- airflow/config_templates/config.yml | 99 +++++++++- airflow/configuration.py | 2 + airflow/security/tokens.py | 375 ++++++++++++++++++++++++++++++++++++ tests/security/test_tokens.py | 185 ++++++++++++++++++ 4 files changed, 653 insertions(+), 8 deletions(-) diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 1c78e05a615..f38c530f335 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -1385,20 +1385,23 @@ api: version_added: 2.7.0 example: ~ default: "False" - auth_jwt_secret: - description: | - Secret key used to encode and decode JWT tokens to authenticate to public and private APIs. - It should be as random as possible. However, when running more than 1 instances of API services, - make sure all of them use the same ``jwt_secret`` otherwise calls will fail on authentication. + auth_jwt_audience: version_added: 3.0.0 + description: | + The audience claim to use when generating and validating JWTs for the API. + + This variable can be a single value, or a comma-separated string, in which case the first value is the + one that will be used when generating, and the others are accepted at validation time. + + Not required, but stronglt encouraged + example: "urn:airflow.apache.org:task" + default: ~ type: string - sensitive: true - example: ~ - default: "{JWT_SECRET_KEY}" auth_jwt_expiration_time: description: | Number in seconds until the JWT token used for authentication expires. When the token expires, all API calls using this token will fail on authentication. + Make sure that time on ALL the machines that you run airflow components on is synchronized (for example using ntpd) otherwise you might get "forbidden" errors. version_added: 3.0.0 @@ -1409,12 +1412,77 @@ api: description: | Number in seconds until the JWT token used for authentication expires for CLI commands. When the token expires, all CLI calls using this token will fail on authentication. + Make sure that time on ALL the machines that you run airflow components on is synchronized (for example using ntpd) otherwise you might get "forbidden" errors. version_added: 3.0.0 type: integer example: ~ default: "3600" +api_auth: + description: Settings relating to authentication on the Airflow APIs + options: + jwt_secret: + description: | + Secret key used to encode and decode JWT tokens to authenticate to public and private APIs. + + It should be as random as possible. However, when running more than 1 instances of API services, + make sure all of them use the same ``jwt_secret`` otherwise calls will fail on authentication. + + Mutually exclusive with ``jwt_private_key_path``. + version_added: 3.0.0 + type: string + sensitive: true + example: ~ + default: "{JWT_SECRET_KEY}" + jwt_private_key_path: + version_added: 3.0.0 + description: | + The path to a file containing a PEM-encoded private key use when generating Task Identity tokens in + the executor. + + Mutually exclusive with ``jwt_secret``. + default: ~ + example: /path/to/private_key.pem + type: string + jwt_algorithm: + version_added: 3.0.0 + description: | + The algorithm name use when generating and validating JWT Task Identities. + + This value must be appropriate for the given private key type. + + Default is "HS512" if ``jwt_secret`` is set, or "EdDSA" otherwise + example: '"EdDSA" or "HS512"' + type: string + default: ~ + trusted_jwks_url: + version_added: 3.0.0 + description: | + The public signing keys of Task Execution token issuers to trust. It must contain the public key + related to ``jwt_private_key_path`` else tasks will be unlikely to execute successfully. + + Can be a local file path (without the ``file://`` url) or an http or https URL. + + If a remote URL is given it will be polled periodically for changes. + default: ~ + example: '"/path/to/public-jwks.json" or "https://my-issuer/.well-known/jwks.json"' + type: string + jwt_issuer: + version_added: 3.0.0 + description: | + Issuer of the JWT. This becomes the ``iss`` claim of generated tokens, and is validated on incoming + requests. + + Ideally this should be unique per individual airflow deployment + + Not required, but strongly recommended to be set. + + See also :ref:`config:task_execution_api__auth_jwt_audience` and :ref:`config:api__auth_jwt_audience` + default: ~ + example: "http://my-airflow.mycompany.com" + type: string + lineage: description: ~ options: @@ -2691,3 +2759,18 @@ fastapi: type: string example: ~ default: "http://localhost:29091" +task_execution_api: + description: Settings for the Task Execution API + options: + auth_jwt_audience: + version_added: 3.0.0 + description: | + The audience claim to use when generating and validating JWT for the Execution API. + + This variable can be a single value, or a comma-separated string, in which case the first value is the + one that will be used when generating, and the others are accepted at validation time + + Not required. + example: "urn:airflow.apache.org:task" + default: ~ + type: string diff --git a/airflow/configuration.py b/airflow/configuration.py index a9a255dde47..d1af14642a1 100644 --- a/airflow/configuration.py +++ b/airflow/configuration.py @@ -1139,6 +1139,8 @@ class AirflowConfigParser(ConfigParser): def getlist(self, section: str, key: str, delimiter=",", **kwargs): val = self.get(section, key, **kwargs) if val is None: + if "fallback" in kwargs: + return val raise AirflowConfigException( f"Failed to convert value None to list. " f'Please check "{key}" key in "{section}" section is set.' diff --git a/airflow/security/tokens.py b/airflow/security/tokens.py new file mode 100644 index 00000000000..adc5a000c0b --- /dev/null +++ b/airflow/security/tokens.py @@ -0,0 +1,375 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import json +import os +import time +from collections.abc import Sequence +from datetime import datetime, timedelta +from typing import TYPE_CHECKING, Any + +import attrs +import httpx +import jwt +import structlog +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes + +from airflow.utils import timezone + +if TYPE_CHECKING: + from jwt.algorithms import AllowedKeys, AllowedPrivateKeys + +log = structlog.get_logger(logger_name=__name__) + +__all__ = [ + "InvalidClaimError", + "JWKS", + "JWTGenerator", + "JWTValidator", + "generate_private_key", + "key_to_pem", + "key_to_jwk_dict", +] + + +class InvalidClaimError(ValueError): + """Raised when a claim in the JWT is invalid.""" + + def __init__(self, claim: str): + super().__init__(f"Invalid claim: {claim}") + + +def key_to_jwk_dict(key: AllowedKeys, kid: str | None = None): + """Convert a public or private key into a valid JWKS dict.""" + from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey + from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey + from jwt.algorithms import OKPAlgorithm, RSAAlgorithm + + if isinstance(key, (RSAPrivateKey, Ed25519PrivateKey)): + key = key.public_key() + + if isinstance(key, RSAPublicKey): + jwk_dict = RSAAlgorithm(RSAAlgorithm.SHA256).to_jwk(key, as_dict=True) + + elif isinstance(key, Ed25519PublicKey): + jwk_dict = OKPAlgorithm().to_jwk(key, as_dict=True) + else: + raise ValueError(f"Unknown key object {type(key)}") + + if not kid: + kid = thumbprint(jwk_dict) + + jwk_dict["kid"] = kid + + return jwk_dict + + [email protected](repr=False) +class JWKS: + """A class to fetch and sync a set of JSON Web Keys.""" + + url: str + fetched_at: float = 0 + last_fetch_attempt_at: float = 0 + + client: httpx.AsyncClient = attrs.field(factory=httpx.AsyncClient) + + _jwks: jwt.PyJWKSet | None = None + refresh_jwks: bool = True + refresh_interval_secs: int = 3600 + refresh_retry_interval_secs: int = 10 + + def __repr__(self) -> str: + return f"JWKS(url={self.url}, fetched_at={self.fetched_at})" + + @classmethod + def from_private_key(cls, **keys: AllowedPrivateKeys): + obj = cls(url=os.devnull) + + obj._jwks = jwt.PyJWKSet([key_to_jwk_dict(key, kid) for kid, key in keys.items()]) + return obj + + async def fetch_jwks(self) -> None: + if not self._should_fetch_jwks(): + return + if self.url.startswith("http"): + data = await self._fetch_remote_jwks() + else: + data = self._fetch_local_jwks() + + if not data: + return + + self._jwks = jwt.PyJWKSet.from_dict(data) + log.debug("Fetched JWKS", url=self.url, keys=len(self._jwks.keys)) + + async def _fetch_remote_jwks(self) -> dict[str, Any] | None: + try: + log.debug( + "Fetching JWKS", + url=self.url, + last_fetched_secs_ago=int(time.monotonic() - self.fetched_at) if self.fetched_at else None, + ) + if TYPE_CHECKING: + assert self.url + self.last_fetch_attempt_at = int(time.monotonic()) + response = await self.client.get(self.url) + response.raise_for_status() + self.fetched_at = int(time.monotonic()) + await response.aread() + await response.aclose() + return response.json() + except Exception: + log.exception("Failed to fetch remote JWKS", url=self.url) + return None + + def _fetch_local_jwks(self) -> dict[str, Any] | None: + try: + with open(self.url) as jwks_file: + content = json.load(jwks_file) + self.fetched_at = int(time.monotonic()) + return content + except Exception: + log.exception("Failed to read local JWKS", url=self.url) + return None + + def _should_fetch_jwks(self) -> bool: + """ + Check if we need to fetch the JWKS based on the last fetch time and the refresh interval. + + If the JWKS URL is local, we only fetch it once. For remote JWKS URLs we fetch it based + on the refresh interval if refreshing has been enabled with a minimum interval between + attempts. The fetcher functions set the fetched_at timestamp to the current monotonic time + when the JWKS is fetched. + """ + if not self.url.startswith("http"): + # Fetch local JWKS only if not already loaded + # This could be improved in future by looking at mtime of file. + return not self._jwks + # For remote fetches we check if the JWKS is not loaded (fetched_at = 0) or if the last fetch was more than + # refresh_interval_secs ago and the last fetch attempt was more than refresh_retry_interval_secs ago + now = time.monotonic() + return self.refresh_jwks and ( + not self._jwks + or ( + self.fetched_at == 0 + or ( + now - self.fetched_at > self.refresh_interval_secs + and now - self.last_fetch_attempt_at > self.refresh_retry_interval_secs + ) + ) + ) + + async def get_key(self, kid: str) -> jwt.PyJWK: + """Fetch the JWKS and find the matching key for the token.""" + await self.fetch_jwks() + + if self._jwks: + return self._jwks[kid] + + # It didn't load! + raise KeyError(f"Key ID {kid} not found in keyset") + + +def _conf_factory(section, key, **kwargs): + def factory() -> str: + from airflow.configuration import conf + + return conf.get(section, key, **kwargs, suppress_warnings=True) # type: ignore[return-value] + + return factory + + +def _conf_list_factory(section, key, **kwargs): + from airflow.configuration import conf + + return conf.getlist(section, key, **kwargs, suppress_warnings=True) + + [email protected](repr=False, kw_only=True) +class JWTValidator: + """Validate the claims of a JWT.""" + + jwks: JWKS + issuer: str | list[str] | None = attrs.field( + factory=_conf_list_factory("api_auth", "jwt_issuer", fallback=None) + ) + required_claims: frozenset[str] = attrs.field(factory=frozenset) + audience: str | Sequence[str] + algorithm: str = attrs.field(factory=_conf_factory("api_auth", "jwt_algorithm")) + + def _get_kid_from_header(self, unvalidated: str) -> str: + header = jwt.get_unverified_header(unvalidated) + if "kid" not in header: + raise ValueError("Missing 'kid' in token header") + return header["kid"] + + async def validated_claims( + self, unvalidated: str, extra_claims: dict[str, Any] | None = None + ) -> dict[str, Any]: + """Decode the JWT token, returning the validated claims or raising an exception.""" + try: + kid = self._get_kid_from_header(unvalidated) + key = await self.jwks.get_key(kid) + + claims = jwt.decode( + unvalidated, + key, + audience=self.audience, + issuer=self.issuer, + options={"require": self.required_claims}, + algorithms=[self.algorithm], + ) + + # Validate additional claims if provided + if extra_claims: + for claim, expected_value in extra_claims.items(): + if expected_value["essential"] and ( + claim not in claims or claims[claim] != expected_value["value"] + ): + raise InvalidClaimError(claim) + + return claims + except jwt.PyJWTError as e: + raise ValueError(f"Invalid JWT: {e}") from e + + [email protected](repr=False, kw_only=True) +class JWTGenerator: + """Generate JWT tokens.""" + + _private_key: AllowedPrivateKeys | str | bytes | None = attrs.field(repr=False, alias="private_key") + _secret_key: str | None = attrs.field( + repr=False, alias="secret_key", factory=_conf_factory("api_auth", "jwt_secret", fallback=None) + ) + """A pre-shared secret key to sign tokens with symmetric encryption""" + + kid: str + validity_period: timedelta + issuer: str + audience: str + algorithm: str + + @_private_key.default + def _load_key_from_file(self): + from airflow.configuration import conf + + path = conf.get("api_auth", "jwt_private_key_path", fallback=None) + if not path: + return None + + with open(path, mode="b") as fh: + return fh.read() + + def __attrs_post_init__(self): + if (self._private_key is None) ^ (self._secret_key is None): + raise ValueError("Exactly one of priavte_key and secret_key must be specified") + + @property + def signing_arg(self): + if callable(self._private_key): + return self._private_key() + if self._private_key: + return self._private_key + return self, self._secret_key + + def generate( + self, subject: str, extras: dict[str, Any] | None = None, headers: dict[str, Any] | None = None + ) -> str: + """Generate a signed JWT for the subject.""" + now = datetime.now(tz=timezone.utc) + claims = { + "iss": self.issuer, + "aud": self.audience, + "sub": subject, + "nbf": int(now.timestamp()), + "exp": int((now + self.validity_period).timestamp()), + "iat": int(now.timestamp()), + } + if extras is not None: + claims.update(extras) + headers = {"alg": self.algorithm, "kid": self.kid, **(headers or {})} + return jwt.encode(claims, self.signing_arg, algorithm=self.algorithm, headers=headers) + + +# @attrs.define(repr=False) +# class TaskJWTGenerator(JWTGenerator): +# issuer: str = attrs.field(factory=_default_issuer) +# audience: str = attrs.field( +# factory=_conf_factory("task_execution_api", "jwt_audience", fallback="urn:airflow.apache.org:task") +# ) +# algorithm: str = attrs.field( +# factory=_conf_factory("task_execution_api", "jwt_algorithm", default="EdDSA") +# ) + + +def generate_private_key(key_type: str = "RSA", key_size: int = 2048): + """ + Generate a valid private key for testing. + + Args: + key_type (str): Type of key to generate. Can be "RSA" or "Ed25516". Defaults to "RSA". + key_size (int): Size of the key in bits. Only applicable for RSA keys. Defaults to 2048. + + Returns: + tuple: A tuple containing the private key in PEM format and the corresponding public key in PEM format. + """ + from cryptography.hazmat.primitives.asymmetric import ed25519, rsa + + if key_type == "RSA": + # Generate an RSA private key + + return rsa.generate_private_key(public_exponent=65537, key_size=key_size, backend=default_backend()) + elif key_type == "Ed25519": + return ed25519.Ed25519PrivateKey.generate() + raise ValueError(f"unsupported key type: {key_type}") + + +def key_to_pem(key: AllowedPrivateKeys) -> str: + from cryptography.hazmat.primitives import serialization + + # Serialize the private key in PEM format + return key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode("utf-8") + + +def thumbprint(jwk: dict[str, Any], hashalg=hashes.SHA256()) -> str: + """ + Return the key thumbprint as specified by RFC 7638. + + :param hashalg: A hash function (defaults to SHA256) + + :return: A base64url encoded digest of the key + """ + digest = hashes.Hash(hashalg, backend=default_backend()) + jsonstr = json.dumps(jwk, separators=(",", ":"), sort_keys=True) + digest.update(jsonstr.encode("utf8")) + return base64url_encode(digest.finalize()) + + +def base64url_encode(payload): + from base64 import urlsafe_b64encode + + if not isinstance(payload, bytes): + payload = payload.encode("utf-8") + encode = urlsafe_b64encode(payload) + return encode.decode("utf-8").rstrip("=") diff --git a/tests/security/test_tokens.py b/tests/security/test_tokens.py new file mode 100644 index 00000000000..4a31fddd72b --- /dev/null +++ b/tests/security/test_tokens.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import json +from datetime import datetime, timedelta +from typing import TYPE_CHECKING + +import httpx +import jwt +import pytest + +from airflow.security.tokens import ( + JWKS, + InvalidClaimError, + JWTGenerator, + JWTValidator, + generate_private_key, + key_to_jwk_dict, +) +from airflow.utils import timezone + +if TYPE_CHECKING: + from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey + from kgb import SpyAgency + from time_machine import TimeMachineFixture + + +pytestmark = [pytest.mark.asyncio] + + [email protected] +def private_key(request): + return request.getfixturevalue(request.param or "ed25519_private_key") + + [email protected] +def mock_kid(): + return "test-kid" + + [email protected] +def mock_subject(): + return "test-subject" + + +class TestJWKS: + @pytest.mark.parametrize("private_key", ["rsa_private_key", "ed25519_private_key"], indirect=True) + async def test_fetch_jwks_success(self, private_key): + jwk_content = json.dumps({"keys": [key_to_jwk_dict(private_key, "kid")]}) + + async def mock_transport(request): + return httpx.Response(status_code=200, content=jwk_content) + + client = httpx.AsyncClient(transport=httpx.MockTransport(mock_transport)) + jwks = JWKS(url="https://example.com/jwks.json", client=client) + + # Test fetching JWKS + await jwks.fetch_jwks() + assert isinstance(await jwks.get_key("kid"), jwt.PyJWK) + + async def test_refresh_remote_jwks( + self, time_machine: TimeMachineFixture, ed25519_private_key, spy_agency: SpyAgency + ): + time_machine.move_to(datetime(2023, 10, 1, 12, 0, 0)) # Initial time: 12:00 PM + jwk_content = json.dumps({"keys": [key_to_jwk_dict(ed25519_private_key, "kid")]}) + + async def mock_transport(request): + return httpx.Response(status_code=200, content=jwk_content) + + client = httpx.AsyncClient(transport=httpx.MockTransport(mock_transport)) + jwks = JWKS(url="https://example.com/jwks.json", client=client) + spy = spy_agency.spy_on(JWKS._fetch_remote_jwks) + + key = await jwks.get_key("kid") + assert isinstance(key, jwt.PyJWK) + + # Move forward in time, but not to a point where it updates. Should not end up re-requesting. + spy.reset_calls() + time_machine.shift(1800) + assert await jwks.get_key("kid") is key + spy_agency.assert_spy_not_called(spy) + + # Not to a point where it should refresh + time_machine.shift(1801) + + key2 = key_to_jwk_dict(generate_private_key("Ed25519"), "kid2") + jwk_content = json.dumps({"keys": [key2]}) + with pytest.raises(KeyError): + # Not in the document anymore, should have gone from the keyset + await jwks.get_key("kid") + assert isinstance(await jwks.get_key("kid2"), jwt.PyJWK) + spy_agency.assert_spy_called(spy) + + [email protected] +def jwt_generator(ed25519_private_key: Ed25519PrivateKey): + key = ed25519_private_key + return JWTGenerator( + private_key=key, + kid="kid1", + validity_period=timedelta(minutes=1), + issuer="http://test-issuer", + algorithm="EdDSA", + audience="abc", + ) + + [email protected] +def jwt_validator(ed25519_private_key: Ed25519PrivateKey): + key = ed25519_private_key + jwks = JWKS.from_private_key(kid1=key) + return JWTValidator(jwks=jwks, issuer="http://test-issuer", algorithm="EdDSA", audience="abc") + + +async def test_task_jwt_generator_validator(jwt_generator, jwt_validator, ed25519_private_key): + token = jwt_generator.generate(subject="test_subject") + # if this does not raise ValueError then the generated token can be decoded and contains all the required + # fields. + claims = await jwt_validator.validated_claims( + token, extra_claims={"sub": {"essential": True, "value": "test_subject"}} + ) + nbf = datetime.fromtimestamp(claims["nbf"], timezone.utc) + iat = datetime.fromtimestamp(claims["iat"], timezone.utc) + exp = datetime.fromtimestamp(claims["exp"], timezone.utc) + now = datetime.now(timezone.utc) + assert nbf == iat, "issued at is different then not before" + assert nbf < exp, "not before is after expiration" + assert nbf <= now, "not before is in the future" + assert exp >= now, "expiration is in the past" + assert exp <= nbf + timedelta(minutes=10), "expiration is more then 10 minutes after not before" + + def token_without_claim(claim: str) -> str: + # remove claim and re-encode + bad_claims = claims.copy() + bad_claims.pop(required_claim) + return jwt.encode( + bad_claims, + ed25519_private_key, + headers={"kid": jwt_generator.kid}, + algorithm=jwt_generator.algorithm, + ) + + for required_claim in jwt_validator.required_claims: + bad_token = token_without_claim(required_claim) + # check that the missing claim is detected in validation + with pytest.raises(ValueError, match="Invalid JWT") as exc_info: + await jwt_validator.validated_claims(bad_token, extra_claims={"sub": "test_subject"}) + cause = exc_info.value.__cause__ + assert isinstance(cause, jwt.MissingRequiredClaimError) + assert cause.claim == required_claim + + +async def test_jwt_wrong_subject(jwt_generator, jwt_validator): + # check that the token is invalid if the subject is not as expected + wrong_subject = jwt_generator.generate(subject="wrong_subject") + with pytest.raises(InvalidClaimError, match="Invalid claim: sub"): + await jwt_validator.validated_claims( + wrong_subject, extra_claims={"sub": {"essential": True, "value": "test_subject"}} + ) + + [email protected](scope="session") +def rsa_private_key(): + return generate_private_key() + + [email protected](scope="session") +def ed25519_private_key(): + return generate_private_key(key_type="Ed25519")
