Compare commits

..

No commits in common. "modern-rest" and "master" have entirely different histories.

11 changed files with 572 additions and 1023 deletions

View file

@ -1,7 +1,8 @@
# django-modern-rest 0.6.0 requires Django 5.x (upgraded from 4.2). Django>=4.0,<5.0
Django>=5.0,<6.0 djangorestframework
django-modern-rest[jwt,pydantic]==0.6.0 djangorestframework-simplejwt
psycopg2-binary psycopg2-binary
drf-spectacular
requests requests
django-cors-headers django-cors-headers
Pillow Pillow

View file

@ -13,12 +13,6 @@ https://docs.djangoproject.com/en/4.2/ref/settings/
from pathlib import Path from pathlib import Path
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
from dmr.settings import Settings
from dmr.security.django_session import DjangoSessionSyncAuth
from dmr.security.jwt import JWTSyncAuth
from dmr.openapi import OpenAPIConfig
load_dotenv() load_dotenv()
# Build paths inside the project like this: BASE_DIR / 'subdir'. # Build paths inside the project like this: BASE_DIR / 'subdir'.
@ -57,10 +51,12 @@ INSTALLED_APPS = [
'django.contrib.sessions', 'django.contrib.sessions',
'django.contrib.messages', 'django.contrib.messages',
'django.contrib.staticfiles', 'django.contrib.staticfiles',
'rest_framework',
'rest_framework.authtoken',
'drf_spectacular',
'corsheaders', 'corsheaders',
'stratoflights_api.apps.StratoflightsApiConfig', 'stratoflights_api.apps.StratoflightsApiConfig',
'channels', 'channels',
'dmr', # required to serve OpenAPI docs static assets
] ]
MIDDLEWARE = [ MIDDLEWARE = [
@ -169,16 +165,25 @@ STATIC_URL = 'static/'
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
AUTH_USER_MODEL = 'stratoflights_api.User' AUTH_USER_MODEL = 'stratoflights_api.User'
# django-modern-rest configuration. REST_FRAMEWORK = {
# Global auth mirrors the previous DRF default (Token + Session); the DRF 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',
# TokenAuthentication is replaced by DMR's JWT (approved drift). Providing 'PAGE_SIZE': 100,
# several auth instances means at least one of them must succeed, so this
# also enforces a 401 by default like DRF's global IsAuthenticated did. 'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema',
# Public endpoints opt out per-endpoint with @modify(auth=None).
DMR_SETTINGS = { 'DEFAULT_AUTHENTICATION_CLASSES': [
Settings.auth: [JWTSyncAuth(), DjangoSessionSyncAuth()], 'rest_framework.authentication.TokenAuthentication',
Settings.validate_responses: not PRODUCTION, 'rest_framework.authentication.SessionAuthentication',
Settings.openapi_config: OpenAPIConfig(title='Stratoflights API', version='1.0.0'), ],
'DEFAULT_PERMISSION_CLASSES': [
'rest_framework.permissions.IsAuthenticated',
# 'rest_framework.permissions.AllowAny',
],
'DEFAULT_RENDERER_CLASSES': [
'rest_framework.renderers.JSONRenderer',
],
} }

View file

@ -1,17 +1,27 @@
"""URL configuration for the stratoflights project.""" """
URL configuration for stratoflights project.
The `urlpatterns` list routes URLs to views. For more information please see:
https://docs.djangoproject.com/en/4.2/topics/http/urls/
Examples:
Function views
1. Add an import: from my_app import views
2. Add a URL to urlpatterns: path('', views.home, name='home')
Class-based views
1. Add an import: from other_app.views import Home
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
Including another URLconf
1. Import the include() function: from django.urls import include, path
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
"""
from django.contrib import admin from django.contrib import admin
from django.urls import include from django.urls import path, include
from dmr.openapi import build_schema from drf_spectacular.views import SpectacularAPIView
from dmr.openapi.views import OpenAPIJsonView, SwaggerView from drf_spectacular.views import SpectacularSwaggerView
from dmr.routing import path
from stratoflights_api.urls import router
schema = build_schema(router)
urlpatterns = [ urlpatterns = [
path('admin/', admin.site.urls), path('admin/', admin.site.urls),
path(router.prefix, include(router.urls)), path('api/', include('stratoflights_api.urls')),
path('api/schema/', OpenAPIJsonView.as_view(schema), name='schema'), path('api/schema/', SpectacularAPIView.as_view(), name='schema'),
path('api/docs/', SwaggerView.as_view(schema), name='docs'), path('api/docs/', SpectacularSwaggerView.as_view(url_name='schema'), name='docs'),
] ]

View file

@ -18,19 +18,17 @@ class TelemetryConsumer(AsyncWebsocketConsumer):
async def receive(self, text_data): async def receive(self, text_data):
from pydantic import ValidationError from .serializers import TelemetryPacketSerializer
from .dtos import TelemetryIn
if not self.write_enabled: if not self.write_enabled:
await self.send(text_data=json.dumps({"error": "Read-only mode"})) await self.send(text_data=json.dumps({"error": "Read-only mode"}))
return return
data = json.loads(text_data) data = json.loads(text_data)
serializer = TelemetryPacketSerializer(data=data)
try: if not serializer.is_valid():
TelemetryIn(**data) await self.send(text_data=json.dumps({"error": serializer.errors}))
except ValidationError as exc:
await self.send(text_data=json.dumps({"error": json.loads(exc.json())}))
return return
saved_data = await self.save_telemetry(data) saved_data = await self.save_telemetry(data)
@ -53,6 +51,15 @@ class TelemetryConsumer(AsyncWebsocketConsumer):
async def telemetry_message(self, event): async def telemetry_message(self, event):
await self.send(text_data=json.dumps(event["data"])) await self.send(text_data=json.dumps(event["data"]))
@database_sync_to_async
def get_user_from_token(self, token_key, Token):
from rest_framework.authtoken.models import Token
User = get_user_model()
try:
token = Token.objects.select_related("user").get(key=token_key)
return token.user
except Token.DoesNotExist:
return None
@database_sync_to_async @database_sync_to_async
def save_telemetry(self, data): def save_telemetry(self, data):
@ -92,20 +99,12 @@ class StationTelemetryConsumer(TelemetryConsumer):
write_enabled = True write_enabled = True
async def connect(self): async def connect(self):
from django.conf import settings from rest_framework.authtoken.models import Token
from dmr.security.jwt.token import JWToken
token_key = self.scope["query_string"].decode().split("token=")[-1] token_key = self.scope["query_string"].decode().split("token=")[-1]
try: try:
decoded = JWToken.decode( token = await database_sync_to_async(Token.objects.select_related("user").get)(key=token_key)
encoded_token=token_key, self.scope["user"] = token.user
secret=settings.SECRET_KEY, except Token.DoesNotExist:
algorithm='HS256',
)
self.scope["user"] = await database_sync_to_async(
get_user_model().objects.get
)(pk=decoded.sub)
except Exception:
await self.close() await self.close()
return return

View file

@ -0,0 +1,17 @@
from rest_framework.pagination import LimitOffsetPagination
from rest_framework.response import Response
class CustomLimitOffsetPagination(LimitOffsetPagination):
limit_query_param = 'limit'
offset_query_param = 'skip'
max_limit = 100
default_limit = 10
def get_paginated_response(self, data):
return Response({
'total': self.count,
'limit': self.limit,
'skip': self.offset,
'predictions': data
})

View file

@ -1,333 +0,0 @@
"""Pydantic DTOs replacing the DRF serializers from serializers.py.
Mapping is 1:1 with the former serializers. Output DTOs enable ``from_attributes``
so an endpoint can build them straight from a Django model instance via
``SomeOut.model_validate(obj)`` (the equivalent of ``Serializer(obj).data``).
Field-level and cross-field validation that DRF kept in ``validate_<field>()`` /
``validate()`` is reproduced with pydantic ``field_validator`` / ``model_validator``.
Business logic DRF kept in ``create()`` (base64 curve decoding, ORM writes, and
resolving related objects from their PKs) is intentionally NOT here -- per DMR's
split it moves into the endpoints (Step 2.9).
"""
import uuid
from datetime import datetime
from typing import Any, Optional
from django.contrib.auth.password_validation import validate_password
from django.core.validators import validate_email
from django.core.exceptions import ValidationError as DjangoValidationError
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
model_validator,
)
from .validators import validate_custom_curve, rate_clip
# Profile constants (moved verbatim from serializers.py).
PROFILE_STANDARD = "standard_profile"
PROFILE_FLOAT = "float_profile"
PROFILE_REVERSE = "reverse_profile"
PROFILE_CUSTOM = "custom_profile"
LATEST_DATASET_KEYWORD = "latest"
SUPPORTED_PROFILES = [PROFILE_STANDARD, PROFILE_FLOAT, PROFILE_REVERSE, PROFILE_CUSTOM]
# --- Prediction request (was PredictionRequestSerializer) ---------------------
class PredictionRequest(BaseModel):
launch_latitude: float = Field(ge=-90, le=90)
launch_longitude: float = Field(ge=0, le=360)
launch_datetime: datetime
launch_altitude: Optional[float] = None
format: str = "json"
profile: str = PROFILE_STANDARD
dataset: str = LATEST_DATASET_KEYWORD
# profile-dependent fields
ascent_rate: Optional[float] = Field(default=None, ge=0.01)
descent_rate: Optional[float] = Field(default=None, ge=0.01)
burst_altitude: Optional[float] = None
float_altitude: Optional[float] = None
stop_datetime: Optional[datetime] = None
ascent_curve: Optional[str] = None
descent_curve: Optional[str] = None
interpolate: bool = False
# Related objects are accepted as PKs; existence is resolved in the endpoint
# (was PrimaryKeyRelatedField(queryset=...)).
start_point: Optional[int] = None
rate_profile: Optional[int] = None
template: Optional[int] = None
@field_validator("profile")
@classmethod
def _validate_profile(cls, value: str) -> str:
if value not in SUPPORTED_PROFILES:
raise ValueError(f'"{value}" is not a valid choice.')
return value
@model_validator(mode="after")
def _validate_cross_fields(self) -> "PredictionRequest":
launch_alt = self.launch_altitude if self.launch_altitude is not None else 0
if self.profile == PROFILE_STANDARD:
if self.burst_altitude is None:
raise ValueError("burst_altitude is required for standard profile.")
if self.burst_altitude <= launch_alt:
raise ValueError("burst_altitude must be greater than launch_altitude.")
elif self.profile == PROFILE_FLOAT:
if self.float_altitude is None or self.float_altitude <= launch_alt:
raise ValueError("float_altitude must be greater than launch_altitude.")
if self.stop_datetime is None or self.stop_datetime <= self.launch_datetime:
raise ValueError("stop_datetime must be later than launch_datetime.")
elif self.profile == PROFILE_CUSTOM:
if self.ascent_curve is None or not validate_custom_curve(self.ascent_curve):
raise ValueError("Invalid ascent_curve.")
if self.descent_curve is None or not validate_custom_curve(self.descent_curve):
raise ValueError("Invalid descent_curve.")
if self.burst_altitude is None or self.burst_altitude <= launch_alt:
raise ValueError("burst_altitude must be greater than launch_altitude.")
# custom clipping logic (was in validate())
if self.ascent_rate is not None:
self.ascent_rate = rate_clip(self.ascent_rate)
if self.descent_rate is not None:
self.descent_rate = rate_clip(self.descent_rate)
return self
# --- Prediction outputs -------------------------------------------------------
class PredictionOut(BaseModel):
"""Was PredictionSerializer."""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
created_at: datetime
updated_at: datetime
result: Any
class PredictionListOut(BaseModel):
"""Was PredictionListSerializer."""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
created_at: datetime
updated_at: datetime
start_point: Optional[int] = Field(default=None, validation_alias="start_point_id")
template: Optional[int] = Field(default=None, validation_alias="template_id")
rate_profile: Optional[int] = Field(default=None, validation_alias="rate_profile_id")
class PredictionDetailOut(BaseModel):
"""Was PredictionDetailSerializer."""
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
created_at: datetime
updated_at: datetime
result: Any
start_point: Optional[int] = Field(default=None, validation_alias="start_point_id")
template: Optional[int] = Field(default=None, validation_alias="template_id")
rate_profile: Optional[int] = Field(default=None, validation_alias="rate_profile_id")
class PredictionCreateOut(BaseModel):
"""Custom create() response shape: {id, created_at, result} (not the full model)."""
id: uuid.UUID
created_at: datetime
result: Any
# --- Telemetry (was TelemetryPacketSerializer) --------------------------------
class TelemetryIn(BaseModel):
# timestamp is required (model BigIntegerField has no default), matching the
# former ModelSerializer; the endpoint still overwrites it with server time.
timestamp: int
lat: float = 0.0
lon: float = 0.0
alt: float = 0.0
payload: Any = Field(default_factory=dict)
class TelemetryOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: uuid.UUID
timestamp: int
lat: float
lon: float
alt: float
payload: Any
# --- SavedPoint (was SavedPointSerializer) ------------------------------------
# `user` was a HiddenField(CurrentUserDefault) -> set from request in the endpoint,
# not part of the wire input/output. The unique_together (user, name) check that
# DRF did via UniqueTogetherValidator is reproduced in the endpoint (Step 2.8).
class SavedPointIn(BaseModel):
name: str
lat: float = 0.0
lon: float = 0.0
alt: float = 0.0
class SavedPointPatchIn(BaseModel):
# All-optional variant for PATCH (partial_update). Only set fields apply.
name: Optional[str] = None
lat: Optional[float] = None
lon: Optional[float] = None
alt: Optional[float] = None
class SavedPointOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
name: str
lat: float
lon: float
alt: float
# --- SavedRateProfile (was SavedRateProfileSerializer) ------------------------
# NOTE: no view referenced this serializer in routing; kept for parity.
class SavedRateProfileIn(BaseModel):
name: str
type: str = "ascent"
rate_profile_data: Any = Field(default_factory=dict)
class SavedRateProfileOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int
name: str
type: str
rate_profile_data: Any
# --- PredictionTemplate (was PreditctionTemplateSerializer) -------------------
class PredictionTemplateIn(BaseModel):
# protected_namespaces=() avoids pydantic's warning about the `model` field.
model_config = ConfigDict(protected_namespaces=())
name: str
is_default: bool = False
description: Optional[str] = None
prediction_mode: str = ""
model: str = ""
dataset: str = ""
flight_parameters: Any = Field(default_factory=dict)
class PredictionTemplatePatchIn(BaseModel):
# All-optional variant for PATCH (partial_update). Only set fields apply.
model_config = ConfigDict(protected_namespaces=())
name: Optional[str] = None
is_default: Optional[bool] = None
description: Optional[str] = None
prediction_mode: Optional[str] = None
model: Optional[str] = None
dataset: Optional[str] = None
flight_parameters: Optional[Any] = None
class PredictionTemplateOut(BaseModel):
model_config = ConfigDict(from_attributes=True, protected_namespaces=())
id: int
name: str
is_default: bool
description: Optional[str] = None
prediction_mode: str
model: str
dataset: str
flight_parameters: Any
# --- User (was UserSerializer) ------------------------------------------------
class UserOut(BaseModel):
model_config = ConfigDict(from_attributes=True)
username: str
email: str
first_name: str
last_name: str
class UserUpdateIn(BaseModel):
# `username` was read_only -> excluded from input. PATCH is partial, so all
# fields are optional; the endpoint applies only fields that were set.
email: Optional[str] = None
first_name: Optional[str] = None
last_name: Optional[str] = None
@field_validator("email")
@classmethod
def _validate_email(cls, value: Optional[str]) -> Optional[str]:
if value is None:
return value
try:
validate_email(value)
except DjangoValidationError:
raise ValueError("Invalid email format")
return value
# --- Password / account (were ChangePasswordSerializer / DeleteAccountSerializer)
class ChangePasswordIn(BaseModel):
old_password: str
new_password: str
@field_validator("new_password")
@classmethod
def _validate_new_password(cls, value: str) -> str:
# validate_password raises Django's ValidationError (not a ValueError),
# which pydantic would not convert into a 4xx -- re-raise as ValueError.
try:
validate_password(value)
except DjangoValidationError as exc:
raise ValueError("; ".join(exc.messages))
return value
class DeleteAccountIn(BaseModel):
password: str
# --- Auth / session response shapes (were inline JsonResponse dicts) ----------
class DetailResponse(BaseModel):
detail: str
class SessionResponse(BaseModel):
isAuthenticated: bool
class WhoAmIResponse(BaseModel):
username: str
class TokenResponse(BaseModel):
token: str
# --- Shared path parameters ---------------------------------------------------
class PkPath(BaseModel):
pk: int
class UuidPkPath(BaseModel):
pk: uuid.UUID

View file

@ -1,120 +0,0 @@
"""Limit/offset pagination for the prediction list endpoints.
DMR has no built-in limit/offset paginator, so (as the docs suggest) we keep our
own envelope model -- matching the former ``CustomLimitOffsetPagination`` output
exactly -- plus a helper that slices the queryset. The query params and envelope
keys are preserved verbatim: ``limit`` / ``skip`` in, and
``{total, limit, skip, predictions}`` out.
"""
from http import HTTPStatus
from typing import Optional
from urllib.parse import urlsplit, urlunsplit
from django.core.paginator import InvalidPage, Paginator
from django.http import QueryDict
from pydantic import BaseModel
from dmr.response import APIError
from .dtos import PredictionOut, TelemetryOut
DEFAULT_LIMIT = 10 # was CustomLimitOffsetPagination.default_limit
MAX_LIMIT = 100 # was CustomLimitOffsetPagination.max_limit
PAGE_SIZE = 100 # global REST_FRAMEWORK PAGE_SIZE (PageNumberPagination)
class PaginationQuery(BaseModel):
# Param names preserved: limit_query_param='limit', offset_query_param='skip'.
limit: int = DEFAULT_LIMIT
skip: int = 0
class PredictionListUserQuery(BaseModel):
"""Query params for PredictionViewSet.list_user: filters + limit/offset."""
satellite_id: Optional[str] = None
created_from: Optional[str] = None
created_till: Optional[str] = None
limit: int = DEFAULT_LIMIT
skip: int = 0
class PredictionPage(BaseModel):
total: int
limit: int
skip: int
predictions: list[PredictionOut]
class TelemetryPage(BaseModel):
count: int
next: Optional[str] = None
previous: Optional[str] = None
results: list[TelemetryOut]
def _replace_query_param(url: str, key: str, value) -> str:
scheme, netloc, path, query, fragment = urlsplit(url)
query_dict = QueryDict(query, mutable=True)
query_dict[key] = value
return urlunsplit((scheme, netloc, path, query_dict.urlencode(), fragment))
def _remove_query_param(url: str, key: str) -> str:
scheme, netloc, path, query, fragment = urlsplit(url)
query_dict = QueryDict(query, mutable=True)
query_dict.pop(key, None)
return urlunsplit((scheme, netloc, path, query_dict.urlencode(), fragment))
def page_number_paginate(request, queryset, page_size: int = PAGE_SIZE):
"""Reproduce DRF PageNumberPagination.
Returns ``(count, next_link, previous_link, object_list)``. Raises a 404
"Invalid page." like DRF on an out-of-range / non-integer page. Honours
``page=last`` and builds absolute next/previous links off the current URL.
"""
paginator = Paginator(queryset, page_size)
page_number = request.GET.get('page', 1)
if page_number in ('last',):
page_number = paginator.num_pages
try:
page = paginator.page(page_number)
except InvalidPage:
raise APIError({'detail': 'Invalid page.'}, status_code=HTTPStatus.NOT_FOUND)
url = request.build_absolute_uri()
next_link = None
if page.has_next():
next_link = _replace_query_param(url, 'page', page.next_page_number())
previous_link = None
if page.has_previous():
previous_number = page.previous_page_number()
if previous_number == 1:
previous_link = _remove_query_param(url, 'page')
else:
previous_link = _replace_query_param(url, 'page', previous_number)
return paginator.count, next_link, previous_link, list(page.object_list)
def paginate_predictions(queryset, query: PaginationQuery) -> PredictionPage:
"""Slice ``queryset`` and build the prediction page envelope.
Mirrors the bounds DRF's LimitOffsetPagination enforced: limit falls back to
the default when non-positive and is capped at MAX_LIMIT; skip floors at 0.
"""
limit = query.limit if query.limit > 0 else DEFAULT_LIMIT
limit = min(limit, MAX_LIMIT)
skip = query.skip if query.skip >= 0 else 0
total = queryset.count()
page = list(queryset[skip:skip + limit])
return PredictionPage(
total=total,
limit=limit,
skip=skip,
predictions=[PredictionOut.model_validate(obj) for obj in page],
)

View file

@ -0,0 +1,16 @@
from rest_framework.permissions import BasePermission, SAFE_METHODS
class ReadOnlyOrAuthenticated(BasePermission):
def has_permission(self, request, view):
return (
request.method in SAFE_METHODS or
request.user and request.user.is_authenticated
)
class IsOwner(BasePermission):
def has_object_permission(self, request, view, obj):
return obj.user == request.user
def has_permission(self, request, view):
return request.user and request.user.is_authenticated

View file

@ -0,0 +1,185 @@
from rest_framework import serializers
from .models import Prediction, SavedPoint, SavedRateProfile, PreditctionTemplate
from datetime import datetime
from django.contrib.auth.password_validation import validate_password
from django.core.validators import validate_email
from django.core.exceptions import ValidationError as DjangoValidationError
from django.contrib.auth import get_user_model
from .validators import (
validate_custom_curve, rate_clip,
_rfc3339_to_timestamp, base64_to_curve
)
class PredictionSerializer(serializers.ModelSerializer):
class Meta:
model = Prediction
fields = ['id', 'created_at', 'updated_at', 'result']
User = get_user_model()
PROFILE_STANDARD = "standard_profile"
PROFILE_FLOAT = "float_profile"
PROFILE_REVERSE = "reverse_profile"
PROFILE_CUSTOM = "custom_profile"
LATEST_DATASET_KEYWORD = "latest"
SUPPORTED_PROFILES = [PROFILE_STANDARD, PROFILE_FLOAT, PROFILE_REVERSE, PROFILE_CUSTOM]
class PredictionRequestSerializer(serializers.Serializer):
launch_latitude = serializers.FloatField(min_value=-90, max_value=90)
launch_longitude = serializers.FloatField(min_value=0, max_value=360)
launch_datetime = serializers.DateTimeField()
launch_altitude = serializers.FloatField(required=False)
format = serializers.CharField(default="json")
profile = serializers.ChoiceField(choices=SUPPORTED_PROFILES, default=PROFILE_STANDARD)
dataset = serializers.CharField(default=LATEST_DATASET_KEYWORD)
# --- профиль-dependent поля ---
ascent_rate = serializers.FloatField(required=False, min_value=0.01)
descent_rate = serializers.FloatField(required=False, min_value=0.01)
burst_altitude = serializers.FloatField(required=False)
float_altitude = serializers.FloatField(required=False)
stop_datetime = serializers.DateTimeField(required=False)
ascent_curve = serializers.CharField(required=False)
descent_curve = serializers.CharField(required=False)
interpolate = serializers.BooleanField(required=False, default=False)
start_point = serializers.PrimaryKeyRelatedField(
queryset=SavedPoint.objects.all(), required=False, allow_null=True
)
rate_profile = serializers.PrimaryKeyRelatedField(
queryset=SavedRateProfile.objects.all(), required=False, allow_null=True
)
template = serializers.PrimaryKeyRelatedField(
queryset=PreditctionTemplate.objects.all(), required=False, allow_null=True
)
def validate(self, data):
profile = data.get("profile", PROFILE_STANDARD)
launch_alt = data.get("launch_altitude", 0)
if profile == PROFILE_STANDARD:
if 'burst_altitude' not in data:
raise serializers.ValidationError("burst_altitude is required for standard profile.")
if data['burst_altitude'] <= launch_alt:
raise serializers.ValidationError("burst_altitude must be greater than launch_altitude.")
elif profile == PROFILE_FLOAT:
if 'float_altitude' not in data or data['float_altitude'] <= launch_alt:
raise serializers.ValidationError("float_altitude must be greater than launch_altitude.")
if 'stop_datetime' not in data or data['stop_datetime'] <= data['launch_datetime']:
raise serializers.ValidationError("stop_datetime must be later than launch_datetime.")
elif profile == PROFILE_CUSTOM:
if 'ascent_curve' not in data or not validate_custom_curve(data['ascent_curve']):
raise serializers.ValidationError("Invalid ascent_curve.")
if 'descent_curve' not in data or not validate_custom_curve(data['descent_curve']):
raise serializers.ValidationError("Invalid descent_curve.")
if 'burst_altitude' not in data or data['burst_altitude'] <= launch_alt:
raise serializers.ValidationError("burst_altitude must be greater than launch_altitude.")
# кастомная логика clipping'а
if 'ascent_rate' in data:
data['ascent_rate'] = rate_clip(data['ascent_rate'])
if 'descent_rate' in data:
data['descent_rate'] = rate_clip(data['descent_rate'])
return data
def create(self, validated_data):
if 'ascent_curve' in validated_data:
validated_data['ascent_curve'] = base64_to_curve(validated_data['ascent_curve'])
if 'descent_curve' in validated_data:
validated_data['descent_curve'] = base64_to_curve(validated_data['descent_curve'])
prediction = Prediction(
user=validated_data.get('user'),
request=validated_data.get('request', {}),
result=validated_data.get('result', {}),
start_point=validated_data.get('start_point'),
template=validated_data.get('template'),
rate_profile=validated_data.get('rate_profile')
)
prediction.save()
return prediction
class PredictionListSerializer(serializers.ModelSerializer):
class Meta:
model = Prediction
fields = ["id", "created_at", "updated_at", "start_point", "template", "rate_profile"]
class PredictionDetailSerializer(serializers.ModelSerializer):
class Meta:
model = Prediction
fields = ["id", "created_at", "updated_at", "result", "start_point", "template", "rate_profile"]
from rest_framework import serializers
from .models import TelemetryPacket
class TelemetryPacketSerializer(serializers.ModelSerializer):
class Meta:
model = TelemetryPacket
fields = ['id', 'timestamp', 'lat', 'lon', 'alt', 'payload']
read_only_fields = ['id']
class SavedPointSerializer(serializers.ModelSerializer):
user = serializers.HiddenField(
default=serializers.CurrentUserDefault()
)
class Meta:
model = SavedPoint
fields = ['user', 'id', 'name', 'lat', 'lon', 'alt']
read_only_fields = ['id']
validators = [
serializers.UniqueTogetherValidator(
queryset=SavedPoint.objects.all(),
fields=['user', 'name'],
message="A saved point with this name already exists for the user."
)
]
class SavedRateProfileSerializer(serializers.ModelSerializer):
class Meta:
model = SavedRateProfile
fields = ['id', 'name', 'type', 'rate_profile_data']
read_only_fields = ['id']
class PreditctionTemplateSerializer(serializers.ModelSerializer):
class Meta:
model = PreditctionTemplate
fields = ['id', 'name', 'is_default', 'description', 'prediction_mode', 'model', 'dataset', 'flight_parameters']
read_only_fields = ['id']
class UserSerializer(serializers.ModelSerializer):
class Meta:
model = User
fields = ['username', 'email', 'first_name', 'last_name']
extra_kwargs = {
'username': {'read_only': True}
}
def validate_email(self, value):
try:
validate_email(value)
except DjangoValidationError:
raise serializers.ValidationError("Invalid email format")
return value
class ChangePasswordSerializer(serializers.Serializer):
old_password = serializers.CharField(required=True)
new_password = serializers.CharField(required=True)
def validate_new_password(self, value):
validate_password(value)
return value
class DeleteAccountSerializer(serializers.Serializer):
password = serializers.CharField(required=True)

View file

@ -1,59 +1,48 @@
from dmr.routing import Router, path from django.urls import path
from rest_framework.routers import DefaultRouter
from rest_framework.authtoken.views import obtain_auth_token
from .views import ( from .views import (
PredictionCollectionController, PredictionViewSet,
PredictionListUserController, SavedPointViewset,
PredictionHistoryController, PreditctionTemplateViewset,
PredictionDetailController, TelemetryListCreateView,
PredictionDeleteController, get_csrf,
SavedPointListController, login_view,
SavedPointDetailController, logout_view,
PredictionTemplateListController, SessionView,
PredictionTemplateDetailController, WhoAmIView,
TelemetryController, UserProfileView,
CsrfController, ChangePasswordView,
LoginController, TokenManagementView,
LogoutController, DeleteUserDataView,
SessionController, DeleteAccountView
WhoAmIController,
UserProfileController,
ChangePasswordController,
ObtainTokenController,
TokenManagementController,
DeleteUserDataController,
DeleteAccountController,
) )
# A Router (prefix + routes) is required so build_schema() can generate the
# OpenAPI document; see stratoflights/urls.py.
router = Router(
'api/',
[
path("csrf/", CsrfController.as_view(), name='api-csrf'),
path('token', ObtainTokenController.as_view(), name='get_token'),
path("login/", LoginController.as_view(), name='api-login'),
path("logout/", LogoutController.as_view(), name='api-logout'),
path("session/", SessionController.as_view(), name='api-session'),
path("whoami/", WhoAmIController.as_view(), name='api-whoami'),
path("<uuid:pk>/telemetry/", TelemetryController.as_view(), name="create_telemetry"),
path("profile/", UserProfileController.as_view(), name='api-profile'),
path("profile/change-password/", ChangePasswordController.as_view(), name='api-change-password'),
path("profile/token/", TokenManagementController.as_view(), name='api-token'),
path("profile/delete-data/", DeleteUserDataController.as_view(), name='api-delete-data'),
path("profile/delete-account/", DeleteAccountController.as_view(), name='api-delete-account'),
# Saved points (was SavedPointViewset via DefaultRouter).
path("saved-points/", SavedPointListController.as_view(), name='saved-points-list'),
path("saved-points/<int:pk>/", SavedPointDetailController.as_view(), name='saved-points-detail'),
# Prediction templates (was PreditctionTemplateViewset via DefaultRouter).
path("saved-templates/", PredictionTemplateListController.as_view(), name='saved-templates-list'),
path("saved-templates/<int:pk>/", PredictionTemplateDetailController.as_view(), name='saved-templates-detail'),
# Predictions (was PredictionViewSet via DefaultRouter).
path("predictions/", PredictionCollectionController.as_view(), name='predictions-list'),
path("predictions/list_user/", PredictionListUserController.as_view(), name='predictions-list-user'),
path("predictions/history/", PredictionHistoryController.as_view(), name='predictions-history'),
path("predictions/<uuid:pk>/detail/", PredictionDetailController.as_view(), name='predictions-detail'),
path("predictions/<uuid:pk>/delete/", PredictionDeleteController.as_view(), name='predictions-delete'),
],
)
urlpatterns = router.urls router = DefaultRouter()
router.register(r'predictions', PredictionViewSet, basename='predictions')
router.register(r'saved-points', SavedPointViewset, basename='saved-points')
router.register(r'saved-templates', PreditctionTemplateViewset, basename='saved-templates')
urlpatterns = [
path("csrf/", get_csrf, name='api-csrf'),
path('token', obtain_auth_token, name = 'get_token'),
path("login/", login_view, name='api-login'),
path("logout/", logout_view, name='api-logout'),
path("session/", SessionView.as_view(), name='api-session'),
path("whoami/", WhoAmIView.as_view(), name='api-whoami'),
path("<uuid:pk>/telemetry/", TelemetryListCreateView.as_view(), name="create_telemetry"),
path('csrf/', get_csrf, name='api-csrf'),
path('login/', login_view, name='api-login'),
path('logout/', logout_view, name='api-logout'),
path('session/', SessionView.as_view(), name='api-session'),
path('whoami/', WhoAmIView.as_view(), name='api-whoami'),
path("profile/", UserProfileView.as_view(), name='api-profile'),
path("profile/change-password/", ChangePasswordView.as_view(), name='api-change-password'),
path("profile/token/", TokenManagementView.as_view(), name='api-token'),
path("profile/delete-data/", DeleteUserDataView.as_view(), name='api-delete-data'),
path("profile/delete-account/", DeleteAccountView.as_view(), name='api-delete-account'),
]
urlpatterns += router.urls

View file

@ -1,477 +1,305 @@
import requests import requests
import time import time
import json import json
from http import HTTPStatus from rest_framework import status, generics, permissions
from rest_framework.response import Response
from rest_framework.views import APIView
from rest_framework.viewsets import ModelViewSet, ViewSet, GenericViewSet
from rest_framework.exceptions import APIException
from rest_framework.permissions import IsAuthenticated, AllowAny
from rest_framework.authentication import SessionAuthentication, BasicAuthentication, TokenAuthentication
from rest_framework.decorators import api_view, permission_classes, authentication_classes, action
from rest_framework.authtoken.models import Token
from django.utils import timezone
from django.views.decorators.csrf import csrf_exempt
from django.utils.decorators import method_decorator
from django.http import JsonResponse
from django.contrib.auth import authenticate, login, logout, get_user_model from django.contrib.auth import authenticate, login, logout, get_user_model
from django.middleware.csrf import get_token from django.middleware.csrf import get_token
from django.core.exceptions import ValidationError
from django.utils.dateparse import parse_datetime from django.utils.dateparse import parse_datetime
from .models import Prediction, Satellite, SavedPoint, SavedRateProfile, PreditctionTemplate, TelemetryPacket from .models import Prediction, User, Satellite, SavedPoint, SavedRateProfile, PreditctionTemplate, TelemetryPacket
from .serializers import PredictionSerializer, TelemetryPacketSerializer, PredictionRequestSerializer, PredictionListSerializer, PredictionDetailSerializer, SavedPointSerializer, SavedRateProfileSerializer, PreditctionTemplateSerializer, UserSerializer, ChangePasswordSerializer, DeleteAccountSerializer
from .services.tawhiri import TawhiriClient from .services.tawhiri import TawhiriClient
from datetime import datetime, timedelta, timezone from drf_spectacular.utils import extend_schema
from .permissions import ReadOnlyOrAuthenticated, IsOwner
from pydantic import ValidationError from .custom_pagination import CustomLimitOffsetPagination
from dmr import Controller, modify from datetime import datetime
from dmr.components import Body, Path, Query
from dmr.plugins.pydantic import PydanticSerializer
from dmr.response import APIError
from dmr.security.jwt.views import (
ObtainTokensPayload,
ObtainTokensResponse,
ObtainTokensSyncController,
)
from .dtos import (
DetailResponse,
SessionResponse,
WhoAmIResponse,
UserOut,
UserUpdateIn,
ChangePasswordIn,
DeleteAccountIn,
TokenResponse,
PkPath,
UuidPkPath,
SavedPointIn,
SavedPointPatchIn,
SavedPointOut,
PredictionTemplateIn,
PredictionTemplatePatchIn,
PredictionTemplateOut,
PredictionRequest,
PredictionOut,
PredictionListOut,
PredictionDetailOut,
PredictionCreateOut,
TelemetryIn,
TelemetryOut,
)
from .pagination import (
PredictionListUserQuery,
PredictionPage,
paginate_predictions,
TelemetryPage,
page_number_paginate,
)
from .validators import base64_to_curve
User = get_user_model() User = get_user_model()
def _api_error(status_code, body, headers=None): def get_prediction_from_tawhiri(params):
"""Adapter for DMR's APIError(raw_data, *, status_code, ...) signature.
Lets call sites keep the (status_code=, body=) style; the body is DMR's base_url = "https://fly.stratonautica.ru/api/v2"
first positional ``raw_data`` argument. response = requests.get(base_url, params=params)
"""
return APIError(body, status_code=status_code, headers=headers) if response.status_code == 200:
return response.json() # получаем результат предсказания
else:
raise Exception(
f"Tawhiri error: {response.status_code} {response.text}")
def _resolve_related(model, pk, field_name): class PredictionViewSet(GenericViewSet):
"""Resolve a related object by PK (was PrimaryKeyRelatedField(queryset=...)). permission_classes = [IsAuthenticated]
pagination_class = CustomLimitOffsetPagination
The original queryset was ``.all()`` (not user-scoped); a missing PK yields a def list(self, request):
400 matching DRF's "Invalid pk" error shape. queryset = Prediction.objects.filter(user=request.user)
""" return Response(PredictionSerializer(queryset, many=True).data)
if pk is None:
return None
obj = model.objects.filter(pk=pk).first()
if obj is None:
raise _api_error(
status_code=HTTPStatus.BAD_REQUEST,
body={field_name: [f'Invalid pk "{pk}" - object does not exist.']},
)
return obj
def create(self, request):
serializer = PredictionRequestSerializer(data=request.data)
class PredictionCollectionController(Controller[PydanticSerializer]): if not serializer.is_valid():
"""`predictions/` -- list the user's predictions and create a new one.""" return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def get(self) -> list[PredictionOut]: validated_data = serializer.validated_data
queryset = Prediction.objects.filter(user=self.request.user)
return [PredictionOut.model_validate(obj) for obj in queryset]
def post(self, parsed_body: Body[PredictionRequest]) -> PredictionCreateOut:
user = self.request.user
# Resolve related objects before calling Tawhiri, so an invalid PK still
# fails with 400 first (as DRF validation did).
start_point = _resolve_related(SavedPoint, parsed_body.start_point, 'start_point')
template = _resolve_related(PreditctionTemplate, parsed_body.template, 'template')
rate_profile = _resolve_related(SavedRateProfile, parsed_body.rate_profile, 'rate_profile')
try: try:
prediction_result = TawhiriClient.get_prediction(parsed_body.model_dump()) prediction_result = TawhiriClient.get_prediction(validated_data)
except requests.RequestException as exc:
raise _api_error(
status_code=HTTPStatus.BAD_GATEWAY,
body={'error': f'Tawhiri error: {str(exc)}'},
)
# Carried over from the old serializer.create(): curves are decoded but except requests.RequestException as e:
# never persisted (the model has no curve fields). Kept for parity -- return Response({"error": f"Tawhiri error: {str(e)}"}, status=status.HTTP_502_BAD_GATEWAY)
# including that a malformed curve here raises (then surfaces as 500).
if parsed_body.ascent_curve is not None:
base64_to_curve(parsed_body.ascent_curve)
if parsed_body.descent_curve is not None:
base64_to_curve(parsed_body.descent_curve)
prediction = Prediction( # prediction = Prediction.objects.create(
user=user, # result=prediction_result, user=request.user, request=request.data, validated_data=validated_data)
request=json.loads(self.request.body), prediction = serializer.save(
result=prediction_result, user=request.user,
start_point=start_point,
template=template,
rate_profile=rate_profile,
)
prediction.save()
return PredictionCreateOut(
id=prediction.id,
created_at=prediction.created_at,
result=prediction_result, result=prediction_result,
request=request.data
) )
return Response({
"id": prediction.id,
"created_at": prediction.created_at,
"result": prediction_result
}, status=status.HTTP_201_CREATED)
class PredictionListUserController(Controller[PydanticSerializer]): @action(detail=False, methods=['get'])
"""`predictions/list_user/` -- filtered, paginated list of the user's predictions.""" def list_user(self, request):
user = request.user
satellite_id = request.query_params.get('satellite_id')
created_from = request.query_params.get('created_from')
created_till = request.query_params.get('created_till')
def get(self, parsed_query: Query[PredictionListUserQuery]) -> PredictionPage: filters = {
user = self.request.user 'user': user,
filters = {'user': user} }
if parsed_query.created_from: if created_from:
filters['created_at__gte'] = parse_datetime(parsed_query.created_from) filters['created_at__gte'] = parse_datetime(created_from)
if parsed_query.created_till:
filters['created_at__lte'] = parse_datetime(parsed_query.created_till) if created_till:
if parsed_query.satellite_id: filters['created_at__lte'] = parse_datetime(created_till)
if not user.satellites.filter(id=parsed_query.satellite_id).exists():
raise _api_error(status_code=HTTPStatus.FORBIDDEN, body={'detail': 'Access denied'}) if satellite_id:
filters['satellite_id'] = parsed_query.satellite_id if not user.satellites.filter(id=satellite_id).exists():
return Response({'detail': 'Access denied'}, status=403)
filters['satellite_id'] = satellite_id
queryset = Prediction.objects.filter(**filters) queryset = Prediction.objects.filter(**filters)
return paginate_predictions(queryset, parsed_query) queryset = self.filter_queryset(queryset)
page = self.paginate_queryset(queryset)
class PredictionHistoryController(Controller[PydanticSerializer]): if page is not None:
"""`predictions/history/` -- compact list of the user's predictions.""" serializer = PredictionSerializer(page, many=True)
return self.get_paginated_response(serializer.data)
def get(self) -> list[PredictionListOut]: serializer = PredictionSerializer(queryset, many=True)
queryset = Prediction.objects.filter(user=self.request.user) return Response(serializer.data)
return [PredictionListOut.model_validate(obj) for obj in queryset]
@action(detail=False, methods=["get"])
def history(self, request):
queryset = Prediction.objects.filter(user=request.user)
return Response(PredictionListSerializer(queryset, many=True).data)
class PredictionDetailController(Controller[PydanticSerializer]): @action(detail=True, methods=["get"])
"""`predictions/<pk>/detail/` -- retrieve a single prediction.""" def detail(self, request, pk=None):
def get(self, parsed_path: Path[UuidPkPath]) -> PredictionDetailOut:
prediction = Prediction.objects.filter( prediction = Prediction.objects.filter(
user=self.request.user, pk=parsed_path.pk).first() user=request.user, pk=pk).first()
if prediction is None: if not prediction:
raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'}) return Response({'detail': 'Not found'}, status=404)
return PredictionDetailOut.model_validate(prediction) return Response(PredictionDetailSerializer(prediction).data)
@action(detail=True, methods=["delete"])
class PredictionDeleteController(Controller[PydanticSerializer]): def delete(self, request, pk=None):
"""`predictions/<pk>/delete/` -- delete a single prediction."""
@modify(status_code=HTTPStatus.NO_CONTENT)
def delete(self, parsed_path: Path[UuidPkPath]) -> None:
prediction = Prediction.objects.filter( prediction = Prediction.objects.filter(
user=self.request.user, pk=parsed_path.pk).first() user=request.user, pk=pk).first()
if prediction is None: if not prediction:
raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'}) return Response({'detail': 'Not found'}, status=404)
prediction.delete() prediction.delete()
return Response(status=204)
class TelemetryController(Controller[PydanticSerializer]): class TelemetryListCreateView(generics.ListCreateAPIView):
"""`<uuid:pk>/telemetry/` -- list and ingest telemetry for a satellite. serializer_class = TelemetryPacketSerializer
permission_classes = [permissions.AllowAny]
Public (was AllowAny). GET preserves the global PageNumberPagination envelope def get_queryset(self):
{count, next, previous, results} with PAGE_SIZE=100. qs = TelemetryPacket.objects.filter(satellite_id=self.kwargs["pk"])
"""
auth = None # public (was AllowAny) from_ts = self.request.query_params.get("from")
csrf_exempt = True # DRF views bypass Django CSRF; ingestion is anonymous till_ts = self.request.query_params.get("till")
def get(self, parsed_path: Path[UuidPkPath]) -> TelemetryPage:
qs = TelemetryPacket.objects.filter(satellite_id=parsed_path.pk)
from_ts = self.request.GET.get('from')
till_ts = self.request.GET.get('till')
if from_ts: if from_ts:
qs = qs.filter(timestamp__gte=int(from_ts)) qs = qs.filter(timestamp__gte=int(from_ts))
if till_ts: if till_ts:
qs = qs.filter(timestamp__lte=int(till_ts)) qs = qs.filter(timestamp__lte=int(till_ts))
qs = qs.order_by('-timestamp') return qs.order_by("-timestamp")
count, next_link, previous_link, page_objects = page_number_paginate(self.request, qs) def post(self, request, pk):
return TelemetryPage( serializer = TelemetryPacketSerializer(data=request.data)
count=count, if not serializer.is_valid():
next=next_link, return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
previous=previous_link,
results=[TelemetryOut.model_validate(obj) for obj in page_objects], validated_data = serializer.validated_data
TelemetryPacket.objects.create(timestamp=time.time(),
satellite=Satellite.objects.get(id=pk),
lat=validated_data["lat"],
lon=validated_data["lon"],
alt=validated_data["alt"],
payload=validated_data['payload'],
) )
return Response(serializer.errors, status=status.HTTP_201_CREATED)
def post(self, parsed_path: Path[UuidPkPath], parsed_body: Body[TelemetryIn]) -> TelemetryOut:
# Bug fix (approved): the original returned serializer.errors on success;
# we return the created packet instead. timestamp is still server-set.
packet = TelemetryPacket.objects.create(
timestamp=time.time(),
satellite=Satellite.objects.get(id=parsed_path.pk),
lat=parsed_body.lat,
lon=parsed_body.lon,
alt=parsed_body.alt,
payload=parsed_body.payload,
)
return TelemetryOut.model_validate(packet)
class SessionController(Controller[PydanticSerializer]): class SessionView(APIView):
"""Report whether the current request is authenticated.""" permission_classes = [IsAuthenticated]
def get(self) -> SessionResponse: @staticmethod
return SessionResponse(isAuthenticated=True) def get(request, format=None):
return JsonResponse({'isAuthenticated': True})
class WhoAmIController(Controller[PydanticSerializer]): class WhoAmIView(APIView):
"""Return the current user's username.""" permission_classes = [IsAuthenticated]
def get(self) -> WhoAmIResponse: @staticmethod
return WhoAmIResponse(username=self.request.user.username) def get(request, format=None):
return JsonResponse({'username': request.user.username})
class CsrfController(Controller[PydanticSerializer]): @extend_schema(methods=["GET"], description="Get CSRF token")
"""Get CSRF token.""" @csrf_exempt
@api_view(["GET"])
auth = None # public endpoint (was AllowAny) @permission_classes([AllowAny])
csrf_exempt = True def get_csrf(request):
response = JsonResponse({'detail': 'CSRF cookie set'})
def get(self) -> DetailResponse: response['X-CSRFToken'] = get_token(request)
token = get_token(self.request) return response
return self.to_response(
DetailResponse(detail='CSRF cookie set'),
headers={'X-CSRFToken': token},
)
class LoginController(Controller[PydanticSerializer]): @extend_schema(methods=["POST"], description="Login user")
"""Login user.""" @csrf_exempt
@api_view(["POST"])
auth = None # public endpoint (was AllowAny) @authentication_classes([BasicAuthentication])
csrf_exempt = True @permission_classes([AllowAny])
def login_view(request):
# post() would default to 201; the original returned 200. data = json.loads(request.body)
@modify(status_code=HTTPStatus.OK)
def post(self) -> DetailResponse:
data = json.loads(self.request.body)
username = data.get('username') username = data.get('username')
password = data.get('password') password = data.get('password')
if username is None or password is None: if username is None or password is None:
raise _api_error( return JsonResponse({'detail': 'Please provide username and password.'}, status=400)
status_code=HTTPStatus.BAD_REQUEST,
body={'detail': 'Please provide username and password.'},
)
user = authenticate(username=username, password=password) user = authenticate(username=username, password=password)
if user is None: if user is None:
raise _api_error( return JsonResponse({'detail': 'Invalid credentials.'}, status=400)
status_code=HTTPStatus.BAD_REQUEST,
body={'detail': 'Invalid credentials.'},
)
login(self.request, user) login(request, user)
return DetailResponse(detail='Successfully logged in.') return JsonResponse({'detail': 'Successfully logged in.'})
class LogoutController(Controller[PydanticSerializer]): @extend_schema(methods=["POST"], description="Logout user")
"""Logout user.""" @api_view(["POST"])
@permission_classes([AllowAny])
def logout_view(request):
if not request.user.is_authenticated:
return JsonResponse({'detail': 'You\'re not logged in.'}, status=400)
auth = None # public endpoint (was AllowAny); checks auth state manually logout(request)
return JsonResponse({'detail': 'Successfully logged out.'})
@modify(status_code=HTTPStatus.OK)
def post(self) -> DetailResponse:
if not self.request.user.is_authenticated:
raise _api_error(
status_code=HTTPStatus.BAD_REQUEST,
body={'detail': "You're not logged in."},
)
logout(self.request)
return DetailResponse(detail='Successfully logged out.')
_SAVED_POINT_DUPLICATE = 'A saved point with this name already exists for the user.' class SavedPointViewset(ModelViewSet):
permission_classes = [IsOwner]
serializer_class = SavedPointSerializer
pagination_class = None
def get_queryset(self):
return SavedPoint.objects.filter(user=self.request.user)
def perform_create(self, serializer):
serializer.save(user=self.request.user)
def _check_saved_point_unique(user, name, exclude_pk=None): class PreditctionTemplateViewset(ModelViewSet):
"""Reproduce the former SavedPointSerializer UniqueTogetherValidator.""" permission_classes = [IsOwner]
qs = SavedPoint.objects.filter(user=user, name=name) serializer_class = PreditctionTemplateSerializer
if exclude_pk is not None: pagination_class = None
qs = qs.exclude(pk=exclude_pk)
if qs.exists(): def get_queryset(self):
raise _api_error( return PreditctionTemplate.objects.filter(user=self.request.user)
status_code=HTTPStatus.BAD_REQUEST,
body={'non_field_errors': [_SAVED_POINT_DUPLICATE]}, def perform_create(self, serializer):
) serializer.save(user=self.request.user)
class SavedPointListController(Controller[PydanticSerializer]): class UserProfileView(APIView):
"""Collection endpoint for the current user's saved points (was SavedPointViewset).""" permission_classes = [IsAuthenticated]
def get(self) -> list[SavedPointOut]: def get(self, request):
qs = SavedPoint.objects.filter(user=self.request.user) serializer = UserSerializer(request.user)
return [SavedPointOut.model_validate(obj) for obj in qs] return Response(serializer.data)
def post(self, parsed_body: Body[SavedPointIn]) -> SavedPointOut: def patch(self, request):
user = self.request.user user = request.user
_check_saved_point_unique(user, parsed_body.name) serializer = UserSerializer(user, data=request.data, partial=True)
obj = SavedPoint.objects.create(user=user, **parsed_body.model_dump())
return SavedPointOut.model_validate(obj) if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
serializer.save()
return Response(serializer.data)
class SavedPointDetailController(Controller[PydanticSerializer]): class ChangePasswordView(APIView):
"""Detail endpoint for a single saved point owned by the current user.""" permission_classes = [IsAuthenticated]
def _get_object(self, pk: int): def post(self, request):
# Filtering by user means another user's object reads as 404 (not 403), user = request.user
# matching DRF's get_object() over a user-scoped queryset. serializer = ChangePasswordSerializer(data=request.data)
obj = SavedPoint.objects.filter(user=self.request.user, pk=pk).first()
if obj is None:
raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'})
return obj
def get(self, parsed_path: Path[PkPath]) -> SavedPointOut: if not serializer.is_valid():
return SavedPointOut.model_validate(self._get_object(parsed_path.pk)) return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
def put(self, parsed_path: Path[PkPath], parsed_body: Body[SavedPointIn]) -> SavedPointOut: if not user.check_password(serializer.validated_data['old_password']):
obj = self._get_object(parsed_path.pk) return Response({'detail': 'Old password is incorrect'},
_check_saved_point_unique(self.request.user, parsed_body.name, exclude_pk=obj.pk) status=status.HTTP_400_BAD_REQUEST)
for field, value in parsed_body.model_dump().items():
setattr(obj, field, value)
obj.save()
return SavedPointOut.model_validate(obj)
def patch(self, parsed_path: Path[PkPath], parsed_body: Body[SavedPointPatchIn]) -> SavedPointOut: user.set_password(serializer.validated_data['new_password'])
obj = self._get_object(parsed_path.pk)
updates = parsed_body.model_dump(exclude_unset=True)
if 'name' in updates:
_check_saved_point_unique(self.request.user, updates['name'], exclude_pk=obj.pk)
for field, value in updates.items():
setattr(obj, field, value)
obj.save()
return SavedPointOut.model_validate(obj)
@modify(status_code=HTTPStatus.NO_CONTENT)
def delete(self, parsed_path: Path[PkPath]) -> None:
self._get_object(parsed_path.pk).delete()
class PredictionTemplateListController(Controller[PydanticSerializer]):
"""Collection endpoint for the current user's templates (was PreditctionTemplateViewset).
NOTE: as before, there is no app-level uniqueness check; a duplicate
(user, name) hits the model's unique_together and surfaces as a DB error.
"""
def get(self) -> list[PredictionTemplateOut]:
qs = PreditctionTemplate.objects.filter(user=self.request.user)
return [PredictionTemplateOut.model_validate(obj) for obj in qs]
def post(self, parsed_body: Body[PredictionTemplateIn]) -> PredictionTemplateOut:
obj = PreditctionTemplate.objects.create(
user=self.request.user, **parsed_body.model_dump()
)
return PredictionTemplateOut.model_validate(obj)
class PredictionTemplateDetailController(Controller[PydanticSerializer]):
"""Detail endpoint for a single template owned by the current user."""
def _get_object(self, pk: int):
obj = PreditctionTemplate.objects.filter(user=self.request.user, pk=pk).first()
if obj is None:
raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'})
return obj
def get(self, parsed_path: Path[PkPath]) -> PredictionTemplateOut:
return PredictionTemplateOut.model_validate(self._get_object(parsed_path.pk))
def put(self, parsed_path: Path[PkPath], parsed_body: Body[PredictionTemplateIn]) -> PredictionTemplateOut:
obj = self._get_object(parsed_path.pk)
for field, value in parsed_body.model_dump().items():
setattr(obj, field, value)
obj.save()
return PredictionTemplateOut.model_validate(obj)
def patch(self, parsed_path: Path[PkPath], parsed_body: Body[PredictionTemplatePatchIn]) -> PredictionTemplateOut:
obj = self._get_object(parsed_path.pk)
for field, value in parsed_body.model_dump(exclude_unset=True).items():
setattr(obj, field, value)
obj.save()
return PredictionTemplateOut.model_validate(obj)
@modify(status_code=HTTPStatus.NO_CONTENT)
def delete(self, parsed_path: Path[PkPath]) -> None:
self._get_object(parsed_path.pk).delete()
class UserProfileController(Controller[PydanticSerializer]):
"""Read and partially update the current user's profile."""
def get(self) -> UserOut:
return UserOut.model_validate(self.request.user)
def patch(self, parsed_body: Body[UserUpdateIn]) -> UserOut:
user = self.request.user
# partial update: apply only the fields the client actually sent.
for field, value in parsed_body.model_dump(exclude_unset=True).items():
setattr(user, field, value)
user.save() user.save()
return UserOut.model_validate(user) return Response({'detail': 'Password changed successfully'})
class ChangePasswordController(Controller[PydanticSerializer]): class DeleteAccountView(APIView):
"""Change the current user's password.""" permission_classes = [IsAuthenticated]
# post() would default to 201; the original returned 200. def delete(self, request):
@modify(status_code=HTTPStatus.OK) user = request.user
def post(self, parsed_body: Body[ChangePasswordIn]) -> DetailResponse: serializer = DeleteAccountSerializer(data=request.data)
user = self.request.user
if not user.check_password(parsed_body.old_password): if not serializer.is_valid():
raise _api_error( return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
status_code=HTTPStatus.BAD_REQUEST,
body={'detail': 'Old password is incorrect'},
)
user.set_password(parsed_body.new_password) if not user.check_password(serializer.validated_data['password']):
user.save() return Response({'detail': 'Incorrect password'},
return DetailResponse(detail='Password changed successfully') status=status.HTTP_400_BAD_REQUEST)
class DeleteAccountController(Controller[PydanticSerializer]):
"""Delete the current user's account and all their data."""
# DMR forbids a Body component on DELETE, but the original endpoint read the
# password from a DELETE request body. Preserve that contract by parsing the
# body manually instead of via Body[DeleteAccountIn].
def delete(self) -> DetailResponse:
user = self.request.user
try:
parsed_body = DeleteAccountIn(**json.loads(self.request.body or b'{}'))
except ValidationError as exc:
raise _api_error(status_code=HTTPStatus.BAD_REQUEST, body=json.loads(exc.json()))
except ValueError:
raise _api_error(status_code=HTTPStatus.BAD_REQUEST, body={'detail': 'Invalid request body.'})
if not user.check_password(parsed_body.password):
raise _api_error(
status_code=HTTPStatus.BAD_REQUEST,
body={'detail': 'Incorrect password'},
)
Prediction.objects.filter(user=user).delete() Prediction.objects.filter(user=user).delete()
SavedPoint.objects.filter(user=user).delete() SavedPoint.objects.filter(user=user).delete()
@ -479,83 +307,35 @@ class DeleteAccountController(Controller[PydanticSerializer]):
user.delete() user.delete()
return DetailResponse(detail='Account deleted successfully') return Response({'detail': 'Account deleted successfully'})
class DeleteUserDataController(Controller[PydanticSerializer]): class DeleteUserDataView(APIView):
"""Delete all of the current user's data without deleting the account.""" permission_classes = [IsAuthenticated]
def delete(self) -> DetailResponse: def delete(self, request):
user = self.request.user user = request.user
Prediction.objects.filter(user=user).delete() Prediction.objects.filter(user=user).delete()
SavedPoint.objects.filter(user=user).delete() SavedPoint.objects.filter(user=user).delete()
PreditctionTemplate.objects.filter(user=user).delete() PreditctionTemplate.objects.filter(user=user).delete()
return DetailResponse(detail='All user data deleted successfully') return Response({'detail': 'All user data deleted successfully'})
class ObtainTokenController( class TokenManagementView(APIView):
ObtainTokensSyncController[ permission_classes = [IsAuthenticated]
PydanticSerializer,
ObtainTokensPayload,
ObtainTokensResponse,
],
):
"""Exchange username/password for JWT access + refresh tokens.
Replaces DRF's obtain_auth_token. Approved drift: the token format and def get(self, request):
semantics change from a single stored DRF token to stateless JWTs, so the
response is {access_token, refresh_token} instead of {token}.
"""
auth = None # public: credentials are supplied in the request body token, created = Token.objects.get_or_create(user=request.user)
csrf_exempt = True return Response({"token": token.key})
jwt_expiration = timedelta(hours=1)
jwt_refresh_expiration = timedelta(days=7)
def convert_auth_payload(self, payload): def post(self, request):
return payload
def make_api_response(self): Token.objects.filter(user=request.user).delete()
now = datetime.now(timezone.utc) token = Token.objects.create(user=request.user)
return { return Response({"token": token.key})
'access_token': self.create_jwt_token(
expiration=now + self.jwt_expiration,
token_type='access',
),
'refresh_token': self.create_jwt_token(
expiration=now + self.jwt_refresh_expiration,
token_type='refresh',
),
}
class TokenManagementController(Controller[PydanticSerializer]):
"""Issue a fresh JWT access token for the current user.
Was TokenManagementView (DRF stored-token get/regenerate). With stateless
JWTs there is nothing stored to fetch or delete, so both GET and POST mint a
new access token (approved drift). The {"token": ...} response shape is kept.
"""
jwt_expiration = timedelta(hours=1)
def get(self) -> TokenResponse:
return TokenResponse(token=self._issue_token())
# post() would default to 201; the original returned 200.
@modify(status_code=HTTPStatus.OK)
def post(self) -> TokenResponse:
return TokenResponse(token=self._issue_token())
def _issue_token(self) -> str:
now = datetime.now(timezone.utc)
return self.create_jwt_token(
subject=str(self.request.user.pk),
expiration=now + self.jwt_expiration,
token_type='access',
)