From 8e44c4501a2585d136e21376f73de8d3c8ec9161 Mon Sep 17 00:00:00 2001 From: straitz Date: Wed, 3 Jun 2026 05:40:35 +0900 Subject: [PATCH 1/2] migrated to modern-rest --- requirements.txt | 7 +- stratoflights/settings.py | 39 +- stratoflights/urls.py | 36 +- stratoflights_api/consumers.py | 35 +- stratoflights_api/custom_pagination.py | 17 - stratoflights_api/dtos.py | 333 ++++++++++++ stratoflights_api/pagination.py | 120 +++++ stratoflights_api/permissions.py | 16 - stratoflights_api/serializers.py | 185 ------- stratoflights_api/urls.py | 97 ++-- stratoflights_api/views.py | 701 ++++++++++++++++--------- 11 files changed, 1014 insertions(+), 572 deletions(-) delete mode 100644 stratoflights_api/custom_pagination.py create mode 100644 stratoflights_api/dtos.py create mode 100644 stratoflights_api/pagination.py delete mode 100644 stratoflights_api/permissions.py delete mode 100644 stratoflights_api/serializers.py diff --git a/requirements.txt b/requirements.txt index 2672235..f91db5a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,7 @@ -Django>=4.0,<5.0 -djangorestframework -djangorestframework-simplejwt +# django-modern-rest 0.6.0 requires Django 5.x (upgraded from 4.2). +Django>=5.0,<6.0 +django-modern-rest[jwt,pydantic]==0.6.0 psycopg2-binary -drf-spectacular requests django-cors-headers Pillow diff --git a/stratoflights/settings.py b/stratoflights/settings.py index bb5b63e..decb5ef 100644 --- a/stratoflights/settings.py +++ b/stratoflights/settings.py @@ -13,6 +13,12 @@ https://docs.djangoproject.com/en/4.2/ref/settings/ from pathlib import Path import os from dotenv import load_dotenv + +from dmr.settings import Settings +from dmr.security.django_session import DjangoSessionSyncAuth +from dmr.security.jwt import JWTSyncAuth +from dmr.openapi import OpenAPIConfig + load_dotenv() # Build paths inside the project like this: BASE_DIR / 'subdir'. @@ -51,12 +57,10 @@ INSTALLED_APPS = [ 'django.contrib.sessions', 'django.contrib.messages', 'django.contrib.staticfiles', - 'rest_framework', - 'rest_framework.authtoken', - 'drf_spectacular', 'corsheaders', 'stratoflights_api.apps.StratoflightsApiConfig', 'channels', + 'dmr', # required to serve OpenAPI docs static assets ] MIDDLEWARE = [ @@ -165,25 +169,16 @@ STATIC_URL = 'static/' DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' AUTH_USER_MODEL = 'stratoflights_api.User' -REST_FRAMEWORK = { - 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination', - 'PAGE_SIZE': 100, - - 'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema', - - 'DEFAULT_AUTHENTICATION_CLASSES': [ - 'rest_framework.authentication.TokenAuthentication', - 'rest_framework.authentication.SessionAuthentication', - ], - - 'DEFAULT_PERMISSION_CLASSES': [ - 'rest_framework.permissions.IsAuthenticated', - # 'rest_framework.permissions.AllowAny', - ], - - 'DEFAULT_RENDERER_CLASSES': [ - 'rest_framework.renderers.JSONRenderer', - ], +# django-modern-rest configuration. +# Global auth mirrors the previous DRF default (Token + Session); the DRF +# TokenAuthentication is replaced by DMR's JWT (approved drift). Providing +# several auth instances means at least one of them must succeed, so this +# also enforces a 401 by default like DRF's global IsAuthenticated did. +# Public endpoints opt out per-endpoint with @modify(auth=None). +DMR_SETTINGS = { + Settings.auth: [JWTSyncAuth(), DjangoSessionSyncAuth()], + Settings.validate_responses: not PRODUCTION, + Settings.openapi_config: OpenAPIConfig(title='Stratoflights API', version='1.0.0'), } diff --git a/stratoflights/urls.py b/stratoflights/urls.py index d97bf9b..0541cd2 100644 --- a/stratoflights/urls.py +++ b/stratoflights/urls.py @@ -1,27 +1,17 @@ -""" -URL configuration for stratoflights project. - -The `urlpatterns` list routes URLs to views. For more information please see: - https://docs.djangoproject.com/en/4.2/topics/http/urls/ -Examples: -Function views - 1. Add an import: from my_app import views - 2. Add a URL to urlpatterns: path('', views.home, name='home') -Class-based views - 1. Add an import: from other_app.views import Home - 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') -Including another URLconf - 1. Import the include() function: from django.urls import include, path - 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) -""" +"""URL configuration for the stratoflights project.""" from django.contrib import admin -from django.urls import path, include -from drf_spectacular.views import SpectacularAPIView -from drf_spectacular.views import SpectacularSwaggerView +from django.urls import include +from dmr.openapi import build_schema +from dmr.openapi.views import OpenAPIJsonView, SwaggerView +from dmr.routing import path + +from stratoflights_api.urls import router + +schema = build_schema(router) urlpatterns = [ path('admin/', admin.site.urls), - path('api/', include('stratoflights_api.urls')), - path('api/schema/', SpectacularAPIView.as_view(), name='schema'), - path('api/docs/', SpectacularSwaggerView.as_view(url_name='schema'), name='docs'), -] \ No newline at end of file + path(router.prefix, include(router.urls)), + path('api/schema/', OpenAPIJsonView.as_view(schema), name='schema'), + path('api/docs/', SwaggerView.as_view(schema), name='docs'), +] diff --git a/stratoflights_api/consumers.py b/stratoflights_api/consumers.py index 66dce9b..3c44e16 100644 --- a/stratoflights_api/consumers.py +++ b/stratoflights_api/consumers.py @@ -18,17 +18,19 @@ class TelemetryConsumer(AsyncWebsocketConsumer): async def receive(self, text_data): - from .serializers import TelemetryPacketSerializer + from pydantic import ValidationError + from .dtos import TelemetryIn if not self.write_enabled: await self.send(text_data=json.dumps({"error": "Read-only mode"})) return data = json.loads(text_data) - serializer = TelemetryPacketSerializer(data=data) - if not serializer.is_valid(): - await self.send(text_data=json.dumps({"error": serializer.errors})) + try: + TelemetryIn(**data) + except ValidationError as exc: + await self.send(text_data=json.dumps({"error": json.loads(exc.json())})) return saved_data = await self.save_telemetry(data) @@ -51,15 +53,6 @@ class TelemetryConsumer(AsyncWebsocketConsumer): async def telemetry_message(self, event): await self.send(text_data=json.dumps(event["data"])) - @database_sync_to_async - def get_user_from_token(self, token_key, Token): - from rest_framework.authtoken.models import Token - User = get_user_model() - try: - token = Token.objects.select_related("user").get(key=token_key) - return token.user - except Token.DoesNotExist: - return None @database_sync_to_async def save_telemetry(self, data): @@ -99,12 +92,20 @@ class StationTelemetryConsumer(TelemetryConsumer): write_enabled = True async def connect(self): - from rest_framework.authtoken.models import Token + from django.conf import settings + from dmr.security.jwt.token import JWToken + token_key = self.scope["query_string"].decode().split("token=")[-1] try: - token = await database_sync_to_async(Token.objects.select_related("user").get)(key=token_key) - self.scope["user"] = token.user - except Token.DoesNotExist: + decoded = JWToken.decode( + encoded_token=token_key, + secret=settings.SECRET_KEY, + algorithm='HS256', + ) + self.scope["user"] = await database_sync_to_async( + get_user_model().objects.get + )(pk=decoded.sub) + except Exception: await self.close() return diff --git a/stratoflights_api/custom_pagination.py b/stratoflights_api/custom_pagination.py deleted file mode 100644 index 6d37de8..0000000 --- a/stratoflights_api/custom_pagination.py +++ /dev/null @@ -1,17 +0,0 @@ -from rest_framework.pagination import LimitOffsetPagination -from rest_framework.response import Response - -class CustomLimitOffsetPagination(LimitOffsetPagination): - limit_query_param = 'limit' - offset_query_param = 'skip' - max_limit = 100 - default_limit = 10 - - - def get_paginated_response(self, data): - return Response({ - 'total': self.count, - 'limit': self.limit, - 'skip': self.offset, - 'predictions': data - }) \ No newline at end of file diff --git a/stratoflights_api/dtos.py b/stratoflights_api/dtos.py new file mode 100644 index 0000000..3178f93 --- /dev/null +++ b/stratoflights_api/dtos.py @@ -0,0 +1,333 @@ +"""Pydantic DTOs replacing the DRF serializers from serializers.py. + +Mapping is 1:1 with the former serializers. Output DTOs enable ``from_attributes`` +so an endpoint can build them straight from a Django model instance via +``SomeOut.model_validate(obj)`` (the equivalent of ``Serializer(obj).data``). + +Field-level and cross-field validation that DRF kept in ``validate_()`` / +``validate()`` is reproduced with pydantic ``field_validator`` / ``model_validator``. +Business logic DRF kept in ``create()`` (base64 curve decoding, ORM writes, and +resolving related objects from their PKs) is intentionally NOT here -- per DMR's +split it moves into the endpoints (Step 2.9). +""" +import uuid +from datetime import datetime +from typing import Any, Optional + +from django.contrib.auth.password_validation import validate_password +from django.core.validators import validate_email +from django.core.exceptions import ValidationError as DjangoValidationError +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_validator, +) + +from .validators import validate_custom_curve, rate_clip + + +# Profile constants (moved verbatim from serializers.py). +PROFILE_STANDARD = "standard_profile" +PROFILE_FLOAT = "float_profile" +PROFILE_REVERSE = "reverse_profile" +PROFILE_CUSTOM = "custom_profile" +LATEST_DATASET_KEYWORD = "latest" +SUPPORTED_PROFILES = [PROFILE_STANDARD, PROFILE_FLOAT, PROFILE_REVERSE, PROFILE_CUSTOM] + + +# --- Prediction request (was PredictionRequestSerializer) --------------------- +class PredictionRequest(BaseModel): + launch_latitude: float = Field(ge=-90, le=90) + launch_longitude: float = Field(ge=0, le=360) + launch_datetime: datetime + launch_altitude: Optional[float] = None + format: str = "json" + profile: str = PROFILE_STANDARD + dataset: str = LATEST_DATASET_KEYWORD + + # profile-dependent fields + ascent_rate: Optional[float] = Field(default=None, ge=0.01) + descent_rate: Optional[float] = Field(default=None, ge=0.01) + burst_altitude: Optional[float] = None + float_altitude: Optional[float] = None + stop_datetime: Optional[datetime] = None + ascent_curve: Optional[str] = None + descent_curve: Optional[str] = None + interpolate: bool = False + # Related objects are accepted as PKs; existence is resolved in the endpoint + # (was PrimaryKeyRelatedField(queryset=...)). + start_point: Optional[int] = None + rate_profile: Optional[int] = None + template: Optional[int] = None + + @field_validator("profile") + @classmethod + def _validate_profile(cls, value: str) -> str: + if value not in SUPPORTED_PROFILES: + raise ValueError(f'"{value}" is not a valid choice.') + return value + + @model_validator(mode="after") + def _validate_cross_fields(self) -> "PredictionRequest": + launch_alt = self.launch_altitude if self.launch_altitude is not None else 0 + + if self.profile == PROFILE_STANDARD: + if self.burst_altitude is None: + raise ValueError("burst_altitude is required for standard profile.") + if self.burst_altitude <= launch_alt: + raise ValueError("burst_altitude must be greater than launch_altitude.") + + elif self.profile == PROFILE_FLOAT: + if self.float_altitude is None or self.float_altitude <= launch_alt: + raise ValueError("float_altitude must be greater than launch_altitude.") + if self.stop_datetime is None or self.stop_datetime <= self.launch_datetime: + raise ValueError("stop_datetime must be later than launch_datetime.") + + elif self.profile == PROFILE_CUSTOM: + if self.ascent_curve is None or not validate_custom_curve(self.ascent_curve): + raise ValueError("Invalid ascent_curve.") + if self.descent_curve is None or not validate_custom_curve(self.descent_curve): + raise ValueError("Invalid descent_curve.") + if self.burst_altitude is None or self.burst_altitude <= launch_alt: + raise ValueError("burst_altitude must be greater than launch_altitude.") + + # custom clipping logic (was in validate()) + if self.ascent_rate is not None: + self.ascent_rate = rate_clip(self.ascent_rate) + if self.descent_rate is not None: + self.descent_rate = rate_clip(self.descent_rate) + + return self + + +# --- Prediction outputs ------------------------------------------------------- +class PredictionOut(BaseModel): + """Was PredictionSerializer.""" + + model_config = ConfigDict(from_attributes=True) + + id: uuid.UUID + created_at: datetime + updated_at: datetime + result: Any + + +class PredictionListOut(BaseModel): + """Was PredictionListSerializer.""" + + model_config = ConfigDict(from_attributes=True) + + id: uuid.UUID + created_at: datetime + updated_at: datetime + start_point: Optional[int] = Field(default=None, validation_alias="start_point_id") + template: Optional[int] = Field(default=None, validation_alias="template_id") + rate_profile: Optional[int] = Field(default=None, validation_alias="rate_profile_id") + + +class PredictionDetailOut(BaseModel): + """Was PredictionDetailSerializer.""" + + model_config = ConfigDict(from_attributes=True) + + id: uuid.UUID + created_at: datetime + updated_at: datetime + result: Any + start_point: Optional[int] = Field(default=None, validation_alias="start_point_id") + template: Optional[int] = Field(default=None, validation_alias="template_id") + rate_profile: Optional[int] = Field(default=None, validation_alias="rate_profile_id") + + +class PredictionCreateOut(BaseModel): + """Custom create() response shape: {id, created_at, result} (not the full model).""" + + id: uuid.UUID + created_at: datetime + result: Any + + +# --- Telemetry (was TelemetryPacketSerializer) -------------------------------- +class TelemetryIn(BaseModel): + # timestamp is required (model BigIntegerField has no default), matching the + # former ModelSerializer; the endpoint still overwrites it with server time. + timestamp: int + lat: float = 0.0 + lon: float = 0.0 + alt: float = 0.0 + payload: Any = Field(default_factory=dict) + + +class TelemetryOut(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: uuid.UUID + timestamp: int + lat: float + lon: float + alt: float + payload: Any + + +# --- SavedPoint (was SavedPointSerializer) ------------------------------------ +# `user` was a HiddenField(CurrentUserDefault) -> set from request in the endpoint, +# not part of the wire input/output. The unique_together (user, name) check that +# DRF did via UniqueTogetherValidator is reproduced in the endpoint (Step 2.8). +class SavedPointIn(BaseModel): + name: str + lat: float = 0.0 + lon: float = 0.0 + alt: float = 0.0 + + +class SavedPointPatchIn(BaseModel): + # All-optional variant for PATCH (partial_update). Only set fields apply. + name: Optional[str] = None + lat: Optional[float] = None + lon: Optional[float] = None + alt: Optional[float] = None + + +class SavedPointOut(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: int + name: str + lat: float + lon: float + alt: float + + +# --- SavedRateProfile (was SavedRateProfileSerializer) ------------------------ +# NOTE: no view referenced this serializer in routing; kept for parity. +class SavedRateProfileIn(BaseModel): + name: str + type: str = "ascent" + rate_profile_data: Any = Field(default_factory=dict) + + +class SavedRateProfileOut(BaseModel): + model_config = ConfigDict(from_attributes=True) + + id: int + name: str + type: str + rate_profile_data: Any + + +# --- PredictionTemplate (was PreditctionTemplateSerializer) ------------------- +class PredictionTemplateIn(BaseModel): + # protected_namespaces=() avoids pydantic's warning about the `model` field. + model_config = ConfigDict(protected_namespaces=()) + + name: str + is_default: bool = False + description: Optional[str] = None + prediction_mode: str = "" + model: str = "" + dataset: str = "" + flight_parameters: Any = Field(default_factory=dict) + + +class PredictionTemplatePatchIn(BaseModel): + # All-optional variant for PATCH (partial_update). Only set fields apply. + model_config = ConfigDict(protected_namespaces=()) + + name: Optional[str] = None + is_default: Optional[bool] = None + description: Optional[str] = None + prediction_mode: Optional[str] = None + model: Optional[str] = None + dataset: Optional[str] = None + flight_parameters: Optional[Any] = None + + +class PredictionTemplateOut(BaseModel): + model_config = ConfigDict(from_attributes=True, protected_namespaces=()) + + id: int + name: str + is_default: bool + description: Optional[str] = None + prediction_mode: str + model: str + dataset: str + flight_parameters: Any + + +# --- User (was UserSerializer) ------------------------------------------------ +class UserOut(BaseModel): + model_config = ConfigDict(from_attributes=True) + + username: str + email: str + first_name: str + last_name: str + + +class UserUpdateIn(BaseModel): + # `username` was read_only -> excluded from input. PATCH is partial, so all + # fields are optional; the endpoint applies only fields that were set. + email: Optional[str] = None + first_name: Optional[str] = None + last_name: Optional[str] = None + + @field_validator("email") + @classmethod + def _validate_email(cls, value: Optional[str]) -> Optional[str]: + if value is None: + return value + try: + validate_email(value) + except DjangoValidationError: + raise ValueError("Invalid email format") + return value + + +# --- Password / account (were ChangePasswordSerializer / DeleteAccountSerializer) +class ChangePasswordIn(BaseModel): + old_password: str + new_password: str + + @field_validator("new_password") + @classmethod + def _validate_new_password(cls, value: str) -> str: + # validate_password raises Django's ValidationError (not a ValueError), + # which pydantic would not convert into a 4xx -- re-raise as ValueError. + try: + validate_password(value) + except DjangoValidationError as exc: + raise ValueError("; ".join(exc.messages)) + return value + + +class DeleteAccountIn(BaseModel): + password: str + + +# --- Auth / session response shapes (were inline JsonResponse dicts) ---------- +class DetailResponse(BaseModel): + detail: str + + +class SessionResponse(BaseModel): + isAuthenticated: bool + + +class WhoAmIResponse(BaseModel): + username: str + + +class TokenResponse(BaseModel): + token: str + + +# --- Shared path parameters --------------------------------------------------- +class PkPath(BaseModel): + pk: int + + +class UuidPkPath(BaseModel): + pk: uuid.UUID diff --git a/stratoflights_api/pagination.py b/stratoflights_api/pagination.py new file mode 100644 index 0000000..797ac58 --- /dev/null +++ b/stratoflights_api/pagination.py @@ -0,0 +1,120 @@ +"""Limit/offset pagination for the prediction list endpoints. + +DMR has no built-in limit/offset paginator, so (as the docs suggest) we keep our +own envelope model -- matching the former ``CustomLimitOffsetPagination`` output +exactly -- plus a helper that slices the queryset. The query params and envelope +keys are preserved verbatim: ``limit`` / ``skip`` in, and +``{total, limit, skip, predictions}`` out. +""" +from http import HTTPStatus +from typing import Optional +from urllib.parse import urlsplit, urlunsplit + +from django.core.paginator import InvalidPage, Paginator +from django.http import QueryDict +from pydantic import BaseModel + +from dmr.response import APIError + +from .dtos import PredictionOut, TelemetryOut + +DEFAULT_LIMIT = 10 # was CustomLimitOffsetPagination.default_limit +MAX_LIMIT = 100 # was CustomLimitOffsetPagination.max_limit +PAGE_SIZE = 100 # global REST_FRAMEWORK PAGE_SIZE (PageNumberPagination) + + +class PaginationQuery(BaseModel): + # Param names preserved: limit_query_param='limit', offset_query_param='skip'. + limit: int = DEFAULT_LIMIT + skip: int = 0 + + +class PredictionListUserQuery(BaseModel): + """Query params for PredictionViewSet.list_user: filters + limit/offset.""" + + satellite_id: Optional[str] = None + created_from: Optional[str] = None + created_till: Optional[str] = None + limit: int = DEFAULT_LIMIT + skip: int = 0 + + +class PredictionPage(BaseModel): + total: int + limit: int + skip: int + predictions: list[PredictionOut] + + +class TelemetryPage(BaseModel): + count: int + next: Optional[str] = None + previous: Optional[str] = None + results: list[TelemetryOut] + + +def _replace_query_param(url: str, key: str, value) -> str: + scheme, netloc, path, query, fragment = urlsplit(url) + query_dict = QueryDict(query, mutable=True) + query_dict[key] = value + return urlunsplit((scheme, netloc, path, query_dict.urlencode(), fragment)) + + +def _remove_query_param(url: str, key: str) -> str: + scheme, netloc, path, query, fragment = urlsplit(url) + query_dict = QueryDict(query, mutable=True) + query_dict.pop(key, None) + return urlunsplit((scheme, netloc, path, query_dict.urlencode(), fragment)) + + +def page_number_paginate(request, queryset, page_size: int = PAGE_SIZE): + """Reproduce DRF PageNumberPagination. + + Returns ``(count, next_link, previous_link, object_list)``. Raises a 404 + "Invalid page." like DRF on an out-of-range / non-integer page. Honours + ``page=last`` and builds absolute next/previous links off the current URL. + """ + paginator = Paginator(queryset, page_size) + page_number = request.GET.get('page', 1) + if page_number in ('last',): + page_number = paginator.num_pages + try: + page = paginator.page(page_number) + except InvalidPage: + raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Invalid page.'}) + + url = request.build_absolute_uri() + + next_link = None + if page.has_next(): + next_link = _replace_query_param(url, 'page', page.next_page_number()) + + previous_link = None + if page.has_previous(): + previous_number = page.previous_page_number() + if previous_number == 1: + previous_link = _remove_query_param(url, 'page') + else: + previous_link = _replace_query_param(url, 'page', previous_number) + + return paginator.count, next_link, previous_link, list(page.object_list) + + +def paginate_predictions(queryset, query: PaginationQuery) -> PredictionPage: + """Slice ``queryset`` and build the prediction page envelope. + + Mirrors the bounds DRF's LimitOffsetPagination enforced: limit falls back to + the default when non-positive and is capped at MAX_LIMIT; skip floors at 0. + """ + limit = query.limit if query.limit > 0 else DEFAULT_LIMIT + limit = min(limit, MAX_LIMIT) + skip = query.skip if query.skip >= 0 else 0 + + total = queryset.count() + page = list(queryset[skip:skip + limit]) + return PredictionPage( + total=total, + limit=limit, + skip=skip, + predictions=[PredictionOut.model_validate(obj) for obj in page], + ) diff --git a/stratoflights_api/permissions.py b/stratoflights_api/permissions.py deleted file mode 100644 index 84c2ed4..0000000 --- a/stratoflights_api/permissions.py +++ /dev/null @@ -1,16 +0,0 @@ -from rest_framework.permissions import BasePermission, SAFE_METHODS - -class ReadOnlyOrAuthenticated(BasePermission): - def has_permission(self, request, view): - return ( - request.method in SAFE_METHODS or - request.user and request.user.is_authenticated - ) - - -class IsOwner(BasePermission): - def has_object_permission(self, request, view, obj): - return obj.user == request.user - - def has_permission(self, request, view): - return request.user and request.user.is_authenticated \ No newline at end of file diff --git a/stratoflights_api/serializers.py b/stratoflights_api/serializers.py deleted file mode 100644 index dacdf57..0000000 --- a/stratoflights_api/serializers.py +++ /dev/null @@ -1,185 +0,0 @@ -from rest_framework import serializers -from .models import Prediction, SavedPoint, SavedRateProfile, PreditctionTemplate -from datetime import datetime -from django.contrib.auth.password_validation import validate_password -from django.core.validators import validate_email -from django.core.exceptions import ValidationError as DjangoValidationError -from django.contrib.auth import get_user_model -from .validators import ( - validate_custom_curve, rate_clip, - _rfc3339_to_timestamp, base64_to_curve -) - -class PredictionSerializer(serializers.ModelSerializer): - class Meta: - model = Prediction - fields = ['id', 'created_at', 'updated_at', 'result'] - -User = get_user_model() - -PROFILE_STANDARD = "standard_profile" -PROFILE_FLOAT = "float_profile" -PROFILE_REVERSE = "reverse_profile" -PROFILE_CUSTOM = "custom_profile" -LATEST_DATASET_KEYWORD = "latest" -SUPPORTED_PROFILES = [PROFILE_STANDARD, PROFILE_FLOAT, PROFILE_REVERSE, PROFILE_CUSTOM] - - -class PredictionRequestSerializer(serializers.Serializer): - launch_latitude = serializers.FloatField(min_value=-90, max_value=90) - launch_longitude = serializers.FloatField(min_value=0, max_value=360) - launch_datetime = serializers.DateTimeField() - launch_altitude = serializers.FloatField(required=False) - format = serializers.CharField(default="json") - profile = serializers.ChoiceField(choices=SUPPORTED_PROFILES, default=PROFILE_STANDARD) - dataset = serializers.CharField(default=LATEST_DATASET_KEYWORD) - - # --- профиль-dependent поля --- - ascent_rate = serializers.FloatField(required=False, min_value=0.01) - descent_rate = serializers.FloatField(required=False, min_value=0.01) - burst_altitude = serializers.FloatField(required=False) - float_altitude = serializers.FloatField(required=False) - stop_datetime = serializers.DateTimeField(required=False) - ascent_curve = serializers.CharField(required=False) - descent_curve = serializers.CharField(required=False) - interpolate = serializers.BooleanField(required=False, default=False) - start_point = serializers.PrimaryKeyRelatedField( - queryset=SavedPoint.objects.all(), required=False, allow_null=True - ) - rate_profile = serializers.PrimaryKeyRelatedField( - queryset=SavedRateProfile.objects.all(), required=False, allow_null=True - ) - template = serializers.PrimaryKeyRelatedField( - queryset=PreditctionTemplate.objects.all(), required=False, allow_null=True - ) - - def validate(self, data): - profile = data.get("profile", PROFILE_STANDARD) - launch_alt = data.get("launch_altitude", 0) - - if profile == PROFILE_STANDARD: - if 'burst_altitude' not in data: - raise serializers.ValidationError("burst_altitude is required for standard profile.") - if data['burst_altitude'] <= launch_alt: - raise serializers.ValidationError("burst_altitude must be greater than launch_altitude.") - - elif profile == PROFILE_FLOAT: - if 'float_altitude' not in data or data['float_altitude'] <= launch_alt: - raise serializers.ValidationError("float_altitude must be greater than launch_altitude.") - if 'stop_datetime' not in data or data['stop_datetime'] <= data['launch_datetime']: - raise serializers.ValidationError("stop_datetime must be later than launch_datetime.") - - elif profile == PROFILE_CUSTOM: - if 'ascent_curve' not in data or not validate_custom_curve(data['ascent_curve']): - raise serializers.ValidationError("Invalid ascent_curve.") - if 'descent_curve' not in data or not validate_custom_curve(data['descent_curve']): - raise serializers.ValidationError("Invalid descent_curve.") - if 'burst_altitude' not in data or data['burst_altitude'] <= launch_alt: - raise serializers.ValidationError("burst_altitude must be greater than launch_altitude.") - - # кастомная логика clipping'а - if 'ascent_rate' in data: - data['ascent_rate'] = rate_clip(data['ascent_rate']) - if 'descent_rate' in data: - data['descent_rate'] = rate_clip(data['descent_rate']) - - return data - - def create(self, validated_data): - if 'ascent_curve' in validated_data: - validated_data['ascent_curve'] = base64_to_curve(validated_data['ascent_curve']) - if 'descent_curve' in validated_data: - validated_data['descent_curve'] = base64_to_curve(validated_data['descent_curve']) - - prediction = Prediction( - user=validated_data.get('user'), - request=validated_data.get('request', {}), - result=validated_data.get('result', {}), - start_point=validated_data.get('start_point'), - template=validated_data.get('template'), - rate_profile=validated_data.get('rate_profile') - ) - prediction.save() - - return prediction - - -class PredictionListSerializer(serializers.ModelSerializer): - class Meta: - model = Prediction - fields = ["id", "created_at", "updated_at", "start_point", "template", "rate_profile"] - - -class PredictionDetailSerializer(serializers.ModelSerializer): - class Meta: - model = Prediction - fields = ["id", "created_at", "updated_at", "result", "start_point", "template", "rate_profile"] - - -from rest_framework import serializers -from .models import TelemetryPacket - -class TelemetryPacketSerializer(serializers.ModelSerializer): - class Meta: - model = TelemetryPacket - fields = ['id', 'timestamp', 'lat', 'lon', 'alt', 'payload'] - read_only_fields = ['id'] - - -class SavedPointSerializer(serializers.ModelSerializer): - user = serializers.HiddenField( - default=serializers.CurrentUserDefault() - ) - class Meta: - model = SavedPoint - fields = ['user', 'id', 'name', 'lat', 'lon', 'alt'] - read_only_fields = ['id'] - - validators = [ - serializers.UniqueTogetherValidator( - queryset=SavedPoint.objects.all(), - fields=['user', 'name'], - message="A saved point with this name already exists for the user." - ) - ] - - -class SavedRateProfileSerializer(serializers.ModelSerializer): - class Meta: - model = SavedRateProfile - fields = ['id', 'name', 'type', 'rate_profile_data'] - read_only_fields = ['id'] - - -class PreditctionTemplateSerializer(serializers.ModelSerializer): - class Meta: - model = PreditctionTemplate - fields = ['id', 'name', 'is_default', 'description', 'prediction_mode', 'model', 'dataset', 'flight_parameters'] - read_only_fields = ['id'] - - -class UserSerializer(serializers.ModelSerializer): - class Meta: - model = User - fields = ['username', 'email', 'first_name', 'last_name'] - extra_kwargs = { - 'username': {'read_only': True} - } - - def validate_email(self, value): - try: - validate_email(value) - except DjangoValidationError: - raise serializers.ValidationError("Invalid email format") - return value - -class ChangePasswordSerializer(serializers.Serializer): - old_password = serializers.CharField(required=True) - new_password = serializers.CharField(required=True) - - def validate_new_password(self, value): - validate_password(value) - return value - -class DeleteAccountSerializer(serializers.Serializer): - password = serializers.CharField(required=True) \ No newline at end of file diff --git a/stratoflights_api/urls.py b/stratoflights_api/urls.py index 3a438c5..9eb84dc 100644 --- a/stratoflights_api/urls.py +++ b/stratoflights_api/urls.py @@ -1,48 +1,59 @@ -from django.urls import path -from rest_framework.routers import DefaultRouter -from rest_framework.authtoken.views import obtain_auth_token +from dmr.routing import Router, path from .views import ( - PredictionViewSet, - SavedPointViewset, - PreditctionTemplateViewset, - TelemetryListCreateView, - get_csrf, - login_view, - logout_view, - SessionView, - WhoAmIView, - UserProfileView, - ChangePasswordView, - TokenManagementView, - DeleteUserDataView, - DeleteAccountView + PredictionCollectionController, + PredictionListUserController, + PredictionHistoryController, + PredictionDetailController, + PredictionDeleteController, + SavedPointListController, + SavedPointDetailController, + PredictionTemplateListController, + PredictionTemplateDetailController, + TelemetryController, + CsrfController, + LoginController, + LogoutController, + SessionController, + WhoAmIController, + UserProfileController, + ChangePasswordController, + ObtainTokenController, + TokenManagementController, + DeleteUserDataController, + DeleteAccountController, ) +# A Router (prefix + routes) is required so build_schema() can generate the +# OpenAPI document; see stratoflights/urls.py. +router = Router( + 'api/', + [ + path("csrf/", CsrfController.as_view(), name='api-csrf'), + path('token', ObtainTokenController.as_view(), name='get_token'), + path("login/", LoginController.as_view(), name='api-login'), + path("logout/", LogoutController.as_view(), name='api-logout'), + path("session/", SessionController.as_view(), name='api-session'), + path("whoami/", WhoAmIController.as_view(), name='api-whoami'), + path("/telemetry/", TelemetryController.as_view(), name="create_telemetry"), + path("profile/", UserProfileController.as_view(), name='api-profile'), + path("profile/change-password/", ChangePasswordController.as_view(), name='api-change-password'), + path("profile/token/", TokenManagementController.as_view(), name='api-token'), + path("profile/delete-data/", DeleteUserDataController.as_view(), name='api-delete-data'), + path("profile/delete-account/", DeleteAccountController.as_view(), name='api-delete-account'), + # Saved points (was SavedPointViewset via DefaultRouter). + path("saved-points/", SavedPointListController.as_view(), name='saved-points-list'), + path("saved-points//", SavedPointDetailController.as_view(), name='saved-points-detail'), + # Prediction templates (was PreditctionTemplateViewset via DefaultRouter). + path("saved-templates/", PredictionTemplateListController.as_view(), name='saved-templates-list'), + path("saved-templates//", PredictionTemplateDetailController.as_view(), name='saved-templates-detail'), + # Predictions (was PredictionViewSet via DefaultRouter). + path("predictions/", PredictionCollectionController.as_view(), name='predictions-list'), + path("predictions/list_user/", PredictionListUserController.as_view(), name='predictions-list-user'), + path("predictions/history/", PredictionHistoryController.as_view(), name='predictions-history'), + path("predictions//detail/", PredictionDetailController.as_view(), name='predictions-detail'), + path("predictions//delete/", PredictionDeleteController.as_view(), name='predictions-delete'), + ], +) -router = DefaultRouter() -router.register(r'predictions', PredictionViewSet, basename='predictions') -router.register(r'saved-points', SavedPointViewset, basename='saved-points') -router.register(r'saved-templates', PreditctionTemplateViewset, basename='saved-templates') - - -urlpatterns = [ - path("csrf/", get_csrf, name='api-csrf'), - path('token', obtain_auth_token, name = 'get_token'), - path("login/", login_view, name='api-login'), - path("logout/", logout_view, name='api-logout'), - path("session/", SessionView.as_view(), name='api-session'), - path("whoami/", WhoAmIView.as_view(), name='api-whoami'), - path("/telemetry/", TelemetryListCreateView.as_view(), name="create_telemetry"), - path('csrf/', get_csrf, name='api-csrf'), - path('login/', login_view, name='api-login'), - path('logout/', logout_view, name='api-logout'), - path('session/', SessionView.as_view(), name='api-session'), - path('whoami/', WhoAmIView.as_view(), name='api-whoami'), - path("profile/", UserProfileView.as_view(), name='api-profile'), - path("profile/change-password/", ChangePasswordView.as_view(), name='api-change-password'), - path("profile/token/", TokenManagementView.as_view(), name='api-token'), - path("profile/delete-data/", DeleteUserDataView.as_view(), name='api-delete-data'), - path("profile/delete-account/", DeleteAccountView.as_view(), name='api-delete-account'), -] -urlpatterns += router.urls +urlpatterns = router.urls diff --git a/stratoflights_api/views.py b/stratoflights_api/views.py index 319f2a1..6f1d8b8 100644 --- a/stratoflights_api/views.py +++ b/stratoflights_api/views.py @@ -1,341 +1,552 @@ import requests import time import json -from rest_framework import status, generics, permissions -from rest_framework.response import Response -from rest_framework.views import APIView -from rest_framework.viewsets import ModelViewSet, ViewSet, GenericViewSet -from rest_framework.exceptions import APIException -from rest_framework.permissions import IsAuthenticated, AllowAny -from rest_framework.authentication import SessionAuthentication, BasicAuthentication, TokenAuthentication -from rest_framework.decorators import api_view, permission_classes, authentication_classes, action -from rest_framework.authtoken.models import Token -from django.utils import timezone -from django.views.decorators.csrf import csrf_exempt -from django.utils.decorators import method_decorator -from django.http import JsonResponse +from http import HTTPStatus from django.contrib.auth import authenticate, login, logout, get_user_model from django.middleware.csrf import get_token -from django.core.exceptions import ValidationError from django.utils.dateparse import parse_datetime -from .models import Prediction, User, Satellite, SavedPoint, SavedRateProfile, PreditctionTemplate, TelemetryPacket -from .serializers import PredictionSerializer, TelemetryPacketSerializer, PredictionRequestSerializer, PredictionListSerializer, PredictionDetailSerializer, SavedPointSerializer, SavedRateProfileSerializer, PreditctionTemplateSerializer, UserSerializer, ChangePasswordSerializer, DeleteAccountSerializer +from .models import Prediction, Satellite, SavedPoint, SavedRateProfile, PreditctionTemplate, TelemetryPacket from .services.tawhiri import TawhiriClient -from drf_spectacular.utils import extend_schema -from .permissions import ReadOnlyOrAuthenticated, IsOwner -from .custom_pagination import CustomLimitOffsetPagination -from datetime import datetime +from datetime import datetime, timedelta, timezone + +from pydantic import ValidationError +from dmr import Controller, modify +from dmr.components import Body, Path, Query +from dmr.plugins.pydantic import PydanticSerializer +from dmr.response import APIError +from dmr.security.jwt.views import ( + ObtainTokensPayload, + ObtainTokensResponse, + ObtainTokensSyncController, +) +from .dtos import ( + DetailResponse, + SessionResponse, + WhoAmIResponse, + UserOut, + UserUpdateIn, + ChangePasswordIn, + DeleteAccountIn, + TokenResponse, + PkPath, + UuidPkPath, + SavedPointIn, + SavedPointPatchIn, + SavedPointOut, + PredictionTemplateIn, + PredictionTemplatePatchIn, + PredictionTemplateOut, + PredictionRequest, + PredictionOut, + PredictionListOut, + PredictionDetailOut, + PredictionCreateOut, + TelemetryIn, + TelemetryOut, +) +from .pagination import ( + PredictionListUserQuery, + PredictionPage, + paginate_predictions, + TelemetryPage, + page_number_paginate, +) +from .validators import base64_to_curve + User = get_user_model() -def get_prediction_from_tawhiri(params): +def _resolve_related(model, pk, field_name): + """Resolve a related object by PK (was PrimaryKeyRelatedField(queryset=...)). - base_url = "https://fly.stratonautica.ru/api/v2" - response = requests.get(base_url, params=params) - - if response.status_code == 200: - return response.json() # получаем результат предсказания - else: - raise Exception( - f"Tawhiri error: {response.status_code} {response.text}") + The original queryset was ``.all()`` (not user-scoped); a missing PK yields a + 400 matching DRF's "Invalid pk" error shape. + """ + if pk is None: + return None + obj = model.objects.filter(pk=pk).first() + if obj is None: + raise APIError( + status_code=HTTPStatus.BAD_REQUEST, + body={field_name: [f'Invalid pk "{pk}" - object does not exist.']}, + ) + return obj -class PredictionViewSet(GenericViewSet): - permission_classes = [IsAuthenticated] - pagination_class = CustomLimitOffsetPagination +class PredictionCollectionController(Controller[PydanticSerializer]): + """`predictions/` -- list the user's predictions and create a new one.""" - def list(self, request): - queryset = Prediction.objects.filter(user=request.user) - return Response(PredictionSerializer(queryset, many=True).data) - - def create(self, request): - serializer = PredictionRequestSerializer(data=request.data) + def get(self) -> list[PredictionOut]: + queryset = Prediction.objects.filter(user=self.request.user) + return [PredictionOut.model_validate(obj) for obj in queryset] - if not serializer.is_valid(): - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + def post(self, parsed_body: Body[PredictionRequest]) -> PredictionCreateOut: + user = self.request.user - validated_data = serializer.validated_data + # Resolve related objects before calling Tawhiri, so an invalid PK still + # fails with 400 first (as DRF validation did). + start_point = _resolve_related(SavedPoint, parsed_body.start_point, 'start_point') + template = _resolve_related(PreditctionTemplate, parsed_body.template, 'template') + rate_profile = _resolve_related(SavedRateProfile, parsed_body.rate_profile, 'rate_profile') try: - prediction_result = TawhiriClient.get_prediction(validated_data) + prediction_result = TawhiriClient.get_prediction(parsed_body.model_dump()) + except requests.RequestException as exc: + raise APIError( + status_code=HTTPStatus.BAD_GATEWAY, + body={'error': f'Tawhiri error: {str(exc)}'}, + ) - except requests.RequestException as e: - return Response({"error": f"Tawhiri error: {str(e)}"}, status=status.HTTP_502_BAD_GATEWAY) + # Carried over from the old serializer.create(): curves are decoded but + # never persisted (the model has no curve fields). Kept for parity -- + # including that a malformed curve here raises (then surfaces as 500). + if parsed_body.ascent_curve is not None: + base64_to_curve(parsed_body.ascent_curve) + if parsed_body.descent_curve is not None: + base64_to_curve(parsed_body.descent_curve) - # prediction = Prediction.objects.create( - # result=prediction_result, user=request.user, request=request.data, validated_data=validated_data) - prediction = serializer.save( - user=request.user, + prediction = Prediction( + user=user, + request=json.loads(self.request.body), + result=prediction_result, + start_point=start_point, + template=template, + rate_profile=rate_profile, + ) + prediction.save() + + return PredictionCreateOut( + id=prediction.id, + created_at=prediction.created_at, result=prediction_result, - request=request.data ) - return Response({ - "id": prediction.id, - "created_at": prediction.created_at, - "result": prediction_result - }, status=status.HTTP_201_CREATED) - @action(detail=False, methods=['get']) - def list_user(self, request): - user = request.user - satellite_id = request.query_params.get('satellite_id') - created_from = request.query_params.get('created_from') - created_till = request.query_params.get('created_till') +class PredictionListUserController(Controller[PydanticSerializer]): + """`predictions/list_user/` -- filtered, paginated list of the user's predictions.""" - filters = { - 'user': user, - } + def get(self, parsed_query: Query[PredictionListUserQuery]) -> PredictionPage: + user = self.request.user + filters = {'user': user} - if created_from: - filters['created_at__gte'] = parse_datetime(created_from) - - if created_till: - filters['created_at__lte'] = parse_datetime(created_till) - - if satellite_id: - if not user.satellites.filter(id=satellite_id).exists(): - return Response({'detail': 'Access denied'}, status=403) - - filters['satellite_id'] = satellite_id + if parsed_query.created_from: + filters['created_at__gte'] = parse_datetime(parsed_query.created_from) + if parsed_query.created_till: + filters['created_at__lte'] = parse_datetime(parsed_query.created_till) + if parsed_query.satellite_id: + if not user.satellites.filter(id=parsed_query.satellite_id).exists(): + raise APIError(status_code=HTTPStatus.FORBIDDEN, body={'detail': 'Access denied'}) + filters['satellite_id'] = parsed_query.satellite_id queryset = Prediction.objects.filter(**filters) - queryset = self.filter_queryset(queryset) - - page = self.paginate_queryset(queryset) + return paginate_predictions(queryset, parsed_query) - if page is not None: - serializer = PredictionSerializer(page, many=True) - return self.get_paginated_response(serializer.data) - - serializer = PredictionSerializer(queryset, many=True) - return Response(serializer.data) - @action(detail=False, methods=["get"]) - def history(self, request): - queryset = Prediction.objects.filter(user=request.user) - return Response(PredictionListSerializer(queryset, many=True).data) +class PredictionHistoryController(Controller[PydanticSerializer]): + """`predictions/history/` -- compact list of the user's predictions.""" - @action(detail=True, methods=["get"]) - def detail(self, request, pk=None): + def get(self) -> list[PredictionListOut]: + queryset = Prediction.objects.filter(user=self.request.user) + return [PredictionListOut.model_validate(obj) for obj in queryset] + + +class PredictionDetailController(Controller[PydanticSerializer]): + """`predictions//detail/` -- retrieve a single prediction.""" + + def get(self, parsed_path: Path[UuidPkPath]) -> PredictionDetailOut: prediction = Prediction.objects.filter( - user=request.user, pk=pk).first() - if not prediction: - return Response({'detail': 'Not found'}, status=404) - return Response(PredictionDetailSerializer(prediction).data) + user=self.request.user, pk=parsed_path.pk).first() + if prediction is None: + raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'}) + return PredictionDetailOut.model_validate(prediction) - @action(detail=True, methods=["delete"]) - def delete(self, request, pk=None): + +class PredictionDeleteController(Controller[PydanticSerializer]): + """`predictions//delete/` -- delete a single prediction.""" + + @modify(status_code=HTTPStatus.NO_CONTENT) + def delete(self, parsed_path: Path[UuidPkPath]) -> None: prediction = Prediction.objects.filter( - user=request.user, pk=pk).first() - if not prediction: - return Response({'detail': 'Not found'}, status=404) + user=self.request.user, pk=parsed_path.pk).first() + if prediction is None: + raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'}) prediction.delete() - return Response(status=204) -class TelemetryListCreateView(generics.ListCreateAPIView): - serializer_class = TelemetryPacketSerializer - permission_classes = [permissions.AllowAny] +class TelemetryController(Controller[PydanticSerializer]): + """`/telemetry/` -- list and ingest telemetry for a satellite. - def get_queryset(self): - qs = TelemetryPacket.objects.filter(satellite_id=self.kwargs["pk"]) + Public (was AllowAny). GET preserves the global PageNumberPagination envelope + {count, next, previous, results} with PAGE_SIZE=100. + """ - from_ts = self.request.query_params.get("from") - till_ts = self.request.query_params.get("till") + auth = None # public (was AllowAny) + csrf_exempt = True # DRF views bypass Django CSRF; ingestion is anonymous + + def get(self, parsed_path: Path[UuidPkPath]) -> TelemetryPage: + qs = TelemetryPacket.objects.filter(satellite_id=parsed_path.pk) + + from_ts = self.request.GET.get('from') + till_ts = self.request.GET.get('till') if from_ts: qs = qs.filter(timestamp__gte=int(from_ts)) if till_ts: qs = qs.filter(timestamp__lte=int(till_ts)) - return qs.order_by("-timestamp") + qs = qs.order_by('-timestamp') - def post(self, request, pk): - serializer = TelemetryPacketSerializer(data=request.data) - if not serializer.is_valid(): - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + count, next_link, previous_link, page_objects = page_number_paginate(self.request, qs) + return TelemetryPage( + count=count, + next=next_link, + previous=previous_link, + results=[TelemetryOut.model_validate(obj) for obj in page_objects], + ) - validated_data = serializer.validated_data - - TelemetryPacket.objects.create(timestamp=time.time(), - satellite=Satellite.objects.get(id=pk), - lat=validated_data["lat"], - lon=validated_data["lon"], - alt=validated_data["alt"], - payload=validated_data['payload'], - ) - return Response(serializer.errors, status=status.HTTP_201_CREATED) + def post(self, parsed_path: Path[UuidPkPath], parsed_body: Body[TelemetryIn]) -> TelemetryOut: + # Bug fix (approved): the original returned serializer.errors on success; + # we return the created packet instead. timestamp is still server-set. + packet = TelemetryPacket.objects.create( + timestamp=time.time(), + satellite=Satellite.objects.get(id=parsed_path.pk), + lat=parsed_body.lat, + lon=parsed_body.lon, + alt=parsed_body.alt, + payload=parsed_body.payload, + ) + return TelemetryOut.model_validate(packet) -class SessionView(APIView): - permission_classes = [IsAuthenticated] +class SessionController(Controller[PydanticSerializer]): + """Report whether the current request is authenticated.""" - @staticmethod - def get(request, format=None): - return JsonResponse({'isAuthenticated': True}) + def get(self) -> SessionResponse: + return SessionResponse(isAuthenticated=True) -class WhoAmIView(APIView): - permission_classes = [IsAuthenticated] +class WhoAmIController(Controller[PydanticSerializer]): + """Return the current user's username.""" - @staticmethod - def get(request, format=None): - return JsonResponse({'username': request.user.username}) + def get(self) -> WhoAmIResponse: + return WhoAmIResponse(username=self.request.user.username) -@extend_schema(methods=["GET"], description="Get CSRF token") -@csrf_exempt -@api_view(["GET"]) -@permission_classes([AllowAny]) -def get_csrf(request): - response = JsonResponse({'detail': 'CSRF cookie set'}) - response['X-CSRFToken'] = get_token(request) - return response +class CsrfController(Controller[PydanticSerializer]): + """Get CSRF token.""" + + auth = None # public endpoint (was AllowAny) + csrf_exempt = True + + def get(self) -> DetailResponse: + token = get_token(self.request) + return self.to_response( + DetailResponse(detail='CSRF cookie set'), + headers={'X-CSRFToken': token}, + ) -@extend_schema(methods=["POST"], description="Login user") -@csrf_exempt -@api_view(["POST"]) -@authentication_classes([BasicAuthentication]) -@permission_classes([AllowAny]) -def login_view(request): - data = json.loads(request.body) - username = data.get('username') - password = data.get('password') +class LoginController(Controller[PydanticSerializer]): + """Login user.""" - if username is None or password is None: - return JsonResponse({'detail': 'Please provide username and password.'}, status=400) + auth = None # public endpoint (was AllowAny) + csrf_exempt = True - user = authenticate(username=username, password=password) - if user is None: - return JsonResponse({'detail': 'Invalid credentials.'}, status=400) + # post() would default to 201; the original returned 200. + @modify(status_code=HTTPStatus.OK) + def post(self) -> DetailResponse: + data = json.loads(self.request.body) + username = data.get('username') + password = data.get('password') - login(request, user) - return JsonResponse({'detail': 'Successfully logged in.'}) + if username is None or password is None: + raise APIError( + status_code=HTTPStatus.BAD_REQUEST, + body={'detail': 'Please provide username and password.'}, + ) + + user = authenticate(username=username, password=password) + if user is None: + raise APIError( + status_code=HTTPStatus.BAD_REQUEST, + body={'detail': 'Invalid credentials.'}, + ) + + login(self.request, user) + return DetailResponse(detail='Successfully logged in.') -@extend_schema(methods=["POST"], description="Logout user") -@api_view(["POST"]) -@permission_classes([AllowAny]) -def logout_view(request): - if not request.user.is_authenticated: - return JsonResponse({'detail': 'You\'re not logged in.'}, status=400) +class LogoutController(Controller[PydanticSerializer]): + """Logout user.""" - logout(request) - return JsonResponse({'detail': 'Successfully logged out.'}) + auth = None # public endpoint (was AllowAny); checks auth state manually + + @modify(status_code=HTTPStatus.OK) + def post(self) -> DetailResponse: + if not self.request.user.is_authenticated: + raise APIError( + status_code=HTTPStatus.BAD_REQUEST, + body={'detail': "You're not logged in."}, + ) + + logout(self.request) + return DetailResponse(detail='Successfully logged out.') -class SavedPointViewset(ModelViewSet): - permission_classes = [IsOwner] - serializer_class = SavedPointSerializer - pagination_class = None - - def get_queryset(self): - return SavedPoint.objects.filter(user=self.request.user) - - def perform_create(self, serializer): - serializer.save(user=self.request.user) +_SAVED_POINT_DUPLICATE = 'A saved point with this name already exists for the user.' -class PreditctionTemplateViewset(ModelViewSet): - permission_classes = [IsOwner] - serializer_class = PreditctionTemplateSerializer - pagination_class = None - - def get_queryset(self): - return PreditctionTemplate.objects.filter(user=self.request.user) - - def perform_create(self, serializer): - serializer.save(user=self.request.user) +def _check_saved_point_unique(user, name, exclude_pk=None): + """Reproduce the former SavedPointSerializer UniqueTogetherValidator.""" + qs = SavedPoint.objects.filter(user=user, name=name) + if exclude_pk is not None: + qs = qs.exclude(pk=exclude_pk) + if qs.exists(): + raise APIError( + status_code=HTTPStatus.BAD_REQUEST, + body={'non_field_errors': [_SAVED_POINT_DUPLICATE]}, + ) -class UserProfileView(APIView): - permission_classes = [IsAuthenticated] +class SavedPointListController(Controller[PydanticSerializer]): + """Collection endpoint for the current user's saved points (was SavedPointViewset).""" - def get(self, request): - serializer = UserSerializer(request.user) - return Response(serializer.data) + def get(self) -> list[SavedPointOut]: + qs = SavedPoint.objects.filter(user=self.request.user) + return [SavedPointOut.model_validate(obj) for obj in qs] - def patch(self, request): - user = request.user - serializer = UserSerializer(user, data=request.data, partial=True) - - if not serializer.is_valid(): - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - - serializer.save() - return Response(serializer.data) + def post(self, parsed_body: Body[SavedPointIn]) -> SavedPointOut: + user = self.request.user + _check_saved_point_unique(user, parsed_body.name) + obj = SavedPoint.objects.create(user=user, **parsed_body.model_dump()) + return SavedPointOut.model_validate(obj) -class ChangePasswordView(APIView): - permission_classes = [IsAuthenticated] +class SavedPointDetailController(Controller[PydanticSerializer]): + """Detail endpoint for a single saved point owned by the current user.""" - def post(self, request): - user = request.user - serializer = ChangePasswordSerializer(data=request.data) - - if not serializer.is_valid(): - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - - if not user.check_password(serializer.validated_data['old_password']): - return Response({'detail': 'Old password is incorrect'}, - status=status.HTTP_400_BAD_REQUEST) + def _get_object(self, pk: int): + # Filtering by user means another user's object reads as 404 (not 403), + # matching DRF's get_object() over a user-scoped queryset. + obj = SavedPoint.objects.filter(user=self.request.user, pk=pk).first() + if obj is None: + raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'}) + return obj - user.set_password(serializer.validated_data['new_password']) + def get(self, parsed_path: Path[PkPath]) -> SavedPointOut: + return SavedPointOut.model_validate(self._get_object(parsed_path.pk)) + + def put(self, parsed_path: Path[PkPath], parsed_body: Body[SavedPointIn]) -> SavedPointOut: + obj = self._get_object(parsed_path.pk) + _check_saved_point_unique(self.request.user, parsed_body.name, exclude_pk=obj.pk) + for field, value in parsed_body.model_dump().items(): + setattr(obj, field, value) + obj.save() + return SavedPointOut.model_validate(obj) + + def patch(self, parsed_path: Path[PkPath], parsed_body: Body[SavedPointPatchIn]) -> SavedPointOut: + obj = self._get_object(parsed_path.pk) + updates = parsed_body.model_dump(exclude_unset=True) + if 'name' in updates: + _check_saved_point_unique(self.request.user, updates['name'], exclude_pk=obj.pk) + for field, value in updates.items(): + setattr(obj, field, value) + obj.save() + return SavedPointOut.model_validate(obj) + + @modify(status_code=HTTPStatus.NO_CONTENT) + def delete(self, parsed_path: Path[PkPath]) -> None: + self._get_object(parsed_path.pk).delete() + + +class PredictionTemplateListController(Controller[PydanticSerializer]): + """Collection endpoint for the current user's templates (was PreditctionTemplateViewset). + + NOTE: as before, there is no app-level uniqueness check; a duplicate + (user, name) hits the model's unique_together and surfaces as a DB error. + """ + + def get(self) -> list[PredictionTemplateOut]: + qs = PreditctionTemplate.objects.filter(user=self.request.user) + return [PredictionTemplateOut.model_validate(obj) for obj in qs] + + def post(self, parsed_body: Body[PredictionTemplateIn]) -> PredictionTemplateOut: + obj = PreditctionTemplate.objects.create( + user=self.request.user, **parsed_body.model_dump() + ) + return PredictionTemplateOut.model_validate(obj) + + +class PredictionTemplateDetailController(Controller[PydanticSerializer]): + """Detail endpoint for a single template owned by the current user.""" + + def _get_object(self, pk: int): + obj = PreditctionTemplate.objects.filter(user=self.request.user, pk=pk).first() + if obj is None: + raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'}) + return obj + + def get(self, parsed_path: Path[PkPath]) -> PredictionTemplateOut: + return PredictionTemplateOut.model_validate(self._get_object(parsed_path.pk)) + + def put(self, parsed_path: Path[PkPath], parsed_body: Body[PredictionTemplateIn]) -> PredictionTemplateOut: + obj = self._get_object(parsed_path.pk) + for field, value in parsed_body.model_dump().items(): + setattr(obj, field, value) + obj.save() + return PredictionTemplateOut.model_validate(obj) + + def patch(self, parsed_path: Path[PkPath], parsed_body: Body[PredictionTemplatePatchIn]) -> PredictionTemplateOut: + obj = self._get_object(parsed_path.pk) + for field, value in parsed_body.model_dump(exclude_unset=True).items(): + setattr(obj, field, value) + obj.save() + return PredictionTemplateOut.model_validate(obj) + + @modify(status_code=HTTPStatus.NO_CONTENT) + def delete(self, parsed_path: Path[PkPath]) -> None: + self._get_object(parsed_path.pk).delete() + + +class UserProfileController(Controller[PydanticSerializer]): + """Read and partially update the current user's profile.""" + + def get(self) -> UserOut: + return UserOut.model_validate(self.request.user) + + def patch(self, parsed_body: Body[UserUpdateIn]) -> UserOut: + user = self.request.user + # partial update: apply only the fields the client actually sent. + for field, value in parsed_body.model_dump(exclude_unset=True).items(): + setattr(user, field, value) user.save() - return Response({'detail': 'Password changed successfully'}) + return UserOut.model_validate(user) -class DeleteAccountView(APIView): - permission_classes = [IsAuthenticated] +class ChangePasswordController(Controller[PydanticSerializer]): + """Change the current user's password.""" + + # post() would default to 201; the original returned 200. + @modify(status_code=HTTPStatus.OK) + def post(self, parsed_body: Body[ChangePasswordIn]) -> DetailResponse: + user = self.request.user + + if not user.check_password(parsed_body.old_password): + raise APIError( + status_code=HTTPStatus.BAD_REQUEST, + body={'detail': 'Old password is incorrect'}, + ) + + user.set_password(parsed_body.new_password) + user.save() + return DetailResponse(detail='Password changed successfully') + + +class DeleteAccountController(Controller[PydanticSerializer]): + """Delete the current user's account and all their data.""" + + # DMR forbids a Body component on DELETE, but the original endpoint read the + # password from a DELETE request body. Preserve that contract by parsing the + # body manually instead of via Body[DeleteAccountIn]. + def delete(self) -> DetailResponse: + user = self.request.user + + try: + parsed_body = DeleteAccountIn(**json.loads(self.request.body or b'{}')) + except ValidationError as exc: + raise APIError(status_code=HTTPStatus.BAD_REQUEST, body=json.loads(exc.json())) + except ValueError: + raise APIError(status_code=HTTPStatus.BAD_REQUEST, body={'detail': 'Invalid request body.'}) + + if not user.check_password(parsed_body.password): + raise APIError( + status_code=HTTPStatus.BAD_REQUEST, + body={'detail': 'Incorrect password'}, + ) - def delete(self, request): - user = request.user - serializer = DeleteAccountSerializer(data=request.data) - - if not serializer.is_valid(): - return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - - if not user.check_password(serializer.validated_data['password']): - return Response({'detail': 'Incorrect password'}, - status=status.HTTP_400_BAD_REQUEST) - Prediction.objects.filter(user=user).delete() SavedPoint.objects.filter(user=user).delete() PreditctionTemplate.objects.filter(user=user).delete() - + user.delete() - - return Response({'detail': 'Account deleted successfully'}) + + return DetailResponse(detail='Account deleted successfully') -class DeleteUserDataView(APIView): - permission_classes = [IsAuthenticated] +class DeleteUserDataController(Controller[PydanticSerializer]): + """Delete all of the current user's data without deleting the account.""" + + def delete(self) -> DetailResponse: + user = self.request.user - def delete(self, request): - user = request.user - Prediction.objects.filter(user=user).delete() SavedPoint.objects.filter(user=user).delete() PreditctionTemplate.objects.filter(user=user).delete() - - return Response({'detail': 'All user data deleted successfully'}) + + return DetailResponse(detail='All user data deleted successfully') -class TokenManagementView(APIView): - permission_classes = [IsAuthenticated] +class ObtainTokenController( + ObtainTokensSyncController[ + PydanticSerializer, + ObtainTokensPayload, + ObtainTokensResponse, + ], +): + """Exchange username/password for JWT access + refresh tokens. - def get(self, request): - - token, created = Token.objects.get_or_create(user=request.user) - return Response({"token": token.key}) + Replaces DRF's obtain_auth_token. Approved drift: the token format and + semantics change from a single stored DRF token to stateless JWTs, so the + response is {access_token, refresh_token} instead of {token}. + """ - def post(self, request): - - Token.objects.filter(user=request.user).delete() - token = Token.objects.create(user=request.user) - return Response({"token": token.key}) + auth = None # public: credentials are supplied in the request body + csrf_exempt = True + jwt_expiration = timedelta(hours=1) + jwt_refresh_expiration = timedelta(days=7) + + def convert_auth_payload(self, payload): + return payload + + def make_api_response(self): + now = datetime.now(timezone.utc) + return { + 'access_token': self.create_jwt_token( + expiration=now + self.jwt_expiration, + token_type='access', + ), + 'refresh_token': self.create_jwt_token( + expiration=now + self.jwt_refresh_expiration, + token_type='refresh', + ), + } + + +class TokenManagementController(Controller[PydanticSerializer]): + """Issue a fresh JWT access token for the current user. + + Was TokenManagementView (DRF stored-token get/regenerate). With stateless + JWTs there is nothing stored to fetch or delete, so both GET and POST mint a + new access token (approved drift). The {"token": ...} response shape is kept. + """ + + jwt_expiration = timedelta(hours=1) + + def get(self) -> TokenResponse: + return TokenResponse(token=self._issue_token()) + + # post() would default to 201; the original returned 200. + @modify(status_code=HTTPStatus.OK) + def post(self) -> TokenResponse: + return TokenResponse(token=self._issue_token()) + + def _issue_token(self) -> str: + now = datetime.now(timezone.utc) + return self.create_jwt_token( + subject=str(self.request.user.pk), + expiration=now + self.jwt_expiration, + token_type='access', + ) From 98214393a63d07a0ed564da8c185ec9913eda7b6 Mon Sep 17 00:00:00 2001 From: straitz Date: Wed, 3 Jun 2026 05:55:39 +0900 Subject: [PATCH 2/2] fix APIError signature and DELETE body for dmr 0.6.0 --- stratoflights_api/pagination.py | 2 +- stratoflights_api/views.py | 39 ++++++++++++++++++++------------- 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/stratoflights_api/pagination.py b/stratoflights_api/pagination.py index 797ac58..6f1fbaf 100644 --- a/stratoflights_api/pagination.py +++ b/stratoflights_api/pagination.py @@ -81,7 +81,7 @@ def page_number_paginate(request, queryset, page_size: int = PAGE_SIZE): try: page = paginator.page(page_number) except InvalidPage: - raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Invalid page.'}) + raise APIError({'detail': 'Invalid page.'}, status_code=HTTPStatus.NOT_FOUND) url = request.build_absolute_uri() diff --git a/stratoflights_api/views.py b/stratoflights_api/views.py index 6f1d8b8..7de0cf3 100644 --- a/stratoflights_api/views.py +++ b/stratoflights_api/views.py @@ -56,6 +56,15 @@ from .validators import base64_to_curve User = get_user_model() +def _api_error(status_code, body, headers=None): + """Adapter for DMR's APIError(raw_data, *, status_code, ...) signature. + + Lets call sites keep the (status_code=, body=) style; the body is DMR's + first positional ``raw_data`` argument. + """ + return APIError(body, status_code=status_code, headers=headers) + + def _resolve_related(model, pk, field_name): """Resolve a related object by PK (was PrimaryKeyRelatedField(queryset=...)). @@ -66,7 +75,7 @@ def _resolve_related(model, pk, field_name): return None obj = model.objects.filter(pk=pk).first() if obj is None: - raise APIError( + raise _api_error( status_code=HTTPStatus.BAD_REQUEST, body={field_name: [f'Invalid pk "{pk}" - object does not exist.']}, ) @@ -92,7 +101,7 @@ class PredictionCollectionController(Controller[PydanticSerializer]): try: prediction_result = TawhiriClient.get_prediction(parsed_body.model_dump()) except requests.RequestException as exc: - raise APIError( + raise _api_error( status_code=HTTPStatus.BAD_GATEWAY, body={'error': f'Tawhiri error: {str(exc)}'}, ) @@ -135,7 +144,7 @@ class PredictionListUserController(Controller[PydanticSerializer]): filters['created_at__lte'] = parse_datetime(parsed_query.created_till) if parsed_query.satellite_id: if not user.satellites.filter(id=parsed_query.satellite_id).exists(): - raise APIError(status_code=HTTPStatus.FORBIDDEN, body={'detail': 'Access denied'}) + raise _api_error(status_code=HTTPStatus.FORBIDDEN, body={'detail': 'Access denied'}) filters['satellite_id'] = parsed_query.satellite_id queryset = Prediction.objects.filter(**filters) @@ -157,7 +166,7 @@ class PredictionDetailController(Controller[PydanticSerializer]): prediction = Prediction.objects.filter( user=self.request.user, pk=parsed_path.pk).first() if prediction is None: - raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'}) + raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'}) return PredictionDetailOut.model_validate(prediction) @@ -169,7 +178,7 @@ class PredictionDeleteController(Controller[PydanticSerializer]): prediction = Prediction.objects.filter( user=self.request.user, pk=parsed_path.pk).first() if prediction is None: - raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'}) + raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'}) prediction.delete() @@ -260,14 +269,14 @@ class LoginController(Controller[PydanticSerializer]): password = data.get('password') if username is None or password is None: - raise APIError( + raise _api_error( status_code=HTTPStatus.BAD_REQUEST, body={'detail': 'Please provide username and password.'}, ) user = authenticate(username=username, password=password) if user is None: - raise APIError( + raise _api_error( status_code=HTTPStatus.BAD_REQUEST, body={'detail': 'Invalid credentials.'}, ) @@ -284,7 +293,7 @@ class LogoutController(Controller[PydanticSerializer]): @modify(status_code=HTTPStatus.OK) def post(self) -> DetailResponse: if not self.request.user.is_authenticated: - raise APIError( + raise _api_error( status_code=HTTPStatus.BAD_REQUEST, body={'detail': "You're not logged in."}, ) @@ -302,7 +311,7 @@ def _check_saved_point_unique(user, name, exclude_pk=None): if exclude_pk is not None: qs = qs.exclude(pk=exclude_pk) if qs.exists(): - raise APIError( + raise _api_error( status_code=HTTPStatus.BAD_REQUEST, body={'non_field_errors': [_SAVED_POINT_DUPLICATE]}, ) @@ -330,7 +339,7 @@ class SavedPointDetailController(Controller[PydanticSerializer]): # matching DRF's get_object() over a user-scoped queryset. obj = SavedPoint.objects.filter(user=self.request.user, pk=pk).first() if obj is None: - raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'}) + raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'}) return obj def get(self, parsed_path: Path[PkPath]) -> SavedPointOut: @@ -383,7 +392,7 @@ class PredictionTemplateDetailController(Controller[PydanticSerializer]): def _get_object(self, pk: int): obj = PreditctionTemplate.objects.filter(user=self.request.user, pk=pk).first() if obj is None: - raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'}) + raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'}) return obj def get(self, parsed_path: Path[PkPath]) -> PredictionTemplateOut: @@ -432,7 +441,7 @@ class ChangePasswordController(Controller[PydanticSerializer]): user = self.request.user if not user.check_password(parsed_body.old_password): - raise APIError( + raise _api_error( status_code=HTTPStatus.BAD_REQUEST, body={'detail': 'Old password is incorrect'}, ) @@ -454,12 +463,12 @@ class DeleteAccountController(Controller[PydanticSerializer]): try: parsed_body = DeleteAccountIn(**json.loads(self.request.body or b'{}')) except ValidationError as exc: - raise APIError(status_code=HTTPStatus.BAD_REQUEST, body=json.loads(exc.json())) + raise _api_error(status_code=HTTPStatus.BAD_REQUEST, body=json.loads(exc.json())) except ValueError: - raise APIError(status_code=HTTPStatus.BAD_REQUEST, body={'detail': 'Invalid request body.'}) + raise _api_error(status_code=HTTPStatus.BAD_REQUEST, body={'detail': 'Invalid request body.'}) if not user.check_password(parsed_body.password): - raise APIError( + raise _api_error( status_code=HTTPStatus.BAD_REQUEST, body={'detail': 'Incorrect password'}, )