migrated to modern-rest

This commit is contained in:
straitz 2026-06-03 05:40:35 +09:00
parent d9a92569f0
commit 8e44c4501a
11 changed files with 1014 additions and 572 deletions

View file

@ -1,8 +1,7 @@
Django>=4.0,<5.0 # django-modern-rest 0.6.0 requires Django 5.x (upgraded from 4.2).
djangorestframework Django>=5.0,<6.0
djangorestframework-simplejwt django-modern-rest[jwt,pydantic]==0.6.0
psycopg2-binary psycopg2-binary
drf-spectacular
requests requests
django-cors-headers django-cors-headers
Pillow Pillow

View file

@ -13,6 +13,12 @@ https://docs.djangoproject.com/en/4.2/ref/settings/
from pathlib import Path from pathlib import Path
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
from dmr.settings import Settings
from dmr.security.django_session import DjangoSessionSyncAuth
from dmr.security.jwt import JWTSyncAuth
from dmr.openapi import OpenAPIConfig
load_dotenv() load_dotenv()
# Build paths inside the project like this: BASE_DIR / 'subdir'. # Build paths inside the project like this: BASE_DIR / 'subdir'.
@ -51,12 +57,10 @@ INSTALLED_APPS = [
'django.contrib.sessions', 'django.contrib.sessions',
'django.contrib.messages', 'django.contrib.messages',
'django.contrib.staticfiles', 'django.contrib.staticfiles',
'rest_framework',
'rest_framework.authtoken',
'drf_spectacular',
'corsheaders', 'corsheaders',
'stratoflights_api.apps.StratoflightsApiConfig', 'stratoflights_api.apps.StratoflightsApiConfig',
'channels', 'channels',
'dmr', # required to serve OpenAPI docs static assets
] ]
MIDDLEWARE = [ MIDDLEWARE = [
@ -165,25 +169,16 @@ STATIC_URL = 'static/'
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
AUTH_USER_MODEL = 'stratoflights_api.User' AUTH_USER_MODEL = 'stratoflights_api.User'
REST_FRAMEWORK = { # django-modern-rest configuration.
'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination', # Global auth mirrors the previous DRF default (Token + Session); the DRF
'PAGE_SIZE': 100, # TokenAuthentication is replaced by DMR's JWT (approved drift). Providing
# several auth instances means at least one of them must succeed, so this
'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema', # also enforces a 401 by default like DRF's global IsAuthenticated did.
# Public endpoints opt out per-endpoint with @modify(auth=None).
'DEFAULT_AUTHENTICATION_CLASSES': [ DMR_SETTINGS = {
'rest_framework.authentication.TokenAuthentication', Settings.auth: [JWTSyncAuth(), DjangoSessionSyncAuth()],
'rest_framework.authentication.SessionAuthentication', Settings.validate_responses: not PRODUCTION,
], Settings.openapi_config: OpenAPIConfig(title='Stratoflights API', version='1.0.0'),
'DEFAULT_PERMISSION_CLASSES': [
'rest_framework.permissions.IsAuthenticated',
# 'rest_framework.permissions.AllowAny',
],
'DEFAULT_RENDERER_CLASSES': [
'rest_framework.renderers.JSONRenderer',
],
} }

View file

@ -1,27 +1,17 @@
""" """URL configuration for the stratoflights project."""
URL configuration for stratoflights project.
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/4.2/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
2. Add a URL to urlpatterns: path('', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
from django.contrib import admin from django.contrib import admin
from django.urls import path, include from django.urls import include
from drf_spectacular.views import SpectacularAPIView from dmr.openapi import build_schema
from drf_spectacular.views import SpectacularSwaggerView from dmr.openapi.views import OpenAPIJsonView, SwaggerView
from dmr.routing import path
from stratoflights_api.urls import router
schema = build_schema(router)
urlpatterns = [ urlpatterns = [
path('admin/', admin.site.urls), path('admin/', admin.site.urls),
path('api/', include('stratoflights_api.urls')), path(router.prefix, include(router.urls)),
path('api/schema/', SpectacularAPIView.as_view(), name='schema'), path('api/schema/', OpenAPIJsonView.as_view(schema), name='schema'),
path('api/docs/', SpectacularSwaggerView.as_view(url_name='schema'), name='docs'), path('api/docs/', SwaggerView.as_view(schema), name='docs'),
] ]

View file

@ -18,17 +18,19 @@ class TelemetryConsumer(AsyncWebsocketConsumer):
async def receive(self, text_data): async def receive(self, text_data):
from .serializers import TelemetryPacketSerializer from pydantic import ValidationError
from .dtos import TelemetryIn
if not self.write_enabled: if not self.write_enabled:
await self.send(text_data=json.dumps({"error": "Read-only mode"})) await self.send(text_data=json.dumps({"error": "Read-only mode"}))
return return
data = json.loads(text_data) data = json.loads(text_data)
serializer = TelemetryPacketSerializer(data=data)
if not serializer.is_valid(): try:
await self.send(text_data=json.dumps({"error": serializer.errors})) TelemetryIn(**data)
except ValidationError as exc:
await self.send(text_data=json.dumps({"error": json.loads(exc.json())}))
return return
saved_data = await self.save_telemetry(data) saved_data = await self.save_telemetry(data)
@ -51,15 +53,6 @@ class TelemetryConsumer(AsyncWebsocketConsumer):
async def telemetry_message(self, event): async def telemetry_message(self, event):
await self.send(text_data=json.dumps(event["data"])) await self.send(text_data=json.dumps(event["data"]))
@database_sync_to_async
def get_user_from_token(self, token_key, Token):
from rest_framework.authtoken.models import Token
User = get_user_model()
try:
token = Token.objects.select_related("user").get(key=token_key)
return token.user
except Token.DoesNotExist:
return None
@database_sync_to_async @database_sync_to_async
def save_telemetry(self, data): def save_telemetry(self, data):
@ -99,12 +92,20 @@ class StationTelemetryConsumer(TelemetryConsumer):
write_enabled = True write_enabled = True
async def connect(self): async def connect(self):
from rest_framework.authtoken.models import Token from django.conf import settings
from dmr.security.jwt.token import JWToken
token_key = self.scope["query_string"].decode().split("token=")[-1] token_key = self.scope["query_string"].decode().split("token=")[-1]
try: try:
token = await database_sync_to_async(Token.objects.select_related("user").get)(key=token_key) decoded = JWToken.decode(
self.scope["user"] = token.user encoded_token=token_key,
except Token.DoesNotExist: secret=settings.SECRET_KEY,
algorithm='HS256',
)
self.scope["user"] = await database_sync_to_async(
get_user_model().objects.get
)(pk=decoded.sub)
except Exception:
await self.close() await self.close()
return return

View file

@ -1,17 +0,0 @@
from rest_framework.pagination import LimitOffsetPagination
from rest_framework.response import Response
class CustomLimitOffsetPagination(LimitOffsetPagination):
limit_query_param = 'limit'
offset_query_param = 'skip'
max_limit = 100
default_limit = 10
def get_paginated_response(self, data):
return Response({
'total': self.count,
'limit': self.limit,
'skip': self.offset,
'predictions': data
})

333
stratoflights_api/dtos.py Normal file
View file

@ -0,0 +1,333 @@
"""Pydantic DTOs replacing the DRF serializers from serializers.py.
Mapping is 1:1 with the former serializers. Output DTOs enable ``from_attributes``
so an endpoint can build them straight from a Django model instance via
``SomeOut.model_validate(obj)`` (the equivalent of ``Serializer(obj).data``).
Field-level and cross-field validation that DRF kept in ``validate_<field>()`` /
``validate()`` is reproduced with pydantic ``field_validator`` / ``model_validator``.
Business logic DRF kept in ``create()`` (base64 curve decoding, ORM writes, and
resolving related objects from their PKs) is intentionally NOT here -- per DMR's
split it moves into the endpoints (Step 2.9).
"""
import uuid
from datetime import datetime
from typing import Any, Optional
from django.contrib.auth.password_validation import validate_password
from django.core.validators import validate_email
from django.core.exceptions import ValidationError as DjangoValidationError
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
model_validator,
)
from .validators import validate_custom_curve, rate_clip
# Profile constants (moved verbatim from serializers.py).
PROFILE_STANDARD = "standard_profile"
PROFILE_FLOAT = "float_profile"
PROFILE_REVERSE = "reverse_profile"
PROFILE_CUSTOM = "custom_profile"
LATEST_DATASET_KEYWORD = "latest"
SUPPORTED_PROFILES = [PROFILE_STANDARD, PROFILE_FLOAT, PROFILE_REVERSE, PROFILE_CUSTOM]
# --- Prediction request (was PredictionRequestSerializer) ---------------------
class PredictionRequest(BaseModel):
launch_latitude: float = Field(ge=-90, le=90)
launch_longitude: float = Field(ge=0, le=360)
launch_datetime: datetime
launch_altitude: Optional[float] = None
format: str = "json"
profile: str = PROFILE_STANDARD
dataset: str = LATEST_DATASET_KEYWORD
# profile-dependent fields
ascent_rate: Optional[float] = Field(default=None, ge=0.01)
descent_rate: Optional[float] = Field(default=None, ge=0.01)
burst_altitude: Optional[float] = None
float_altitude: Optional[float] = None
stop_datetime: Optional[datetime] = None
ascent_curve: Optional[str] = None
descent_curve: Optional[str] = None
interpolate: bool = False
# Related objects are accepted as PKs; existence is resolved in the endpoint
# (was PrimaryKeyRelatedField(queryset=...)).
start_point: Optional[int] = None
rate_profile: Optional[int] = None
template: Optional[int] = None
@field_validator("profile")
@classmethod
def _validate_profile(cls, value: str) -> str:
if value not in SUPPORTED_PROFILES:
raise ValueError(f'"{value}" is not a valid choice.')
return value
@model_validator(mode="after")
def _validate_cross_fields(self) -> "PredictionRequest":
launch_alt = self.launch_altitude if self.launch_altitude is not None else 0
if self.profile == PROFILE_STANDARD:
if self.burst_altitude is None:
raise ValueError("burst_altitude is required for standard profile.")
if self.burst_altitude <= launch_alt:
raise ValueError("burst_altitude must be greater than launch_altitude.")
elif self.profile == PROFILE_FLOAT:
if self.float_altitude is None or self.float_altitude <= launch_alt:
raise ValueError("float_altitude must be greater than launch_altitude.")
if self.stop_datetime is None or self.stop_datetime <= self.launch_datetime:
raise ValueError("stop_datetime must be later than launch_datetime.")
elif self.profile == PROFILE_CUSTOM:
if self.ascent_curve is None or not validate_custom_curve(self.ascent_curve):
raise ValueError("Invalid ascent_curve.")
if self.descent_curve is None or not validate_custom_curve(self.descent_curve):
raise ValueError("Invalid descent_curve.")
if self.burst_altitude is None or self.burst_altitude <= launch_alt:
raise ValueError("burst_altitude must be greater than launch_altitude.")
# custom clipping logic (was in validate())
if self.ascent_rate is not None:
self.ascent_rate = rate_clip(self.ascent_rate)
if self.descent_rate is not None:
self.descent_rate = rate_clip(self.descent_rate)
return self
# --- Prediction outputs -------------------------------------------------------
class PredictionOut(BaseModel):
"""Was PredictionSerializer."""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
created_at: datetime
updated_at: datetime
result: Any
class PredictionListOut(BaseModel):
"""Was PredictionListSerializer."""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
created_at: datetime
updated_at: datetime
start_point: Optional[int] = Field(default=None, validation_alias="start_point_id")
template: Optional[int] = Field(default=None, validation_alias="template_id")
rate_profile: Optional[int] = Field(default=None, validation_alias="rate_profile_id")
class PredictionDetailOut(BaseModel):
"""Was PredictionDetailSerializer."""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
created_at: datetime
updated_at: datetime
result: Any
start_point: Optional[int] = Field(default=None, validation_alias="start_point_id")
template: Optional[int] = Field(default=None, validation_alias="template_id")
rate_profile: Optional[int] = Field(default=None, validation_alias="rate_profile_id")
class PredictionCreateOut(BaseModel):
"""Custom create() response shape: {id, created_at, result} (not the full model)."""
id: uuid.UUID
created_at: datetime
result: Any
# --- Telemetry (was TelemetryPacketSerializer) --------------------------------
class TelemetryIn(BaseModel):
# timestamp is required (model BigIntegerField has no default), matching the
# former ModelSerializer; the endpoint still overwrites it with server time.
timestamp: int
lat: float = 0.0
lon: float = 0.0
alt: float = 0.0
payload: Any = Field(default_factory=dict)
class TelemetryOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
timestamp: int
lat: float
lon: float
alt: float
payload: Any
# --- SavedPoint (was SavedPointSerializer) ------------------------------------
# `user` was a HiddenField(CurrentUserDefault) -> set from request in the endpoint,
# not part of the wire input/output. The unique_together (user, name) check that
# DRF did via UniqueTogetherValidator is reproduced in the endpoint (Step 2.8).
class SavedPointIn(BaseModel):
name: str
lat: float = 0.0
lon: float = 0.0
alt: float = 0.0
class SavedPointPatchIn(BaseModel):
# All-optional variant for PATCH (partial_update). Only set fields apply.
name: Optional[str] = None
lat: Optional[float] = None
lon: Optional[float] = None
alt: Optional[float] = None
class SavedPointOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
name: str
lat: float
lon: float
alt: float
# --- SavedRateProfile (was SavedRateProfileSerializer) ------------------------
# NOTE: no view referenced this serializer in routing; kept for parity.
class SavedRateProfileIn(BaseModel):
name: str
type: str = "ascent"
rate_profile_data: Any = Field(default_factory=dict)
class SavedRateProfileOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
name: str
type: str
rate_profile_data: Any
# --- PredictionTemplate (was PreditctionTemplateSerializer) -------------------
class PredictionTemplateIn(BaseModel):
# protected_namespaces=() avoids pydantic's warning about the `model` field.
model_config = ConfigDict(protected_namespaces=())
name: str
is_default: bool = False
description: Optional[str] = None
prediction_mode: str = ""
model: str = ""
dataset: str = ""
flight_parameters: Any = Field(default_factory=dict)
class PredictionTemplatePatchIn(BaseModel):
# All-optional variant for PATCH (partial_update). Only set fields apply.
model_config = ConfigDict(protected_namespaces=())
name: Optional[str] = None
is_default: Optional[bool] = None
description: Optional[str] = None
prediction_mode: Optional[str] = None
model: Optional[str] = None
dataset: Optional[str] = None
flight_parameters: Optional[Any] = None
class PredictionTemplateOut(BaseModel):
model_config = ConfigDict(from_attributes=True, protected_namespaces=())
id: int
name: str
is_default: bool
description: Optional[str] = None
prediction_mode: str
model: str
dataset: str
flight_parameters: Any
# --- User (was UserSerializer) ------------------------------------------------
class UserOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
username: str
email: str
first_name: str
last_name: str
class UserUpdateIn(BaseModel):
# `username` was read_only -> excluded from input. PATCH is partial, so all
# fields are optional; the endpoint applies only fields that were set.
email: Optional[str] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
@field_validator("email")
@classmethod
def _validate_email(cls, value: Optional[str]) -> Optional[str]:
if value is None:
return value
try:
validate_email(value)
except DjangoValidationError:
raise ValueError("Invalid email format")
return value
# --- Password / account (were ChangePasswordSerializer / DeleteAccountSerializer)
class ChangePasswordIn(BaseModel):
old_password: str
new_password: str
@field_validator("new_password")
@classmethod
def _validate_new_password(cls, value: str) -> str:
# validate_password raises Django's ValidationError (not a ValueError),
# which pydantic would not convert into a 4xx -- re-raise as ValueError.
try:
validate_password(value)
except DjangoValidationError as exc:
raise ValueError("; ".join(exc.messages))
return value
class DeleteAccountIn(BaseModel):
password: str
# --- Auth / session response shapes (were inline JsonResponse dicts) ----------
class DetailResponse(BaseModel):
detail: str
class SessionResponse(BaseModel):
isAuthenticated: bool
class WhoAmIResponse(BaseModel):
username: str
class TokenResponse(BaseModel):
token: str
# --- Shared path parameters ---------------------------------------------------
class PkPath(BaseModel):
pk: int
class UuidPkPath(BaseModel):
pk: uuid.UUID

View file

@ -0,0 +1,120 @@
"""Limit/offset pagination for the prediction list endpoints.
DMR has no built-in limit/offset paginator, so (as the docs suggest) we keep our
own envelope model -- matching the former ``CustomLimitOffsetPagination`` output
exactly -- plus a helper that slices the queryset. The query params and envelope
keys are preserved verbatim: ``limit`` / ``skip`` in, and
``{total, limit, skip, predictions}`` out.
"""
from http import HTTPStatus
from typing import Optional
from urllib.parse import urlsplit, urlunsplit
from django.core.paginator import InvalidPage, Paginator
from django.http import QueryDict
from pydantic import BaseModel
from dmr.response import APIError
from .dtos import PredictionOut, TelemetryOut
DEFAULT_LIMIT = 10 # was CustomLimitOffsetPagination.default_limit
MAX_LIMIT = 100 # was CustomLimitOffsetPagination.max_limit
PAGE_SIZE = 100 # global REST_FRAMEWORK PAGE_SIZE (PageNumberPagination)
class PaginationQuery(BaseModel):
# Param names preserved: limit_query_param='limit', offset_query_param='skip'.
limit: int = DEFAULT_LIMIT
skip: int = 0
class PredictionListUserQuery(BaseModel):
"""Query params for PredictionViewSet.list_user: filters + limit/offset."""
satellite_id: Optional[str] = None
created_from: Optional[str] = None
created_till: Optional[str] = None
limit: int = DEFAULT_LIMIT
skip: int = 0
class PredictionPage(BaseModel):
total: int
limit: int
skip: int
predictions: list[PredictionOut]
class TelemetryPage(BaseModel):
count: int
next: Optional[str] = None
previous: Optional[str] = None
results: list[TelemetryOut]
def _replace_query_param(url: str, key: str, value) -> str:
scheme, netloc, path, query, fragment = urlsplit(url)
query_dict = QueryDict(query, mutable=True)
query_dict[key] = value
return urlunsplit((scheme, netloc, path, query_dict.urlencode(), fragment))
def _remove_query_param(url: str, key: str) -> str:
scheme, netloc, path, query, fragment = urlsplit(url)
query_dict = QueryDict(query, mutable=True)
query_dict.pop(key, None)
return urlunsplit((scheme, netloc, path, query_dict.urlencode(), fragment))
def page_number_paginate(request, queryset, page_size: int = PAGE_SIZE):
"""Reproduce DRF PageNumberPagination.
Returns ``(count, next_link, previous_link, object_list)``. Raises a 404
"Invalid page." like DRF on an out-of-range / non-integer page. Honours
``page=last`` and builds absolute next/previous links off the current URL.
"""
paginator = Paginator(queryset, page_size)
page_number = request.GET.get('page', 1)
if page_number in ('last',):
page_number = paginator.num_pages
try:
page = paginator.page(page_number)
except InvalidPage:
raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Invalid page.'})
url = request.build_absolute_uri()
next_link = None
if page.has_next():
next_link = _replace_query_param(url, 'page', page.next_page_number())
previous_link = None
if page.has_previous():
previous_number = page.previous_page_number()
if previous_number == 1:
previous_link = _remove_query_param(url, 'page')
else:
previous_link = _replace_query_param(url, 'page', previous_number)
return paginator.count, next_link, previous_link, list(page.object_list)
def paginate_predictions(queryset, query: PaginationQuery) -> PredictionPage:
"""Slice ``queryset`` and build the prediction page envelope.
Mirrors the bounds DRF's LimitOffsetPagination enforced: limit falls back to
the default when non-positive and is capped at MAX_LIMIT; skip floors at 0.
"""
limit = query.limit if query.limit > 0 else DEFAULT_LIMIT
limit = min(limit, MAX_LIMIT)
skip = query.skip if query.skip >= 0 else 0
total = queryset.count()
page = list(queryset[skip:skip + limit])
return PredictionPage(
total=total,
limit=limit,
skip=skip,
predictions=[PredictionOut.model_validate(obj) for obj in page],
)

View file

@ -1,16 +0,0 @@
from rest_framework.permissions import BasePermission, SAFE_METHODS
class ReadOnlyOrAuthenticated(BasePermission):
def has_permission(self, request, view):
return (
request.method in SAFE_METHODS or
request.user and request.user.is_authenticated
)
class IsOwner(BasePermission):
def has_object_permission(self, request, view, obj):
return obj.user == request.user
def has_permission(self, request, view):
return request.user and request.user.is_authenticated

View file

@ -1,185 +0,0 @@
from rest_framework import serializers
from .models import Prediction, SavedPoint, SavedRateProfile, PreditctionTemplate
from datetime import datetime
from django.contrib.auth.password_validation import validate_password
from django.core.validators import validate_email
from django.core.exceptions import ValidationError as DjangoValidationError
from django.contrib.auth import get_user_model
from .validators import (
validate_custom_curve, rate_clip,
_rfc3339_to_timestamp, base64_to_curve
)
class PredictionSerializer(serializers.ModelSerializer):
class Meta:
model = Prediction
fields = ['id', 'created_at', 'updated_at', 'result']
User = get_user_model()
PROFILE_STANDARD = "standard_profile"
PROFILE_FLOAT = "float_profile"
PROFILE_REVERSE = "reverse_profile"
PROFILE_CUSTOM = "custom_profile"
LATEST_DATASET_KEYWORD = "latest"
SUPPORTED_PROFILES = [PROFILE_STANDARD, PROFILE_FLOAT, PROFILE_REVERSE, PROFILE_CUSTOM]
class PredictionRequestSerializer(serializers.Serializer):
launch_latitude = serializers.FloatField(min_value=-90, max_value=90)
launch_longitude = serializers.FloatField(min_value=0, max_value=360)
launch_datetime = serializers.DateTimeField()
launch_altitude = serializers.FloatField(required=False)
format = serializers.CharField(default="json")
profile = serializers.ChoiceField(choices=SUPPORTED_PROFILES, default=PROFILE_STANDARD)
dataset = serializers.CharField(default=LATEST_DATASET_KEYWORD)
# --- профиль-dependent поля ---
ascent_rate = serializers.FloatField(required=False, min_value=0.01)
descent_rate = serializers.FloatField(required=False, min_value=0.01)
burst_altitude = serializers.FloatField(required=False)
float_altitude = serializers.FloatField(required=False)
stop_datetime = serializers.DateTimeField(required=False)
ascent_curve = serializers.CharField(required=False)
descent_curve = serializers.CharField(required=False)
interpolate = serializers.BooleanField(required=False, default=False)
start_point = serializers.PrimaryKeyRelatedField(
queryset=SavedPoint.objects.all(), required=False, allow_null=True
)
rate_profile = serializers.PrimaryKeyRelatedField(
queryset=SavedRateProfile.objects.all(), required=False, allow_null=True
)
template = serializers.PrimaryKeyRelatedField(
queryset=PreditctionTemplate.objects.all(), required=False, allow_null=True
)
def validate(self, data):
profile = data.get("profile", PROFILE_STANDARD)
launch_alt = data.get("launch_altitude", 0)
if profile == PROFILE_STANDARD:
if 'burst_altitude' not in data:
raise serializers.ValidationError("burst_altitude is required for standard profile.")
if data['burst_altitude'] <= launch_alt:
raise serializers.ValidationError("burst_altitude must be greater than launch_altitude.")
elif profile == PROFILE_FLOAT:
if 'float_altitude' not in data or data['float_altitude'] <= launch_alt:
raise serializers.ValidationError("float_altitude must be greater than launch_altitude.")
if 'stop_datetime' not in data or data['stop_datetime'] <= data['launch_datetime']:
raise serializers.ValidationError("stop_datetime must be later than launch_datetime.")
elif profile == PROFILE_CUSTOM:
if 'ascent_curve' not in data or not validate_custom_curve(data['ascent_curve']):
raise serializers.ValidationError("Invalid ascent_curve.")
if 'descent_curve' not in data or not validate_custom_curve(data['descent_curve']):
raise serializers.ValidationError("Invalid descent_curve.")
if 'burst_altitude' not in data or data['burst_altitude'] <= launch_alt:
raise serializers.ValidationError("burst_altitude must be greater than launch_altitude.")
# кастомная логика clipping'а
if 'ascent_rate' in data:
data['ascent_rate'] = rate_clip(data['ascent_rate'])
if 'descent_rate' in data:
data['descent_rate'] = rate_clip(data['descent_rate'])
return data
def create(self, validated_data):
if 'ascent_curve' in validated_data:
validated_data['ascent_curve'] = base64_to_curve(validated_data['ascent_curve'])
if 'descent_curve' in validated_data:
validated_data['descent_curve'] = base64_to_curve(validated_data['descent_curve'])
prediction = Prediction(
user=validated_data.get('user'),
request=validated_data.get('request', {}),
result=validated_data.get('result', {}),
start_point=validated_data.get('start_point'),
template=validated_data.get('template'),
rate_profile=validated_data.get('rate_profile')
)
prediction.save()
return prediction
class PredictionListSerializer(serializers.ModelSerializer):
class Meta:
model = Prediction
fields = ["id", "created_at", "updated_at", "start_point", "template", "rate_profile"]
class PredictionDetailSerializer(serializers.ModelSerializer):
class Meta:
model = Prediction
fields = ["id", "created_at", "updated_at", "result", "start_point", "template", "rate_profile"]
from rest_framework import serializers
from .models import TelemetryPacket
class TelemetryPacketSerializer(serializers.ModelSerializer):
class Meta:
model = TelemetryPacket
fields = ['id', 'timestamp', 'lat', 'lon', 'alt', 'payload']
read_only_fields = ['id']
class SavedPointSerializer(serializers.ModelSerializer):
user = serializers.HiddenField(
default=serializers.CurrentUserDefault()
)
class Meta:
model = SavedPoint
fields = ['user', 'id', 'name', 'lat', 'lon', 'alt']
read_only_fields = ['id']
validators = [
serializers.UniqueTogetherValidator(
queryset=SavedPoint.objects.all(),
fields=['user', 'name'],
message="A saved point with this name already exists for the user."
)
]
class SavedRateProfileSerializer(serializers.ModelSerializer):
class Meta:
model = SavedRateProfile
fields = ['id', 'name', 'type', 'rate_profile_data']
read_only_fields = ['id']
class PreditctionTemplateSerializer(serializers.ModelSerializer):
class Meta:
model = PreditctionTemplate
fields = ['id', 'name', 'is_default', 'description', 'prediction_mode', 'model', 'dataset', 'flight_parameters']
read_only_fields = ['id']
class UserSerializer(serializers.ModelSerializer):
class Meta:
model = User
fields = ['username', 'email', 'first_name', 'last_name']
extra_kwargs = {
'username': {'read_only': True}
}
def validate_email(self, value):
try:
validate_email(value)
except DjangoValidationError:
raise serializers.ValidationError("Invalid email format")
return value
class ChangePasswordSerializer(serializers.Serializer):
old_password = serializers.CharField(required=True)
new_password = serializers.CharField(required=True)
def validate_new_password(self, value):
validate_password(value)
return value
class DeleteAccountSerializer(serializers.Serializer):
password = serializers.CharField(required=True)

View file

@ -1,48 +1,59 @@
from django.urls import path from dmr.routing import Router, path
from rest_framework.routers import DefaultRouter
from rest_framework.authtoken.views import obtain_auth_token
from .views import ( from .views import (
PredictionViewSet, PredictionCollectionController,
SavedPointViewset, PredictionListUserController,
PreditctionTemplateViewset, PredictionHistoryController,
TelemetryListCreateView, PredictionDetailController,
get_csrf, PredictionDeleteController,
login_view, SavedPointListController,
logout_view, SavedPointDetailController,
SessionView, PredictionTemplateListController,
WhoAmIView, PredictionTemplateDetailController,
UserProfileView, TelemetryController,
ChangePasswordView, CsrfController,
TokenManagementView, LoginController,
DeleteUserDataView, LogoutController,
DeleteAccountView SessionController,
WhoAmIController,
UserProfileController,
ChangePasswordController,
ObtainTokenController,
TokenManagementController,
DeleteUserDataController,
DeleteAccountController,
) )
# A Router (prefix + routes) is required so build_schema() can generate the
# OpenAPI document; see stratoflights/urls.py.
router = Router(
'api/',
[
path("csrf/", CsrfController.as_view(), name='api-csrf'),
path('token', ObtainTokenController.as_view(), name='get_token'),
path("login/", LoginController.as_view(), name='api-login'),
path("logout/", LogoutController.as_view(), name='api-logout'),
path("session/", SessionController.as_view(), name='api-session'),
path("whoami/", WhoAmIController.as_view(), name='api-whoami'),
path("<uuid:pk>/telemetry/", TelemetryController.as_view(), name="create_telemetry"),
path("profile/", UserProfileController.as_view(), name='api-profile'),
path("profile/change-password/", ChangePasswordController.as_view(), name='api-change-password'),
path("profile/token/", TokenManagementController.as_view(), name='api-token'),
path("profile/delete-data/", DeleteUserDataController.as_view(), name='api-delete-data'),
path("profile/delete-account/", DeleteAccountController.as_view(), name='api-delete-account'),
# Saved points (was SavedPointViewset via DefaultRouter).
path("saved-points/", SavedPointListController.as_view(), name='saved-points-list'),
path("saved-points/<int:pk>/", SavedPointDetailController.as_view(), name='saved-points-detail'),
# Prediction templates (was PreditctionTemplateViewset via DefaultRouter).
path("saved-templates/", PredictionTemplateListController.as_view(), name='saved-templates-list'),
path("saved-templates/<int:pk>/", PredictionTemplateDetailController.as_view(), name='saved-templates-detail'),
# Predictions (was PredictionViewSet via DefaultRouter).
path("predictions/", PredictionCollectionController.as_view(), name='predictions-list'),
path("predictions/list_user/", PredictionListUserController.as_view(), name='predictions-list-user'),
path("predictions/history/", PredictionHistoryController.as_view(), name='predictions-history'),
path("predictions/<uuid:pk>/detail/", PredictionDetailController.as_view(), name='predictions-detail'),
path("predictions/<uuid:pk>/delete/", PredictionDeleteController.as_view(), name='predictions-delete'),
],
)
router = DefaultRouter() urlpatterns = router.urls
router.register(r'predictions', PredictionViewSet, basename='predictions')
router.register(r'saved-points', SavedPointViewset, basename='saved-points')
router.register(r'saved-templates', PreditctionTemplateViewset, basename='saved-templates')
urlpatterns = [
path("csrf/", get_csrf, name='api-csrf'),
path('token', obtain_auth_token, name = 'get_token'),
path("login/", login_view, name='api-login'),
path("logout/", logout_view, name='api-logout'),
path("session/", SessionView.as_view(), name='api-session'),
path("whoami/", WhoAmIView.as_view(), name='api-whoami'),
path("<uuid:pk>/telemetry/", TelemetryListCreateView.as_view(), name="create_telemetry"),
path('csrf/', get_csrf, name='api-csrf'),
path('login/', login_view, name='api-login'),
path('logout/', logout_view, name='api-logout'),
path('session/', SessionView.as_view(), name='api-session'),
path('whoami/', WhoAmIView.as_view(), name='api-whoami'),
path("profile/", UserProfileView.as_view(), name='api-profile'),
path("profile/change-password/", ChangePasswordView.as_view(), name='api-change-password'),
path("profile/token/", TokenManagementView.as_view(), name='api-token'),
path("profile/delete-data/", DeleteUserDataView.as_view(), name='api-delete-data'),
path("profile/delete-account/", DeleteAccountView.as_view(), name='api-delete-account'),
]
urlpatterns += router.urls

View file

@ -1,305 +1,468 @@
import requests import requests
import time import time
import json import json
from rest_framework import status, generics, permissions from http import HTTPStatus
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.viewsets import ModelViewSet, ViewSet, GenericViewSet
from rest_framework.exceptions import APIException
from rest_framework.permissions import IsAuthenticated, AllowAny
from rest_framework.authentication import SessionAuthentication, BasicAuthentication, TokenAuthentication
from rest_framework.decorators import api_view, permission_classes, authentication_classes, action
from rest_framework.authtoken.models import Token
from django.utils import timezone
from django.views.decorators.csrf import csrf_exempt
from django.utils.decorators import method_decorator
from django.http import JsonResponse
from django.contrib.auth import authenticate, login, logout, get_user_model from django.contrib.auth import authenticate, login, logout, get_user_model
from django.middleware.csrf import get_token from django.middleware.csrf import get_token
from django.core.exceptions import ValidationError
from django.utils.dateparse import parse_datetime from django.utils.dateparse import parse_datetime
from .models import Prediction, User, Satellite, SavedPoint, SavedRateProfile, PreditctionTemplate, TelemetryPacket from .models import Prediction, Satellite, SavedPoint, SavedRateProfile, PreditctionTemplate, TelemetryPacket
from .serializers import PredictionSerializer, TelemetryPacketSerializer, PredictionRequestSerializer, PredictionListSerializer, PredictionDetailSerializer, SavedPointSerializer, SavedRateProfileSerializer, PreditctionTemplateSerializer, UserSerializer, ChangePasswordSerializer, DeleteAccountSerializer
from .services.tawhiri import TawhiriClient from .services.tawhiri import TawhiriClient
from drf_spectacular.utils import extend_schema from datetime import datetime, timedelta, timezone
from .permissions import ReadOnlyOrAuthenticated, IsOwner
from .custom_pagination import CustomLimitOffsetPagination from pydantic import ValidationError
from datetime import datetime from dmr import Controller, modify
from dmr.components import Body, Path, Query
from dmr.plugins.pydantic import PydanticSerializer
from dmr.response import APIError
from dmr.security.jwt.views import (
ObtainTokensPayload,
ObtainTokensResponse,
ObtainTokensSyncController,
)
from .dtos import (
DetailResponse,
SessionResponse,
WhoAmIResponse,
UserOut,
UserUpdateIn,
ChangePasswordIn,
DeleteAccountIn,
TokenResponse,
PkPath,
UuidPkPath,
SavedPointIn,
SavedPointPatchIn,
SavedPointOut,
PredictionTemplateIn,
PredictionTemplatePatchIn,
PredictionTemplateOut,
PredictionRequest,
PredictionOut,
PredictionListOut,
PredictionDetailOut,
PredictionCreateOut,
TelemetryIn,
TelemetryOut,
)
from .pagination import (
PredictionListUserQuery,
PredictionPage,
paginate_predictions,
TelemetryPage,
page_number_paginate,
)
from .validators import base64_to_curve
User = get_user_model() User = get_user_model()
def get_prediction_from_tawhiri(params): def _resolve_related(model, pk, field_name):
"""Resolve a related object by PK (was PrimaryKeyRelatedField(queryset=...)).
base_url = "https://fly.stratonautica.ru/api/v2" The original queryset was ``.all()`` (not user-scoped); a missing PK yields a
response = requests.get(base_url, params=params) 400 matching DRF's "Invalid pk" error shape.
"""
if response.status_code == 200: if pk is None:
return response.json() # получаем результат предсказания return None
else: obj = model.objects.filter(pk=pk).first()
raise Exception( if obj is None:
f"Tawhiri error: {response.status_code} {response.text}") raise APIError(
status_code=HTTPStatus.BAD_REQUEST,
body={field_name: [f'Invalid pk "{pk}" - object does not exist.']},
)
return obj
class PredictionViewSet(GenericViewSet): class PredictionCollectionController(Controller[PydanticSerializer]):
permission_classes = [IsAuthenticated] """`predictions/` -- list the user's predictions and create a new one."""
pagination_class = CustomLimitOffsetPagination
def list(self, request): def get(self) -> list[PredictionOut]:
queryset = Prediction.objects.filter(user=request.user) queryset = Prediction.objects.filter(user=self.request.user)
return Response(PredictionSerializer(queryset, many=True).data) return [PredictionOut.model_validate(obj) for obj in queryset]
def create(self, request): def post(self, parsed_body: Body[PredictionRequest]) -> PredictionCreateOut:
serializer = PredictionRequestSerializer(data=request.data) user = self.request.user
if not serializer.is_valid(): # Resolve related objects before calling Tawhiri, so an invalid PK still
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) # fails with 400 first (as DRF validation did).
start_point = _resolve_related(SavedPoint, parsed_body.start_point, 'start_point')
validated_data = serializer.validated_data template = _resolve_related(PreditctionTemplate, parsed_body.template, 'template')
rate_profile = _resolve_related(SavedRateProfile, parsed_body.rate_profile, 'rate_profile')
try: try:
prediction_result = TawhiriClient.get_prediction(validated_data) prediction_result = TawhiriClient.get_prediction(parsed_body.model_dump())
except requests.RequestException as exc:
raise APIError(
status_code=HTTPStatus.BAD_GATEWAY,
body={'error': f'Tawhiri error: {str(exc)}'},
)
except requests.RequestException as e: # Carried over from the old serializer.create(): curves are decoded but
return Response({"error": f"Tawhiri error: {str(e)}"}, status=status.HTTP_502_BAD_GATEWAY) # never persisted (the model has no curve fields). Kept for parity --
# including that a malformed curve here raises (then surfaces as 500).
if parsed_body.ascent_curve is not None:
base64_to_curve(parsed_body.ascent_curve)
if parsed_body.descent_curve is not None:
base64_to_curve(parsed_body.descent_curve)
# prediction = Prediction.objects.create( prediction = Prediction(
# result=prediction_result, user=request.user, request=request.data, validated_data=validated_data) user=user,
prediction = serializer.save( request=json.loads(self.request.body),
user=request.user, result=prediction_result,
start_point=start_point,
template=template,
rate_profile=rate_profile,
)
prediction.save()
return PredictionCreateOut(
id=prediction.id,
created_at=prediction.created_at,
result=prediction_result, result=prediction_result,
request=request.data
) )
return Response({
"id": prediction.id,
"created_at": prediction.created_at,
"result": prediction_result
}, status=status.HTTP_201_CREATED)
@action(detail=False, methods=['get']) class PredictionListUserController(Controller[PydanticSerializer]):
def list_user(self, request): """`predictions/list_user/` -- filtered, paginated list of the user's predictions."""
user = request.user
satellite_id = request.query_params.get('satellite_id')
created_from = request.query_params.get('created_from')
created_till = request.query_params.get('created_till')
filters = { def get(self, parsed_query: Query[PredictionListUserQuery]) -> PredictionPage:
'user': user, user = self.request.user
} filters = {'user': user}
if created_from: if parsed_query.created_from:
filters['created_at__gte'] = parse_datetime(created_from) filters['created_at__gte'] = parse_datetime(parsed_query.created_from)
if parsed_query.created_till:
if created_till: filters['created_at__lte'] = parse_datetime(parsed_query.created_till)
filters['created_at__lte'] = parse_datetime(created_till) if parsed_query.satellite_id:
if not user.satellites.filter(id=parsed_query.satellite_id).exists():
if satellite_id: raise APIError(status_code=HTTPStatus.FORBIDDEN, body={'detail': 'Access denied'})
if not user.satellites.filter(id=satellite_id).exists(): filters['satellite_id'] = parsed_query.satellite_id
return Response({'detail': 'Access denied'}, status=403)
filters['satellite_id'] = satellite_id
queryset = Prediction.objects.filter(**filters) queryset = Prediction.objects.filter(**filters)
queryset = self.filter_queryset(queryset) return paginate_predictions(queryset, parsed_query)
page = self.paginate_queryset(queryset)
if page is not None: class PredictionHistoryController(Controller[PydanticSerializer]):
serializer = PredictionSerializer(page, many=True) """`predictions/history/` -- compact list of the user's predictions."""
return self.get_paginated_response(serializer.data)
serializer = PredictionSerializer(queryset, many=True) def get(self) -> list[PredictionListOut]:
return Response(serializer.data) queryset = Prediction.objects.filter(user=self.request.user)
return [PredictionListOut.model_validate(obj) for obj in queryset]
@action(detail=False, methods=["get"])
def history(self, request):
queryset = Prediction.objects.filter(user=request.user)
return Response(PredictionListSerializer(queryset, many=True).data)
@action(detail=True, methods=["get"]) class PredictionDetailController(Controller[PydanticSerializer]):
def detail(self, request, pk=None): """`predictions/<pk>/detail/` -- retrieve a single prediction."""
def get(self, parsed_path: Path[UuidPkPath]) -> PredictionDetailOut:
prediction = Prediction.objects.filter( prediction = Prediction.objects.filter(
user=request.user, pk=pk).first() user=self.request.user, pk=parsed_path.pk).first()
if not prediction: if prediction is None:
return Response({'detail': 'Not found'}, status=404) raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'})
return Response(PredictionDetailSerializer(prediction).data) return PredictionDetailOut.model_validate(prediction)
@action(detail=True, methods=["delete"])
def delete(self, request, pk=None): class PredictionDeleteController(Controller[PydanticSerializer]):
"""`predictions/<pk>/delete/` -- delete a single prediction."""
@modify(status_code=HTTPStatus.NO_CONTENT)
def delete(self, parsed_path: Path[UuidPkPath]) -> None:
prediction = Prediction.objects.filter( prediction = Prediction.objects.filter(
user=request.user, pk=pk).first() user=self.request.user, pk=parsed_path.pk).first()
if not prediction: if prediction is None:
return Response({'detail': 'Not found'}, status=404) raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'})
prediction.delete() prediction.delete()
return Response(status=204)
class TelemetryListCreateView(generics.ListCreateAPIView): class TelemetryController(Controller[PydanticSerializer]):
serializer_class = TelemetryPacketSerializer """`<uuid:pk>/telemetry/` -- list and ingest telemetry for a satellite.
permission_classes = [permissions.AllowAny]
def get_queryset(self): Public (was AllowAny). GET preserves the global PageNumberPagination envelope
qs = TelemetryPacket.objects.filter(satellite_id=self.kwargs["pk"]) {count, next, previous, results} with PAGE_SIZE=100.
"""
from_ts = self.request.query_params.get("from") auth = None # public (was AllowAny)
till_ts = self.request.query_params.get("till") csrf_exempt = True # DRF views bypass Django CSRF; ingestion is anonymous
def get(self, parsed_path: Path[UuidPkPath]) -> TelemetryPage:
qs = TelemetryPacket.objects.filter(satellite_id=parsed_path.pk)
from_ts = self.request.GET.get('from')
till_ts = self.request.GET.get('till')
if from_ts: if from_ts:
qs = qs.filter(timestamp__gte=int(from_ts)) qs = qs.filter(timestamp__gte=int(from_ts))
if till_ts: if till_ts:
qs = qs.filter(timestamp__lte=int(till_ts)) qs = qs.filter(timestamp__lte=int(till_ts))
return qs.order_by("-timestamp") qs = qs.order_by('-timestamp')
def post(self, request, pk): count, next_link, previous_link, page_objects = page_number_paginate(self.request, qs)
serializer = TelemetryPacketSerializer(data=request.data) return TelemetryPage(
if not serializer.is_valid(): count=count,
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) next=next_link,
previous=previous_link,
results=[TelemetryOut.model_validate(obj) for obj in page_objects],
)
validated_data = serializer.validated_data def post(self, parsed_path: Path[UuidPkPath], parsed_body: Body[TelemetryIn]) -> TelemetryOut:
# Bug fix (approved): the original returned serializer.errors on success;
TelemetryPacket.objects.create(timestamp=time.time(), # we return the created packet instead. timestamp is still server-set.
satellite=Satellite.objects.get(id=pk), packet = TelemetryPacket.objects.create(
lat=validated_data["lat"], timestamp=time.time(),
lon=validated_data["lon"], satellite=Satellite.objects.get(id=parsed_path.pk),
alt=validated_data["alt"], lat=parsed_body.lat,
payload=validated_data['payload'], lon=parsed_body.lon,
) alt=parsed_body.alt,
return Response(serializer.errors, status=status.HTTP_201_CREATED) payload=parsed_body.payload,
)
return TelemetryOut.model_validate(packet)
class SessionView(APIView): class SessionController(Controller[PydanticSerializer]):
permission_classes = [IsAuthenticated] """Report whether the current request is authenticated."""
@staticmethod def get(self) -> SessionResponse:
def get(request, format=None): return SessionResponse(isAuthenticated=True)
return JsonResponse({'isAuthenticated': True})
class WhoAmIView(APIView): class WhoAmIController(Controller[PydanticSerializer]):
permission_classes = [IsAuthenticated] """Return the current user's username."""
@staticmethod def get(self) -> WhoAmIResponse:
def get(request, format=None): return WhoAmIResponse(username=self.request.user.username)
return JsonResponse({'username': request.user.username})
@extend_schema(methods=["GET"], description="Get CSRF token") class CsrfController(Controller[PydanticSerializer]):
@csrf_exempt """Get CSRF token."""
@api_view(["GET"])
@permission_classes([AllowAny]) auth = None # public endpoint (was AllowAny)
def get_csrf(request): csrf_exempt = True
response = JsonResponse({'detail': 'CSRF cookie set'})
response['X-CSRFToken'] = get_token(request) def get(self) -> DetailResponse:
return response token = get_token(self.request)
return self.to_response(
DetailResponse(detail='CSRF cookie set'),
headers={'X-CSRFToken': token},
)
@extend_schema(methods=["POST"], description="Login user") class LoginController(Controller[PydanticSerializer]):
@csrf_exempt """Login user."""
@api_view(["POST"])
@authentication_classes([BasicAuthentication])
@permission_classes([AllowAny])
def login_view(request):
data = json.loads(request.body)
username = data.get('username')
password = data.get('password')
if username is None or password is None: auth = None # public endpoint (was AllowAny)
return JsonResponse({'detail': 'Please provide username and password.'}, status=400) csrf_exempt = True
user = authenticate(username=username, password=password) # post() would default to 201; the original returned 200.
if user is None: @modify(status_code=HTTPStatus.OK)
return JsonResponse({'detail': 'Invalid credentials.'}, status=400) def post(self) -> DetailResponse:
data = json.loads(self.request.body)
username = data.get('username')
password = data.get('password')
login(request, user) if username is None or password is None:
return JsonResponse({'detail': 'Successfully logged in.'}) raise APIError(
status_code=HTTPStatus.BAD_REQUEST,
body={'detail': 'Please provide username and password.'},
)
user = authenticate(username=username, password=password)
if user is None:
raise APIError(
status_code=HTTPStatus.BAD_REQUEST,
body={'detail': 'Invalid credentials.'},
)
login(self.request, user)
return DetailResponse(detail='Successfully logged in.')
@extend_schema(methods=["POST"], description="Logout user") class LogoutController(Controller[PydanticSerializer]):
@api_view(["POST"]) """Logout user."""
@permission_classes([AllowAny])
def logout_view(request):
if not request.user.is_authenticated:
return JsonResponse({'detail': 'You\'re not logged in.'}, status=400)
logout(request) auth = None # public endpoint (was AllowAny); checks auth state manually
return JsonResponse({'detail': 'Successfully logged out.'})
@modify(status_code=HTTPStatus.OK)
def post(self) -> DetailResponse:
if not self.request.user.is_authenticated:
raise APIError(
status_code=HTTPStatus.BAD_REQUEST,
body={'detail': "You're not logged in."},
)
logout(self.request)
return DetailResponse(detail='Successfully logged out.')
class SavedPointViewset(ModelViewSet): _SAVED_POINT_DUPLICATE = 'A saved point with this name already exists for the user.'
permission_classes = [IsOwner]
serializer_class = SavedPointSerializer
pagination_class = None
def get_queryset(self):
return SavedPoint.objects.filter(user=self.request.user)
def perform_create(self, serializer):
serializer.save(user=self.request.user)
class PreditctionTemplateViewset(ModelViewSet): def _check_saved_point_unique(user, name, exclude_pk=None):
permission_classes = [IsOwner] """Reproduce the former SavedPointSerializer UniqueTogetherValidator."""
serializer_class = PreditctionTemplateSerializer qs = SavedPoint.objects.filter(user=user, name=name)
pagination_class = None if exclude_pk is not None:
qs = qs.exclude(pk=exclude_pk)
def get_queryset(self): if qs.exists():
return PreditctionTemplate.objects.filter(user=self.request.user) raise APIError(
status_code=HTTPStatus.BAD_REQUEST,
def perform_create(self, serializer): body={'non_field_errors': [_SAVED_POINT_DUPLICATE]},
serializer.save(user=self.request.user) )
class UserProfileView(APIView): class SavedPointListController(Controller[PydanticSerializer]):
permission_classes = [IsAuthenticated] """Collection endpoint for the current user's saved points (was SavedPointViewset)."""
def get(self, request): def get(self) -> list[SavedPointOut]:
serializer = UserSerializer(request.user) qs = SavedPoint.objects.filter(user=self.request.user)
return Response(serializer.data) return [SavedPointOut.model_validate(obj) for obj in qs]
def patch(self, request): def post(self, parsed_body: Body[SavedPointIn]) -> SavedPointOut:
user = request.user user = self.request.user
serializer = UserSerializer(user, data=request.data, partial=True) _check_saved_point_unique(user, parsed_body.name)
obj = SavedPoint.objects.create(user=user, **parsed_body.model_dump())
if not serializer.is_valid(): return SavedPointOut.model_validate(obj)
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
serializer.save()
return Response(serializer.data)
class ChangePasswordView(APIView): class SavedPointDetailController(Controller[PydanticSerializer]):
permission_classes = [IsAuthenticated] """Detail endpoint for a single saved point owned by the current user."""
def post(self, request): def _get_object(self, pk: int):
user = request.user # Filtering by user means another user's object reads as 404 (not 403),
serializer = ChangePasswordSerializer(data=request.data) # matching DRF's get_object() over a user-scoped queryset.
obj = SavedPoint.objects.filter(user=self.request.user, pk=pk).first()
if obj is None:
raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'})
return obj
if not serializer.is_valid(): def get(self, parsed_path: Path[PkPath]) -> SavedPointOut:
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) return SavedPointOut.model_validate(self._get_object(parsed_path.pk))
if not user.check_password(serializer.validated_data['old_password']): def put(self, parsed_path: Path[PkPath], parsed_body: Body[SavedPointIn]) -> SavedPointOut:
return Response({'detail': 'Old password is incorrect'}, obj = self._get_object(parsed_path.pk)
status=status.HTTP_400_BAD_REQUEST) _check_saved_point_unique(self.request.user, parsed_body.name, exclude_pk=obj.pk)
for field, value in parsed_body.model_dump().items():
setattr(obj, field, value)
obj.save()
return SavedPointOut.model_validate(obj)
user.set_password(serializer.validated_data['new_password']) def patch(self, parsed_path: Path[PkPath], parsed_body: Body[SavedPointPatchIn]) -> SavedPointOut:
obj = self._get_object(parsed_path.pk)
updates = parsed_body.model_dump(exclude_unset=True)
if 'name' in updates:
_check_saved_point_unique(self.request.user, updates['name'], exclude_pk=obj.pk)
for field, value in updates.items():
setattr(obj, field, value)
obj.save()
return SavedPointOut.model_validate(obj)
@modify(status_code=HTTPStatus.NO_CONTENT)
def delete(self, parsed_path: Path[PkPath]) -> None:
self._get_object(parsed_path.pk).delete()
class PredictionTemplateListController(Controller[PydanticSerializer]):
"""Collection endpoint for the current user's templates (was PreditctionTemplateViewset).
NOTE: as before, there is no app-level uniqueness check; a duplicate
(user, name) hits the model's unique_together and surfaces as a DB error.
"""
def get(self) -> list[PredictionTemplateOut]:
qs = PreditctionTemplate.objects.filter(user=self.request.user)
return [PredictionTemplateOut.model_validate(obj) for obj in qs]
def post(self, parsed_body: Body[PredictionTemplateIn]) -> PredictionTemplateOut:
obj = PreditctionTemplate.objects.create(
user=self.request.user, **parsed_body.model_dump()
)
return PredictionTemplateOut.model_validate(obj)
class PredictionTemplateDetailController(Controller[PydanticSerializer]):
"""Detail endpoint for a single template owned by the current user."""
def _get_object(self, pk: int):
obj = PreditctionTemplate.objects.filter(user=self.request.user, pk=pk).first()
if obj is None:
raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'})
return obj
def get(self, parsed_path: Path[PkPath]) -> PredictionTemplateOut:
return PredictionTemplateOut.model_validate(self._get_object(parsed_path.pk))
def put(self, parsed_path: Path[PkPath], parsed_body: Body[PredictionTemplateIn]) -> PredictionTemplateOut:
obj = self._get_object(parsed_path.pk)
for field, value in parsed_body.model_dump().items():
setattr(obj, field, value)
obj.save()
return PredictionTemplateOut.model_validate(obj)
def patch(self, parsed_path: Path[PkPath], parsed_body: Body[PredictionTemplatePatchIn]) -> PredictionTemplateOut:
obj = self._get_object(parsed_path.pk)
for field, value in parsed_body.model_dump(exclude_unset=True).items():
setattr(obj, field, value)
obj.save()
return PredictionTemplateOut.model_validate(obj)
@modify(status_code=HTTPStatus.NO_CONTENT)
def delete(self, parsed_path: Path[PkPath]) -> None:
self._get_object(parsed_path.pk).delete()
class UserProfileController(Controller[PydanticSerializer]):
"""Read and partially update the current user's profile."""
def get(self) -> UserOut:
return UserOut.model_validate(self.request.user)
def patch(self, parsed_body: Body[UserUpdateIn]) -> UserOut:
user = self.request.user
# partial update: apply only the fields the client actually sent.
for field, value in parsed_body.model_dump(exclude_unset=True).items():
setattr(user, field, value)
user.save() user.save()
return Response({'detail': 'Password changed successfully'}) return UserOut.model_validate(user)
class DeleteAccountView(APIView): class ChangePasswordController(Controller[PydanticSerializer]):
permission_classes = [IsAuthenticated] """Change the current user's password."""
def delete(self, request): # post() would default to 201; the original returned 200.
user = request.user @modify(status_code=HTTPStatus.OK)
serializer = DeleteAccountSerializer(data=request.data) def post(self, parsed_body: Body[ChangePasswordIn]) -> DetailResponse:
user = self.request.user
if not serializer.is_valid(): if not user.check_password(parsed_body.old_password):
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) raise APIError(
status_code=HTTPStatus.BAD_REQUEST,
body={'detail': 'Old password is incorrect'},
)
if not user.check_password(serializer.validated_data['password']): user.set_password(parsed_body.new_password)
return Response({'detail': 'Incorrect password'}, user.save()
status=status.HTTP_400_BAD_REQUEST) return DetailResponse(detail='Password changed successfully')
class DeleteAccountController(Controller[PydanticSerializer]):
"""Delete the current user's account and all their data."""
# DMR forbids a Body component on DELETE, but the original endpoint read the
# password from a DELETE request body. Preserve that contract by parsing the
# body manually instead of via Body[DeleteAccountIn].
def delete(self) -> DetailResponse:
user = self.request.user
try:
parsed_body = DeleteAccountIn(**json.loads(self.request.body or b'{}'))
except ValidationError as exc:
raise APIError(status_code=HTTPStatus.BAD_REQUEST, body=json.loads(exc.json()))
except ValueError:
raise APIError(status_code=HTTPStatus.BAD_REQUEST, body={'detail': 'Invalid request body.'})
if not user.check_password(parsed_body.password):
raise APIError(
status_code=HTTPStatus.BAD_REQUEST,
body={'detail': 'Incorrect password'},
)
Prediction.objects.filter(user=user).delete() Prediction.objects.filter(user=user).delete()
SavedPoint.objects.filter(user=user).delete() SavedPoint.objects.filter(user=user).delete()
@ -307,35 +470,83 @@ class DeleteAccountView(APIView):
user.delete() user.delete()
return Response({'detail': 'Account deleted successfully'}) return DetailResponse(detail='Account deleted successfully')
class DeleteUserDataView(APIView): class DeleteUserDataController(Controller[PydanticSerializer]):
permission_classes = [IsAuthenticated] """Delete all of the current user's data without deleting the account."""
def delete(self, request): def delete(self) -> DetailResponse:
user = request.user user = self.request.user
Prediction.objects.filter(user=user).delete() Prediction.objects.filter(user=user).delete()
SavedPoint.objects.filter(user=user).delete() SavedPoint.objects.filter(user=user).delete()
PreditctionTemplate.objects.filter(user=user).delete() PreditctionTemplate.objects.filter(user=user).delete()
return Response({'detail': 'All user data deleted successfully'}) return DetailResponse(detail='All user data deleted successfully')
class TokenManagementView(APIView): class ObtainTokenController(
permission_classes = [IsAuthenticated] ObtainTokensSyncController[
PydanticSerializer,
ObtainTokensPayload,
ObtainTokensResponse,
],
):
"""Exchange username/password for JWT access + refresh tokens.
def get(self, request): Replaces DRF's obtain_auth_token. Approved drift: the token format and
semantics change from a single stored DRF token to stateless JWTs, so the
response is {access_token, refresh_token} instead of {token}.
"""
token, created = Token.objects.get_or_create(user=request.user) auth = None # public: credentials are supplied in the request body
return Response({"token": token.key}) csrf_exempt = True
jwt_expiration = timedelta(hours=1)
jwt_refresh_expiration = timedelta(days=7)
def post(self, request): def convert_auth_payload(self, payload):
return payload
Token.objects.filter(user=request.user).delete() def make_api_response(self):
token = Token.objects.create(user=request.user) now = datetime.now(timezone.utc)
return Response({"token": token.key}) return {
'access_token': self.create_jwt_token(
expiration=now + self.jwt_expiration,
token_type='access',
),
'refresh_token': self.create_jwt_token(
expiration=now + self.jwt_refresh_expiration,
token_type='refresh',
),
}
class TokenManagementController(Controller[PydanticSerializer]):
"""Issue a fresh JWT access token for the current user.
Was TokenManagementView (DRF stored-token get/regenerate). With stateless
JWTs there is nothing stored to fetch or delete, so both GET and POST mint a
new access token (approved drift). The {"token": ...} response shape is kept.
"""
jwt_expiration = timedelta(hours=1)
def get(self) -> TokenResponse:
return TokenResponse(token=self._issue_token())
# post() would default to 201; the original returned 200.
@modify(status_code=HTTPStatus.OK)
def post(self) -> TokenResponse:
return TokenResponse(token=self._issue_token())
def _issue_token(self) -> str:
now = datetime.now(timezone.utc)
return self.create_jwt_token(
subject=str(self.request.user.pk),
expiration=now + self.jwt_expiration,
token_type='access',
)