This is an automated email from the ASF dual-hosted git repository. yasith pushed a commit to branch feat/sdk-facade-migration in repository https://gitbox.apache.org/repos/asf/airavata-portals.git
commit 89a5c1994af4fa9c6a1801afcfd9a10c02f196f8 Author: yasithdev <[email protected]> AuthorDate: Wed Apr 8 01:39:05 2026 -0500 feat: add type hints to core API files (10 files, ~200 functions) --- .../django_airavata/apps/api/helpers.py | 16 +- .../django_airavata/apps/api/output_views.py | 42 +-- .../django_airavata/apps/api/serializers.py | 87 +++--- .../django_airavata/apps/api/user_storage.py | 50 ++-- .../django_airavata/apps/api/view_utils.py | 61 ++-- .../django_airavata/apps/api/views.py | 308 +++++++++++---------- .../django_airavata/apps/auth/middleware.py | 17 +- .../django_airavata/context_processors.py | 20 +- .../django_airavata/middleware.py | 10 +- airavata-django-portal/django_airavata/utils.py | 2 +- 10 files changed, 317 insertions(+), 296 deletions(-) diff --git a/airavata-django-portal/django_airavata/apps/api/helpers.py b/airavata-django-portal/django_airavata/apps/api/helpers.py index 47475efc1..de7eef814 100644 --- a/airavata-django-portal/django_airavata/apps/api/helpers.py +++ b/airavata-django-portal/django_airavata/apps/api/helpers.py @@ -1,7 +1,9 @@ import logging +from typing import Any from django.conf import settings from django.core.exceptions import ObjectDoesNotExist +from django.http import HttpRequest from django_airavata.proto_compat import ResourcePermissionType @@ -11,7 +13,7 @@ logger = logging.getLogger(__name__) class WorkspacePreferencesHelper: - def get(self, request): + def get(self, request: HttpRequest) -> models.WorkspacePreferences: try: workspace_preferences = models.WorkspacePreferences.objects.get(username=request.user.username) self._check(request, workspace_preferences) @@ -20,7 +22,7 @@ class WorkspacePreferencesHelper: workspace_preferences.save() return workspace_preferences - def _create_default(self, request): + def _create_default(self, request: HttpRequest) -> models.WorkspacePreferences: workspace_preferences = models.WorkspacePreferences.create(request.user.username) most_recent_project = self._get_most_recent_project(request) workspace_preferences.most_recent_project_id = most_recent_project.projectID @@ -30,7 +32,7 @@ class WorkspacePreferencesHelper: ) return workspace_preferences - def _get_most_recent_project(self, request): + def _get_most_recent_project(self, request: HttpRequest) -> Any: "Return most recent writeable project." projects = request.airavata_client.research.get_user_projects(settings.GATEWAY_ID, request.user.username, -1, 0) for project in projects: @@ -38,7 +40,7 @@ class WorkspacePreferencesHelper: return project return None - def _get_first_group_resource_profile(self, request): + def _get_first_group_resource_profile(self, request: HttpRequest) -> Any: "Return first accessible group resource profile" group_resource_profiles = request.airavata_client.compute.get_group_resource_list(settings.GATEWAY_ID) @@ -47,7 +49,7 @@ class WorkspacePreferencesHelper: else: return None - def _check(self, request, prefs): + def _check(self, request: HttpRequest, prefs: models.WorkspacePreferences) -> None: "Validate preference values and update as needed." if not prefs.most_recent_project_id or not self._can_write(request, prefs.most_recent_project_id): most_recent_project = self._get_most_recent_project(request) @@ -70,8 +72,8 @@ class WorkspacePreferencesHelper: prefs.most_recent_group_resource_profile_id = first_grp_id prefs.save() - def _can_write(self, request, entity_id): + def _can_write(self, request: HttpRequest, entity_id: str) -> bool: return request.airavata_client.sharing.user_has_access(entity_id, ResourcePermissionType.WRITE) - def _can_read(self, request, entity_id): + def _can_read(self, request: HttpRequest, entity_id: str) -> bool: return request.airavata_client.sharing.user_has_access(entity_id, ResourcePermissionType.READ) diff --git a/airavata-django-portal/django_airavata/apps/api/output_views.py b/airavata-django-portal/django_airavata/apps/api/output_views.py index 5a02bfa31..46f1325cb 100644 --- a/airavata-django-portal/django_airavata/apps/api/output_views.py +++ b/airavata-django-portal/django_airavata/apps/api/output_views.py @@ -4,11 +4,13 @@ import json import logging import os from functools import partial +from typing import Any import nbformat import papermill as pm from django_airavata.apps.api import user_storage from django.conf import settings +from django.http import HttpRequest from nbconvert import HTMLExporter from django_airavata.proto_compat import DataType @@ -18,7 +20,7 @@ logger = logging.getLogger(__name__) BASE_DIR = os.path.dirname(os.path.abspath(__file__)) # This is populated by apps.ApiConfig.ready() -OUTPUT_VIEW_PROVIDERS = {} +OUTPUT_VIEW_PROVIDERS: dict[str, Any] = {} class DefaultViewProvider: @@ -26,7 +28,7 @@ class DefaultViewProvider: immediate = False name = "Default" - def generate_data(self, request, experiment_output, experiment, output_file=None, **kwargs): + def generate_data(self, request: HttpRequest, experiment_output: Any, experiment: Any, output_file: Any = None, **kwargs: Any) -> dict[str, Any]: return {} @@ -35,7 +37,7 @@ class ParameterizedNotebookViewProvider: name = "Example Parameterized Notebook View" # test_output_file = os.path.join(BASE_DIR, "data", "Gaussian.log") - def generate_data(self, request, experiment_output, experiment, output_file=None, output_dir=None): + def generate_data(self, request: HttpRequest, experiment_output: Any, experiment: Any, output_file: Any = None, output_dir: str | None = None) -> dict[str, str]: # use papermill to generate the output notebook output_file_path = os.path.realpath(output_file.name) pm.execute_notebook( @@ -52,11 +54,11 @@ class ParameterizedNotebookViewProvider: return {"output": body} -DEFAULT_VIEW_PROVIDERS = {"default": DefaultViewProvider()} +DEFAULT_VIEW_PROVIDERS: dict[str, Any] = {"default": DefaultViewProvider()} -def get_output_views(request, experiment, application_interface=None): - output_views = {} +def get_output_views(request: HttpRequest, experiment: Any, application_interface: Any = None) -> dict[str, list[dict[str, Any]]]: + output_views: dict[str, list[dict[str, Any]]] = {} for output in experiment.experimentOutputs: output_views[output.name] = [] output_view_provider_ids = _get_output_view_providers(output, application_interface) @@ -69,7 +71,7 @@ def get_output_views(request, experiment, application_interface=None): else: logger.warning(f"Unable to find output view provider with name '{output_view_provider_id}'") if output_view_provider is not None: - view_config = { + view_config: dict[str, Any] = { "provider-id": output_view_provider_id, "display-type": output_view_provider.display_type, "name": getattr(output_view_provider, "name", output_view_provider_id), @@ -84,7 +86,7 @@ def get_output_views(request, experiment, application_interface=None): return output_views -def _get_output_view_provider(output_view_provider_id): +def _get_output_view_provider(output_view_provider_id: str) -> Any: if output_view_provider_id in DEFAULT_VIEW_PROVIDERS: return DEFAULT_VIEW_PROVIDERS[output_view_provider_id] @@ -92,8 +94,8 @@ def _get_output_view_provider(output_view_provider_id): return OUTPUT_VIEW_PROVIDERS[output_view_provider_id] -def _get_output_view_providers(experiment_output, application_interface): - output_view_providers = [] +def _get_output_view_providers(experiment_output: Any, application_interface: Any) -> list[str]: + output_view_providers: list[str] = [] logger.debug(f"experiment_output={experiment_output}") if experiment_output.metaData: try: @@ -116,7 +118,7 @@ def _get_output_view_providers(experiment_output, application_interface): return output_view_providers -def _get_application_output_view_providers(application_interface, output_name): +def _get_application_output_view_providers(application_interface: Any, output_name: str) -> list[str]: app_output = [o for o in application_interface.applicationOutputs if o.name == output_name] if len(app_output) == 1: logger.debug(f"{output_name}: {app_output}") @@ -133,7 +135,7 @@ def _get_application_output_view_providers(application_interface, output_name): return [] -def generate_data(request, output_view_provider_id, experiment_output_name, experiment_id, test_mode=False, **kwargs): +def generate_data(request: HttpRequest, output_view_provider_id: str, experiment_output_name: str, experiment_id: str, test_mode: bool = False, **kwargs: Any) -> dict[str, Any]: output_view_provider = _get_output_view_provider(output_view_provider_id) # TODO if output_view_provider is None, return 404 experiment = request.airavata_client.research.get_experiment(experiment_id) @@ -146,8 +148,8 @@ def generate_data(request, output_view_provider_id, experiment_output_name, expe return _generate_data(request, output_view_provider, experiment_output, experiment, test_mode=test_mode, **kwargs) -def _generate_data(request, output_view_provider, experiment_output, experiment, test_mode=False, **kwargs): - output_files = [] +def _generate_data(request: HttpRequest, output_view_provider: Any, experiment_output: Any, experiment: Any, test_mode: bool = False, **kwargs: Any) -> dict[str, Any]: + output_files: list[Any] = [] # test_mode can only be used in DEBUG=True mode if test_mode and settings.DEBUG: test_output_file = getattr(output_view_provider, "test_output_file", None) @@ -181,7 +183,7 @@ def _generate_data(request, output_view_provider, experiment_output, experiment, return data -def _process_interactive_params(data): +def _process_interactive_params(data: dict[str, Any]) -> None: if "interactive" in data: _convert_options(data) for param in data["interactive"]: @@ -192,7 +194,7 @@ def _process_interactive_params(data): param["step"] = 1 -def _convert_options(data): +def _convert_options(data: dict[str, Any]) -> None: """Convert interactive options to explicit text/value dicts.""" for param in data["interactive"]: if "options" in param and isinstance(param["options"][0], str): @@ -201,15 +203,15 @@ def _convert_options(data): param["options"] = _convert_options_sequences(param["options"]) -def _convert_options_strings(options): +def _convert_options_strings(options: list[str]) -> list[dict[str, str]]: return [{"text": o, "value": o} for o in options] -def _convert_options_sequences(options): +def _convert_options_sequences(options: list[Any]) -> list[dict[str, Any]]: return [{"text": o[0], "value": o[1]} for o in options] -def _infer_interactive_param_type(param): +def _infer_interactive_param_type(param: dict[str, Any]) -> str | None: v = param["value"] # Boolean test must come first since bools are also integers if isinstance(v, bool): @@ -222,7 +224,7 @@ def _infer_interactive_param_type(param): return "string" -def _convert_params_to_type(output_view_provider, params): +def _convert_params_to_type(output_view_provider: Any, params: dict[str, Any]) -> dict[str, Any]: method_sig = inspect.signature(output_view_provider.generate_data) method_params = method_sig.parameters # Special query parameter _meta holds type information for interactive diff --git a/airavata-django-portal/django_airavata/apps/api/serializers.py b/airavata-django-portal/django_airavata/apps/api/serializers.py index 80881faba..a016d5832 100644 --- a/airavata-django-portal/django_airavata/apps/api/serializers.py +++ b/airavata-django-portal/django_airavata/apps/api/serializers.py @@ -3,13 +3,16 @@ import datetime import json import logging from pathlib import Path +from typing import Any from urllib.parse import quote from django_airavata.apps.api import user_storage from django.conf import settings from django.contrib.auth import get_user_model +from django.http import HttpRequest from django.urls import reverse from rest_framework import serializers +from rest_framework.request import Request from django_airavata.proto_compat import ( ApplicationDeploymentDescription, @@ -59,7 +62,7 @@ log = logging.getLogger(__name__) class FullyEncodedHyperlinkedIdentityField(serializers.HyperlinkedIdentityField): - def get_url(self, obj, view_name, request, format): + def get_url(self, obj: Any, view_name: str, request: Request, format: str | None) -> str: if hasattr(obj, self.lookup_field): lookup_value = getattr(obj, self.lookup_field) else: @@ -87,27 +90,27 @@ class UTCPosixTimestampDateTimeField(serializers.DateTimeField): self.initial = self.initial_value self.required = False - def to_representation(self, obj): + def to_representation(self, obj: int) -> str: # Create datetime instance from milliseconds that is aware of timezon dt = datetime.datetime.fromtimestamp(obj / 1000, datetime.UTC) return super().to_representation(dt) - def to_internal_value(self, data): + def to_internal_value(self, data: str) -> int: dt = super().to_internal_value(data) return int(dt.timestamp() * 1000) - def initial_value(self): + def initial_value(self) -> str: return self.to_representation(self.current_time_ms()) - def current_time_ms(self): + def current_time_ms(self) -> int: return int(datetime.datetime.utcnow().timestamp() * 1000) class StoredJSONField(serializers.JSONField): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - def to_representation(self, value): + def to_representation(self, value: str | None) -> Any: try: if value: return json.loads(value) @@ -116,7 +119,7 @@ class StoredJSONField(serializers.JSONField): except Exception: return value - def to_internal_value(self, data): + def to_internal_value(self, data: Any) -> str: try: return json.dumps(data) except (TypeError, ValueError): @@ -124,17 +127,17 @@ class StoredJSONField(serializers.JSONField): class OrderedListField(serializers.ListField): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: self.order_by = kwargs.pop("order_by", None) super().__init__(*args, **kwargs) - def to_representation(self, instance): + def to_representation(self, instance: list[Any]) -> list[dict[str, Any]] | None: rep = super().to_representation(instance) if rep is not None: rep.sort(key=lambda item: item[self.order_by]) return rep - def to_internal_value(self, data): + def to_internal_value(self, data: list[Any]) -> list[Any]: validated_data = super().to_internal_value(data) # Update order field based on order in array items = validated_data if validated_data else [] @@ -158,12 +161,12 @@ class GroupSerializer(proto_utils.create_serializer_class(GroupModel)): required = ("name",) read_only = ("ownerId",) - def create(self, validated_data): + def create(self, validated_data: dict[str, Any]) -> Any: group = super().create(validated_data) group.ownerId = self.context["request"].user.username + "@" + settings.GATEWAY_ID return group - def update(self, instance, validated_data): + def update(self, instance: Any, validated_data: dict[str, Any]) -> Any: instance.name = validated_data.get("name", instance.name) instance.description = validated_data.get("description", instance.description) # Calculate added and removed members @@ -187,31 +190,31 @@ class GroupSerializer(proto_utils.create_serializer_class(GroupModel)): instance.members.extend(list(added_admins - new_members)) return instance - def get_isAdmin(self, group): + def get_isAdmin(self, group: Any) -> bool: request = self.context["request"] return request.airavata_client.sharing.has_admin_access( group.id, request.user.username + "@" + settings.GATEWAY_ID ) - def get_isOwner(self, group): + def get_isOwner(self, group: Any) -> bool: request = self.context["request"] return group.ownerId == (request.user.username + "@" + settings.GATEWAY_ID) - def get_isMember(self, group): + def get_isMember(self, group: Any) -> bool: request = self.context["request"] username = request.user.username + "@" + settings.GATEWAY_ID return group.members and username in group.members - def get_isGatewayAdminsGroup(self, group): + def get_isGatewayAdminsGroup(self, group: Any) -> bool: return group.id == self._gateway_groups()["adminsGroupId"] - def get_isReadOnlyGatewayAdminsGroup(self, group): + def get_isReadOnlyGatewayAdminsGroup(self, group: Any) -> bool: return group.id == self._gateway_groups()["readOnlyAdminsGroupId"] - def get_isDefaultGatewayUsersGroup(self, group): + def get_isDefaultGatewayUsersGroup(self, group: Any) -> bool: return group.id == self._gateway_groups()["defaultGatewayUsersGroupId"] - def _gateway_groups(self): + def _gateway_groups(self) -> dict[str, Any]: request = self.context["request"] # gateway_groups_middleware sets this session variable if "GATEWAY_GROUPS" in request.session: @@ -236,19 +239,19 @@ class ProjectSerializer(proto_utils.create_serializer_class(Project)): userHasWriteAccess = serializers.SerializerMethodField() isOwner = serializers.SerializerMethodField() - def create(self, validated_data): + def create(self, validated_data: dict[str, Any]) -> Any: return Project(**validated_data) - def update(self, instance, validated_data): + def update(self, instance: Any, validated_data: dict[str, Any]) -> Any: instance.name = validated_data.get("name", instance.name) instance.description = validated_data.get("description", instance.description) return instance - def get_userHasWriteAccess(self, project): + def get_userHasWriteAccess(self, project: Any) -> bool: request = self.context["request"] return request.airavata_client.sharing.user_has_access(project.projectID, ResourcePermissionType.WRITE) - def get_isOwner(self, project): + def get_isOwner(self, project: Any) -> bool: request = self.context["request"] return project.owner == request.user.username @@ -278,12 +281,12 @@ class ApplicationModuleSerializer(proto_utils.create_serializer_class(Applicatio class EnumChoiceField(serializers.ChoiceField): - def __init__(self, enum_class, **kwargs): + def __init__(self, enum_class: type, **kwargs: Any) -> None: self.enum_class = enum_class kwargs["choices"] = [(member.name, member.name) for member in enum_class] super().__init__(**kwargs) - def to_internal_value(self, data): + def to_internal_value(self, data: str | int) -> Any: if isinstance(data, int): try: return self.enum_class(data) @@ -294,7 +297,7 @@ class EnumChoiceField(serializers.ChoiceField): except KeyError: self.fail("invalid_choice", input=data) - def to_representation(self, value): + def to_representation(self, value: Any) -> str: return value.name @@ -552,7 +555,7 @@ class DataProductSerializer(proto_utils.create_serializer_class(DataProductModel filesize = serializers.SerializerMethodField() userHasWriteAccess = serializers.SerializerMethodField() - def get_downloadURL(self, data_product): + def get_downloadURL(self, data_product: Any) -> str | None: """Getter for downloadURL field. Returns None if file is not available.""" request = self.context["request"] if user_storage.exists(request, data_product): @@ -560,12 +563,12 @@ class DataProductSerializer(proto_utils.create_serializer_class(DataProductModel else: return None - def get_isInputFileUpload(self, data_product): + def get_isInputFileUpload(self, data_product: Any) -> bool: """Return True if this is an uploaded input file.""" request = self.context["request"] return user_storage.is_input_file(request, data_product) - def get_filesize(self, data_product): + def get_filesize(self, data_product: Any) -> int: request = self.context["request"] if user_storage.exists(request, data_product): metadata = user_storage.get_data_product_metadata(request, data_product) @@ -573,7 +576,7 @@ class DataProductSerializer(proto_utils.create_serializer_class(DataProductModel else: return 0 - def get_userHasWriteAccess(self, data_product: DataProductModel): + def get_userHasWriteAccess(self, data_product: DataProductModel) -> bool: request = self.context["request"] if user_storage.exists(request, data_product): file_metadata = user_storage.get_data_product_metadata(request, data_product=data_product) @@ -597,15 +600,15 @@ class FullExperiment: def __init__( self, - experimentModel, - project=None, - outputDataProducts=None, - inputDataProducts=None, - applicationModule=None, - computeResource=None, - jobDetails=None, - outputViews=None, - ): + experimentModel: Any, + project: Any = None, + outputDataProducts: list[Any] | None = None, + inputDataProducts: list[Any] | None = None, + applicationModule: Any = None, + computeResource: Any = None, + jobDetails: list[Any] | None = None, + outputViews: dict[str, list[dict[str, Any]]] | None = None, + ) -> None: self.experiment = experimentModel self.experimentId = experimentModel.experimentId self.project = project @@ -1726,11 +1729,11 @@ class SharedEntitySerializer(serializers.Serializer): # return tuple: permissions to revoke and permissions to grant return (current_permissions_set - new_permissions_set, new_permissions_set - current_permissions_set) - def get_isOwner(self, shared_entity): + def get_isOwner(self, shared_entity: dict[str, Any]) -> bool: request = self.context["request"] return shared_entity["owner"].userId == request.user.username - def get_hasSharingPermission(self, shared_entity): + def get_hasSharingPermission(self, shared_entity: dict[str, Any]) -> bool: request = self.context["request"] return request.airavata_client.sharing.user_has_access( shared_entity["entityId"], ResourcePermissionType.MANAGE_SHARING diff --git a/airavata-django-portal/django_airavata/apps/api/user_storage.py b/airavata-django-portal/django_airavata/apps/api/user_storage.py index 3b68b7f51..aed80ba33 100644 --- a/airavata-django-portal/django_airavata/apps/api/user_storage.py +++ b/airavata-django-portal/django_airavata/apps/api/user_storage.py @@ -8,8 +8,10 @@ old API) and delegates to ``request.airavata_client.storage``. import io import logging import os +from typing import Any, BinaryIO from django.conf import settings +from django.http import HttpRequest log = logging.getLogger(__name__) @@ -18,7 +20,7 @@ log = logging.getLogger(__name__) # Helpers to extract file paths from DataProductModel proto objects # --------------------------------------------------------------------------- -def _get_replica_filepath(data_product): +def _get_replica_filepath(data_product: Any) -> str | None: """Return the file_path from the first GATEWAY_DATA_STORE replica location.""" for replica in data_product.replica_locations: if replica.file_path: @@ -26,7 +28,7 @@ def _get_replica_filepath(data_product): return None -def _get_replica_storage_resource_id(data_product): +def _get_replica_storage_resource_id(data_product: Any) -> str | None: """Return the storage_resource_id from the first replica location.""" for replica in data_product.replica_locations: if replica.storage_resource_id: @@ -38,7 +40,7 @@ def _get_replica_storage_resource_id(data_product): # File existence / metadata # --------------------------------------------------------------------------- -def exists(request, data_product): +def exists(request: HttpRequest, data_product: Any) -> bool: """Check whether the file backing *data_product* exists in user storage.""" path = _get_replica_filepath(data_product) if not path: @@ -49,14 +51,14 @@ def exists(request, data_product): return False -def dir_exists(request, path, experiment_id=None): +def dir_exists(request: HttpRequest, path: str, experiment_id: str | None = None) -> bool: """Check whether *path* exists as a directory in user storage.""" if experiment_id: return experiment_dir_exists(request, experiment_id, path) return request.airavata_client.storage.dir_exists(path) -def experiment_dir_exists(request, experiment_id, path=""): +def experiment_dir_exists(request: HttpRequest, experiment_id: str, path: str = "") -> bool: """Check whether the experiment output directory exists.""" try: request.airavata_client.storage.list_experiment_dir(experiment_id, path) @@ -65,7 +67,7 @@ def experiment_dir_exists(request, experiment_id, path=""): return False -def is_input_file(request, data_product): +def is_input_file(request: HttpRequest, data_product: Any) -> bool: """Return True if the data product's path is under the inputs directory.""" path = _get_replica_filepath(data_product) if not path: @@ -79,7 +81,7 @@ def is_input_file(request, data_product): # File / directory listing # --------------------------------------------------------------------------- -def listdir(request, path, experiment_id=None): +def listdir(request: HttpRequest, path: str, experiment_id: str | None = None) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """List the contents of *path*, returning (directories, files) dicts.""" if experiment_id: return list_experiment_dir(request, experiment_id, path) @@ -89,7 +91,7 @@ def listdir(request, path, experiment_id=None): return directories, files -def list_experiment_dir(request, experiment_id, path=""): +def list_experiment_dir(request: HttpRequest, experiment_id: str, path: str = "") -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: """List the experiment output directory.""" resp = request.airavata_client.storage.list_experiment_dir(experiment_id, path) directories = _metadata_list_to_dicts(resp.directories) @@ -97,9 +99,9 @@ def list_experiment_dir(request, experiment_id, path=""): return directories, files -def _metadata_list_to_dicts(items): +def _metadata_list_to_dicts(items: Any) -> list[dict[str, Any]]: """Convert repeated FileMetadataResponse protos to plain dicts.""" - result = [] + result: list[dict[str, Any]] = [] for item in items: result.append({ "name": item.name, @@ -118,7 +120,7 @@ def _metadata_list_to_dicts(items): # File open / download # --------------------------------------------------------------------------- -def open_file(request, data_product): +def open_file(request: HttpRequest, data_product: Any) -> io.BytesIO: """Download the file for *data_product* and return a file-like object.""" path = _get_replica_filepath(data_product) resp = request.airavata_client.storage.download_file(path) @@ -131,7 +133,7 @@ def open_file(request, data_product): # File upload / save # --------------------------------------------------------------------------- -def save_input_file(request, input_file, name=None, content_type=""): +def save_input_file(request: HttpRequest, input_file: BinaryIO, name: str | None = None, content_type: str = "") -> Any: """Upload *input_file* to the user's input files directory. Returns a DataProductModel proto. @@ -149,7 +151,7 @@ def save_input_file(request, input_file, name=None, content_type=""): return request.airavata_client.research.get_data_product(resp.uri) -def save(request, path, file_obj, name=None, content_type="", experiment_id=None): +def save(request: HttpRequest, path: str, file_obj: BinaryIO, name: str | None = None, content_type: str = "", experiment_id: str | None = None) -> Any: """Upload *file_obj* to *path* in user storage. Returns a DataProductModel proto. @@ -169,7 +171,7 @@ def save(request, path, file_obj, name=None, content_type="", experiment_id=None # File content update # --------------------------------------------------------------------------- -def update_data_product_content(request, data_product, fileContentText): +def update_data_product_content(request: HttpRequest, data_product: Any, fileContentText: str) -> None: """Replace the content of the file backing *data_product* with *fileContentText*.""" path = _get_replica_filepath(data_product) name = os.path.basename(path) @@ -180,7 +182,7 @@ def update_data_product_content(request, data_product, fileContentText): ) -def update_file_content(request, path, fileContentText): +def update_file_content(request: HttpRequest, path: str, fileContentText: str) -> None: """Replace the content of the file at *path* with *fileContentText*.""" name = os.path.basename(path) request.airavata_client.storage.upload_file( @@ -194,30 +196,30 @@ def update_file_content(request, path, fileContentText): # File / directory creation and deletion # --------------------------------------------------------------------------- -def create_user_dir(request, path, experiment_id=None): +def create_user_dir(request: HttpRequest, path: str, experiment_id: str | None = None) -> tuple[None, str]: """Create a directory at *path*. Returns (storage_resource_id, created_path).""" resp = request.airavata_client.storage.create_dir(path) return None, resp.created_path -def create_symlink(request, source_path, dest_path): +def create_symlink(request: HttpRequest, source_path: str, dest_path: str) -> None: """Create a symlink from *source_path* to *dest_path*.""" request.airavata_client.storage.create_symlink(source_path, dest_path) -def delete(request, data_product): +def delete(request: HttpRequest, data_product: Any) -> None: """Delete the file backing *data_product*.""" path = _get_replica_filepath(data_product) if path: request.airavata_client.storage.delete_file(path) -def delete_user_file(request, path, experiment_id=None): +def delete_user_file(request: HttpRequest, path: str, experiment_id: str | None = None) -> None: """Delete a user file at *path*.""" request.airavata_client.storage.delete_file(path) -def delete_dir(request, path, experiment_id=None): +def delete_dir(request: HttpRequest, path: str, experiment_id: str | None = None) -> None: """Delete a directory at *path*.""" request.airavata_client.storage.delete_dir(path) @@ -226,7 +228,7 @@ def delete_dir(request, path, experiment_id=None): # File metadata # --------------------------------------------------------------------------- -def get_file_metadata(request, path, experiment_id=None): +def get_file_metadata(request: HttpRequest, path: str, experiment_id: str | None = None) -> dict[str, Any]: """Get metadata for the file at *path*. Returns a dict.""" resp = request.airavata_client.storage.get_file_metadata(path) return { @@ -241,7 +243,7 @@ def get_file_metadata(request, path, experiment_id=None): } -def get_data_product_metadata(request, data_product=None, data_product_uri=None): +def get_data_product_metadata(request: HttpRequest, data_product: Any = None, data_product_uri: str | None = None) -> dict[str, Any]: """Get metadata for a data product. Returns a dict with path, size, etc.""" if data_product is None and data_product_uri: data_product = request.airavata_client.research.get_data_product(data_product_uri) @@ -268,14 +270,14 @@ def get_data_product_metadata(request, data_product=None, data_product_uri=None) # Download URL helpers # --------------------------------------------------------------------------- -def get_download_url(request, data_product_uri=None): +def get_download_url(request: HttpRequest, data_product_uri: str | None = None) -> str: """Return a URL to download the file for *data_product_uri*.""" from django.urls import reverse from urllib.parse import quote return reverse("django_airavata_api:download_file") + f"?data-product-uri={quote(data_product_uri)}" -def get_lazy_download_url(request, data_product=None, data_product_uri=None): +def get_lazy_download_url(request: HttpRequest, data_product: Any = None, data_product_uri: str | None = None) -> str | None: """Return a download URL. Accepts either a data_product or data_product_uri.""" if data_product_uri: return get_download_url(request, data_product_uri=data_product_uri) diff --git a/airavata-django-portal/django_airavata/apps/api/view_utils.py b/airavata-django-portal/django_airavata/apps/api/view_utils.py index 9c90de970..a1bc917a1 100644 --- a/airavata-django-portal/django_airavata/apps/api/view_utils.py +++ b/airavata-django-portal/django_airavata/apps/api/view_utils.py @@ -1,8 +1,10 @@ import logging import os from collections.__init__ import OrderedDict +from collections.abc import Iterator from datetime import datetime from pathlib import Path +from typing import Any import pytz from django_airavata.apps.api import user_storage @@ -10,6 +12,7 @@ from django.conf import settings from django.http import Http404 from django.http.request import QueryDict from rest_framework import mixins, pagination, permissions +from rest_framework.request import Request from rest_framework.response import Response from rest_framework.reverse import reverse from rest_framework.utils.urls import remove_query_param, replace_query_param @@ -24,19 +27,19 @@ class GenericAPIBackedViewSet(GenericViewSet): # in DRF doesn't allow. lookup_value_regex = "[^/]+" - def get_list(self): + def get_list(self) -> Any: """ Subclasses must implement. """ raise NotImplementedError() - def get_instance(self, lookup_value): + def get_instance(self, lookup_value: str) -> Any: """ Subclasses must implement. """ raise NotImplementedError() - def get_queryset(self): + def get_queryset(self) -> Any: if isinstance(self, mixins.ListModelMixin): return self.get_list() else: @@ -46,7 +49,7 @@ class GenericAPIBackedViewSet(GenericViewSet): # get_list() implementation return None - def get_object(self): + def get_object(self) -> Any: lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field lookup_value = self.kwargs[lookup_url_kwarg] inst = self.get_instance(lookup_value) @@ -56,15 +59,15 @@ class GenericAPIBackedViewSet(GenericViewSet): return inst @property - def username(self): + def username(self) -> str: return self.request.user.username @property - def gateway_id(self): + def gateway_id(self) -> str: return settings.GATEWAY_ID @property - def authz_token(self): + def authz_token(self) -> Any: # Deprecated: SDK facade handles auth internally. # Kept for backward compatibility with code that references self.authz_token. return getattr(self.request, "authz_token", None) @@ -110,21 +113,21 @@ class APIResultIterator: Iterable container over API results which allow limit/offset style slicing. """ - limit = -1 - offset = 0 + limit: int = -1 + offset: int = 0 - def __init__(self, query_params=None): + def __init__(self, query_params: QueryDict | None = None) -> None: self.query_params = query_params if query_params is not None else QueryDict() - def get_results(self, limit=-1, offset=0): + def get_results(self, limit: int = -1, offset: int = 0) -> list[Any]: raise NotImplementedError("Subclasses must implement get_results") - def __iter__(self): + def __iter__(self) -> Iterator[Any]: results = self.get_results(self.limit, self.offset) for result in results: yield result - def __getitem__(self, key): + def __getitem__(self, key: int | slice) -> Iterator[Any] | list[Any]: if isinstance(key, slice): self.limit = key.stop - key.start self.offset = key.start @@ -142,7 +145,7 @@ class APIResultPagination(pagination.LimitOffsetPagination): default_limit = 10 - def paginate_queryset(self, queryset, request, view=None): + def paginate_queryset(self, queryset: "APIResultIterator", request: Request, view: Any = None) -> list[Any] | None: assert isinstance(queryset, APIResultIterator), f"queryset is not an APIResultIterator: {queryset}" self.query_params = queryset.query_params.copy() self.limit = self.get_limit(request) @@ -160,13 +163,13 @@ class APIResultPagination(pagination.LimitOffsetPagination): return list(queryset[self.offset : self.offset + self.limit]) - def get_limit(self, request): + def get_limit(self, request: Request) -> int | None: # If limit <= 0 then don't paginate if self.limit_query_param in request.query_params and int(request.query_params[self.limit_query_param]) <= 0: return None return super().get_limit(request) - def get_paginated_response(self, data): + def get_paginated_response(self, data: list[Any]) -> Response: has_next_link = len(data) >= self.limit return Response( OrderedDict( @@ -180,14 +183,14 @@ class APIResultPagination(pagination.LimitOffsetPagination): ) ) - def get_next_link(self): + def get_next_link(self) -> str: url = self.get_base_url() url = replace_query_param(url, self.limit_query_param, self.limit) offset = self.offset + self.limit return replace_query_param(url, self.offset_query_param, offset) - def get_previous_link(self): + def get_previous_link(self) -> str | None: if self.offset <= 0: return None @@ -200,7 +203,7 @@ class APIResultPagination(pagination.LimitOffsetPagination): offset = self.offset - self.limit return replace_query_param(url, self.offset_query_param, offset) - def get_base_url(self): + def get_base_url(self) -> str: if hasattr(self, "viewname"): base_url = self.request.build_absolute_uri(reverse(self.viewname)) if len(self.query_params) > 0: @@ -210,7 +213,7 @@ class APIResultPagination(pagination.LimitOffsetPagination): return self.request.build_absolute_uri() -def convert_utc_iso8601_to_date(iso8601_utc_string): +def convert_utc_iso8601_to_date(iso8601_utc_string: str) -> datetime: # This is meant to convert a JavaScript `new Date().toJSON()` into a # datetime instance timestamp = datetime.strptime(iso8601_utc_string, "%Y-%m-%dT%H:%M:%S.%fZ") @@ -222,7 +225,7 @@ def convert_utc_iso8601_to_date(iso8601_utc_string): class IsInAdminsGroupPermission(permissions.BasePermission): message = "User must be member of the Admins or Read Only Admins groups." - def has_permission(self, request, view): + def has_permission(self, request: Request, view: Any) -> bool: # Read Only Admins can make GET requests only if request.method in permissions.SAFE_METHODS: return request.is_gateway_admin or request.is_read_only_gateway_admin @@ -231,16 +234,16 @@ class IsInAdminsGroupPermission(permissions.BasePermission): class ReadOnly(permissions.BasePermission): - def has_permission(self, request, view): + def has_permission(self, request: Request, view: Any) -> bool: return request.method in permissions.SAFE_METHODS -def is_shared_dir(path): +def is_shared_dir(path: str) -> bool: shared_dirs: dict = getattr(settings, "GATEWAY_DATA_SHARED_DIRECTORIES", {}) return any(map(lambda n: Path(n) == Path(path), shared_dirs.keys())) -def is_shared_path(path): +def is_shared_path(path: str) -> bool: shared_dirs: dict = getattr(settings, "GATEWAY_DATA_SHARED_DIRECTORIES", {}) # FIXME: path returned when creating a new directory in user storage is an # absolute path. Assume that when an absolute path is given that it was for @@ -252,10 +255,10 @@ def is_shared_path(path): class BaseSharedDirPermission(permissions.BasePermission): - def get_path(self, request, view) -> str: + def get_path(self, request: Request, view: Any) -> str: raise NotImplementedError() - def has_permission(self, request, view): + def has_permission(self, request: Request, view: Any) -> bool: if request.method in permissions.SAFE_METHODS: return True @@ -275,12 +278,12 @@ class BaseSharedDirPermission(permissions.BasePermission): class DataProductSharedDirPermission(BaseSharedDirPermission): - def get_path(self, request, view) -> str: + def get_path(self, request: Request, view: Any) -> str: data_product_uri = request.query_params.get("data-product-uri", request.query_params.get("product-uri", "")) file_metadata = user_storage.get_data_product_metadata(request, data_product_uri=data_product_uri) return file_metadata["path"] - def has_permission(self, request, view): + def has_permission(self, request: Request, view: Any) -> bool: # Special handling for remote API, just get the userHasWriteAccess attribute and use that if hasattr(settings, "GATEWAY_DATA_STORE_REMOTE_API"): if request.method in permissions.SAFE_METHODS: @@ -293,6 +296,6 @@ class DataProductSharedDirPermission(BaseSharedDirPermission): class UserStorageSharedDirPermission(BaseSharedDirPermission): - def get_path(self, request, view): + def get_path(self, request: Request, view: Any) -> str: # 'path' can be a url path parameter, query parameter or in the request body (data) return request.query_params.get("path", request.data.get("path", view.kwargs.get("path"))) diff --git a/airavata-django-portal/django_airavata/apps/api/views.py b/airavata-django-portal/django_airavata/apps/api/views.py index 4af37c93f..d5e421aa8 100644 --- a/airavata-django-portal/django_airavata/apps/api/views.py +++ b/airavata-django-portal/django_airavata/apps/api/views.py @@ -4,6 +4,7 @@ import logging import os import warnings from datetime import datetime, timedelta +from typing import Any from django_airavata.apps.api import user_storage from django.conf import settings @@ -18,6 +19,7 @@ from rest_framework.decorators import action, api_view, permission_classes from rest_framework.exceptions import ParseError from rest_framework.permissions import IsAuthenticated from rest_framework.renderers import JSONRenderer +from rest_framework.request import Request from rest_framework.response import Response from rest_framework.views import APIView @@ -71,28 +73,28 @@ class GroupViewSet(APIBackedViewSet): pagination_class = APIResultPagination pagination_viewname = "django_airavata_api:group-list" - def get_list(self): + def get_list(self) -> APIResultIterator: view = self class GroupResultsIterator(APIResultIterator): - def get_results(self, limit=-1, offset=0): + def get_results(self, limit: int = -1, offset: int = 0) -> list[Any]: groups = view.request.airavata_client.sharing.get_groups() end = offset + limit if limit > 0 else len(groups) return groups[offset:end] if groups else [] return GroupResultsIterator() - def get_instance(self, lookup_value): + def get_instance(self, lookup_value: str) -> Any: return self.request.airavata_client.sharing.get_group(lookup_value) - def perform_create(self, serializer): + def perform_create(self, serializer: Any) -> None: group = serializer.save() group_id = self.request.airavata_client.sharing.create_group(group) group.id = group_id users_added_to_group = set(group.members) - {group.ownerId} self._send_users_added_to_group(users_added_to_group, group) - def perform_update(self, serializer): + def perform_update(self, serializer: Any) -> None: group = serializer.save() sharing_client = self.request.airavata_client.sharing if len(group._added_members) > 0: @@ -106,10 +108,10 @@ class GroupViewSet(APIBackedViewSet): sharing_client.remove_group_admins(group.id, group._removed_admins) sharing_client.update_group(group) - def perform_destroy(self, group): + def perform_destroy(self, group: Any) -> None: self.request.airavata_client.sharing.delete_group(group.id, group.ownerId) - def _send_users_added_to_group(self, internal_user_ids, group): + def _send_users_added_to_group(self, internal_user_ids: set[str], group: Any) -> None: for internal_user_id in internal_user_ids: user_id, gateway_id = internal_user_id.rsplit("@", maxsplit=1) user_profile = self.request.airavata_client.iam.get_user_profile_by_id(user_id, gateway_id) @@ -124,44 +126,44 @@ class ProjectViewSet(APIBackedViewSet): pagination_class = APIResultPagination pagination_viewname = "django_airavata_api:project-list" - def get_list(self): + def get_list(self) -> APIResultIterator: view = self class ProjectResultIterator(APIResultIterator): - def get_results(self, limit=-1, offset=0): + def get_results(self, limit: int = -1, offset: int = 0) -> list[Any]: return view.request.airavata_client.research.get_user_projects( view.gateway_id, view.username, limit, offset ) return ProjectResultIterator() - def get_instance(self, lookup_value): + def get_instance(self, lookup_value: str) -> Any: return self.request.airavata_client.research.get_project(lookup_value) - def perform_create(self, serializer): + def perform_create(self, serializer: Any) -> None: project = serializer.save(owner=self.username, gatewayId=self.gateway_id) project_id = self.request.airavata_client.research.create_project(self.gateway_id, project) project.projectID = project_id self._update_most_recent_project(project_id) - def perform_update(self, serializer): + def perform_update(self, serializer: Any) -> None: project = serializer.save() self.request.airavata_client.research.update_project(project.projectID, project) self._update_most_recent_project(project.projectID) @action(detail=False) - def list_all(self, request): + def list_all(self, request: Request) -> Response: projects = self.request.airavata_client.research.get_user_projects(self.gateway_id, self.username, -1, 0) serializer = serializers.ProjectSerializer(projects, many=True, context={"request": request}) return Response(serializer.data) @action(detail=True) - def experiments(self, request, project_id=None): + def experiments(self, request: Request, project_id: str | None = None) -> Response: experiments = request.airavata_client.research.get_experiments_in_project(project_id, -1, 0) serializer = serializers.ExperimentSerializer(experiments, many=True, context={"request": request}) return Response(serializer.data) - def _update_most_recent_project(self, project_id): + def _update_most_recent_project(self, project_id: str) -> None: prefs = helpers.WorkspacePreferencesHelper().get(self.request) prefs.most_recent_project_id = project_id prefs.save() @@ -173,10 +175,10 @@ class ExperimentViewSet( serializer_class = serializers.ExperimentSerializer lookup_field = "experiment_id" - def get_instance(self, lookup_value): + def get_instance(self, lookup_value: str) -> Any: return self.request.airavata_client.research.get_experiment(lookup_value) - def perform_create(self, serializer): + def perform_create(self, serializer: Any) -> None: experiment = serializer.save(gatewayId=self.gateway_id, userName=self.username) experiment_id = self.request.airavata_client.research.create_experiment(self.gateway_id, experiment) self._update_workspace_preferences( @@ -186,7 +188,7 @@ class ExperimentViewSet( ) experiment.experimentId = experiment_id - def perform_update(self, serializer): + def perform_update(self, serializer: Any) -> None: experiment = serializer.save(gatewayId=self.gateway_id, userName=self.username) self.request.airavata_client.research.update_experiment(experiment.experimentId, experiment) self._update_workspace_preferences( @@ -196,7 +198,7 @@ class ExperimentViewSet( ) @action(methods=["post"], detail=True) - def launch(self, request, experiment_id=None): + def launch(self, request: Request, experiment_id: str | None = None) -> Response: try: experiment = request.airavata_client.research.get_experiment(experiment_id) if experiment.enableEmailNotification: @@ -209,20 +211,20 @@ class ExperimentViewSet( return Response({"success": False, "errorMessage": str(e)}) @action(methods=["get"], detail=True) - def jobs(self, request, experiment_id=None): + def jobs(self, request: Request, experiment_id: str | None = None) -> Response: jobs = request.airavata_client.research.get_job_details(experiment_id) serializer = serializers.JobSerializer(jobs, many=True, context={"request": request}) return Response(serializer.data) @action(methods=["post"], detail=True) - def clone(self, request, experiment_id=None): + def clone(self, request: Request, experiment_id: str | None = None) -> Response: cloned_experiment_id = request.airavata_client.research.clone_experiment(experiment_id) cloned_experiment = request.airavata_client.research.get_experiment(cloned_experiment_id) serializer = self.serializer_class(cloned_experiment, context={"request": request}) return Response(serializer.data) @action(methods=["post"], detail=True) - def cancel(self, request, experiment_id=None): + def cancel(self, request: Request, experiment_id: str | None = None) -> Response: try: request.airavata_client.research.terminate_experiment(experiment_id, self.gateway_id) return Response({"success": True}) @@ -231,7 +233,7 @@ class ExperimentViewSet( raise e @action(methods=["post"], detail=True) - def fetch_intermediate_outputs(self, request, experiment_id=None): + def fetch_intermediate_outputs(self, request: Request, experiment_id: str | None = None) -> Response: if "outputNames" not in request.data: return Response(status=status.HTTP_400_BAD_REQUEST) try: @@ -243,7 +245,7 @@ class ExperimentViewSet( log.exception("fetchIntermediateOutputs failed with the following error", extra={"request": request}) raise e - def _update_workspace_preferences(self, project_id, group_resource_profile_id, compute_resource_id): + def _update_workspace_preferences(self, project_id: str, group_resource_profile_id: str, compute_resource_id: str) -> None: prefs = helpers.WorkspacePreferencesHelper().get(self.request) prefs.most_recent_project_id = project_id prefs.most_recent_group_resource_profile_id = group_resource_profile_id @@ -256,10 +258,10 @@ class ExperimentSearchViewSet(mixins.ListModelMixin, GenericAPIBackedViewSet): pagination_class = APIResultPagination pagination_viewname = "django_airavata_api:experiment-search-list" - def get_list(self): + def get_list(self) -> APIResultIterator: view = self - filters = {} + filters: dict[Any, str] = {} for filter_item in self.request.query_params.items(): if filter_item[0] in ExperimentSearchFields.__members__: # Lookup enum value for this ExperimentSearchFields @@ -267,7 +269,7 @@ class ExperimentSearchViewSet(mixins.ListModelMixin, GenericAPIBackedViewSet): filters[search_field] = filter_item[1] class ExperimentSearchResultIterator(APIResultIterator): - def get_results(self, limit=-1, offset=0): + def get_results(self, limit: int = -1, offset: int = 0) -> list[Any]: return view.request.airavata_client.research.search_experiments( view.gateway_id, view.username, filters, limit, offset ) @@ -275,7 +277,7 @@ class ExperimentSearchViewSet(mixins.ListModelMixin, GenericAPIBackedViewSet): # Preserve query parameters when moving to next and previous links return ExperimentSearchResultIterator(query_params=self.request.query_params.copy()) - def get_instance(self, lookup_value): + def get_instance(self, lookup_value: str) -> Any: raise NotImplementedError() @@ -283,7 +285,7 @@ class FullExperimentViewSet(mixins.RetrieveModelMixin, GenericAPIBackedViewSet): serializer_class = serializers.FullExperimentSerializer lookup_field = "experiment_id" - def get_instance(self, lookup_value): + def get_instance(self, lookup_value: str) -> serializers.FullExperiment: """Get FullExperiment instance with resolved references.""" # TODO: move loading experiment and references to airavata_sdk? experimentModel = self.request.airavata_client.research.get_experiment(lookup_value) @@ -373,26 +375,26 @@ class ApplicationModuleViewSet(APIBackedViewSet): serializer_class = serializers.ApplicationModuleSerializer lookup_field = "app_module_id" - def get_list(self): + def get_list(self) -> list[Any]: return self.request.airavata_client.research.get_accessible_app_modules(self.gateway_id) - def get_instance(self, lookup_value): + def get_instance(self, lookup_value: str) -> Any: return self.request.airavata_client.research.get_application_module(lookup_value) - def perform_create(self, serializer): + def perform_create(self, serializer: Any) -> None: app_module = serializer.save() app_module_id = self.request.airavata_client.research.register_application_module(self.gateway_id, app_module) app_module.appModuleId = app_module_id - def perform_update(self, serializer): + def perform_update(self, serializer: Any) -> None: app_module = serializer.save() self.request.airavata_client.research.update_application_module(app_module.appModuleId, app_module) - def perform_destroy(self, instance): + def perform_destroy(self, instance: Any) -> None: self.request.airavata_client.research.delete_application_module(instance.appModuleId) @action(detail=True) - def application_interface(self, request, app_module_id): + def application_interface(self, request: Request, app_module_id: str) -> Response: all_app_interfaces = request.airavata_client.research.get_all_application_interfaces(self.gateway_id) app_interfaces = [] for app_interface in all_app_interfaces: @@ -415,7 +417,7 @@ class ApplicationModuleViewSet(APIBackedViewSet): raise Http404(f"No application interface found for module id {app_module_id}") @action(detail=True) - def application_deployments(self, request, app_module_id): + def application_deployments(self, request: Request, app_module_id: str) -> Response: all_deployments = self.request.airavata_client.research.get_all_application_deployments(self.gateway_id) app_deployments = [dep for dep in all_deployments if dep.appModuleId == app_module_id] serializer = serializers.ApplicationDeploymentDescriptionSerializer( @@ -424,7 +426,7 @@ class ApplicationModuleViewSet(APIBackedViewSet): return Response(serializer.data) @action(methods=["post"], detail=True) - def favorite(self, request, app_module_id): + def favorite(self, request: Request, app_module_id: str) -> HttpResponse: helper = helpers.WorkspacePreferencesHelper() workspace_preferences = helper.get(request) try: @@ -439,7 +441,7 @@ class ApplicationModuleViewSet(APIBackedViewSet): return HttpResponse(status=204) @action(methods=["post"], detail=True) - def unfavorite(self, request, app_module_id): + def unfavorite(self, request: Request, app_module_id: str) -> HttpResponse: helper = helpers.WorkspacePreferencesHelper() workspace_preferences = helper.get(request) try: @@ -454,7 +456,7 @@ class ApplicationModuleViewSet(APIBackedViewSet): return HttpResponse(status=204) @action(detail=False) - def list_all(self, request, format=None): + def list_all(self, request: Request, format: str | None = None) -> Response: all_modules = self.request.airavata_client.research.get_all_app_modules(self.gateway_id) serializer = self.serializer_class(all_modules, many=True, context={"request": request}) return Response(serializer.data) @@ -464,10 +466,10 @@ class ApplicationInterfaceViewSet(APIBackedViewSet): serializer_class = serializers.ApplicationInterfaceDescriptionSerializer lookup_field = "app_interface_id" - def get_list(self): + def get_list(self) -> list[Any]: return self.request.airavata_client.research.get_all_application_interfaces(self.gateway_id) - def get_instance(self, lookup_value): + def get_instance(self, lookup_value: str) -> Any: try: return self.request.airavata_client.research.get_application_interface(lookup_value) except Exception: @@ -479,7 +481,7 @@ class ApplicationInterfaceViewSet(APIBackedViewSet): else: raise # re-raise - def perform_create(self, serializer): + def perform_create(self, serializer: Any) -> None: application_interface = serializer.save() self._update_input_metadata(application_interface) log.debug(f"application_interface: {application_interface}") @@ -488,17 +490,17 @@ class ApplicationInterfaceViewSet(APIBackedViewSet): ) application_interface.applicationInterfaceId = app_interface_id - def perform_update(self, serializer): + def perform_update(self, serializer: Any) -> None: application_interface = serializer.save() self._update_input_metadata(application_interface) self.request.airavata_client.research.update_application_interface( application_interface.applicationInterfaceId, application_interface ) - def perform_destroy(self, instance): + def perform_destroy(self, instance: Any) -> None: self.request.airavata_client.research.delete_application_interface(instance.applicationInterfaceId) - def _update_input_metadata(self, app_interface): + def _update_input_metadata(self, app_interface: Any) -> None: for app_input in app_interface.applicationInputs: if app_input.metaData: metadata = json.loads(app_input.metaData) @@ -516,7 +518,7 @@ class ApplicationInterfaceViewSet(APIBackedViewSet): app_input.metaData = json.dumps(metadata) @action(detail=True) - def compute_resources(self, request, app_interface_id): + def compute_resources(self, request: Request, app_interface_id: str) -> Response: compute_resources = request.airavata_client.research.get_available_app_interface_compute_resources( app_interface_id ) @@ -527,7 +529,7 @@ class ApplicationDeploymentViewSet(APIBackedViewSet): serializer_class = serializers.ApplicationDeploymentDescriptionSerializer lookup_field = "app_deployment_id" - def get_list(self): + def get_list(self) -> list[Any]: app_module_id = self.request.query_params.get("appModuleId", None) group_resource_profile_id = self.request.query_params.get("groupResourceProfileId", None) if (app_module_id and not group_resource_profile_id) or (not app_module_id and group_resource_profile_id): @@ -541,27 +543,27 @@ class ApplicationDeploymentViewSet(APIBackedViewSet): self.gateway_id, ResourcePermissionType.READ ) - def get_instance(self, lookup_value): + def get_instance(self, lookup_value: str) -> Any: return self.request.airavata_client.research.get_application_deployment(lookup_value) - def perform_create(self, serializer): + def perform_create(self, serializer: Any) -> None: application_deployment = serializer.save() app_deployment_id = self.request.airavata_client.research.register_application_deployment( self.gateway_id, application_deployment ) application_deployment.appDeploymentId = app_deployment_id - def perform_update(self, serializer): + def perform_update(self, serializer: Any) -> None: application_deployment = serializer.save() self.request.airavata_client.research.update_application_deployment( application_deployment.appDeploymentId, application_deployment ) - def perform_destroy(self, instance): + def perform_destroy(self, instance: Any) -> None: self.request.airavata_client.research.delete_application_deployment(instance.appDeploymentId) @action(detail=True) - def queues(self, request, app_deployment_id): + def queues(self, request: Request, app_deployment_id: str) -> Response: """Return queues for this deployment with defaults overridden by deployment defaults if they exist""" app_deployment = self.request.airavata_client.research.get_application_deployment(app_deployment_id) compute_resource = request.airavata_client.compute.get_compute_resource(app_deployment.computeHostId) @@ -585,16 +587,16 @@ class ComputeResourceViewSet(mixins.RetrieveModelMixin, GenericAPIBackedViewSet) serializer_class = serializers.ComputeResourceDescriptionSerializer lookup_field = "compute_resource_id" - def get_instance(self, lookup_value, format=None): + def get_instance(self, lookup_value: str, format: str | None = None) -> Any: return self.request.airavata_client.compute.get_compute_resource(lookup_value) @action(detail=False) - def all_names(self, request, format=None): + def all_names(self, request: Request, format: str | None = None) -> Response: """Return a map of compute resource names keyed by resource id.""" return Response(request.airavata_client.compute.get_all_compute_resource_names()) @action(detail=False) - def all_names_list(self, request, format=None): + def all_names_list(self, request: Request, format: str | None = None) -> Response: """Return a list of compute resource names keyed by resource id.""" all_names = request.airavata_client.compute.get_all_compute_resource_names() return Response( @@ -611,7 +613,7 @@ class ComputeResourceViewSet(mixins.RetrieveModelMixin, GenericAPIBackedViewSet) ) @action(detail=True) - def queues(self, request, compute_resource_id, format=None): + def queues(self, request: Request, compute_resource_id: str, format: str | None = None) -> Response: details = request.airavata_client.compute.get_compute_resource(compute_resource_id) serializer = self.serializer_class(instance=details, context={"request": request}) data = serializer.data @@ -621,7 +623,7 @@ class ComputeResourceViewSet(mixins.RetrieveModelMixin, GenericAPIBackedViewSet) class LocalJobSubmissionView(APIView): renderer_classes = (JSONRenderer,) - def get(self, request, format=None): + def get(self, request: Request, format: str | None = None) -> Response: job_submission_id = request.query_params["id"] local_job_submission = request.airavata_client.compute.get_local_job_submission(job_submission_id) from . import proto_utils @@ -632,7 +634,7 @@ class LocalJobSubmissionView(APIView): class CloudJobSubmissionView(APIView): renderer_classes = (JSONRenderer,) - def get(self, request, format=None): + def get(self, request: Request, format: str | None = None) -> Response: job_submission_id = request.query_params["id"] job_submission = request.airavata_client.compute.get_cloud_job_submission(job_submission_id) from . import proto_utils @@ -643,7 +645,7 @@ class CloudJobSubmissionView(APIView): class GlobusJobSubmissionView(APIView): renderer_classes = (JSONRenderer,) - def get(self, request, format=None): + def get(self, request: Request, format: str | None = None) -> Response: job_submission_id = request.query_params["id"] job_submission = request.airavata_client.compute.get_globus_job_submission(job_submission_id) from . import proto_utils @@ -654,7 +656,7 @@ class GlobusJobSubmissionView(APIView): class SshJobSubmissionView(APIView): renderer_classes = (JSONRenderer,) - def get(self, request, format=None): + def get(self, request: Request, format: str | None = None) -> Response: job_submission_id = request.query_params["id"] job_submission = request.airavata_client.compute.get_ssh_job_submission(job_submission_id) from . import proto_utils @@ -665,7 +667,7 @@ class SshJobSubmissionView(APIView): class UnicoreJobSubmissionView(APIView): renderer_classes = (JSONRenderer,) - def get(self, request, format=None): + def get(self, request: Request, format: str | None = None) -> Response: job_submission_id = request.query_params["id"] job_submission = request.airavata_client.compute.get_unicore_job_submission(job_submission_id) from . import proto_utils @@ -676,7 +678,7 @@ class UnicoreJobSubmissionView(APIView): class GridFtpDataMovementView(APIView): renderer_classes = (JSONRenderer,) - def get(self, request, format=None): + def get(self, request: Request, format: str | None = None) -> Response: data_movement_id = request.query_params["id"] data_movement = request.airavata_client.compute.get_grid_ftp_data_movement(data_movement_id) from . import proto_utils @@ -687,7 +689,7 @@ class GridFtpDataMovementView(APIView): class ScpDataMovementView(APIView): renderer_classes = (JSONRenderer,) - def get(self, request, format=None): + def get(self, request: Request, format: str | None = None) -> Response: data_movement_id = request.query_params["id"] data_movement = request.airavata_client.compute.get_scp_data_movement(data_movement_id) from . import proto_utils @@ -698,7 +700,7 @@ class ScpDataMovementView(APIView): class UnicoreDataMovementView(APIView): renderer_classes = (JSONRenderer,) - def get(self, request, format=None): + def get(self, request: Request, format: str | None = None) -> Response: data_movement_id = request.query_params["id"] data_movement = request.airavata_client.compute.get_unicore_data_movement(data_movement_id) from . import proto_utils @@ -709,7 +711,7 @@ class UnicoreDataMovementView(APIView): class LocalDataMovementView(APIView): renderer_classes = (JSONRenderer,) - def get(self, request, format=None): + def get(self, request: Request, format: str | None = None) -> Response: data_movement_id = request.query_params["id"] data_movement = request.airavata_client.compute.get_local_data_movement(data_movement_id) from . import proto_utils @@ -721,13 +723,13 @@ class DataProductView(APIView): serializer_class = serializers.DataProductSerializer permission_classes = [IsAuthenticated, DataProductSharedDirPermission] - def get(self, request, format=None): + def get(self, request: Request, format: str | None = None) -> Response: data_product_uri = request.query_params["product-uri"] data_product = request.airavata_client.research.get_data_product(data_product_uri) serializer = self.serializer_class(data_product, context={"request": request}) return Response(serializer.data) - def put(self, request, format=None): + def put(self, request: Request, format: str | None = None) -> Response: data_product_uri = request.query_params["product-uri"] data_product = request.airavata_client.research.get_data_product(data_product_uri) if request.data and "fileContentText" in request.data: @@ -740,7 +742,7 @@ class DataProductView(APIView): @api_view(http_method_names=["POST"]) -def upload_input_file(request): +def upload_input_file(request: Request) -> JsonResponse: try: input_file = request.FILES["file"] data_product = user_storage.save_input_file(request, input_file, content_type=input_file.content_type) @@ -754,7 +756,7 @@ def upload_input_file(request): @api_view(http_method_names=["POST"]) -def tus_upload_finish(request): +def tus_upload_finish(request: Request) -> JsonResponse: uploadURL = request.POST["uploadURL"] def save_upload(file_path, file_name, file_type): @@ -771,7 +773,7 @@ def tus_upload_finish(request): @gzip_page @api_view() -def download_file(request): +def download_file(request: Request) -> HttpResponse: # TODO: remove this deprecated view warnings.warn("download_file view has moved to SDK", DeprecationWarning) # redirect to /sdk/download @@ -781,7 +783,7 @@ def download_file(request): @api_view(http_method_names=["DELETE"]) @permission_classes([IsAuthenticated, DataProductSharedDirPermission]) -def delete_file(request): +def delete_file(request: Request) -> HttpResponse: # TODO check that user has write access to this file using sharing API data_product_uri = request.GET.get("data-product-uri", "") data_product = None @@ -802,10 +804,10 @@ def delete_file(request): class UserProfileViewSet(mixins.RetrieveModelMixin, mixins.ListModelMixin, GenericAPIBackedViewSet): serializer_class = serializers.UserProfileSerializer - def get_list(self): + def get_list(self) -> list[Any]: return self.request.airavata_client.iam.get_all_user_profiles_in_gateway(self.gateway_id, 0, -1) - def get_instance(self, lookup_value): + def get_instance(self, lookup_value: str) -> Any: return self.request.airavata_client.iam.get_user_profile_by_id(self.request.user.username, self.gateway_id) @@ -813,13 +815,13 @@ class GroupResourceProfileViewSet(APIBackedViewSet): serializer_class = serializers.GroupResourceProfileSerializer lookup_field = "group_resource_profile_id" - def get_list(self): + def get_list(self) -> list[Any]: return self.request.airavata_client.compute.get_group_resource_list(self.gateway_id) - def get_instance(self, lookup_value): + def get_instance(self, lookup_value: str) -> Any: return self.request.airavata_client.compute.get_group_resource_profile(lookup_value) - def perform_create(self, serializer): + def perform_create(self, serializer: Any) -> None: group_resource_profile = serializer.save() group_resource_profile.gatewayId = self.gateway_id group_resource_profile_id = self.request.airavata_client.compute.create_group_resource_profile( @@ -832,7 +834,7 @@ class GroupResourceProfileViewSet(APIBackedViewSet): ) group_resource_profile.creationTime = new_group_resource_profile.creationTime - def perform_update(self, serializer): + def perform_update(self, serializer: Any) -> None: original_instance = serializer.instance grp = serializer.save() @@ -1019,7 +1021,7 @@ class GroupResourceProfileViewSet(APIBackedViewSet): self.request.airavata_client.compute.update_group_resource_profile(grp) - def perform_destroy(self, instance): + def perform_destroy(self, instance: Any) -> None: self.request.airavata_client.compute.remove_group_resource_profile(instance.groupResourceProfileId) @@ -1027,7 +1029,7 @@ class SharedEntityViewSet(mixins.RetrieveModelMixin, mixins.UpdateModelMixin, Ge serializer_class = serializers.SharedEntitySerializer lookup_field = "entity_id" - def get_instance(self, lookup_value): + def get_instance(self, lookup_value: str) -> dict[str, Any]: users = {} # Only load *directly* granted permissions since these are the only # ones that can be edited @@ -1061,30 +1063,30 @@ class SharedEntityViewSet(mixins.RetrieveModelMixin, mixins.UpdateModelMixin, Ge "owner": self._load_user_profile(owner_id), } - def _load_accessible_users(self, entity_id, permission_type): + def _load_accessible_users(self, entity_id: str, permission_type: Any) -> dict[str, Any]: users = self.request.airavata_client.sharing.get_all_accessible_users(entity_id, permission_type) return {user_id: permission_type for user_id in users} - def _load_directly_accessible_users(self, entity_id, permission_type): + def _load_directly_accessible_users(self, entity_id: str, permission_type: Any) -> dict[str, Any]: users = self.request.airavata_client.sharing.get_all_directly_accessible_users(entity_id, permission_type) return {user_id: permission_type for user_id in users} - def _load_user_profile(self, user_id): + def _load_user_profile(self, user_id: str) -> Any: username = user_id[0 : user_id.rindex("@")] return self.request.airavata_client.iam.get_user_profile_by_id(username, settings.GATEWAY_ID) - def _load_accessible_groups(self, entity_id, permission_type): + def _load_accessible_groups(self, entity_id: str, permission_type: Any) -> dict[str, Any]: groups = self.request.airavata_client.sharing.get_all_accessible_groups(entity_id, permission_type) return {group_id: permission_type for group_id in groups} - def _load_directly_accessible_groups(self, entity_id, permission_type): + def _load_directly_accessible_groups(self, entity_id: str, permission_type: Any) -> dict[str, Any]: groups = self.request.airavata_client.sharing.get_all_directly_accessible_groups(entity_id, permission_type) return {group_id: permission_type for group_id in groups} - def _load_group(self, group_id): + def _load_group(self, group_id: str) -> Any: return self.request.airavata_client.sharing.get_group(group_id) - def perform_update(self, serializer): + def perform_update(self, serializer: Any) -> None: shared_entity = serializer.save() entity_id = shared_entity["entityId"] if len(shared_entity["_user_grant_read_permission"]) > 0: @@ -1140,28 +1142,28 @@ class SharedEntityViewSet(mixins.RetrieveModelMixin, mixins.UpdateModelMixin, Ge shared_entity["_group_revoke_manage_sharing_permission"], ) - def _share_with_users(self, entity_id, permission_type, user_ids): + def _share_with_users(self, entity_id: str, permission_type: Any, user_ids: list[str]) -> None: self.request.airavata_client.sharing.share_resource_with_users( entity_id, {user_id: permission_type for user_id in user_ids} ) - def _revoke_from_users(self, entity_id, permission_type, user_ids): + def _revoke_from_users(self, entity_id: str, permission_type: Any, user_ids: list[str]) -> None: self.request.airavata_client.sharing.revoke_sharing_of_resource_from_users( entity_id, {user_id: permission_type for user_id in user_ids} ) - def _share_with_groups(self, entity_id, permission_type, group_ids): + def _share_with_groups(self, entity_id: str, permission_type: Any, group_ids: list[str]) -> None: self.request.airavata_client.sharing.share_resource_with_groups( entity_id, {group_id: permission_type for group_id in group_ids} ) - def _revoke_from_groups(self, entity_id, permission_type, group_ids): + def _revoke_from_groups(self, entity_id: str, permission_type: Any, group_ids: list[str]) -> None: self.request.airavata_client.sharing.revoke_sharing_of_resource_from_groups( entity_id, {group_id: permission_type for group_id in group_ids} ) @action(methods=["put"], detail=True) - def merge(self, request, entity_id=None): + def merge(self, request: Request, entity_id: str | None = None) -> Response: # Validate updated sharing settings updated = self.get_serializer(data=request.data) updated.is_valid(raise_exception=True) @@ -1181,7 +1183,7 @@ class SharedEntityViewSet(mixins.RetrieveModelMixin, mixins.UpdateModelMixin, Ge return Response(merged_serializer.data) @action(methods=["get"], detail=True) - def all(self, request, entity_id=None): + def all(self, request: Request, entity_id: str | None = None) -> Response: """Load direct plus indirectly (inherited) shared permissions.""" users = {} # Load accessible users in order of permission precedence: users that @@ -1220,28 +1222,28 @@ class SharedEntityViewSet(mixins.RetrieveModelMixin, mixins.UpdateModelMixin, Ge class CredentialSummaryViewSet(APIBackedViewSet): serializer_class = serializers.CredentialSummarySerializer - def get_list(self): + def get_list(self) -> list[Any]: ssh_creds = self.request.airavata_client.credential.get_all_credential_summaries(SummaryType.SSH) pwd_creds = self.request.airavata_client.credential.get_all_credential_summaries(SummaryType.PASSWD) return ssh_creds + pwd_creds - def get_instance(self, lookup_value): + def get_instance(self, lookup_value: str) -> Any: return self.request.airavata_client.credential.get_credential_summary(lookup_value) @action(detail=False) - def ssh(self, request): + def ssh(self, request: Request) -> Response: summaries = self.request.airavata_client.credential.get_all_credential_summaries(SummaryType.SSH) serializer = self.get_serializer(summaries, many=True) return Response(serializer.data) @action(detail=False) - def password(self, request): + def password(self, request: Request) -> Response: summaries = self.request.airavata_client.credential.get_all_credential_summaries(SummaryType.PASSWD) serializer = self.get_serializer(summaries, many=True) return Response(serializer.data) @action(methods=["post"], detail=False) - def create_ssh(self, request): + def create_ssh(self, request: Request) -> Response: if "description" not in request.data: raise ParseError("'description' is required in request") description = request.data.get("description") @@ -1251,7 +1253,7 @@ class CredentialSummaryViewSet(APIBackedViewSet): return Response(serializer.data) @action(methods=["post"], detail=False) - def create_password(self, request): + def create_password(self, request: Request) -> Response: if "username" not in request.data or "password" not in request.data or "description" not in request.data: raise ParseError("'username', 'password' and 'description' are all required in request") username = request.data.get("username") @@ -1262,7 +1264,7 @@ class CredentialSummaryViewSet(APIBackedViewSet): serializer = self.get_serializer(credential_summary) return Response(serializer.data) - def perform_destroy(self, instance): + def perform_destroy(self, instance: Any) -> None: if instance.type == SummaryType.SSH: self.request.airavata_client.credential.delete_ssh_pub_key(instance.token) elif instance.type == SummaryType.PASSWD: @@ -1270,14 +1272,14 @@ class CredentialSummaryViewSet(APIBackedViewSet): class CurrentGatewayResourceProfile(APIView): - def get(self, request, format=None): + def get(self, request: Request, format: str | None = None) -> Response: gateway_resource_profile = request.airavata_client.compute.get_gateway_resource_profile(settings.GATEWAY_ID) serializer = serializers.GatewayResourceProfileSerializer( gateway_resource_profile, context={"request": request} ) return Response(serializer.data) - def put(self, request, format=None): + def put(self, request: Request, format: str | None = None) -> Response: serializer = serializers.GatewayResourceProfileSerializer(data=request.data, context={"request": request}) if serializer.is_valid(): gateway_resource_profile = serializer.save() @@ -1290,7 +1292,7 @@ class CurrentGatewayResourceProfile(APIView): class ExperimentArchiveView(APIView): - def get(self, request, experiment_id=None, format=None): + def get(self, request: Request, experiment_id: str | None = None, format: str | None = None) -> Response: experiment: ExperimentModel = request.airavata_client.research.get_experiment(experiment_id) result = dict( archived=False, @@ -1314,11 +1316,11 @@ class StorageResourceViewSet(mixins.RetrieveModelMixin, GenericAPIBackedViewSet) serializer_class = serializers.StorageResourceSerializer lookup_field = "storage_resource_id" - def get_instance(self, lookup_value, format=None): + def get_instance(self, lookup_value: str, format: str | None = None) -> Any: return self.request.airavata_client.storage.get_storage_resource(lookup_value) @action(detail=False) - def all_names(self, request, format=None): + def all_names(self, request: Request, format: str | None = None) -> Response: """Return a map of compute resource names keyed by resource id.""" return Response(request.airavata_client.storage.get_all_storage_resource_names()) @@ -1327,25 +1329,25 @@ class StoragePreferenceViewSet(APIBackedViewSet): serializer_class = serializers.StoragePreferenceSerializer lookup_field = "storage_resource_id" - def get_list(self): + def get_list(self) -> list[Any]: return self.request.airavata_client.compute.get_all_gateway_storage_preferences(settings.GATEWAY_ID) - def get_instance(self, lookup_value): + def get_instance(self, lookup_value: str) -> Any: return self.request.airavata_client.compute.get_gateway_storage_preference(settings.GATEWAY_ID, lookup_value) - def perform_create(self, serializer): + def perform_create(self, serializer: Any) -> None: storage_preference = serializer.save() self.request.airavata_client.compute.add_gateway_storage_preference( settings.GATEWAY_ID, storage_preference.storageResourceId, storage_preference ) - def perform_update(self, serializer): + def perform_update(self, serializer: Any) -> None: storage_preference = serializer.save() self.request.airavata_client.compute.update_gateway_storage_preference( settings.GATEWAY_ID, storage_preference.storageResourceId, storage_preference ) - def perform_destroy(self, instance): + def perform_destroy(self, instance: Any) -> None: self.request.airavata_client.compute.delete_gateway_storage_preference( settings.GATEWAY_ID, instance.storageResourceId ) @@ -1361,17 +1363,17 @@ class ParserViewSet( serializer_class = serializers.ParserSerializer lookup_field = "parser_id" - def get_list(self): + def get_list(self) -> list[Any]: return self.request.airavata_client.research.list_all_parsers(settings.GATEWAY_ID) - def get_instance(self, lookup_value): + def get_instance(self, lookup_value: str) -> Any: return self.request.airavata_client.research.get_parser(lookup_value, settings.GATEWAY_ID) - def perform_create(self, serializer): + def perform_create(self, serializer: Any) -> None: parser = serializer.save() self.request.airavata_client.research.save_parser(parser) - def perform_update(self, serializer): + def perform_update(self, serializer: Any) -> None: parser = serializer.save() self.request.airavata_client.research.save_parser(parser) @@ -1380,13 +1382,13 @@ class UserStoragePathView(APIView): serializer_class = serializers.UserStoragePathSerializer permission_classes = (IsAuthenticated, UserStorageSharedDirPermission) - def get(self, request, path="/", format=None): + def get(self, request: Request, path: str = "/", format: str | None = None) -> Response: # AIRAVATA-3460 Allow passing path as a query parameter instead path = request.query_params.get("path", path) experiment_id = request.query_params.get("experiment-id") return self._create_response(request, path, experiment_id=experiment_id) - def post(self, request, path="/", format=None): + def post(self, request: Request, path: str = "/", format: str | None = None) -> Response: path = request.data.get("path", path) experiment_id = request.data.get("experiment-id") if not user_storage.dir_exists(request, path, experiment_id=experiment_id): @@ -1422,7 +1424,7 @@ class UserStoragePathView(APIView): return self._create_response(request, path, uploaded=data_product, experiment_id=experiment_id) # Accept wither to replace file or to replace file content text. - def put(self, request, path="/", format=None): + def put(self, request: Request, path: str = "/", format: str | None = None) -> Response: path = request.POST.get("path", path) # Replace the file if the request has a file upload. if "file" in request.FILES: @@ -1439,7 +1441,7 @@ class UserStoragePathView(APIView): return self._create_response(request=request, path=path) - def delete(self, request, path="/", format=None): + def delete(self, request: Request, path: str = "/", format: str | None = None) -> Response: path = request.data.get("path", path) experiment_id = request.data.get("experiment-id") if user_storage.dir_exists(request, path, experiment_id=experiment_id): @@ -1449,7 +1451,7 @@ class UserStoragePathView(APIView): return Response(status=204) - def _create_response(self, request, path, uploaded=None, experiment_id=None): + def _create_response(self, request: Request, path: str, uploaded: Any = None, experiment_id: str | None = None) -> Response: if user_storage.dir_exists(request, path, experiment_id=experiment_id): directories, files = user_storage.listdir(request, path, experiment_id=experiment_id) data = {"isDir": True, "directories": directories, "files": files} @@ -1468,7 +1470,7 @@ class UserStoragePathView(APIView): serializer = self.serializer_class(data, context={"request": request}) return Response(serializer.data) - def _split_path(self, path): + def _split_path(self, path: str) -> list[str]: head, tail = os.path.split(path) if head != path: return self._split_path(head) + [tail] @@ -1481,10 +1483,10 @@ class UserStoragePathView(APIView): class ExperimentStoragePathView(APIView): serializer_class = serializers.ExperimentStoragePathSerializer - def get(self, request, experiment_id=None, path="", format=None): + def get(self, request: Request, experiment_id: str | None = None, path: str = "", format: str | None = None) -> Response: return self._create_response(request, experiment_id, path) - def _create_response(self, request, experiment_id, path): + def _create_response(self, request: Request, experiment_id: str, path: str) -> Response: if user_storage.experiment_dir_exists(request, experiment_id, path): directories, files = user_storage.list_experiment_dir(request, experiment_id, path) @@ -1499,7 +1501,7 @@ class ExperimentStoragePathView(APIView): else: raise Http404(f"Path '{path}' does not exist for {experiment_id}") - def _split_path(self, path): + def _split_path(self, path: str) -> list[str]: head, tail = os.path.split(path) if head != "": return self._split_path(head) + [tail] @@ -1512,7 +1514,7 @@ class ExperimentStoragePathView(APIView): class WorkspacePreferencesView(APIView): serializer_class = serializers.WorkspacePreferencesSerializer - def get(self, request, format=None): + def get(self, request: Request, format: str | None = None) -> Response: helper = helpers.WorkspacePreferencesHelper() workspace_preferences = helper.get(request) serializer = self.serializer_class(workspace_preferences, context={"request": request}) @@ -1523,23 +1525,23 @@ class ManageNotificationViewSet(APIBackedViewSet): serializer_class = serializers.NotificationSerializer lookup_field = "notification_id" - def get_instance(self, lookup_value): + def get_instance(self, lookup_value: str) -> Any: return self.request.airavata_client.research.get_notification(settings.GATEWAY_ID, lookup_value) - def get_list(self): + def get_list(self) -> list[Any]: return self.request.airavata_client.research.get_all_notifications(self.gateway_id) - def perform_destroy(self, instance): + def perform_destroy(self, instance: Any) -> None: self.request.airavata_client.research.delete_notification(settings.GATEWAY_ID, instance.notificationId) - def perform_create(self, serializer): + def perform_create(self, serializer: Any) -> None: notification = serializer.save(gatewayId=self.gateway_id) notificationId = self.request.airavata_client.research.create_notification(notification) notification.notificationId = notificationId serializer.update_notification_extension(self.request, notification) - def perform_update(self, serializer): + def perform_update(self, serializer: Any) -> None: notification = serializer.save() self.request.airavata_client.research.update_notification(notification) @@ -1547,7 +1549,7 @@ class ManageNotificationViewSet(APIBackedViewSet): class AckNotificationViewSet(APIView): - def get(self, request, format=None): + def get(self, request: Request, format: str | None = None) -> HttpResponse: if "id" in request.GET: notification_id = request.GET["id"] try: @@ -1578,21 +1580,21 @@ class IAMUserViewSet( ) lookup_field = "user_id" - def get_list(self): + def get_list(self) -> APIResultIterator: search = self.request.GET.get("search", None) convert_user_profile = self._convert_user_profile class IAMUsersResultIterator(APIResultIterator): - def get_results(self, limit=-1, offset=0): + def get_results(self, limit: int = -1, offset: int = 0) -> Any: return map(convert_user_profile, iam_admin_client.get_users(offset, limit, search)) return IAMUsersResultIterator(query_params=self.request.query_params.copy()) - def get_instance(self, lookup_value): + def get_instance(self, lookup_value: str) -> dict[str, Any]: return self._convert_user_profile(iam_admin_client.get_user(lookup_value)) - def perform_update(self, serializer): + def perform_update(self, serializer: Any) -> None: managed_user_profile = serializer.save() sharing_client = self.request.airavata_client.sharing iam_client = self.request.airavata_client.iam @@ -1610,18 +1612,18 @@ class IAMUserViewSet( for group_id in managed_user_profile["_removed_group_ids"]: sharing_client.remove_users_from_group([user_id], group_id) - def perform_destroy(self, instance): + def perform_destroy(self, instance: dict[str, Any]) -> None: iam_admin_client.delete_user(instance["userId"]) @action(methods=["post"], detail=True) - def enable(self, request, user_id=None): + def enable(self, request: Request, user_id: str | None = None) -> Response: iam_admin_client.enable_user(user_id) instance = self.get_instance(user_id) serializer = self.serializer_class(instance=instance, context={"request": request}) return Response(serializer.data) @action(methods=["put"], detail=False) - def update_username(self, request): + def update_username(self, request: Request) -> Response: serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) old_username = serializer.validated_data["userId"] @@ -1640,7 +1642,7 @@ class IAMUserViewSet( serializer = self.serializer_class(instance=instance, context={"request": request}) return Response(serializer.data) - def _convert_user_profile(self, user_profile): + def _convert_user_profile(self, user_profile: Any) -> dict[str, Any]: iam_client = self.request.airavata_client.iam sharing_client = self.request.airavata_client.sharing airavata_user_profile_exists = iam_client.does_user_exist(user_profile.userId, self.gateway_id) @@ -1666,7 +1668,7 @@ class ExperimentStatisticsView(APIView): # TODO: restrict to only Admins or Read Only Admins group members serializer_class = serializers.ExperimentStatisticsSerializer - def get(self, request, format=None): + def get(self, request: Request, format: str | None = None) -> Response: if "fromTime" in request.GET: from_time = view_utils.convert_utc_iso8601_to_date(request.GET["fromTime"]).timestamp() * 1000 else: @@ -1709,23 +1711,23 @@ class UnverifiedEmailUserViewSet(mixins.ListModelMixin, mixins.RetrieveModelMixi ) lookup_field = "user_id" - def get_list(self): + def get_list(self) -> APIResultIterator: get_users = self._get_unverified_email_user_profiles class UnverifiedEmailUsersResultIterator(APIResultIterator): - def get_results(self, limit=-1, offset=0): + def get_results(self, limit: int = -1, offset: int = 0) -> list[dict[str, Any]]: return get_users(limit, offset) return UnverifiedEmailUsersResultIterator() - def get_instance(self, lookup_value): + def get_instance(self, lookup_value: str) -> dict[str, Any]: users = self._get_unverified_email_user_profiles(limit=1, username=lookup_value) if len(users) == 0: raise Http404(f"No unverified email record found for user {lookup_value}") else: return users[0] - def _get_unverified_email_user_profiles(self, limit=-1, offset=0, username=None): + def _get_unverified_email_user_profiles(self, limit: int = -1, offset: int = 0, username: str | None = None) -> list[dict[str, Any]]: unverified_emails = ( EmailVerification.objects.filter(verified=False).order_by("username").values("username").distinct() ) @@ -1766,7 +1768,7 @@ class UnverifiedEmailUserViewSet(mixins.ListModelMixin, mixins.RetrieveModelMixi class LogRecordConsumer(APIView): serializer_class = serializers.LogRecordSerializer - def post(self, request, format=None): + def post(self, request: Request, format: str | None = None) -> Response: serializer = self.serializer_class(data=request.data) serializer.is_valid(raise_exception=True) log_record = serializer.validated_data @@ -1786,7 +1788,7 @@ class LogRecordConsumer(APIView): class SettingsAPIView(APIView): serializer_class = serializers.SettingsSerializer - def get(self, request, format=None): + def get(self, request: Request, format: str | None = None) -> Response: data = { "fileUploadMaxFileSize": settings.FILE_UPLOAD_MAX_FILE_SIZE, "tusEndpoint": settings.TUS_ENDPOINT, @@ -1797,7 +1799,7 @@ class SettingsAPIView(APIView): class APIServerStatusCheckView(APIView): - def get(self, request, format=None): + def get(self, request: Request, format: str | None = None) -> Response: try: request.airavata_client.research.get_user_projects( settings.GATEWAY_ID, @@ -1813,7 +1815,7 @@ class APIServerStatusCheckView(APIView): @api_view() -def notebook_output_view(request): +def notebook_output_view(request: Request) -> HttpResponse: provider_id = request.GET["provider-id"] experiment_id = request.GET["experiment-id"] experiment_output_name = request.GET["experiment-output-name"] @@ -1822,13 +1824,13 @@ def notebook_output_view(request): @api_view() -def html_output_view(request): +def html_output_view(request: Request) -> JsonResponse: data = _generate_output_view_data(request) return JsonResponse(data) @api_view() -def image_output_view(request): +def image_output_view(request: Request) -> JsonResponse: data = _generate_output_view_data(request) # data should contain 'image' as a file-like object or raw bytes with the # file data and 'mime-type' with the images mimetype @@ -1837,12 +1839,12 @@ def image_output_view(request): @api_view() -def link_output_view(request): +def link_output_view(request: Request) -> JsonResponse: data = _generate_output_view_data(request) return JsonResponse(data) -def _generate_output_view_data(request): +def _generate_output_view_data(request: Request) -> dict[str, Any]: params = request.GET.copy() provider_id = params.pop("provider-id")[0] experiment_id = params.pop("experiment-id")[0] @@ -1856,10 +1858,10 @@ def _generate_output_view_data(request): class QueueSettingsCalculatorViewSet(mixins.ListModelMixin, mixins.RetrieveModelMixin, GenericAPIBackedViewSet): serializer_class = serializers.QueueSettingsCalculatorSerializer - def get_list(self): + def get_list(self) -> list[Any]: return queue_settings_calculators.get_all() - def get_instance(self, lookup_value): + def get_instance(self, lookup_value: str) -> Any: calcs = queue_settings_calculators.get_all() calc = [calc for calc in calcs if calc.id == lookup_value] if len(calc) == 0: @@ -1867,7 +1869,7 @@ class QueueSettingsCalculatorViewSet(mixins.ListModelMixin, mixins.RetrieveModel return calc[0] @action(methods=["post"], detail=True, serializer_class=serializers.ExperimentSerializer) - def calculate(self, request, pk=None): + def calculate(self, request: Request, pk: str | None = None) -> Response: serializer = self.get_serializer(data=request.data) result = {} diff --git a/airavata-django-portal/django_airavata/apps/auth/middleware.py b/airavata-django-portal/django_airavata/apps/auth/middleware.py index a81c39840..195090cf3 100644 --- a/airavata-django-portal/django_airavata/apps/auth/middleware.py +++ b/airavata-django-portal/django_airavata/apps/auth/middleware.py @@ -2,9 +2,12 @@ import copy import logging +from collections.abc import Callable +from typing import Any from django.conf import settings from django.contrib.auth import logout +from django.http import HttpRequest, HttpResponse from django.shortcuts import redirect from django.urls import reverse @@ -13,10 +16,10 @@ from . import utils log = logging.getLogger(__name__) -def authz_token_middleware(get_response): +def authz_token_middleware(get_response: Callable[[HttpRequest], HttpResponse]) -> Callable[[HttpRequest], HttpResponse]: """Automatically add the 'authz_token' to the request.""" - def middleware(request): + def middleware(request: HttpRequest) -> HttpResponse: authz_token = None if request.user.is_authenticated: @@ -33,7 +36,7 @@ def authz_token_middleware(get_response): return middleware -def set_admin_group_attributes(request, gateway_groups=None): +def set_admin_group_attributes(request: HttpRequest, gateway_groups: Any = None) -> None: """Set is_gateway_admin and is_read_only_gateway_admin request attrs.""" if gateway_groups is None: gateway_groups = request.airavata_client.iam.get_gateway_groups() @@ -52,10 +55,10 @@ def set_admin_group_attributes(request, gateway_groups=None): request.is_read_only_gateway_admin = read_only_admins_group_id in group_ids -def gateway_groups_middleware(get_response): +def gateway_groups_middleware(get_response: Callable[[HttpRequest], HttpResponse]) -> Callable[[HttpRequest], HttpResponse]: """Add 'is_gateway_admin' and 'is_read_only_gateway_admin' to request.""" - def middleware(request): + def middleware(request: HttpRequest) -> HttpResponse: request.is_gateway_admin = False request.is_read_only_gateway_admin = False @@ -93,10 +96,10 @@ def gateway_groups_middleware(get_response): return middleware -def user_profile_completeness_check(get_response): +def user_profile_completeness_check(get_response: Callable[[HttpRequest], HttpResponse]) -> Callable[[HttpRequest], HttpResponse]: """Check if user profile is complete and if not, redirect to user profile editor.""" - def middleware(request): + def middleware(request: HttpRequest) -> HttpResponse: if not request.user.is_authenticated: return get_response(request) diff --git a/airavata-django-portal/django_airavata/context_processors.py b/airavata-django-portal/django_airavata/context_processors.py index 9f9e9ab3f..fe0077520 100644 --- a/airavata-django-portal/django_airavata/context_processors.py +++ b/airavata-django-portal/django_airavata/context_processors.py @@ -3,10 +3,12 @@ import datetime import json import logging import re +from typing import Any from django.apps import apps from django.conf import settings from django.core.exceptions import ObjectDoesNotExist +from django.http import HttpRequest from django.urls import reverse from django_airavata.app_config import AiravataAppConfig @@ -15,7 +17,7 @@ from django_airavata.apps.api.models import User_Notifications logger = logging.getLogger(__name__) -def get_notifications(request): +def get_notifications(request: HttpRequest) -> dict[str, Any]: if request.user.is_authenticated and hasattr(request, "airavata_client"): unread_notifications = 0 try: @@ -24,7 +26,7 @@ def get_notifications(request): logger.warning("Failed to load notifications") notifications = [] current_time = datetime.datetime.utcnow() - valid_notifications = [] + valid_notifications: list[dict[str, Any]] = [] for notification in notifications: notification_data = notification.__dict__ expirationTime = datetime.datetime.fromtimestamp(notification.expirationTime / 1000) @@ -55,8 +57,8 @@ def get_notifications(request): return {"notifications": json.dumps([])} -def user_session_data(request): - data = {} +def user_session_data(request: HttpRequest) -> dict[str, str]: + data: dict[str, Any] = {} if request.user.is_authenticated: data["username"] = request.user.username data["airavataInternalUserId"] = request.user.username + "@" + settings.GATEWAY_ID @@ -65,9 +67,9 @@ def user_session_data(request): return {"user_session_data": json.dumps(data)} -def airavata_app_registry(request): +def airavata_app_registry(request: HttpRequest) -> dict[str, Any]: """Put airavata django apps into the context.""" - airavata_apps = [ + airavata_apps: list[AiravataAppConfig] = [ app for app in apps.get_app_configs() if isinstance(app, AiravataAppConfig) @@ -85,14 +87,14 @@ def airavata_app_registry(request): } -def _get_current_app(request, apps): +def _get_current_app(request: HttpRequest, apps: list[AiravataAppConfig]) -> AiravataAppConfig | None: current_app = [ app for app in apps if request.resolver_match and app.url_app_name == request.resolver_match.app_name ] return current_app[0] if len(current_app) > 0 else None -def _get_app_nav(request, current_app): +def _get_app_nav(request: HttpRequest, current_app: AiravataAppConfig) -> list[dict[str, Any]]: if hasattr(current_app, "nav"): # Copy and filter current_app's nav items nav = [item for item in copy.copy(current_app.nav) if "enabled" not in item or item["enabled"](request)] @@ -116,6 +118,6 @@ def _get_app_nav(request, current_app): return nav -def google_analytics_tracking_id(request): +def google_analytics_tracking_id(request: HttpRequest) -> dict[str, str | None]: """Put the Google Analytics tracking id into context.""" return {"ga_tracking_id": getattr(settings, "GOOGLE_ANALYTICS_TRACKING_ID", None)} diff --git a/airavata-django-portal/django_airavata/middleware.py b/airavata-django-portal/django_airavata/middleware.py index 9bfe37a14..c908ca184 100644 --- a/airavata-django-portal/django_airavata/middleware.py +++ b/airavata-django-portal/django_airavata/middleware.py @@ -1,6 +1,8 @@ import logging +from collections.abc import Callable from django.conf import settings +from django.http import HttpRequest, HttpResponse from django.shortcuts import render from .utils import create_airavata_client @@ -9,10 +11,10 @@ logger = logging.getLogger(__name__) class AiravataClientMiddleware: - def __init__(self, get_response): + def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None: self.get_response = get_response - def __call__(self, request): + def __call__(self, request: HttpRequest) -> HttpResponse: access_token = _get_access_token(request) gateway_id = settings.GATEWAY_ID request.airavata_client = create_airavata_client(access_token, gateway_id) @@ -25,7 +27,7 @@ class AiravataClientMiddleware: request.airavata_client.close() return response - def process_exception(self, request, exception): + def process_exception(self, request: HttpRequest, exception: Exception) -> HttpResponse | None: # Handle connection errors to the Airavata API server if isinstance(exception, ConnectionError): return render( @@ -40,7 +42,7 @@ class AiravataClientMiddleware: return None -def _get_access_token(request): +def _get_access_token(request: HttpRequest) -> str: """Extract access token from request auth or session.""" if hasattr(request, "auth") and request.auth is not None: return request.auth diff --git a/airavata-django-portal/django_airavata/utils.py b/airavata-django-portal/django_airavata/utils.py index 6a463e3d4..23b7bac21 100644 --- a/airavata-django-portal/django_airavata/utils.py +++ b/airavata-django-portal/django_airavata/utils.py @@ -6,7 +6,7 @@ from django.conf import settings log = logging.getLogger(__name__) -def create_airavata_client(access_token, gateway_id): +def create_airavata_client(access_token: str, gateway_id: str) -> AiravataClient: """Create an AiravataClient instance for the given auth token.""" return AiravataClient( host=settings.AIRAVATA_API_HOST,
