This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch generate-task-jwt-tokens
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit fee57a190d08faa8a18a4db3bcf6962f0b140743
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Thu Feb 20 15:33:12 2025 +0000

    Add JWT validation and generation machinery for the Task Execution API to 
use
---
 airflow/config_templates/config.yml |  99 +++++++++-
 airflow/configuration.py            |   2 +
 airflow/security/tokens.py          | 375 ++++++++++++++++++++++++++++++++++++
 tests/security/test_tokens.py       | 185 ++++++++++++++++++
 4 files changed, 653 insertions(+), 8 deletions(-)

diff --git a/airflow/config_templates/config.yml 
b/airflow/config_templates/config.yml
index 1c78e05a615..f38c530f335 100644
--- a/airflow/config_templates/config.yml
+++ b/airflow/config_templates/config.yml
@@ -1385,20 +1385,23 @@ api:
       version_added: 2.7.0
       example: ~
       default: "False"
-    auth_jwt_secret:
-      description: |
-        Secret key used to encode and decode JWT tokens to authenticate to 
public and private APIs.
-        It should be as random as possible. However, when running more than 1 
instances of API services,
-        make sure all of them use the same ``jwt_secret`` otherwise calls will 
fail on authentication.
+    auth_jwt_audience:
       version_added: 3.0.0
+      description: |
+        The audience claim to use when generating and validating JWTs for the 
API.
+
+        This variable can be a single value, or a comma-separated string, in 
which case the first value is the
+        one that will be used when generating, and the others are accepted at 
validation time.
+
+        Not required, but stronglt encouraged
+      example: "urn:airflow.apache.org:task"
+      default: ~
       type: string
-      sensitive: true
-      example: ~
-      default: "{JWT_SECRET_KEY}"
     auth_jwt_expiration_time:
       description: |
         Number in seconds until the JWT token used for authentication expires. 
When the token expires,
         all API calls using this token will fail on authentication.
+
         Make sure that time on ALL the machines that you run airflow 
components on is synchronized
         (for example using ntpd) otherwise you might get "forbidden" errors.
       version_added: 3.0.0
@@ -1409,12 +1412,77 @@ api:
       description: |
         Number in seconds until the JWT token used for authentication expires 
for CLI commands.
         When the token expires, all CLI calls using this token will fail on 
authentication.
+
         Make sure that time on ALL the machines that you run airflow 
components on is synchronized
         (for example using ntpd) otherwise you might get "forbidden" errors.
       version_added: 3.0.0
       type: integer
       example: ~
       default: "3600"
+api_auth:
+  description: Settings relating to authentication on the Airflow APIs
+  options:
+    jwt_secret:
+      description: |
+        Secret key used to encode and decode JWT tokens to authenticate to 
public and private APIs.
+
+        It should be as random as possible. However, when running more than 1 
instances of API services,
+        make sure all of them use the same ``jwt_secret`` otherwise calls will 
fail on authentication.
+
+        Mutually exclusive with ``jwt_private_key_path``.
+      version_added: 3.0.0
+      type: string
+      sensitive: true
+      example: ~
+      default: "{JWT_SECRET_KEY}"
+    jwt_private_key_path:
+      version_added: 3.0.0
+      description: |
+        The path to a file containing a PEM-encoded private key use when 
generating Task Identity tokens in
+        the executor.
+
+        Mutually exclusive with ``jwt_secret``.
+      default: ~
+      example: /path/to/private_key.pem
+      type: string
+    jwt_algorithm:
+      version_added: 3.0.0
+      description: |
+        The algorithm name use when generating and validating JWT Task 
Identities.
+
+        This value must be appropriate for the given private key type.
+
+        Default is "HS512" if ``jwt_secret`` is set, or "EdDSA" otherwise
+      example: '"EdDSA" or "HS512"'
+      type: string
+      default: ~
+    trusted_jwks_url:
+      version_added: 3.0.0
+      description: |
+        The public signing keys of Task Execution token issuers to trust. It 
must contain the public key
+        related to ``jwt_private_key_path`` else tasks will be unlikely to 
execute successfully.
+
+        Can be a local file path (without the ``file://`` url) or an http or 
https URL.
+
+        If a remote URL is given it will be polled periodically for changes.
+      default: ~
+      example: '"/path/to/public-jwks.json" or 
"https://my-issuer/.well-known/jwks.json";'
+      type: string
+    jwt_issuer:
+      version_added: 3.0.0
+      description: |
+        Issuer of the JWT. This becomes the ``iss`` claim of generated tokens, 
and is validated on incoming
+        requests.
+
+        Ideally this should be unique per individual airflow deployment
+
+        Not required, but strongly recommended to be set.
+
+        See also :ref:`config:task_execution_api__auth_jwt_audience` and 
:ref:`config:api__auth_jwt_audience`
+      default: ~
+      example: "http://my-airflow.mycompany.com";
+      type: string
+
 lineage:
   description: ~
   options:
@@ -2691,3 +2759,18 @@ fastapi:
       type: string
       example: ~
       default: "http://localhost:29091";
+task_execution_api:
+  description: Settings for the Task Execution API
+  options:
+    auth_jwt_audience:
+      version_added: 3.0.0
+      description: |
+        The audience claim to use when generating and validating JWT for the 
Execution API.
+
+        This variable can be a single value, or a comma-separated string, in 
which case the first value is the
+        one that will be used when generating, and the others are accepted at 
validation time
+
+        Not required.
+      example: "urn:airflow.apache.org:task"
+      default: ~
+      type: string
diff --git a/airflow/configuration.py b/airflow/configuration.py
index a9a255dde47..d1af14642a1 100644
--- a/airflow/configuration.py
+++ b/airflow/configuration.py
@@ -1139,6 +1139,8 @@ class AirflowConfigParser(ConfigParser):
     def getlist(self, section: str, key: str, delimiter=",", **kwargs):
         val = self.get(section, key, **kwargs)
         if val is None:
+            if "fallback" in kwargs:
+                return val
             raise AirflowConfigException(
                 f"Failed to convert value None to list. "
                 f'Please check "{key}" key in "{section}" section is set.'
diff --git a/airflow/security/tokens.py b/airflow/security/tokens.py
new file mode 100644
index 00000000000..adc5a000c0b
--- /dev/null
+++ b/airflow/security/tokens.py
@@ -0,0 +1,375 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import json
+import os
+import time
+from collections.abc import Sequence
+from datetime import datetime, timedelta
+from typing import TYPE_CHECKING, Any
+
+import attrs
+import httpx
+import jwt
+import structlog
+from cryptography.hazmat.backends import default_backend
+from cryptography.hazmat.primitives import hashes
+
+from airflow.utils import timezone
+
+if TYPE_CHECKING:
+    from jwt.algorithms import AllowedKeys, AllowedPrivateKeys
+
+log = structlog.get_logger(logger_name=__name__)
+
+__all__ = [
+    "InvalidClaimError",
+    "JWKS",
+    "JWTGenerator",
+    "JWTValidator",
+    "generate_private_key",
+    "key_to_pem",
+    "key_to_jwk_dict",
+]
+
+
+class InvalidClaimError(ValueError):
+    """Raised when a claim in the JWT is invalid."""
+
+    def __init__(self, claim: str):
+        super().__init__(f"Invalid claim: {claim}")
+
+
+def key_to_jwk_dict(key: AllowedKeys, kid: str | None = None):
+    """Convert a public or private key into a valid JWKS dict."""
+    from cryptography.hazmat.primitives.asymmetric.ed25519 import 
Ed25519PrivateKey, Ed25519PublicKey
+    from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, 
RSAPublicKey
+    from jwt.algorithms import OKPAlgorithm, RSAAlgorithm
+
+    if isinstance(key, (RSAPrivateKey, Ed25519PrivateKey)):
+        key = key.public_key()
+
+    if isinstance(key, RSAPublicKey):
+        jwk_dict = RSAAlgorithm(RSAAlgorithm.SHA256).to_jwk(key, as_dict=True)
+
+    elif isinstance(key, Ed25519PublicKey):
+        jwk_dict = OKPAlgorithm().to_jwk(key, as_dict=True)
+    else:
+        raise ValueError(f"Unknown key object {type(key)}")
+
+    if not kid:
+        kid = thumbprint(jwk_dict)
+
+    jwk_dict["kid"] = kid
+
+    return jwk_dict
+
+
[email protected](repr=False)
+class JWKS:
+    """A class to fetch and sync a set of JSON Web Keys."""
+
+    url: str
+    fetched_at: float = 0
+    last_fetch_attempt_at: float = 0
+
+    client: httpx.AsyncClient = attrs.field(factory=httpx.AsyncClient)
+
+    _jwks: jwt.PyJWKSet | None = None
+    refresh_jwks: bool = True
+    refresh_interval_secs: int = 3600
+    refresh_retry_interval_secs: int = 10
+
+    def __repr__(self) -> str:
+        return f"JWKS(url={self.url}, fetched_at={self.fetched_at})"
+
+    @classmethod
+    def from_private_key(cls, **keys: AllowedPrivateKeys):
+        obj = cls(url=os.devnull)
+
+        obj._jwks = jwt.PyJWKSet([key_to_jwk_dict(key, kid) for kid, key in 
keys.items()])
+        return obj
+
+    async def fetch_jwks(self) -> None:
+        if not self._should_fetch_jwks():
+            return
+        if self.url.startswith("http"):
+            data = await self._fetch_remote_jwks()
+        else:
+            data = self._fetch_local_jwks()
+
+        if not data:
+            return
+
+        self._jwks = jwt.PyJWKSet.from_dict(data)
+        log.debug("Fetched JWKS", url=self.url, keys=len(self._jwks.keys))
+
+    async def _fetch_remote_jwks(self) -> dict[str, Any] | None:
+        try:
+            log.debug(
+                "Fetching JWKS",
+                url=self.url,
+                last_fetched_secs_ago=int(time.monotonic() - self.fetched_at) 
if self.fetched_at else None,
+            )
+            if TYPE_CHECKING:
+                assert self.url
+            self.last_fetch_attempt_at = int(time.monotonic())
+            response = await self.client.get(self.url)
+            response.raise_for_status()
+            self.fetched_at = int(time.monotonic())
+            await response.aread()
+            await response.aclose()
+            return response.json()
+        except Exception:
+            log.exception("Failed to fetch remote JWKS", url=self.url)
+            return None
+
+    def _fetch_local_jwks(self) -> dict[str, Any] | None:
+        try:
+            with open(self.url) as jwks_file:
+                content = json.load(jwks_file)
+            self.fetched_at = int(time.monotonic())
+            return content
+        except Exception:
+            log.exception("Failed to read local JWKS", url=self.url)
+            return None
+
+    def _should_fetch_jwks(self) -> bool:
+        """
+        Check if we need to fetch the JWKS based on the last fetch time and 
the refresh interval.
+
+        If the JWKS URL is local, we only fetch it once. For remote JWKS URLs 
we fetch it based
+        on the refresh interval if refreshing has been enabled with a minimum 
interval between
+        attempts. The fetcher functions set the fetched_at timestamp to the 
current monotonic time
+        when the JWKS is fetched.
+        """
+        if not self.url.startswith("http"):
+            # Fetch local JWKS only if not already loaded
+            # This could be improved in future by looking at mtime of file.
+            return not self._jwks
+        # For remote fetches we check if the JWKS is not loaded (fetched_at = 
0) or if the last fetch was more than
+        # refresh_interval_secs ago and the last fetch attempt was more than 
refresh_retry_interval_secs ago
+        now = time.monotonic()
+        return self.refresh_jwks and (
+            not self._jwks
+            or (
+                self.fetched_at == 0
+                or (
+                    now - self.fetched_at > self.refresh_interval_secs
+                    and now - self.last_fetch_attempt_at > 
self.refresh_retry_interval_secs
+                )
+            )
+        )
+
+    async def get_key(self, kid: str) -> jwt.PyJWK:
+        """Fetch the JWKS and find the matching key for the token."""
+        await self.fetch_jwks()
+
+        if self._jwks:
+            return self._jwks[kid]
+
+        # It didn't load!
+        raise KeyError(f"Key ID {kid} not found in keyset")
+
+
+def _conf_factory(section, key, **kwargs):
+    def factory() -> str:
+        from airflow.configuration import conf
+
+        return conf.get(section, key, **kwargs, suppress_warnings=True)  # 
type: ignore[return-value]
+
+    return factory
+
+
+def _conf_list_factory(section, key, **kwargs):
+    from airflow.configuration import conf
+
+    return conf.getlist(section, key, **kwargs, suppress_warnings=True)
+
+
[email protected](repr=False, kw_only=True)
+class JWTValidator:
+    """Validate the claims of a JWT."""
+
+    jwks: JWKS
+    issuer: str | list[str] | None = attrs.field(
+        factory=_conf_list_factory("api_auth", "jwt_issuer", fallback=None)
+    )
+    required_claims: frozenset[str] = attrs.field(factory=frozenset)
+    audience: str | Sequence[str]
+    algorithm: str = attrs.field(factory=_conf_factory("api_auth", 
"jwt_algorithm"))
+
+    def _get_kid_from_header(self, unvalidated: str) -> str:
+        header = jwt.get_unverified_header(unvalidated)
+        if "kid" not in header:
+            raise ValueError("Missing 'kid' in token header")
+        return header["kid"]
+
+    async def validated_claims(
+        self, unvalidated: str, extra_claims: dict[str, Any] | None = None
+    ) -> dict[str, Any]:
+        """Decode the JWT token, returning the validated claims or raising an 
exception."""
+        try:
+            kid = self._get_kid_from_header(unvalidated)
+            key = await self.jwks.get_key(kid)
+
+            claims = jwt.decode(
+                unvalidated,
+                key,
+                audience=self.audience,
+                issuer=self.issuer,
+                options={"require": self.required_claims},
+                algorithms=[self.algorithm],
+            )
+
+            # Validate additional claims if provided
+            if extra_claims:
+                for claim, expected_value in extra_claims.items():
+                    if expected_value["essential"] and (
+                        claim not in claims or claims[claim] != 
expected_value["value"]
+                    ):
+                        raise InvalidClaimError(claim)
+
+            return claims
+        except jwt.PyJWTError as e:
+            raise ValueError(f"Invalid JWT: {e}") from e
+
+
[email protected](repr=False, kw_only=True)
+class JWTGenerator:
+    """Generate JWT tokens."""
+
+    _private_key: AllowedPrivateKeys | str | bytes | None = 
attrs.field(repr=False, alias="private_key")
+    _secret_key: str | None = attrs.field(
+        repr=False, alias="secret_key", factory=_conf_factory("api_auth", 
"jwt_secret", fallback=None)
+    )
+    """A pre-shared secret key to sign tokens with symmetric encryption"""
+
+    kid: str
+    validity_period: timedelta
+    issuer: str
+    audience: str
+    algorithm: str
+
+    @_private_key.default
+    def _load_key_from_file(self):
+        from airflow.configuration import conf
+
+        path = conf.get("api_auth", "jwt_private_key_path", fallback=None)
+        if not path:
+            return None
+
+        with open(path, mode="b") as fh:
+            return fh.read()
+
+    def __attrs_post_init__(self):
+        if (self._private_key is None) ^ (self._secret_key is None):
+            raise ValueError("Exactly one of priavte_key and secret_key must 
be specified")
+
+    @property
+    def signing_arg(self):
+        if callable(self._private_key):
+            return self._private_key()
+        if self._private_key:
+            return self._private_key
+        return self, self._secret_key
+
+    def generate(
+        self, subject: str, extras: dict[str, Any] | None = None, headers: 
dict[str, Any] | None = None
+    ) -> str:
+        """Generate a signed JWT for the subject."""
+        now = datetime.now(tz=timezone.utc)
+        claims = {
+            "iss": self.issuer,
+            "aud": self.audience,
+            "sub": subject,
+            "nbf": int(now.timestamp()),
+            "exp": int((now + self.validity_period).timestamp()),
+            "iat": int(now.timestamp()),
+        }
+        if extras is not None:
+            claims.update(extras)
+        headers = {"alg": self.algorithm, "kid": self.kid, **(headers or {})}
+        return jwt.encode(claims, self.signing_arg, algorithm=self.algorithm, 
headers=headers)
+
+
+# @attrs.define(repr=False)
+# class TaskJWTGenerator(JWTGenerator):
+#     issuer: str = attrs.field(factory=_default_issuer)
+#     audience: str = attrs.field(
+#         factory=_conf_factory("task_execution_api", "jwt_audience", 
fallback="urn:airflow.apache.org:task")
+#     )
+#     algorithm: str = attrs.field(
+#         factory=_conf_factory("task_execution_api", "jwt_algorithm", 
default="EdDSA")
+#     )
+
+
+def generate_private_key(key_type: str = "RSA", key_size: int = 2048):
+    """
+    Generate a valid private key for testing.
+
+    Args:
+        key_type (str): Type of key to generate. Can be "RSA" or "Ed25516". 
Defaults to "RSA".
+        key_size (int): Size of the key in bits. Only applicable for RSA keys. 
Defaults to 2048.
+
+    Returns:
+        tuple: A tuple containing the private key in PEM format and the 
corresponding public key in PEM format.
+    """
+    from cryptography.hazmat.primitives.asymmetric import ed25519, rsa
+
+    if key_type == "RSA":
+        # Generate an RSA private key
+
+        return rsa.generate_private_key(public_exponent=65537, 
key_size=key_size, backend=default_backend())
+    elif key_type == "Ed25519":
+        return ed25519.Ed25519PrivateKey.generate()
+    raise ValueError(f"unsupported key type: {key_type}")
+
+
+def key_to_pem(key: AllowedPrivateKeys) -> str:
+    from cryptography.hazmat.primitives import serialization
+
+    # Serialize the private key in PEM format
+    return key.private_bytes(
+        encoding=serialization.Encoding.PEM,
+        format=serialization.PrivateFormat.PKCS8,
+        encryption_algorithm=serialization.NoEncryption(),
+    ).decode("utf-8")
+
+
+def thumbprint(jwk: dict[str, Any], hashalg=hashes.SHA256()) -> str:
+    """
+    Return the key thumbprint as specified by RFC 7638.
+
+    :param hashalg: A hash function (defaults to SHA256)
+
+    :return: A base64url encoded digest of the key
+    """
+    digest = hashes.Hash(hashalg, backend=default_backend())
+    jsonstr = json.dumps(jwk, separators=(",", ":"), sort_keys=True)
+    digest.update(jsonstr.encode("utf8"))
+    return base64url_encode(digest.finalize())
+
+
+def base64url_encode(payload):
+    from base64 import urlsafe_b64encode
+
+    if not isinstance(payload, bytes):
+        payload = payload.encode("utf-8")
+    encode = urlsafe_b64encode(payload)
+    return encode.decode("utf-8").rstrip("=")
diff --git a/tests/security/test_tokens.py b/tests/security/test_tokens.py
new file mode 100644
index 00000000000..4a31fddd72b
--- /dev/null
+++ b/tests/security/test_tokens.py
@@ -0,0 +1,185 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import json
+from datetime import datetime, timedelta
+from typing import TYPE_CHECKING
+
+import httpx
+import jwt
+import pytest
+
+from airflow.security.tokens import (
+    JWKS,
+    InvalidClaimError,
+    JWTGenerator,
+    JWTValidator,
+    generate_private_key,
+    key_to_jwk_dict,
+)
+from airflow.utils import timezone
+
+if TYPE_CHECKING:
+    from cryptography.hazmat.primitives.asymmetric.ed25519 import 
Ed25519PrivateKey
+    from kgb import SpyAgency
+    from time_machine import TimeMachineFixture
+
+
+pytestmark = [pytest.mark.asyncio]
+
+
[email protected]
+def private_key(request):
+    return request.getfixturevalue(request.param or "ed25519_private_key")
+
+
[email protected]
+def mock_kid():
+    return "test-kid"
+
+
[email protected]
+def mock_subject():
+    return "test-subject"
+
+
+class TestJWKS:
+    @pytest.mark.parametrize("private_key", ["rsa_private_key", 
"ed25519_private_key"], indirect=True)
+    async def test_fetch_jwks_success(self, private_key):
+        jwk_content = json.dumps({"keys": [key_to_jwk_dict(private_key, 
"kid")]})
+
+        async def mock_transport(request):
+            return httpx.Response(status_code=200, content=jwk_content)
+
+        client = 
httpx.AsyncClient(transport=httpx.MockTransport(mock_transport))
+        jwks = JWKS(url="https://example.com/jwks.json";, client=client)
+
+        # Test fetching JWKS
+        await jwks.fetch_jwks()
+        assert isinstance(await jwks.get_key("kid"), jwt.PyJWK)
+
+    async def test_refresh_remote_jwks(
+        self, time_machine: TimeMachineFixture, ed25519_private_key, 
spy_agency: SpyAgency
+    ):
+        time_machine.move_to(datetime(2023, 10, 1, 12, 0, 0))  # Initial time: 
12:00 PM
+        jwk_content = json.dumps({"keys": 
[key_to_jwk_dict(ed25519_private_key, "kid")]})
+
+        async def mock_transport(request):
+            return httpx.Response(status_code=200, content=jwk_content)
+
+        client = 
httpx.AsyncClient(transport=httpx.MockTransport(mock_transport))
+        jwks = JWKS(url="https://example.com/jwks.json";, client=client)
+        spy = spy_agency.spy_on(JWKS._fetch_remote_jwks)
+
+        key = await jwks.get_key("kid")
+        assert isinstance(key, jwt.PyJWK)
+
+        # Move forward in time, but not to a point where it updates. Should 
not end up re-requesting.
+        spy.reset_calls()
+        time_machine.shift(1800)
+        assert await jwks.get_key("kid") is key
+        spy_agency.assert_spy_not_called(spy)
+
+        # Not to a point where it should refresh
+        time_machine.shift(1801)
+
+        key2 = key_to_jwk_dict(generate_private_key("Ed25519"), "kid2")
+        jwk_content = json.dumps({"keys": [key2]})
+        with pytest.raises(KeyError):
+            # Not in the document anymore, should have gone from the keyset
+            await jwks.get_key("kid")
+        assert isinstance(await jwks.get_key("kid2"), jwt.PyJWK)
+        spy_agency.assert_spy_called(spy)
+
+
[email protected]
+def jwt_generator(ed25519_private_key: Ed25519PrivateKey):
+    key = ed25519_private_key
+    return JWTGenerator(
+        private_key=key,
+        kid="kid1",
+        validity_period=timedelta(minutes=1),
+        issuer="http://test-issuer";,
+        algorithm="EdDSA",
+        audience="abc",
+    )
+
+
[email protected]
+def jwt_validator(ed25519_private_key: Ed25519PrivateKey):
+    key = ed25519_private_key
+    jwks = JWKS.from_private_key(kid1=key)
+    return JWTValidator(jwks=jwks, issuer="http://test-issuer";, 
algorithm="EdDSA", audience="abc")
+
+
+async def test_task_jwt_generator_validator(jwt_generator, jwt_validator, 
ed25519_private_key):
+    token = jwt_generator.generate(subject="test_subject")
+    # if this does not raise ValueError then the generated token can be 
decoded and contains all the required
+    # fields.
+    claims = await jwt_validator.validated_claims(
+        token, extra_claims={"sub": {"essential": True, "value": 
"test_subject"}}
+    )
+    nbf = datetime.fromtimestamp(claims["nbf"], timezone.utc)
+    iat = datetime.fromtimestamp(claims["iat"], timezone.utc)
+    exp = datetime.fromtimestamp(claims["exp"], timezone.utc)
+    now = datetime.now(timezone.utc)
+    assert nbf == iat, "issued at is different then not before"
+    assert nbf < exp, "not before is after expiration"
+    assert nbf <= now, "not before is in the future"
+    assert exp >= now, "expiration is in the past"
+    assert exp <= nbf + timedelta(minutes=10), "expiration is more then 10 
minutes after not before"
+
+    def token_without_claim(claim: str) -> str:
+        # remove claim and re-encode
+        bad_claims = claims.copy()
+        bad_claims.pop(required_claim)
+        return jwt.encode(
+            bad_claims,
+            ed25519_private_key,
+            headers={"kid": jwt_generator.kid},
+            algorithm=jwt_generator.algorithm,
+        )
+
+    for required_claim in jwt_validator.required_claims:
+        bad_token = token_without_claim(required_claim)
+        # check that the missing claim is detected in validation
+        with pytest.raises(ValueError, match="Invalid JWT") as exc_info:
+            await jwt_validator.validated_claims(bad_token, 
extra_claims={"sub": "test_subject"})
+        cause = exc_info.value.__cause__
+        assert isinstance(cause, jwt.MissingRequiredClaimError)
+        assert cause.claim == required_claim
+
+
+async def test_jwt_wrong_subject(jwt_generator, jwt_validator):
+    # check that the token is invalid if the subject is not as expected
+    wrong_subject = jwt_generator.generate(subject="wrong_subject")
+    with pytest.raises(InvalidClaimError, match="Invalid claim: sub"):
+        await jwt_validator.validated_claims(
+            wrong_subject, extra_claims={"sub": {"essential": True, "value": 
"test_subject"}}
+        )
+
+
[email protected](scope="session")
+def rsa_private_key():
+    return generate_private_key()
+
+
[email protected](scope="session")
+def ed25519_private_key():
+    return generate_private_key(key_type="Ed25519")

Reply via email to