Compare commits

..

2 commits

Author SHA1 Message Date
straitz
98214393a6 fix APIError signature and DELETE body for dmr 0.6.0 2026-06-03 05:55:39 +09:00
straitz
8e44c4501a migrated to modern-rest 2026-06-03 05:40:35 +09:00
11 changed files with 1023 additions and 572 deletions

View file

@ -1,8 +1,7 @@
Django>=4.0,<5.0
djangorestframework
djangorestframework-simplejwt
# django-modern-rest 0.6.0 requires Django 5.x (upgraded from 4.2).
Django>=5.0,<6.0
django-modern-rest[jwt,pydantic]==0.6.0
psycopg2-binary
drf-spectacular
requests
django-cors-headers
Pillow

View file

@ -13,6 +13,12 @@ https://docs.djangoproject.com/en/4.2/ref/settings/
from pathlib import Path
import os
from dotenv import load_dotenv
from dmr.settings import Settings
from dmr.security.django_session import DjangoSessionSyncAuth
from dmr.security.jwt import JWTSyncAuth
from dmr.openapi import OpenAPIConfig
load_dotenv()
# Build paths inside the project like this: BASE_DIR / 'subdir'.
@ -51,12 +57,10 @@ INSTALLED_APPS = [
'django.contrib.sessions',
'django.contrib.messages',
'django.contrib.staticfiles',
'rest_framework',
'rest_framework.authtoken',
'drf_spectacular',
'corsheaders',
'stratoflights_api.apps.StratoflightsApiConfig',
'channels',
'dmr', # required to serve OpenAPI docs static assets
]
MIDDLEWARE = [
@ -165,25 +169,16 @@ STATIC_URL = 'static/'
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
AUTH_USER_MODEL = 'stratoflights_api.User'
REST_FRAMEWORK = {
'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',
'PAGE_SIZE': 100,
'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema',
'DEFAULT_AUTHENTICATION_CLASSES': [
'rest_framework.authentication.TokenAuthentication',
'rest_framework.authentication.SessionAuthentication',
],
'DEFAULT_PERMISSION_CLASSES': [
'rest_framework.permissions.IsAuthenticated',
# 'rest_framework.permissions.AllowAny',
],
'DEFAULT_RENDERER_CLASSES': [
'rest_framework.renderers.JSONRenderer',
],
# django-modern-rest configuration.
# Global auth mirrors the previous DRF default (Token + Session); the DRF
# TokenAuthentication is replaced by DMR's JWT (approved drift). Providing
# several auth instances means at least one of them must succeed, so this
# also enforces a 401 by default like DRF's global IsAuthenticated did.
# Public endpoints opt out per-endpoint with @modify(auth=None).
DMR_SETTINGS = {
Settings.auth: [JWTSyncAuth(), DjangoSessionSyncAuth()],
Settings.validate_responses: not PRODUCTION,
Settings.openapi_config: OpenAPIConfig(title='Stratoflights API', version='1.0.0'),
}

View file

@ -1,27 +1,17 @@
"""
URL configuration for stratoflights project.
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/4.2/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
2. Add a URL to urlpatterns: path('', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
"""URL configuration for the stratoflights project."""
from django.contrib import admin
from django.urls import path, include
from drf_spectacular.views import SpectacularAPIView
from drf_spectacular.views import SpectacularSwaggerView
from django.urls import include
from dmr.openapi import build_schema
from dmr.openapi.views import OpenAPIJsonView, SwaggerView
from dmr.routing import path
from stratoflights_api.urls import router
schema = build_schema(router)
urlpatterns = [
path('admin/', admin.site.urls),
path('api/', include('stratoflights_api.urls')),
path('api/schema/', SpectacularAPIView.as_view(), name='schema'),
path('api/docs/', SpectacularSwaggerView.as_view(url_name='schema'), name='docs'),
path(router.prefix, include(router.urls)),
path('api/schema/', OpenAPIJsonView.as_view(schema), name='schema'),
path('api/docs/', SwaggerView.as_view(schema), name='docs'),
]

View file

@ -18,17 +18,19 @@ class TelemetryConsumer(AsyncWebsocketConsumer):
async def receive(self, text_data):
from .serializers import TelemetryPacketSerializer
from pydantic import ValidationError
from .dtos import TelemetryIn
if not self.write_enabled:
await self.send(text_data=json.dumps({"error": "Read-only mode"}))
return
data = json.loads(text_data)
serializer = TelemetryPacketSerializer(data=data)
if not serializer.is_valid():
await self.send(text_data=json.dumps({"error": serializer.errors}))
try:
TelemetryIn(**data)
except ValidationError as exc:
await self.send(text_data=json.dumps({"error": json.loads(exc.json())}))
return
saved_data = await self.save_telemetry(data)
@ -51,15 +53,6 @@ class TelemetryConsumer(AsyncWebsocketConsumer):
async def telemetry_message(self, event):
await self.send(text_data=json.dumps(event["data"]))
@database_sync_to_async
def get_user_from_token(self, token_key, Token):
from rest_framework.authtoken.models import Token
User = get_user_model()
try:
token = Token.objects.select_related("user").get(key=token_key)
return token.user
except Token.DoesNotExist:
return None
@database_sync_to_async
def save_telemetry(self, data):
@ -99,12 +92,20 @@ class StationTelemetryConsumer(TelemetryConsumer):
write_enabled = True
async def connect(self):
from rest_framework.authtoken.models import Token
from django.conf import settings
from dmr.security.jwt.token import JWToken
token_key = self.scope["query_string"].decode().split("token=")[-1]
try:
token = await database_sync_to_async(Token.objects.select_related("user").get)(key=token_key)
self.scope["user"] = token.user
except Token.DoesNotExist:
decoded = JWToken.decode(
encoded_token=token_key,
secret=settings.SECRET_KEY,
algorithm='HS256',
)
self.scope["user"] = await database_sync_to_async(
get_user_model().objects.get
)(pk=decoded.sub)
except Exception:
await self.close()
return

View file

@ -1,17 +0,0 @@
from rest_framework.pagination import LimitOffsetPagination
from rest_framework.response import Response
class CustomLimitOffsetPagination(LimitOffsetPagination):
limit_query_param = 'limit'
offset_query_param = 'skip'
max_limit = 100
default_limit = 10
def get_paginated_response(self, data):
return Response({
'total': self.count,
'limit': self.limit,
'skip': self.offset,
'predictions': data
})

333
stratoflights_api/dtos.py Normal file
View file

@ -0,0 +1,333 @@
"""Pydantic DTOs replacing the DRF serializers from serializers.py.
Mapping is 1:1 with the former serializers. Output DTOs enable ``from_attributes``
so an endpoint can build them straight from a Django model instance via
``SomeOut.model_validate(obj)`` (the equivalent of ``Serializer(obj).data``).
Field-level and cross-field validation that DRF kept in ``validate_<field>()`` /
``validate()`` is reproduced with pydantic ``field_validator`` / ``model_validator``.
Business logic DRF kept in ``create()`` (base64 curve decoding, ORM writes, and
resolving related objects from their PKs) is intentionally NOT here -- per DMR's
split it moves into the endpoints (Step 2.9).
"""
import uuid
from datetime import datetime
from typing import Any, Optional
from django.contrib.auth.password_validation import validate_password
from django.core.validators import validate_email
from django.core.exceptions import ValidationError as DjangoValidationError
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
model_validator,
)
from .validators import validate_custom_curve, rate_clip
# Profile constants (moved verbatim from serializers.py).
PROFILE_STANDARD = "standard_profile"
PROFILE_FLOAT = "float_profile"
PROFILE_REVERSE = "reverse_profile"
PROFILE_CUSTOM = "custom_profile"
LATEST_DATASET_KEYWORD = "latest"
SUPPORTED_PROFILES = [PROFILE_STANDARD, PROFILE_FLOAT, PROFILE_REVERSE, PROFILE_CUSTOM]
# --- Prediction request (was PredictionRequestSerializer) ---------------------
class PredictionRequest(BaseModel):
launch_latitude: float = Field(ge=-90, le=90)
launch_longitude: float = Field(ge=0, le=360)
launch_datetime: datetime
launch_altitude: Optional[float] = None
format: str = "json"
profile: str = PROFILE_STANDARD
dataset: str = LATEST_DATASET_KEYWORD
# profile-dependent fields
ascent_rate: Optional[float] = Field(default=None, ge=0.01)
descent_rate: Optional[float] = Field(default=None, ge=0.01)
burst_altitude: Optional[float] = None
float_altitude: Optional[float] = None
stop_datetime: Optional[datetime] = None
ascent_curve: Optional[str] = None
descent_curve: Optional[str] = None
interpolate: bool = False
# Related objects are accepted as PKs; existence is resolved in the endpoint
# (was PrimaryKeyRelatedField(queryset=...)).
start_point: Optional[int] = None
rate_profile: Optional[int] = None
template: Optional[int] = None
@field_validator("profile")
@classmethod
def _validate_profile(cls, value: str) -> str:
if value not in SUPPORTED_PROFILES:
raise ValueError(f'"{value}" is not a valid choice.')
return value
@model_validator(mode="after")
def _validate_cross_fields(self) -> "PredictionRequest":
launch_alt = self.launch_altitude if self.launch_altitude is not None else 0
if self.profile == PROFILE_STANDARD:
if self.burst_altitude is None:
raise ValueError("burst_altitude is required for standard profile.")
if self.burst_altitude <= launch_alt:
raise ValueError("burst_altitude must be greater than launch_altitude.")
elif self.profile == PROFILE_FLOAT:
if self.float_altitude is None or self.float_altitude <= launch_alt:
raise ValueError("float_altitude must be greater than launch_altitude.")
if self.stop_datetime is None or self.stop_datetime <= self.launch_datetime:
raise ValueError("stop_datetime must be later than launch_datetime.")
elif self.profile == PROFILE_CUSTOM:
if self.ascent_curve is None or not validate_custom_curve(self.ascent_curve):
raise ValueError("Invalid ascent_curve.")
if self.descent_curve is None or not validate_custom_curve(self.descent_curve):
raise ValueError("Invalid descent_curve.")
if self.burst_altitude is None or self.burst_altitude <= launch_alt:
raise ValueError("burst_altitude must be greater than launch_altitude.")
# custom clipping logic (was in validate())
if self.ascent_rate is not None:
self.ascent_rate = rate_clip(self.ascent_rate)
if self.descent_rate is not None:
self.descent_rate = rate_clip(self.descent_rate)
return self
# --- Prediction outputs -------------------------------------------------------
class PredictionOut(BaseModel):
"""Was PredictionSerializer."""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
created_at: datetime
updated_at: datetime
result: Any
class PredictionListOut(BaseModel):
"""Was PredictionListSerializer."""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
created_at: datetime
updated_at: datetime
start_point: Optional[int] = Field(default=None, validation_alias="start_point_id")
template: Optional[int] = Field(default=None, validation_alias="template_id")
rate_profile: Optional[int] = Field(default=None, validation_alias="rate_profile_id")
class PredictionDetailOut(BaseModel):
"""Was PredictionDetailSerializer."""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
created_at: datetime
updated_at: datetime
result: Any
start_point: Optional[int] = Field(default=None, validation_alias="start_point_id")
template: Optional[int] = Field(default=None, validation_alias="template_id")
rate_profile: Optional[int] = Field(default=None, validation_alias="rate_profile_id")
class PredictionCreateOut(BaseModel):
"""Custom create() response shape: {id, created_at, result} (not the full model)."""
id: uuid.UUID
created_at: datetime
result: Any
# --- Telemetry (was TelemetryPacketSerializer) --------------------------------
class TelemetryIn(BaseModel):
# timestamp is required (model BigIntegerField has no default), matching the
# former ModelSerializer; the endpoint still overwrites it with server time.
timestamp: int
lat: float = 0.0
lon: float = 0.0
alt: float = 0.0
payload: Any = Field(default_factory=dict)
class TelemetryOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
timestamp: int
lat: float
lon: float
alt: float
payload: Any
# --- SavedPoint (was SavedPointSerializer) ------------------------------------
# `user` was a HiddenField(CurrentUserDefault) -> set from request in the endpoint,
# not part of the wire input/output. The unique_together (user, name) check that
# DRF did via UniqueTogetherValidator is reproduced in the endpoint (Step 2.8).
class SavedPointIn(BaseModel):
name: str
lat: float = 0.0
lon: float = 0.0
alt: float = 0.0
class SavedPointPatchIn(BaseModel):
# All-optional variant for PATCH (partial_update). Only set fields apply.
name: Optional[str] = None
lat: Optional[float] = None
lon: Optional[float] = None
alt: Optional[float] = None
class SavedPointOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
name: str
lat: float
lon: float
alt: float
# --- SavedRateProfile (was SavedRateProfileSerializer) ------------------------
# NOTE: no view referenced this serializer in routing; kept for parity.
class SavedRateProfileIn(BaseModel):
name: str
type: str = "ascent"
rate_profile_data: Any = Field(default_factory=dict)
class SavedRateProfileOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
name: str
type: str
rate_profile_data: Any
# --- PredictionTemplate (was PreditctionTemplateSerializer) -------------------
class PredictionTemplateIn(BaseModel):
# protected_namespaces=() avoids pydantic's warning about the `model` field.
model_config = ConfigDict(protected_namespaces=())
name: str
is_default: bool = False
description: Optional[str] = None
prediction_mode: str = ""
model: str = ""
dataset: str = ""
flight_parameters: Any = Field(default_factory=dict)
class PredictionTemplatePatchIn(BaseModel):
# All-optional variant for PATCH (partial_update). Only set fields apply.
model_config = ConfigDict(protected_namespaces=())
name: Optional[str] = None
is_default: Optional[bool] = None
description: Optional[str] = None
prediction_mode: Optional[str] = None
model: Optional[str] = None
dataset: Optional[str] = None
flight_parameters: Optional[Any] = None
class PredictionTemplateOut(BaseModel):
model_config = ConfigDict(from_attributes=True, protected_namespaces=())
id: int
name: str
is_default: bool
description: Optional[str] = None
prediction_mode: str
model: str
dataset: str
flight_parameters: Any
# --- User (was UserSerializer) ------------------------------------------------
class UserOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
username: str
email: str
first_name: str
last_name: str
class UserUpdateIn(BaseModel):
# `username` was read_only -> excluded from input. PATCH is partial, so all
# fields are optional; the endpoint applies only fields that were set.
email: Optional[str] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
@field_validator("email")
@classmethod
def _validate_email(cls, value: Optional[str]) -> Optional[str]:
if value is None:
return value
try:
validate_email(value)
except DjangoValidationError:
raise ValueError("Invalid email format")
return value
# --- Password / account (were ChangePasswordSerializer / DeleteAccountSerializer)
class ChangePasswordIn(BaseModel):
old_password: str
new_password: str
@field_validator("new_password")
@classmethod
def _validate_new_password(cls, value: str) -> str:
# validate_password raises Django's ValidationError (not a ValueError),
# which pydantic would not convert into a 4xx -- re-raise as ValueError.
try:
validate_password(value)
except DjangoValidationError as exc:
raise ValueError("; ".join(exc.messages))
return value
class DeleteAccountIn(BaseModel):
password: str
# --- Auth / session response shapes (were inline JsonResponse dicts) ----------
class DetailResponse(BaseModel):
detail: str
class SessionResponse(BaseModel):
isAuthenticated: bool
class WhoAmIResponse(BaseModel):
username: str
class TokenResponse(BaseModel):
token: str
# --- Shared path parameters ---------------------------------------------------
class PkPath(BaseModel):
pk: int
class UuidPkPath(BaseModel):
pk: uuid.UUID

View file

@ -0,0 +1,120 @@
"""Limit/offset pagination for the prediction list endpoints.
DMR has no built-in limit/offset paginator, so (as the docs suggest) we keep our
own envelope model -- matching the former ``CustomLimitOffsetPagination`` output
exactly -- plus a helper that slices the queryset. The query params and envelope
keys are preserved verbatim: ``limit`` / ``skip`` in, and
``{total, limit, skip, predictions}`` out.
"""
from http import HTTPStatus
from typing import Optional
from urllib.parse import urlsplit, urlunsplit
from django.core.paginator import InvalidPage, Paginator
from django.http import QueryDict
from pydantic import BaseModel
from dmr.response import APIError
from .dtos import PredictionOut, TelemetryOut
DEFAULT_LIMIT = 10 # was CustomLimitOffsetPagination.default_limit
MAX_LIMIT = 100 # was CustomLimitOffsetPagination.max_limit
PAGE_SIZE = 100 # global REST_FRAMEWORK PAGE_SIZE (PageNumberPagination)
class PaginationQuery(BaseModel):
# Param names preserved: limit_query_param='limit', offset_query_param='skip'.
limit: int = DEFAULT_LIMIT
skip: int = 0
class PredictionListUserQuery(BaseModel):
"""Query params for PredictionViewSet.list_user: filters + limit/offset."""
satellite_id: Optional[str] = None
created_from: Optional[str] = None
created_till: Optional[str] = None
limit: int = DEFAULT_LIMIT
skip: int = 0
class PredictionPage(BaseModel):
total: int
limit: int
skip: int
predictions: list[PredictionOut]
class TelemetryPage(BaseModel):
count: int
next: Optional[str] = None
previous: Optional[str] = None
results: list[TelemetryOut]
def _replace_query_param(url: str, key: str, value) -> str:
scheme, netloc, path, query, fragment = urlsplit(url)
query_dict = QueryDict(query, mutable=True)
query_dict[key] = value
return urlunsplit((scheme, netloc, path, query_dict.urlencode(), fragment))
def _remove_query_param(url: str, key: str) -> str:
scheme, netloc, path, query, fragment = urlsplit(url)
query_dict = QueryDict(query, mutable=True)
query_dict.pop(key, None)
return urlunsplit((scheme, netloc, path, query_dict.urlencode(), fragment))
def page_number_paginate(request, queryset, page_size: int = PAGE_SIZE):
"""Reproduce DRF PageNumberPagination.
Returns ``(count, next_link, previous_link, object_list)``. Raises a 404
"Invalid page." like DRF on an out-of-range / non-integer page. Honours
``page=last`` and builds absolute next/previous links off the current URL.
"""
paginator = Paginator(queryset, page_size)
page_number = request.GET.get('page', 1)
if page_number in ('last',):
page_number = paginator.num_pages
try:
page = paginator.page(page_number)
except InvalidPage:
raise APIError({'detail': 'Invalid page.'}, status_code=HTTPStatus.NOT_FOUND)
url = request.build_absolute_uri()
next_link = None
if page.has_next():
next_link = _replace_query_param(url, 'page', page.next_page_number())
previous_link = None
if page.has_previous():
previous_number = page.previous_page_number()
if previous_number == 1:
previous_link = _remove_query_param(url, 'page')
else:
previous_link = _replace_query_param(url, 'page', previous_number)
return paginator.count, next_link, previous_link, list(page.object_list)
def paginate_predictions(queryset, query: PaginationQuery) -> PredictionPage:
"""Slice ``queryset`` and build the prediction page envelope.
Mirrors the bounds DRF's LimitOffsetPagination enforced: limit falls back to
the default when non-positive and is capped at MAX_LIMIT; skip floors at 0.
"""
limit = query.limit if query.limit > 0 else DEFAULT_LIMIT
limit = min(limit, MAX_LIMIT)
skip = query.skip if query.skip >= 0 else 0
total = queryset.count()
page = list(queryset[skip:skip + limit])
return PredictionPage(
total=total,
limit=limit,
skip=skip,
predictions=[PredictionOut.model_validate(obj) for obj in page],
)

View file

@ -1,16 +0,0 @@
from rest_framework.permissions import BasePermission, SAFE_METHODS
class ReadOnlyOrAuthenticated(BasePermission):
def has_permission(self, request, view):
return (
request.method in SAFE_METHODS or
request.user and request.user.is_authenticated
)
class IsOwner(BasePermission):
def has_object_permission(self, request, view, obj):
return obj.user == request.user
def has_permission(self, request, view):
return request.user and request.user.is_authenticated

View file

@ -1,185 +0,0 @@
from rest_framework import serializers
from .models import Prediction, SavedPoint, SavedRateProfile, PreditctionTemplate
from datetime import datetime
from django.contrib.auth.password_validation import validate_password
from django.core.validators import validate_email
from django.core.exceptions import ValidationError as DjangoValidationError
from django.contrib.auth import get_user_model
from .validators import (
validate_custom_curve, rate_clip,
_rfc3339_to_timestamp, base64_to_curve
)
class PredictionSerializer(serializers.ModelSerializer):
class Meta:
model = Prediction
fields = ['id', 'created_at', 'updated_at', 'result']
User = get_user_model()
PROFILE_STANDARD = "standard_profile"
PROFILE_FLOAT = "float_profile"
PROFILE_REVERSE = "reverse_profile"
PROFILE_CUSTOM = "custom_profile"
LATEST_DATASET_KEYWORD = "latest"
SUPPORTED_PROFILES = [PROFILE_STANDARD, PROFILE_FLOAT, PROFILE_REVERSE, PROFILE_CUSTOM]
class PredictionRequestSerializer(serializers.Serializer):
launch_latitude = serializers.FloatField(min_value=-90, max_value=90)
launch_longitude = serializers.FloatField(min_value=0, max_value=360)
launch_datetime = serializers.DateTimeField()
launch_altitude = serializers.FloatField(required=False)
format = serializers.CharField(default="json")
profile = serializers.ChoiceField(choices=SUPPORTED_PROFILES, default=PROFILE_STANDARD)
dataset = serializers.CharField(default=LATEST_DATASET_KEYWORD)
# --- профиль-dependent поля ---
ascent_rate = serializers.FloatField(required=False, min_value=0.01)
descent_rate = serializers.FloatField(required=False, min_value=0.01)
burst_altitude = serializers.FloatField(required=False)
float_altitude = serializers.FloatField(required=False)
stop_datetime = serializers.DateTimeField(required=False)
ascent_curve = serializers.CharField(required=False)
descent_curve = serializers.CharField(required=False)
interpolate = serializers.BooleanField(required=False, default=False)
start_point = serializers.PrimaryKeyRelatedField(
queryset=SavedPoint.objects.all(), required=False, allow_null=True
)
rate_profile = serializers.PrimaryKeyRelatedField(
queryset=SavedRateProfile.objects.all(), required=False, allow_null=True
)
template = serializers.PrimaryKeyRelatedField(
queryset=PreditctionTemplate.objects.all(), required=False, allow_null=True
)
def validate(self, data):
profile = data.get("profile", PROFILE_STANDARD)
launch_alt = data.get("launch_altitude", 0)
if profile == PROFILE_STANDARD:
if 'burst_altitude' not in data:
raise serializers.ValidationError("burst_altitude is required for standard profile.")
if data['burst_altitude'] <= launch_alt:
raise serializers.ValidationError("burst_altitude must be greater than launch_altitude.")
elif profile == PROFILE_FLOAT:
if 'float_altitude' not in data or data['float_altitude'] <= launch_alt:
raise serializers.ValidationError("float_altitude must be greater than launch_altitude.")
if 'stop_datetime' not in data or data['stop_datetime'] <= data['launch_datetime']:
raise serializers.ValidationError("stop_datetime must be later than launch_datetime.")
elif profile == PROFILE_CUSTOM:
if 'ascent_curve' not in data or not validate_custom_curve(data['ascent_curve']):
raise serializers.ValidationError("Invalid ascent_curve.")
if 'descent_curve' not in data or not validate_custom_curve(data['descent_curve']):
raise serializers.ValidationError("Invalid descent_curve.")
if 'burst_altitude' not in data or data['burst_altitude'] <= launch_alt:
raise serializers.ValidationError("burst_altitude must be greater than launch_altitude.")
# кастомная логика clipping'а
if 'ascent_rate' in data:
data['ascent_rate'] = rate_clip(data['ascent_rate'])
if 'descent_rate' in data:
data['descent_rate'] = rate_clip(data['descent_rate'])
return data
def create(self, validated_data):
if 'ascent_curve' in validated_data:
validated_data['ascent_curve'] = base64_to_curve(validated_data['ascent_curve'])
if 'descent_curve' in validated_data:
validated_data['descent_curve'] = base64_to_curve(validated_data['descent_curve'])
prediction = Prediction(
user=validated_data.get('user'),
request=validated_data.get('request', {}),
result=validated_data.get('result', {}),
start_point=validated_data.get('start_point'),
template=validated_data.get('template'),
rate_profile=validated_data.get('rate_profile')
)
prediction.save()
return prediction
class PredictionListSerializer(serializers.ModelSerializer):
class Meta:
model = Prediction
fields = ["id", "created_at", "updated_at", "start_point", "template", "rate_profile"]
class PredictionDetailSerializer(serializers.ModelSerializer):
class Meta:
model = Prediction
fields = ["id", "created_at", "updated_at", "result", "start_point", "template", "rate_profile"]
from rest_framework import serializers
from .models import TelemetryPacket
class TelemetryPacketSerializer(serializers.ModelSerializer):
class Meta:
model = TelemetryPacket
fields = ['id', 'timestamp', 'lat', 'lon', 'alt', 'payload']
read_only_fields = ['id']
class SavedPointSerializer(serializers.ModelSerializer):
user = serializers.HiddenField(
default=serializers.CurrentUserDefault()
)
class Meta:
model = SavedPoint
fields = ['user', 'id', 'name', 'lat', 'lon', 'alt']
read_only_fields = ['id']
validators = [
serializers.UniqueTogetherValidator(
queryset=SavedPoint.objects.all(),
fields=['user', 'name'],
message="A saved point with this name already exists for the user."
)
]
class SavedRateProfileSerializer(serializers.ModelSerializer):
class Meta:
model = SavedRateProfile
fields = ['id', 'name', 'type', 'rate_profile_data']
read_only_fields = ['id']
class PreditctionTemplateSerializer(serializers.ModelSerializer):
class Meta:
model = PreditctionTemplate
fields = ['id', 'name', 'is_default', 'description', 'prediction_mode', 'model', 'dataset', 'flight_parameters']
read_only_fields = ['id']
class UserSerializer(serializers.ModelSerializer):
class Meta:
model = User
fields = ['username', 'email', 'first_name', 'last_name']
extra_kwargs = {
'username': {'read_only': True}
}
def validate_email(self, value):
try:
validate_email(value)
except DjangoValidationError:
raise serializers.ValidationError("Invalid email format")
return value
class ChangePasswordSerializer(serializers.Serializer):
old_password = serializers.CharField(required=True)
new_password = serializers.CharField(required=True)
def validate_new_password(self, value):
validate_password(value)
return value
class DeleteAccountSerializer(serializers.Serializer):
password = serializers.CharField(required=True)

View file

@ -1,48 +1,59 @@
from django.urls import path
from rest_framework.routers import DefaultRouter
from rest_framework.authtoken.views import obtain_auth_token
from dmr.routing import Router, path
from .views import (
PredictionViewSet,
SavedPointViewset,
PreditctionTemplateViewset,
TelemetryListCreateView,
get_csrf,
login_view,
logout_view,
SessionView,
WhoAmIView,
UserProfileView,
ChangePasswordView,
TokenManagementView,
DeleteUserDataView,
DeleteAccountView
PredictionCollectionController,
PredictionListUserController,
PredictionHistoryController,
PredictionDetailController,
PredictionDeleteController,
SavedPointListController,
SavedPointDetailController,
PredictionTemplateListController,
PredictionTemplateDetailController,
TelemetryController,
CsrfController,
LoginController,
LogoutController,
SessionController,
WhoAmIController,
UserProfileController,
ChangePasswordController,
ObtainTokenController,
TokenManagementController,
DeleteUserDataController,
DeleteAccountController,
)
# A Router (prefix + routes) is required so build_schema() can generate the
# OpenAPI document; see stratoflights/urls.py.
router = Router(
'api/',
[
path("csrf/", CsrfController.as_view(), name='api-csrf'),
path('token', ObtainTokenController.as_view(), name='get_token'),
path("login/", LoginController.as_view(), name='api-login'),
path("logout/", LogoutController.as_view(), name='api-logout'),
path("session/", SessionController.as_view(), name='api-session'),
path("whoami/", WhoAmIController.as_view(), name='api-whoami'),
path("<uuid:pk>/telemetry/", TelemetryController.as_view(), name="create_telemetry"),
path("profile/", UserProfileController.as_view(), name='api-profile'),
path("profile/change-password/", ChangePasswordController.as_view(), name='api-change-password'),
path("profile/token/", TokenManagementController.as_view(), name='api-token'),
path("profile/delete-data/", DeleteUserDataController.as_view(), name='api-delete-data'),
path("profile/delete-account/", DeleteAccountController.as_view(), name='api-delete-account'),
# Saved points (was SavedPointViewset via DefaultRouter).
path("saved-points/", SavedPointListController.as_view(), name='saved-points-list'),
path("saved-points/<int:pk>/", SavedPointDetailController.as_view(), name='saved-points-detail'),
# Prediction templates (was PreditctionTemplateViewset via DefaultRouter).
path("saved-templates/", PredictionTemplateListController.as_view(), name='saved-templates-list'),
path("saved-templates/<int:pk>/", PredictionTemplateDetailController.as_view(), name='saved-templates-detail'),
# Predictions (was PredictionViewSet via DefaultRouter).
path("predictions/", PredictionCollectionController.as_view(), name='predictions-list'),
path("predictions/list_user/", PredictionListUserController.as_view(), name='predictions-list-user'),
path("predictions/history/", PredictionHistoryController.as_view(), name='predictions-history'),
path("predictions/<uuid:pk>/detail/", PredictionDetailController.as_view(), name='predictions-detail'),
path("predictions/<uuid:pk>/delete/", PredictionDeleteController.as_view(), name='predictions-delete'),
],
)
router = DefaultRouter()
router.register(r'predictions', PredictionViewSet, basename='predictions')
router.register(r'saved-points', SavedPointViewset, basename='saved-points')
router.register(r'saved-templates', PreditctionTemplateViewset, basename='saved-templates')
urlpatterns = [
path("csrf/", get_csrf, name='api-csrf'),
path('token', obtain_auth_token, name = 'get_token'),
path("login/", login_view, name='api-login'),
path("logout/", logout_view, name='api-logout'),
path("session/", SessionView.as_view(), name='api-session'),
path("whoami/", WhoAmIView.as_view(), name='api-whoami'),
path("<uuid:pk>/telemetry/", TelemetryListCreateView.as_view(), name="create_telemetry"),
path('csrf/', get_csrf, name='api-csrf'),
path('login/', login_view, name='api-login'),
path('logout/', logout_view, name='api-logout'),
path('session/', SessionView.as_view(), name='api-session'),
path('whoami/', WhoAmIView.as_view(), name='api-whoami'),
path("profile/", UserProfileView.as_view(), name='api-profile'),
path("profile/change-password/", ChangePasswordView.as_view(), name='api-change-password'),
path("profile/token/", TokenManagementView.as_view(), name='api-token'),
path("profile/delete-data/", DeleteUserDataView.as_view(), name='api-delete-data'),
path("profile/delete-account/", DeleteAccountView.as_view(), name='api-delete-account'),
]
urlpatterns += router.urls
urlpatterns = router.urls

View file

@ -1,305 +1,477 @@
import requests
import time
import json
from rest_framework import status, generics, permissions
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.viewsets import ModelViewSet, ViewSet, GenericViewSet
from rest_framework.exceptions import APIException
from rest_framework.permissions import IsAuthenticated, AllowAny
from rest_framework.authentication import SessionAuthentication, BasicAuthentication, TokenAuthentication
from rest_framework.decorators import api_view, permission_classes, authentication_classes, action
from rest_framework.authtoken.models import Token
from django.utils import timezone
from django.views.decorators.csrf import csrf_exempt
from django.utils.decorators import method_decorator
from django.http import JsonResponse
from http import HTTPStatus
from django.contrib.auth import authenticate, login, logout, get_user_model
from django.middleware.csrf import get_token
from django.core.exceptions import ValidationError
from django.utils.dateparse import parse_datetime
from .models import Prediction, User, Satellite, SavedPoint, SavedRateProfile, PreditctionTemplate, TelemetryPacket
from .serializers import PredictionSerializer, TelemetryPacketSerializer, PredictionRequestSerializer, PredictionListSerializer, PredictionDetailSerializer, SavedPointSerializer, SavedRateProfileSerializer, PreditctionTemplateSerializer, UserSerializer, ChangePasswordSerializer, DeleteAccountSerializer
from .models import Prediction, Satellite, SavedPoint, SavedRateProfile, PreditctionTemplate, TelemetryPacket
from .services.tawhiri import TawhiriClient
from drf_spectacular.utils import extend_schema
from .permissions import ReadOnlyOrAuthenticated, IsOwner
from .custom_pagination import CustomLimitOffsetPagination
from datetime import datetime
from datetime import datetime, timedelta, timezone
from pydantic import ValidationError
from dmr import Controller, modify
from dmr.components import Body, Path, Query
from dmr.plugins.pydantic import PydanticSerializer
from dmr.response import APIError
from dmr.security.jwt.views import (
ObtainTokensPayload,
ObtainTokensResponse,
ObtainTokensSyncController,
)
from .dtos import (
DetailResponse,
SessionResponse,
WhoAmIResponse,
UserOut,
UserUpdateIn,
ChangePasswordIn,
DeleteAccountIn,
TokenResponse,
PkPath,
UuidPkPath,
SavedPointIn,
SavedPointPatchIn,
SavedPointOut,
PredictionTemplateIn,
PredictionTemplatePatchIn,
PredictionTemplateOut,
PredictionRequest,
PredictionOut,
PredictionListOut,
PredictionDetailOut,
PredictionCreateOut,
TelemetryIn,
TelemetryOut,
)
from .pagination import (
PredictionListUserQuery,
PredictionPage,
paginate_predictions,
TelemetryPage,
page_number_paginate,
)
from .validators import base64_to_curve
User = get_user_model()
def get_prediction_from_tawhiri(params):
def _api_error(status_code, body, headers=None):
"""Adapter for DMR's APIError(raw_data, *, status_code, ...) signature.
base_url = "https://fly.stratonautica.ru/api/v2"
response = requests.get(base_url, params=params)
if response.status_code == 200:
return response.json() # получаем результат предсказания
else:
raise Exception(
f"Tawhiri error: {response.status_code} {response.text}")
Lets call sites keep the (status_code=, body=) style; the body is DMR's
first positional ``raw_data`` argument.
"""
return APIError(body, status_code=status_code, headers=headers)
class PredictionViewSet(GenericViewSet):
permission_classes = [IsAuthenticated]
pagination_class = CustomLimitOffsetPagination
def _resolve_related(model, pk, field_name):
"""Resolve a related object by PK (was PrimaryKeyRelatedField(queryset=...)).
def list(self, request):
queryset = Prediction.objects.filter(user=request.user)
return Response(PredictionSerializer(queryset, many=True).data)
The original queryset was ``.all()`` (not user-scoped); a missing PK yields a
400 matching DRF's "Invalid pk" error shape.
"""
if pk is None:
return None
obj = model.objects.filter(pk=pk).first()
if obj is None:
raise _api_error(
status_code=HTTPStatus.BAD_REQUEST,
body={field_name: [f'Invalid pk "{pk}" - object does not exist.']},
)
return obj
def create(self, request):
serializer = PredictionRequestSerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
class PredictionCollectionController(Controller[PydanticSerializer]):
"""`predictions/` -- list the user's predictions and create a new one."""
validated_data = serializer.validated_data
def get(self) -> list[PredictionOut]:
queryset = Prediction.objects.filter(user=self.request.user)
return [PredictionOut.model_validate(obj) for obj in queryset]
def post(self, parsed_body: Body[PredictionRequest]) -> PredictionCreateOut:
user = self.request.user
# Resolve related objects before calling Tawhiri, so an invalid PK still
# fails with 400 first (as DRF validation did).
start_point = _resolve_related(SavedPoint, parsed_body.start_point, 'start_point')
template = _resolve_related(PreditctionTemplate, parsed_body.template, 'template')
rate_profile = _resolve_related(SavedRateProfile, parsed_body.rate_profile, 'rate_profile')
try:
prediction_result = TawhiriClient.get_prediction(validated_data)
prediction_result = TawhiriClient.get_prediction(parsed_body.model_dump())
except requests.RequestException as exc:
raise _api_error(
status_code=HTTPStatus.BAD_GATEWAY,
body={'error': f'Tawhiri error: {str(exc)}'},
)
except requests.RequestException as e:
return Response({"error": f"Tawhiri error: {str(e)}"}, status=status.HTTP_502_BAD_GATEWAY)
# Carried over from the old serializer.create(): curves are decoded but
# never persisted (the model has no curve fields). Kept for parity --
# including that a malformed curve here raises (then surfaces as 500).
if parsed_body.ascent_curve is not None:
base64_to_curve(parsed_body.ascent_curve)
if parsed_body.descent_curve is not None:
base64_to_curve(parsed_body.descent_curve)
# prediction = Prediction.objects.create(
# result=prediction_result, user=request.user, request=request.data, validated_data=validated_data)
prediction = serializer.save(
user=request.user,
prediction = Prediction(
user=user,
request=json.loads(self.request.body),
result=prediction_result,
start_point=start_point,
template=template,
rate_profile=rate_profile,
)
prediction.save()
return PredictionCreateOut(
id=prediction.id,
created_at=prediction.created_at,
result=prediction_result,
request=request.data
)
return Response({
"id": prediction.id,
"created_at": prediction.created_at,
"result": prediction_result
}, status=status.HTTP_201_CREATED)
@action(detail=False, methods=['get'])
def list_user(self, request):
user = request.user
satellite_id = request.query_params.get('satellite_id')
created_from = request.query_params.get('created_from')
created_till = request.query_params.get('created_till')
class PredictionListUserController(Controller[PydanticSerializer]):
"""`predictions/list_user/` -- filtered, paginated list of the user's predictions."""
filters = {
'user': user,
}
def get(self, parsed_query: Query[PredictionListUserQuery]) -> PredictionPage:
user = self.request.user
filters = {'user': user}
if created_from:
filters['created_at__gte'] = parse_datetime(created_from)
if created_till:
filters['created_at__lte'] = parse_datetime(created_till)
if satellite_id:
if not user.satellites.filter(id=satellite_id).exists():
return Response({'detail': 'Access denied'}, status=403)
filters['satellite_id'] = satellite_id
if parsed_query.created_from:
filters['created_at__gte'] = parse_datetime(parsed_query.created_from)
if parsed_query.created_till:
filters['created_at__lte'] = parse_datetime(parsed_query.created_till)
if parsed_query.satellite_id:
if not user.satellites.filter(id=parsed_query.satellite_id).exists():
raise _api_error(status_code=HTTPStatus.FORBIDDEN, body={'detail': 'Access denied'})
filters['satellite_id'] = parsed_query.satellite_id
queryset = Prediction.objects.filter(**filters)
queryset = self.filter_queryset(queryset)
return paginate_predictions(queryset, parsed_query)
page = self.paginate_queryset(queryset)
if page is not None:
serializer = PredictionSerializer(page, many=True)
return self.get_paginated_response(serializer.data)
class PredictionHistoryController(Controller[PydanticSerializer]):
"""`predictions/history/` -- compact list of the user's predictions."""
serializer = PredictionSerializer(queryset, many=True)
return Response(serializer.data)
def get(self) -> list[PredictionListOut]:
queryset = Prediction.objects.filter(user=self.request.user)
return [PredictionListOut.model_validate(obj) for obj in queryset]
@action(detail=False, methods=["get"])
def history(self, request):
queryset = Prediction.objects.filter(user=request.user)
return Response(PredictionListSerializer(queryset, many=True).data)
@action(detail=True, methods=["get"])
def detail(self, request, pk=None):
class PredictionDetailController(Controller[PydanticSerializer]):
"""`predictions/<pk>/detail/` -- retrieve a single prediction."""
def get(self, parsed_path: Path[UuidPkPath]) -> PredictionDetailOut:
prediction = Prediction.objects.filter(
user=request.user, pk=pk).first()
if not prediction:
return Response({'detail': 'Not found'}, status=404)
return Response(PredictionDetailSerializer(prediction).data)
user=self.request.user, pk=parsed_path.pk).first()
if prediction is None:
raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'})
return PredictionDetailOut.model_validate(prediction)
@action(detail=True, methods=["delete"])
def delete(self, request, pk=None):
class PredictionDeleteController(Controller[PydanticSerializer]):
"""`predictions/<pk>/delete/` -- delete a single prediction."""
@modify(status_code=HTTPStatus.NO_CONTENT)
def delete(self, parsed_path: Path[UuidPkPath]) -> None:
prediction = Prediction.objects.filter(
user=request.user, pk=pk).first()
if not prediction:
return Response({'detail': 'Not found'}, status=404)
user=self.request.user, pk=parsed_path.pk).first()
if prediction is None:
raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'})
prediction.delete()
return Response(status=204)
class TelemetryListCreateView(generics.ListCreateAPIView):
serializer_class = TelemetryPacketSerializer
permission_classes = [permissions.AllowAny]
class TelemetryController(Controller[PydanticSerializer]):
"""`<uuid:pk>/telemetry/` -- list and ingest telemetry for a satellite.
def get_queryset(self):
qs = TelemetryPacket.objects.filter(satellite_id=self.kwargs["pk"])
Public (was AllowAny). GET preserves the global PageNumberPagination envelope
{count, next, previous, results} with PAGE_SIZE=100.
"""
from_ts = self.request.query_params.get("from")
till_ts = self.request.query_params.get("till")
auth = None # public (was AllowAny)
csrf_exempt = True # DRF views bypass Django CSRF; ingestion is anonymous
def get(self, parsed_path: Path[UuidPkPath]) -> TelemetryPage:
qs = TelemetryPacket.objects.filter(satellite_id=parsed_path.pk)
from_ts = self.request.GET.get('from')
till_ts = self.request.GET.get('till')
if from_ts:
qs = qs.filter(timestamp__gte=int(from_ts))
if till_ts:
qs = qs.filter(timestamp__lte=int(till_ts))
return qs.order_by("-timestamp")
qs = qs.order_by('-timestamp')
def post(self, request, pk):
serializer = TelemetryPacketSerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
count, next_link, previous_link, page_objects = page_number_paginate(self.request, qs)
return TelemetryPage(
count=count,
next=next_link,
previous=previous_link,
results=[TelemetryOut.model_validate(obj) for obj in page_objects],
)
validated_data = serializer.validated_data
TelemetryPacket.objects.create(timestamp=time.time(),
satellite=Satellite.objects.get(id=pk),
lat=validated_data["lat"],
lon=validated_data["lon"],
alt=validated_data["alt"],
payload=validated_data['payload'],
)
return Response(serializer.errors, status=status.HTTP_201_CREATED)
def post(self, parsed_path: Path[UuidPkPath], parsed_body: Body[TelemetryIn]) -> TelemetryOut:
# Bug fix (approved): the original returned serializer.errors on success;
# we return the created packet instead. timestamp is still server-set.
packet = TelemetryPacket.objects.create(
timestamp=time.time(),
satellite=Satellite.objects.get(id=parsed_path.pk),
lat=parsed_body.lat,
lon=parsed_body.lon,
alt=parsed_body.alt,
payload=parsed_body.payload,
)
return TelemetryOut.model_validate(packet)
class SessionView(APIView):
permission_classes = [IsAuthenticated]
class SessionController(Controller[PydanticSerializer]):
"""Report whether the current request is authenticated."""
@staticmethod
def get(request, format=None):
return JsonResponse({'isAuthenticated': True})
def get(self) -> SessionResponse:
return SessionResponse(isAuthenticated=True)
class WhoAmIView(APIView):
permission_classes = [IsAuthenticated]
class WhoAmIController(Controller[PydanticSerializer]):
"""Return the current user's username."""
@staticmethod
def get(request, format=None):
return JsonResponse({'username': request.user.username})
def get(self) -> WhoAmIResponse:
return WhoAmIResponse(username=self.request.user.username)
@extend_schema(methods=["GET"], description="Get CSRF token")
@csrf_exempt
@api_view(["GET"])
@permission_classes([AllowAny])
def get_csrf(request):
response = JsonResponse({'detail': 'CSRF cookie set'})
response['X-CSRFToken'] = get_token(request)
return response
class CsrfController(Controller[PydanticSerializer]):
"""Get CSRF token."""
auth = None # public endpoint (was AllowAny)
csrf_exempt = True
def get(self) -> DetailResponse:
token = get_token(self.request)
return self.to_response(
DetailResponse(detail='CSRF cookie set'),
headers={'X-CSRFToken': token},
)
@extend_schema(methods=["POST"], description="Login user")
@csrf_exempt
@api_view(["POST"])
@authentication_classes([BasicAuthentication])
@permission_classes([AllowAny])
def login_view(request):
data = json.loads(request.body)
username = data.get('username')
password = data.get('password')
class LoginController(Controller[PydanticSerializer]):
"""Login user."""
if username is None or password is None:
return JsonResponse({'detail': 'Please provide username and password.'}, status=400)
auth = None # public endpoint (was AllowAny)
csrf_exempt = True
user = authenticate(username=username, password=password)
if user is None:
return JsonResponse({'detail': 'Invalid credentials.'}, status=400)
# post() would default to 201; the original returned 200.
@modify(status_code=HTTPStatus.OK)
def post(self) -> DetailResponse:
data = json.loads(self.request.body)
username = data.get('username')
password = data.get('password')
login(request, user)
return JsonResponse({'detail': 'Successfully logged in.'})
if username is None or password is None:
raise _api_error(
status_code=HTTPStatus.BAD_REQUEST,
body={'detail': 'Please provide username and password.'},
)
user = authenticate(username=username, password=password)
if user is None:
raise _api_error(
status_code=HTTPStatus.BAD_REQUEST,
body={'detail': 'Invalid credentials.'},
)
login(self.request, user)
return DetailResponse(detail='Successfully logged in.')
@extend_schema(methods=["POST"], description="Logout user")
@api_view(["POST"])
@permission_classes([AllowAny])
def logout_view(request):
if not request.user.is_authenticated:
return JsonResponse({'detail': 'You\'re not logged in.'}, status=400)
class LogoutController(Controller[PydanticSerializer]):
"""Logout user."""
logout(request)
return JsonResponse({'detail': 'Successfully logged out.'})
auth = None # public endpoint (was AllowAny); checks auth state manually
@modify(status_code=HTTPStatus.OK)
def post(self) -> DetailResponse:
if not self.request.user.is_authenticated:
raise _api_error(
status_code=HTTPStatus.BAD_REQUEST,
body={'detail': "You're not logged in."},
)
logout(self.request)
return DetailResponse(detail='Successfully logged out.')
class SavedPointViewset(ModelViewSet):
permission_classes = [IsOwner]
serializer_class = SavedPointSerializer
pagination_class = None
def get_queryset(self):
return SavedPoint.objects.filter(user=self.request.user)
def perform_create(self, serializer):
serializer.save(user=self.request.user)
_SAVED_POINT_DUPLICATE = 'A saved point with this name already exists for the user.'
class PreditctionTemplateViewset(ModelViewSet):
permission_classes = [IsOwner]
serializer_class = PreditctionTemplateSerializer
pagination_class = None
def get_queryset(self):
return PreditctionTemplate.objects.filter(user=self.request.user)
def perform_create(self, serializer):
serializer.save(user=self.request.user)
def _check_saved_point_unique(user, name, exclude_pk=None):
"""Reproduce the former SavedPointSerializer UniqueTogetherValidator."""
qs = SavedPoint.objects.filter(user=user, name=name)
if exclude_pk is not None:
qs = qs.exclude(pk=exclude_pk)
if qs.exists():
raise _api_error(
status_code=HTTPStatus.BAD_REQUEST,
body={'non_field_errors': [_SAVED_POINT_DUPLICATE]},
)
class UserProfileView(APIView):
permission_classes = [IsAuthenticated]
class SavedPointListController(Controller[PydanticSerializer]):
"""Collection endpoint for the current user's saved points (was SavedPointViewset)."""
def get(self, request):
serializer = UserSerializer(request.user)
return Response(serializer.data)
def get(self) -> list[SavedPointOut]:
qs = SavedPoint.objects.filter(user=self.request.user)
return [SavedPointOut.model_validate(obj) for obj in qs]
def patch(self, request):
user = request.user
serializer = UserSerializer(user, data=request.data, partial=True)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
serializer.save()
return Response(serializer.data)
def post(self, parsed_body: Body[SavedPointIn]) -> SavedPointOut:
user = self.request.user
_check_saved_point_unique(user, parsed_body.name)
obj = SavedPoint.objects.create(user=user, **parsed_body.model_dump())
return SavedPointOut.model_validate(obj)
class ChangePasswordView(APIView):
permission_classes = [IsAuthenticated]
class SavedPointDetailController(Controller[PydanticSerializer]):
"""Detail endpoint for a single saved point owned by the current user."""
def post(self, request):
user = request.user
serializer = ChangePasswordSerializer(data=request.data)
def _get_object(self, pk: int):
# Filtering by user means another user's object reads as 404 (not 403),
# matching DRF's get_object() over a user-scoped queryset.
obj = SavedPoint.objects.filter(user=self.request.user, pk=pk).first()
if obj is None:
raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'})
return obj
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def get(self, parsed_path: Path[PkPath]) -> SavedPointOut:
return SavedPointOut.model_validate(self._get_object(parsed_path.pk))
if not user.check_password(serializer.validated_data['old_password']):
return Response({'detail': 'Old password is incorrect'},
status=status.HTTP_400_BAD_REQUEST)
def put(self, parsed_path: Path[PkPath], parsed_body: Body[SavedPointIn]) -> SavedPointOut:
obj = self._get_object(parsed_path.pk)
_check_saved_point_unique(self.request.user, parsed_body.name, exclude_pk=obj.pk)
for field, value in parsed_body.model_dump().items():
setattr(obj, field, value)
obj.save()
return SavedPointOut.model_validate(obj)
user.set_password(serializer.validated_data['new_password'])
def patch(self, parsed_path: Path[PkPath], parsed_body: Body[SavedPointPatchIn]) -> SavedPointOut:
obj = self._get_object(parsed_path.pk)
updates = parsed_body.model_dump(exclude_unset=True)
if 'name' in updates:
_check_saved_point_unique(self.request.user, updates['name'], exclude_pk=obj.pk)
for field, value in updates.items():
setattr(obj, field, value)
obj.save()
return SavedPointOut.model_validate(obj)
@modify(status_code=HTTPStatus.NO_CONTENT)
def delete(self, parsed_path: Path[PkPath]) -> None:
self._get_object(parsed_path.pk).delete()
class PredictionTemplateListController(Controller[PydanticSerializer]):
"""Collection endpoint for the current user's templates (was PreditctionTemplateViewset).
NOTE: as before, there is no app-level uniqueness check; a duplicate
(user, name) hits the model's unique_together and surfaces as a DB error.
"""
def get(self) -> list[PredictionTemplateOut]:
qs = PreditctionTemplate.objects.filter(user=self.request.user)
return [PredictionTemplateOut.model_validate(obj) for obj in qs]
def post(self, parsed_body: Body[PredictionTemplateIn]) -> PredictionTemplateOut:
obj = PreditctionTemplate.objects.create(
user=self.request.user, **parsed_body.model_dump()
)
return PredictionTemplateOut.model_validate(obj)
class PredictionTemplateDetailController(Controller[PydanticSerializer]):
"""Detail endpoint for a single template owned by the current user."""
def _get_object(self, pk: int):
obj = PreditctionTemplate.objects.filter(user=self.request.user, pk=pk).first()
if obj is None:
raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'})
return obj
def get(self, parsed_path: Path[PkPath]) -> PredictionTemplateOut:
return PredictionTemplateOut.model_validate(self._get_object(parsed_path.pk))
def put(self, parsed_path: Path[PkPath], parsed_body: Body[PredictionTemplateIn]) -> PredictionTemplateOut:
obj = self._get_object(parsed_path.pk)
for field, value in parsed_body.model_dump().items():
setattr(obj, field, value)
obj.save()
return PredictionTemplateOut.model_validate(obj)
def patch(self, parsed_path: Path[PkPath], parsed_body: Body[PredictionTemplatePatchIn]) -> PredictionTemplateOut:
obj = self._get_object(parsed_path.pk)
for field, value in parsed_body.model_dump(exclude_unset=True).items():
setattr(obj, field, value)
obj.save()
return PredictionTemplateOut.model_validate(obj)
@modify(status_code=HTTPStatus.NO_CONTENT)
def delete(self, parsed_path: Path[PkPath]) -> None:
self._get_object(parsed_path.pk).delete()
class UserProfileController(Controller[PydanticSerializer]):
"""Read and partially update the current user's profile."""
def get(self) -> UserOut:
return UserOut.model_validate(self.request.user)
def patch(self, parsed_body: Body[UserUpdateIn]) -> UserOut:
user = self.request.user
# partial update: apply only the fields the client actually sent.
for field, value in parsed_body.model_dump(exclude_unset=True).items():
setattr(user, field, value)
user.save()
return Response({'detail': 'Password changed successfully'})
return UserOut.model_validate(user)
class DeleteAccountView(APIView):
permission_classes = [IsAuthenticated]
class ChangePasswordController(Controller[PydanticSerializer]):
"""Change the current user's password."""
def delete(self, request):
user = request.user
serializer = DeleteAccountSerializer(data=request.data)
# post() would default to 201; the original returned 200.
@modify(status_code=HTTPStatus.OK)
def post(self, parsed_body: Body[ChangePasswordIn]) -> DetailResponse:
user = self.request.user
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
if not user.check_password(parsed_body.old_password):
raise _api_error(
status_code=HTTPStatus.BAD_REQUEST,
body={'detail': 'Old password is incorrect'},
)
if not user.check_password(serializer.validated_data['password']):
return Response({'detail': 'Incorrect password'},
status=status.HTTP_400_BAD_REQUEST)
user.set_password(parsed_body.new_password)
user.save()
return DetailResponse(detail='Password changed successfully')
class DeleteAccountController(Controller[PydanticSerializer]):
"""Delete the current user's account and all their data."""
# DMR forbids a Body component on DELETE, but the original endpoint read the
# password from a DELETE request body. Preserve that contract by parsing the
# body manually instead of via Body[DeleteAccountIn].
def delete(self) -> DetailResponse:
user = self.request.user
try:
parsed_body = DeleteAccountIn(**json.loads(self.request.body or b'{}'))
except ValidationError as exc:
raise _api_error(status_code=HTTPStatus.BAD_REQUEST, body=json.loads(exc.json()))
except ValueError:
raise _api_error(status_code=HTTPStatus.BAD_REQUEST, body={'detail': 'Invalid request body.'})
if not user.check_password(parsed_body.password):
raise _api_error(
status_code=HTTPStatus.BAD_REQUEST,
body={'detail': 'Incorrect password'},
)
Prediction.objects.filter(user=user).delete()
SavedPoint.objects.filter(user=user).delete()
@ -307,35 +479,83 @@ class DeleteAccountView(APIView):
user.delete()
return Response({'detail': 'Account deleted successfully'})
return DetailResponse(detail='Account deleted successfully')
class DeleteUserDataView(APIView):
permission_classes = [IsAuthenticated]
class DeleteUserDataController(Controller[PydanticSerializer]):
"""Delete all of the current user's data without deleting the account."""
def delete(self, request):
user = request.user
def delete(self) -> DetailResponse:
user = self.request.user
Prediction.objects.filter(user=user).delete()
SavedPoint.objects.filter(user=user).delete()
PreditctionTemplate.objects.filter(user=user).delete()
return Response({'detail': 'All user data deleted successfully'})
return DetailResponse(detail='All user data deleted successfully')
class TokenManagementView(APIView):
permission_classes = [IsAuthenticated]
class ObtainTokenController(
ObtainTokensSyncController[
PydanticSerializer,
ObtainTokensPayload,
ObtainTokensResponse,
],
):
"""Exchange username/password for JWT access + refresh tokens.
def get(self, request):
Replaces DRF's obtain_auth_token. Approved drift: the token format and
semantics change from a single stored DRF token to stateless JWTs, so the
response is {access_token, refresh_token} instead of {token}.
"""
token, created = Token.objects.get_or_create(user=request.user)
return Response({"token": token.key})
auth = None # public: credentials are supplied in the request body
csrf_exempt = True
jwt_expiration = timedelta(hours=1)
jwt_refresh_expiration = timedelta(days=7)
def post(self, request):
def convert_auth_payload(self, payload):
return payload
Token.objects.filter(user=request.user).delete()
token = Token.objects.create(user=request.user)
return Response({"token": token.key})
def make_api_response(self):
now = datetime.now(timezone.utc)
return {
'access_token': self.create_jwt_token(
expiration=now + self.jwt_expiration,
token_type='access',
),
'refresh_token': self.create_jwt_token(
expiration=now + self.jwt_refresh_expiration,
token_type='refresh',
),
}
class TokenManagementController(Controller[PydanticSerializer]):
"""Issue a fresh JWT access token for the current user.
Was TokenManagementView (DRF stored-token get/regenerate). With stateless
JWTs there is nothing stored to fetch or delete, so both GET and POST mint a
new access token (approved drift). The {"token": ...} response shape is kept.
"""
jwt_expiration = timedelta(hours=1)
def get(self) -> TokenResponse:
return TokenResponse(token=self._issue_token())
# post() would default to 201; the original returned 200.
@modify(status_code=HTTPStatus.OK)
def post(self) -> TokenResponse:
return TokenResponse(token=self._issue_token())
def _issue_token(self) -> str:
now = datetime.now(timezone.utc)
return self.create_jwt_token(
subject=str(self.request.user.pk),
expiration=now + self.jwt_expiration,
token_type='access',
)