Compare commits
No commits in common. "modern-rest" and "master" have entirely different histories.
modern-res
...
master
11 changed files with 572 additions and 1023 deletions
|
|
@ -1,7 +1,8 @@
|
||||||
# django-modern-rest 0.6.0 requires Django 5.x (upgraded from 4.2).
|
Django>=4.0,<5.0
|
||||||
Django>=5.0,<6.0
|
djangorestframework
|
||||||
django-modern-rest[jwt,pydantic]==0.6.0
|
djangorestframework-simplejwt
|
||||||
psycopg2-binary
|
psycopg2-binary
|
||||||
|
drf-spectacular
|
||||||
requests
|
requests
|
||||||
django-cors-headers
|
django-cors-headers
|
||||||
Pillow
|
Pillow
|
||||||
|
|
|
||||||
|
|
@ -13,12 +13,6 @@ https://docs.djangoproject.com/en/4.2/ref/settings/
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from dmr.settings import Settings
|
|
||||||
from dmr.security.django_session import DjangoSessionSyncAuth
|
|
||||||
from dmr.security.jwt import JWTSyncAuth
|
|
||||||
from dmr.openapi import OpenAPIConfig
|
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
# Build paths inside the project like this: BASE_DIR / 'subdir'.
|
# Build paths inside the project like this: BASE_DIR / 'subdir'.
|
||||||
|
|
@ -57,10 +51,12 @@ INSTALLED_APPS = [
|
||||||
'django.contrib.sessions',
|
'django.contrib.sessions',
|
||||||
'django.contrib.messages',
|
'django.contrib.messages',
|
||||||
'django.contrib.staticfiles',
|
'django.contrib.staticfiles',
|
||||||
|
'rest_framework',
|
||||||
|
'rest_framework.authtoken',
|
||||||
|
'drf_spectacular',
|
||||||
'corsheaders',
|
'corsheaders',
|
||||||
'stratoflights_api.apps.StratoflightsApiConfig',
|
'stratoflights_api.apps.StratoflightsApiConfig',
|
||||||
'channels',
|
'channels',
|
||||||
'dmr', # required to serve OpenAPI docs static assets
|
|
||||||
]
|
]
|
||||||
|
|
||||||
MIDDLEWARE = [
|
MIDDLEWARE = [
|
||||||
|
|
@ -169,16 +165,25 @@ STATIC_URL = 'static/'
|
||||||
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
|
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
|
||||||
AUTH_USER_MODEL = 'stratoflights_api.User'
|
AUTH_USER_MODEL = 'stratoflights_api.User'
|
||||||
|
|
||||||
# django-modern-rest configuration.
|
REST_FRAMEWORK = {
|
||||||
# Global auth mirrors the previous DRF default (Token + Session); the DRF
|
'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',
|
||||||
# TokenAuthentication is replaced by DMR's JWT (approved drift). Providing
|
'PAGE_SIZE': 100,
|
||||||
# several auth instances means at least one of them must succeed, so this
|
|
||||||
# also enforces a 401 by default like DRF's global IsAuthenticated did.
|
'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema',
|
||||||
# Public endpoints opt out per-endpoint with @modify(auth=None).
|
|
||||||
DMR_SETTINGS = {
|
'DEFAULT_AUTHENTICATION_CLASSES': [
|
||||||
Settings.auth: [JWTSyncAuth(), DjangoSessionSyncAuth()],
|
'rest_framework.authentication.TokenAuthentication',
|
||||||
Settings.validate_responses: not PRODUCTION,
|
'rest_framework.authentication.SessionAuthentication',
|
||||||
Settings.openapi_config: OpenAPIConfig(title='Stratoflights API', version='1.0.0'),
|
],
|
||||||
|
|
||||||
|
'DEFAULT_PERMISSION_CLASSES': [
|
||||||
|
'rest_framework.permissions.IsAuthenticated',
|
||||||
|
# 'rest_framework.permissions.AllowAny',
|
||||||
|
],
|
||||||
|
|
||||||
|
'DEFAULT_RENDERER_CLASSES': [
|
||||||
|
'rest_framework.renderers.JSONRenderer',
|
||||||
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,27 @@
|
||||||
"""URL configuration for the stratoflights project."""
|
"""
|
||||||
|
URL configuration for stratoflights project.
|
||||||
|
|
||||||
|
The `urlpatterns` list routes URLs to views. For more information please see:
|
||||||
|
https://docs.djangoproject.com/en/4.2/topics/http/urls/
|
||||||
|
Examples:
|
||||||
|
Function views
|
||||||
|
1. Add an import: from my_app import views
|
||||||
|
2. Add a URL to urlpatterns: path('', views.home, name='home')
|
||||||
|
Class-based views
|
||||||
|
1. Add an import: from other_app.views import Home
|
||||||
|
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
|
||||||
|
Including another URLconf
|
||||||
|
1. Import the include() function: from django.urls import include, path
|
||||||
|
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
|
||||||
|
"""
|
||||||
from django.contrib import admin
|
from django.contrib import admin
|
||||||
from django.urls import include
|
from django.urls import path, include
|
||||||
from dmr.openapi import build_schema
|
from drf_spectacular.views import SpectacularAPIView
|
||||||
from dmr.openapi.views import OpenAPIJsonView, SwaggerView
|
from drf_spectacular.views import SpectacularSwaggerView
|
||||||
from dmr.routing import path
|
|
||||||
|
|
||||||
from stratoflights_api.urls import router
|
|
||||||
|
|
||||||
schema = build_schema(router)
|
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
path('admin/', admin.site.urls),
|
path('admin/', admin.site.urls),
|
||||||
path(router.prefix, include(router.urls)),
|
path('api/', include('stratoflights_api.urls')),
|
||||||
path('api/schema/', OpenAPIJsonView.as_view(schema), name='schema'),
|
path('api/schema/', SpectacularAPIView.as_view(), name='schema'),
|
||||||
path('api/docs/', SwaggerView.as_view(schema), name='docs'),
|
path('api/docs/', SpectacularSwaggerView.as_view(url_name='schema'), name='docs'),
|
||||||
]
|
]
|
||||||
|
|
@ -18,19 +18,17 @@ class TelemetryConsumer(AsyncWebsocketConsumer):
|
||||||
|
|
||||||
async def receive(self, text_data):
|
async def receive(self, text_data):
|
||||||
|
|
||||||
from pydantic import ValidationError
|
from .serializers import TelemetryPacketSerializer
|
||||||
from .dtos import TelemetryIn
|
|
||||||
|
|
||||||
if not self.write_enabled:
|
if not self.write_enabled:
|
||||||
await self.send(text_data=json.dumps({"error": "Read-only mode"}))
|
await self.send(text_data=json.dumps({"error": "Read-only mode"}))
|
||||||
return
|
return
|
||||||
|
|
||||||
data = json.loads(text_data)
|
data = json.loads(text_data)
|
||||||
|
serializer = TelemetryPacketSerializer(data=data)
|
||||||
|
|
||||||
try:
|
if not serializer.is_valid():
|
||||||
TelemetryIn(**data)
|
await self.send(text_data=json.dumps({"error": serializer.errors}))
|
||||||
except ValidationError as exc:
|
|
||||||
await self.send(text_data=json.dumps({"error": json.loads(exc.json())}))
|
|
||||||
return
|
return
|
||||||
|
|
||||||
saved_data = await self.save_telemetry(data)
|
saved_data = await self.save_telemetry(data)
|
||||||
|
|
@ -53,6 +51,15 @@ class TelemetryConsumer(AsyncWebsocketConsumer):
|
||||||
|
|
||||||
async def telemetry_message(self, event):
|
async def telemetry_message(self, event):
|
||||||
await self.send(text_data=json.dumps(event["data"]))
|
await self.send(text_data=json.dumps(event["data"]))
|
||||||
|
@database_sync_to_async
|
||||||
|
def get_user_from_token(self, token_key, Token):
|
||||||
|
from rest_framework.authtoken.models import Token
|
||||||
|
User = get_user_model()
|
||||||
|
try:
|
||||||
|
token = Token.objects.select_related("user").get(key=token_key)
|
||||||
|
return token.user
|
||||||
|
except Token.DoesNotExist:
|
||||||
|
return None
|
||||||
|
|
||||||
@database_sync_to_async
|
@database_sync_to_async
|
||||||
def save_telemetry(self, data):
|
def save_telemetry(self, data):
|
||||||
|
|
@ -92,20 +99,12 @@ class StationTelemetryConsumer(TelemetryConsumer):
|
||||||
write_enabled = True
|
write_enabled = True
|
||||||
|
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
from django.conf import settings
|
from rest_framework.authtoken.models import Token
|
||||||
from dmr.security.jwt.token import JWToken
|
|
||||||
|
|
||||||
token_key = self.scope["query_string"].decode().split("token=")[-1]
|
token_key = self.scope["query_string"].decode().split("token=")[-1]
|
||||||
try:
|
try:
|
||||||
decoded = JWToken.decode(
|
token = await database_sync_to_async(Token.objects.select_related("user").get)(key=token_key)
|
||||||
encoded_token=token_key,
|
self.scope["user"] = token.user
|
||||||
secret=settings.SECRET_KEY,
|
except Token.DoesNotExist:
|
||||||
algorithm='HS256',
|
|
||||||
)
|
|
||||||
self.scope["user"] = await database_sync_to_async(
|
|
||||||
get_user_model().objects.get
|
|
||||||
)(pk=decoded.sub)
|
|
||||||
except Exception:
|
|
||||||
await self.close()
|
await self.close()
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
|
||||||
17
stratoflights_api/custom_pagination.py
Normal file
17
stratoflights_api/custom_pagination.py
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
from rest_framework.pagination import LimitOffsetPagination
|
||||||
|
from rest_framework.response import Response
|
||||||
|
|
||||||
|
class CustomLimitOffsetPagination(LimitOffsetPagination):
|
||||||
|
limit_query_param = 'limit'
|
||||||
|
offset_query_param = 'skip'
|
||||||
|
max_limit = 100
|
||||||
|
default_limit = 10
|
||||||
|
|
||||||
|
|
||||||
|
def get_paginated_response(self, data):
|
||||||
|
return Response({
|
||||||
|
'total': self.count,
|
||||||
|
'limit': self.limit,
|
||||||
|
'skip': self.offset,
|
||||||
|
'predictions': data
|
||||||
|
})
|
||||||
|
|
@ -1,333 +0,0 @@
|
||||||
"""Pydantic DTOs replacing the DRF serializers from serializers.py.
|
|
||||||
|
|
||||||
Mapping is 1:1 with the former serializers. Output DTOs enable ``from_attributes``
|
|
||||||
so an endpoint can build them straight from a Django model instance via
|
|
||||||
``SomeOut.model_validate(obj)`` (the equivalent of ``Serializer(obj).data``).
|
|
||||||
|
|
||||||
Field-level and cross-field validation that DRF kept in ``validate_<field>()`` /
|
|
||||||
``validate()`` is reproduced with pydantic ``field_validator`` / ``model_validator``.
|
|
||||||
Business logic DRF kept in ``create()`` (base64 curve decoding, ORM writes, and
|
|
||||||
resolving related objects from their PKs) is intentionally NOT here -- per DMR's
|
|
||||||
split it moves into the endpoints (Step 2.9).
|
|
||||||
"""
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Optional
|
|
||||||
|
|
||||||
from django.contrib.auth.password_validation import validate_password
|
|
||||||
from django.core.validators import validate_email
|
|
||||||
from django.core.exceptions import ValidationError as DjangoValidationError
|
|
||||||
from pydantic import (
|
|
||||||
BaseModel,
|
|
||||||
ConfigDict,
|
|
||||||
Field,
|
|
||||||
field_validator,
|
|
||||||
model_validator,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .validators import validate_custom_curve, rate_clip
|
|
||||||
|
|
||||||
|
|
||||||
# Profile constants (moved verbatim from serializers.py).
|
|
||||||
PROFILE_STANDARD = "standard_profile"
|
|
||||||
PROFILE_FLOAT = "float_profile"
|
|
||||||
PROFILE_REVERSE = "reverse_profile"
|
|
||||||
PROFILE_CUSTOM = "custom_profile"
|
|
||||||
LATEST_DATASET_KEYWORD = "latest"
|
|
||||||
SUPPORTED_PROFILES = [PROFILE_STANDARD, PROFILE_FLOAT, PROFILE_REVERSE, PROFILE_CUSTOM]
|
|
||||||
|
|
||||||
|
|
||||||
# --- Prediction request (was PredictionRequestSerializer) ---------------------
|
|
||||||
class PredictionRequest(BaseModel):
|
|
||||||
launch_latitude: float = Field(ge=-90, le=90)
|
|
||||||
launch_longitude: float = Field(ge=0, le=360)
|
|
||||||
launch_datetime: datetime
|
|
||||||
launch_altitude: Optional[float] = None
|
|
||||||
format: str = "json"
|
|
||||||
profile: str = PROFILE_STANDARD
|
|
||||||
dataset: str = LATEST_DATASET_KEYWORD
|
|
||||||
|
|
||||||
# profile-dependent fields
|
|
||||||
ascent_rate: Optional[float] = Field(default=None, ge=0.01)
|
|
||||||
descent_rate: Optional[float] = Field(default=None, ge=0.01)
|
|
||||||
burst_altitude: Optional[float] = None
|
|
||||||
float_altitude: Optional[float] = None
|
|
||||||
stop_datetime: Optional[datetime] = None
|
|
||||||
ascent_curve: Optional[str] = None
|
|
||||||
descent_curve: Optional[str] = None
|
|
||||||
interpolate: bool = False
|
|
||||||
# Related objects are accepted as PKs; existence is resolved in the endpoint
|
|
||||||
# (was PrimaryKeyRelatedField(queryset=...)).
|
|
||||||
start_point: Optional[int] = None
|
|
||||||
rate_profile: Optional[int] = None
|
|
||||||
template: Optional[int] = None
|
|
||||||
|
|
||||||
@field_validator("profile")
|
|
||||||
@classmethod
|
|
||||||
def _validate_profile(cls, value: str) -> str:
|
|
||||||
if value not in SUPPORTED_PROFILES:
|
|
||||||
raise ValueError(f'"{value}" is not a valid choice.')
|
|
||||||
return value
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def _validate_cross_fields(self) -> "PredictionRequest":
|
|
||||||
launch_alt = self.launch_altitude if self.launch_altitude is not None else 0
|
|
||||||
|
|
||||||
if self.profile == PROFILE_STANDARD:
|
|
||||||
if self.burst_altitude is None:
|
|
||||||
raise ValueError("burst_altitude is required for standard profile.")
|
|
||||||
if self.burst_altitude <= launch_alt:
|
|
||||||
raise ValueError("burst_altitude must be greater than launch_altitude.")
|
|
||||||
|
|
||||||
elif self.profile == PROFILE_FLOAT:
|
|
||||||
if self.float_altitude is None or self.float_altitude <= launch_alt:
|
|
||||||
raise ValueError("float_altitude must be greater than launch_altitude.")
|
|
||||||
if self.stop_datetime is None or self.stop_datetime <= self.launch_datetime:
|
|
||||||
raise ValueError("stop_datetime must be later than launch_datetime.")
|
|
||||||
|
|
||||||
elif self.profile == PROFILE_CUSTOM:
|
|
||||||
if self.ascent_curve is None or not validate_custom_curve(self.ascent_curve):
|
|
||||||
raise ValueError("Invalid ascent_curve.")
|
|
||||||
if self.descent_curve is None or not validate_custom_curve(self.descent_curve):
|
|
||||||
raise ValueError("Invalid descent_curve.")
|
|
||||||
if self.burst_altitude is None or self.burst_altitude <= launch_alt:
|
|
||||||
raise ValueError("burst_altitude must be greater than launch_altitude.")
|
|
||||||
|
|
||||||
# custom clipping logic (was in validate())
|
|
||||||
if self.ascent_rate is not None:
|
|
||||||
self.ascent_rate = rate_clip(self.ascent_rate)
|
|
||||||
if self.descent_rate is not None:
|
|
||||||
self.descent_rate = rate_clip(self.descent_rate)
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
# --- Prediction outputs -------------------------------------------------------
|
|
||||||
class PredictionOut(BaseModel):
|
|
||||||
"""Was PredictionSerializer."""
|
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: uuid.UUID
|
|
||||||
created_at: datetime
|
|
||||||
updated_at: datetime
|
|
||||||
result: Any
|
|
||||||
|
|
||||||
|
|
||||||
class PredictionListOut(BaseModel):
|
|
||||||
"""Was PredictionListSerializer."""
|
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: uuid.UUID
|
|
||||||
created_at: datetime
|
|
||||||
updated_at: datetime
|
|
||||||
start_point: Optional[int] = Field(default=None, validation_alias="start_point_id")
|
|
||||||
template: Optional[int] = Field(default=None, validation_alias="template_id")
|
|
||||||
rate_profile: Optional[int] = Field(default=None, validation_alias="rate_profile_id")
|
|
||||||
|
|
||||||
|
|
||||||
class PredictionDetailOut(BaseModel):
|
|
||||||
"""Was PredictionDetailSerializer."""
|
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: uuid.UUID
|
|
||||||
created_at: datetime
|
|
||||||
updated_at: datetime
|
|
||||||
result: Any
|
|
||||||
start_point: Optional[int] = Field(default=None, validation_alias="start_point_id")
|
|
||||||
template: Optional[int] = Field(default=None, validation_alias="template_id")
|
|
||||||
rate_profile: Optional[int] = Field(default=None, validation_alias="rate_profile_id")
|
|
||||||
|
|
||||||
|
|
||||||
class PredictionCreateOut(BaseModel):
|
|
||||||
"""Custom create() response shape: {id, created_at, result} (not the full model)."""
|
|
||||||
|
|
||||||
id: uuid.UUID
|
|
||||||
created_at: datetime
|
|
||||||
result: Any
|
|
||||||
|
|
||||||
|
|
||||||
# --- Telemetry (was TelemetryPacketSerializer) --------------------------------
|
|
||||||
class TelemetryIn(BaseModel):
|
|
||||||
# timestamp is required (model BigIntegerField has no default), matching the
|
|
||||||
# former ModelSerializer; the endpoint still overwrites it with server time.
|
|
||||||
timestamp: int
|
|
||||||
lat: float = 0.0
|
|
||||||
lon: float = 0.0
|
|
||||||
alt: float = 0.0
|
|
||||||
payload: Any = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class TelemetryOut(BaseModel):
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: uuid.UUID
|
|
||||||
timestamp: int
|
|
||||||
lat: float
|
|
||||||
lon: float
|
|
||||||
alt: float
|
|
||||||
payload: Any
|
|
||||||
|
|
||||||
|
|
||||||
# --- SavedPoint (was SavedPointSerializer) ------------------------------------
|
|
||||||
# `user` was a HiddenField(CurrentUserDefault) -> set from request in the endpoint,
|
|
||||||
# not part of the wire input/output. The unique_together (user, name) check that
|
|
||||||
# DRF did via UniqueTogetherValidator is reproduced in the endpoint (Step 2.8).
|
|
||||||
class SavedPointIn(BaseModel):
|
|
||||||
name: str
|
|
||||||
lat: float = 0.0
|
|
||||||
lon: float = 0.0
|
|
||||||
alt: float = 0.0
|
|
||||||
|
|
||||||
|
|
||||||
class SavedPointPatchIn(BaseModel):
|
|
||||||
# All-optional variant for PATCH (partial_update). Only set fields apply.
|
|
||||||
name: Optional[str] = None
|
|
||||||
lat: Optional[float] = None
|
|
||||||
lon: Optional[float] = None
|
|
||||||
alt: Optional[float] = None
|
|
||||||
|
|
||||||
|
|
||||||
class SavedPointOut(BaseModel):
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: int
|
|
||||||
name: str
|
|
||||||
lat: float
|
|
||||||
lon: float
|
|
||||||
alt: float
|
|
||||||
|
|
||||||
|
|
||||||
# --- SavedRateProfile (was SavedRateProfileSerializer) ------------------------
|
|
||||||
# NOTE: no view referenced this serializer in routing; kept for parity.
|
|
||||||
class SavedRateProfileIn(BaseModel):
|
|
||||||
name: str
|
|
||||||
type: str = "ascent"
|
|
||||||
rate_profile_data: Any = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class SavedRateProfileOut(BaseModel):
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: int
|
|
||||||
name: str
|
|
||||||
type: str
|
|
||||||
rate_profile_data: Any
|
|
||||||
|
|
||||||
|
|
||||||
# --- PredictionTemplate (was PreditctionTemplateSerializer) -------------------
|
|
||||||
class PredictionTemplateIn(BaseModel):
|
|
||||||
# protected_namespaces=() avoids pydantic's warning about the `model` field.
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
|
||||||
|
|
||||||
name: str
|
|
||||||
is_default: bool = False
|
|
||||||
description: Optional[str] = None
|
|
||||||
prediction_mode: str = ""
|
|
||||||
model: str = ""
|
|
||||||
dataset: str = ""
|
|
||||||
flight_parameters: Any = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class PredictionTemplatePatchIn(BaseModel):
|
|
||||||
# All-optional variant for PATCH (partial_update). Only set fields apply.
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
|
||||||
|
|
||||||
name: Optional[str] = None
|
|
||||||
is_default: Optional[bool] = None
|
|
||||||
description: Optional[str] = None
|
|
||||||
prediction_mode: Optional[str] = None
|
|
||||||
model: Optional[str] = None
|
|
||||||
dataset: Optional[str] = None
|
|
||||||
flight_parameters: Optional[Any] = None
|
|
||||||
|
|
||||||
|
|
||||||
class PredictionTemplateOut(BaseModel):
|
|
||||||
model_config = ConfigDict(from_attributes=True, protected_namespaces=())
|
|
||||||
|
|
||||||
id: int
|
|
||||||
name: str
|
|
||||||
is_default: bool
|
|
||||||
description: Optional[str] = None
|
|
||||||
prediction_mode: str
|
|
||||||
model: str
|
|
||||||
dataset: str
|
|
||||||
flight_parameters: Any
|
|
||||||
|
|
||||||
|
|
||||||
# --- User (was UserSerializer) ------------------------------------------------
|
|
||||||
class UserOut(BaseModel):
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
username: str
|
|
||||||
email: str
|
|
||||||
first_name: str
|
|
||||||
last_name: str
|
|
||||||
|
|
||||||
|
|
||||||
class UserUpdateIn(BaseModel):
|
|
||||||
# `username` was read_only -> excluded from input. PATCH is partial, so all
|
|
||||||
# fields are optional; the endpoint applies only fields that were set.
|
|
||||||
email: Optional[str] = None
|
|
||||||
first_name: Optional[str] = None
|
|
||||||
last_name: Optional[str] = None
|
|
||||||
|
|
||||||
@field_validator("email")
|
|
||||||
@classmethod
|
|
||||||
def _validate_email(cls, value: Optional[str]) -> Optional[str]:
|
|
||||||
if value is None:
|
|
||||||
return value
|
|
||||||
try:
|
|
||||||
validate_email(value)
|
|
||||||
except DjangoValidationError:
|
|
||||||
raise ValueError("Invalid email format")
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
# --- Password / account (were ChangePasswordSerializer / DeleteAccountSerializer)
|
|
||||||
class ChangePasswordIn(BaseModel):
|
|
||||||
old_password: str
|
|
||||||
new_password: str
|
|
||||||
|
|
||||||
@field_validator("new_password")
|
|
||||||
@classmethod
|
|
||||||
def _validate_new_password(cls, value: str) -> str:
|
|
||||||
# validate_password raises Django's ValidationError (not a ValueError),
|
|
||||||
# which pydantic would not convert into a 4xx -- re-raise as ValueError.
|
|
||||||
try:
|
|
||||||
validate_password(value)
|
|
||||||
except DjangoValidationError as exc:
|
|
||||||
raise ValueError("; ".join(exc.messages))
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class DeleteAccountIn(BaseModel):
|
|
||||||
password: str
|
|
||||||
|
|
||||||
|
|
||||||
# --- Auth / session response shapes (were inline JsonResponse dicts) ----------
|
|
||||||
class DetailResponse(BaseModel):
|
|
||||||
detail: str
|
|
||||||
|
|
||||||
|
|
||||||
class SessionResponse(BaseModel):
|
|
||||||
isAuthenticated: bool
|
|
||||||
|
|
||||||
|
|
||||||
class WhoAmIResponse(BaseModel):
|
|
||||||
username: str
|
|
||||||
|
|
||||||
|
|
||||||
class TokenResponse(BaseModel):
|
|
||||||
token: str
|
|
||||||
|
|
||||||
|
|
||||||
# --- Shared path parameters ---------------------------------------------------
|
|
||||||
class PkPath(BaseModel):
|
|
||||||
pk: int
|
|
||||||
|
|
||||||
|
|
||||||
class UuidPkPath(BaseModel):
|
|
||||||
pk: uuid.UUID
|
|
||||||
|
|
@ -1,120 +0,0 @@
|
||||||
"""Limit/offset pagination for the prediction list endpoints.
|
|
||||||
|
|
||||||
DMR has no built-in limit/offset paginator, so (as the docs suggest) we keep our
|
|
||||||
own envelope model -- matching the former ``CustomLimitOffsetPagination`` output
|
|
||||||
exactly -- plus a helper that slices the queryset. The query params and envelope
|
|
||||||
keys are preserved verbatim: ``limit`` / ``skip`` in, and
|
|
||||||
``{total, limit, skip, predictions}`` out.
|
|
||||||
"""
|
|
||||||
from http import HTTPStatus
|
|
||||||
from typing import Optional
|
|
||||||
from urllib.parse import urlsplit, urlunsplit
|
|
||||||
|
|
||||||
from django.core.paginator import InvalidPage, Paginator
|
|
||||||
from django.http import QueryDict
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from dmr.response import APIError
|
|
||||||
|
|
||||||
from .dtos import PredictionOut, TelemetryOut
|
|
||||||
|
|
||||||
DEFAULT_LIMIT = 10 # was CustomLimitOffsetPagination.default_limit
|
|
||||||
MAX_LIMIT = 100 # was CustomLimitOffsetPagination.max_limit
|
|
||||||
PAGE_SIZE = 100 # global REST_FRAMEWORK PAGE_SIZE (PageNumberPagination)
|
|
||||||
|
|
||||||
|
|
||||||
class PaginationQuery(BaseModel):
|
|
||||||
# Param names preserved: limit_query_param='limit', offset_query_param='skip'.
|
|
||||||
limit: int = DEFAULT_LIMIT
|
|
||||||
skip: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class PredictionListUserQuery(BaseModel):
|
|
||||||
"""Query params for PredictionViewSet.list_user: filters + limit/offset."""
|
|
||||||
|
|
||||||
satellite_id: Optional[str] = None
|
|
||||||
created_from: Optional[str] = None
|
|
||||||
created_till: Optional[str] = None
|
|
||||||
limit: int = DEFAULT_LIMIT
|
|
||||||
skip: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class PredictionPage(BaseModel):
|
|
||||||
total: int
|
|
||||||
limit: int
|
|
||||||
skip: int
|
|
||||||
predictions: list[PredictionOut]
|
|
||||||
|
|
||||||
|
|
||||||
class TelemetryPage(BaseModel):
|
|
||||||
count: int
|
|
||||||
next: Optional[str] = None
|
|
||||||
previous: Optional[str] = None
|
|
||||||
results: list[TelemetryOut]
|
|
||||||
|
|
||||||
|
|
||||||
def _replace_query_param(url: str, key: str, value) -> str:
|
|
||||||
scheme, netloc, path, query, fragment = urlsplit(url)
|
|
||||||
query_dict = QueryDict(query, mutable=True)
|
|
||||||
query_dict[key] = value
|
|
||||||
return urlunsplit((scheme, netloc, path, query_dict.urlencode(), fragment))
|
|
||||||
|
|
||||||
|
|
||||||
def _remove_query_param(url: str, key: str) -> str:
|
|
||||||
scheme, netloc, path, query, fragment = urlsplit(url)
|
|
||||||
query_dict = QueryDict(query, mutable=True)
|
|
||||||
query_dict.pop(key, None)
|
|
||||||
return urlunsplit((scheme, netloc, path, query_dict.urlencode(), fragment))
|
|
||||||
|
|
||||||
|
|
||||||
def page_number_paginate(request, queryset, page_size: int = PAGE_SIZE):
|
|
||||||
"""Reproduce DRF PageNumberPagination.
|
|
||||||
|
|
||||||
Returns ``(count, next_link, previous_link, object_list)``. Raises a 404
|
|
||||||
"Invalid page." like DRF on an out-of-range / non-integer page. Honours
|
|
||||||
``page=last`` and builds absolute next/previous links off the current URL.
|
|
||||||
"""
|
|
||||||
paginator = Paginator(queryset, page_size)
|
|
||||||
page_number = request.GET.get('page', 1)
|
|
||||||
if page_number in ('last',):
|
|
||||||
page_number = paginator.num_pages
|
|
||||||
try:
|
|
||||||
page = paginator.page(page_number)
|
|
||||||
except InvalidPage:
|
|
||||||
raise APIError({'detail': 'Invalid page.'}, status_code=HTTPStatus.NOT_FOUND)
|
|
||||||
|
|
||||||
url = request.build_absolute_uri()
|
|
||||||
|
|
||||||
next_link = None
|
|
||||||
if page.has_next():
|
|
||||||
next_link = _replace_query_param(url, 'page', page.next_page_number())
|
|
||||||
|
|
||||||
previous_link = None
|
|
||||||
if page.has_previous():
|
|
||||||
previous_number = page.previous_page_number()
|
|
||||||
if previous_number == 1:
|
|
||||||
previous_link = _remove_query_param(url, 'page')
|
|
||||||
else:
|
|
||||||
previous_link = _replace_query_param(url, 'page', previous_number)
|
|
||||||
|
|
||||||
return paginator.count, next_link, previous_link, list(page.object_list)
|
|
||||||
|
|
||||||
|
|
||||||
def paginate_predictions(queryset, query: PaginationQuery) -> PredictionPage:
|
|
||||||
"""Slice ``queryset`` and build the prediction page envelope.
|
|
||||||
|
|
||||||
Mirrors the bounds DRF's LimitOffsetPagination enforced: limit falls back to
|
|
||||||
the default when non-positive and is capped at MAX_LIMIT; skip floors at 0.
|
|
||||||
"""
|
|
||||||
limit = query.limit if query.limit > 0 else DEFAULT_LIMIT
|
|
||||||
limit = min(limit, MAX_LIMIT)
|
|
||||||
skip = query.skip if query.skip >= 0 else 0
|
|
||||||
|
|
||||||
total = queryset.count()
|
|
||||||
page = list(queryset[skip:skip + limit])
|
|
||||||
return PredictionPage(
|
|
||||||
total=total,
|
|
||||||
limit=limit,
|
|
||||||
skip=skip,
|
|
||||||
predictions=[PredictionOut.model_validate(obj) for obj in page],
|
|
||||||
)
|
|
||||||
16
stratoflights_api/permissions.py
Normal file
16
stratoflights_api/permissions.py
Normal file
|
|
@ -0,0 +1,16 @@
|
||||||
|
from rest_framework.permissions import BasePermission, SAFE_METHODS
|
||||||
|
|
||||||
|
class ReadOnlyOrAuthenticated(BasePermission):
|
||||||
|
def has_permission(self, request, view):
|
||||||
|
return (
|
||||||
|
request.method in SAFE_METHODS or
|
||||||
|
request.user and request.user.is_authenticated
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IsOwner(BasePermission):
|
||||||
|
def has_object_permission(self, request, view, obj):
|
||||||
|
return obj.user == request.user
|
||||||
|
|
||||||
|
def has_permission(self, request, view):
|
||||||
|
return request.user and request.user.is_authenticated
|
||||||
185
stratoflights_api/serializers.py
Normal file
185
stratoflights_api/serializers.py
Normal file
|
|
@ -0,0 +1,185 @@
|
||||||
|
from rest_framework import serializers
|
||||||
|
from .models import Prediction, SavedPoint, SavedRateProfile, PreditctionTemplate
|
||||||
|
from datetime import datetime
|
||||||
|
from django.contrib.auth.password_validation import validate_password
|
||||||
|
from django.core.validators import validate_email
|
||||||
|
from django.core.exceptions import ValidationError as DjangoValidationError
|
||||||
|
from django.contrib.auth import get_user_model
|
||||||
|
from .validators import (
|
||||||
|
validate_custom_curve, rate_clip,
|
||||||
|
_rfc3339_to_timestamp, base64_to_curve
|
||||||
|
)
|
||||||
|
|
||||||
|
class PredictionSerializer(serializers.ModelSerializer):
|
||||||
|
class Meta:
|
||||||
|
model = Prediction
|
||||||
|
fields = ['id', 'created_at', 'updated_at', 'result']
|
||||||
|
|
||||||
|
User = get_user_model()
|
||||||
|
|
||||||
|
PROFILE_STANDARD = "standard_profile"
|
||||||
|
PROFILE_FLOAT = "float_profile"
|
||||||
|
PROFILE_REVERSE = "reverse_profile"
|
||||||
|
PROFILE_CUSTOM = "custom_profile"
|
||||||
|
LATEST_DATASET_KEYWORD = "latest"
|
||||||
|
SUPPORTED_PROFILES = [PROFILE_STANDARD, PROFILE_FLOAT, PROFILE_REVERSE, PROFILE_CUSTOM]
|
||||||
|
|
||||||
|
|
||||||
|
class PredictionRequestSerializer(serializers.Serializer):
|
||||||
|
launch_latitude = serializers.FloatField(min_value=-90, max_value=90)
|
||||||
|
launch_longitude = serializers.FloatField(min_value=0, max_value=360)
|
||||||
|
launch_datetime = serializers.DateTimeField()
|
||||||
|
launch_altitude = serializers.FloatField(required=False)
|
||||||
|
format = serializers.CharField(default="json")
|
||||||
|
profile = serializers.ChoiceField(choices=SUPPORTED_PROFILES, default=PROFILE_STANDARD)
|
||||||
|
dataset = serializers.CharField(default=LATEST_DATASET_KEYWORD)
|
||||||
|
|
||||||
|
# --- профиль-dependent поля ---
|
||||||
|
ascent_rate = serializers.FloatField(required=False, min_value=0.01)
|
||||||
|
descent_rate = serializers.FloatField(required=False, min_value=0.01)
|
||||||
|
burst_altitude = serializers.FloatField(required=False)
|
||||||
|
float_altitude = serializers.FloatField(required=False)
|
||||||
|
stop_datetime = serializers.DateTimeField(required=False)
|
||||||
|
ascent_curve = serializers.CharField(required=False)
|
||||||
|
descent_curve = serializers.CharField(required=False)
|
||||||
|
interpolate = serializers.BooleanField(required=False, default=False)
|
||||||
|
start_point = serializers.PrimaryKeyRelatedField(
|
||||||
|
queryset=SavedPoint.objects.all(), required=False, allow_null=True
|
||||||
|
)
|
||||||
|
rate_profile = serializers.PrimaryKeyRelatedField(
|
||||||
|
queryset=SavedRateProfile.objects.all(), required=False, allow_null=True
|
||||||
|
)
|
||||||
|
template = serializers.PrimaryKeyRelatedField(
|
||||||
|
queryset=PreditctionTemplate.objects.all(), required=False, allow_null=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def validate(self, data):
|
||||||
|
profile = data.get("profile", PROFILE_STANDARD)
|
||||||
|
launch_alt = data.get("launch_altitude", 0)
|
||||||
|
|
||||||
|
if profile == PROFILE_STANDARD:
|
||||||
|
if 'burst_altitude' not in data:
|
||||||
|
raise serializers.ValidationError("burst_altitude is required for standard profile.")
|
||||||
|
if data['burst_altitude'] <= launch_alt:
|
||||||
|
raise serializers.ValidationError("burst_altitude must be greater than launch_altitude.")
|
||||||
|
|
||||||
|
elif profile == PROFILE_FLOAT:
|
||||||
|
if 'float_altitude' not in data or data['float_altitude'] <= launch_alt:
|
||||||
|
raise serializers.ValidationError("float_altitude must be greater than launch_altitude.")
|
||||||
|
if 'stop_datetime' not in data or data['stop_datetime'] <= data['launch_datetime']:
|
||||||
|
raise serializers.ValidationError("stop_datetime must be later than launch_datetime.")
|
||||||
|
|
||||||
|
elif profile == PROFILE_CUSTOM:
|
||||||
|
if 'ascent_curve' not in data or not validate_custom_curve(data['ascent_curve']):
|
||||||
|
raise serializers.ValidationError("Invalid ascent_curve.")
|
||||||
|
if 'descent_curve' not in data or not validate_custom_curve(data['descent_curve']):
|
||||||
|
raise serializers.ValidationError("Invalid descent_curve.")
|
||||||
|
if 'burst_altitude' not in data or data['burst_altitude'] <= launch_alt:
|
||||||
|
raise serializers.ValidationError("burst_altitude must be greater than launch_altitude.")
|
||||||
|
|
||||||
|
# кастомная логика clipping'а
|
||||||
|
if 'ascent_rate' in data:
|
||||||
|
data['ascent_rate'] = rate_clip(data['ascent_rate'])
|
||||||
|
if 'descent_rate' in data:
|
||||||
|
data['descent_rate'] = rate_clip(data['descent_rate'])
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
def create(self, validated_data):
|
||||||
|
if 'ascent_curve' in validated_data:
|
||||||
|
validated_data['ascent_curve'] = base64_to_curve(validated_data['ascent_curve'])
|
||||||
|
if 'descent_curve' in validated_data:
|
||||||
|
validated_data['descent_curve'] = base64_to_curve(validated_data['descent_curve'])
|
||||||
|
|
||||||
|
prediction = Prediction(
|
||||||
|
user=validated_data.get('user'),
|
||||||
|
request=validated_data.get('request', {}),
|
||||||
|
result=validated_data.get('result', {}),
|
||||||
|
start_point=validated_data.get('start_point'),
|
||||||
|
template=validated_data.get('template'),
|
||||||
|
rate_profile=validated_data.get('rate_profile')
|
||||||
|
)
|
||||||
|
prediction.save()
|
||||||
|
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
|
||||||
|
class PredictionListSerializer(serializers.ModelSerializer):
|
||||||
|
class Meta:
|
||||||
|
model = Prediction
|
||||||
|
fields = ["id", "created_at", "updated_at", "start_point", "template", "rate_profile"]
|
||||||
|
|
||||||
|
|
||||||
|
class PredictionDetailSerializer(serializers.ModelSerializer):
|
||||||
|
class Meta:
|
||||||
|
model = Prediction
|
||||||
|
fields = ["id", "created_at", "updated_at", "result", "start_point", "template", "rate_profile"]
|
||||||
|
|
||||||
|
|
||||||
|
from rest_framework import serializers
|
||||||
|
from .models import TelemetryPacket
|
||||||
|
|
||||||
|
class TelemetryPacketSerializer(serializers.ModelSerializer):
|
||||||
|
class Meta:
|
||||||
|
model = TelemetryPacket
|
||||||
|
fields = ['id', 'timestamp', 'lat', 'lon', 'alt', 'payload']
|
||||||
|
read_only_fields = ['id']
|
||||||
|
|
||||||
|
|
||||||
|
class SavedPointSerializer(serializers.ModelSerializer):
|
||||||
|
user = serializers.HiddenField(
|
||||||
|
default=serializers.CurrentUserDefault()
|
||||||
|
)
|
||||||
|
class Meta:
|
||||||
|
model = SavedPoint
|
||||||
|
fields = ['user', 'id', 'name', 'lat', 'lon', 'alt']
|
||||||
|
read_only_fields = ['id']
|
||||||
|
|
||||||
|
validators = [
|
||||||
|
serializers.UniqueTogetherValidator(
|
||||||
|
queryset=SavedPoint.objects.all(),
|
||||||
|
fields=['user', 'name'],
|
||||||
|
message="A saved point with this name already exists for the user."
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SavedRateProfileSerializer(serializers.ModelSerializer):
|
||||||
|
class Meta:
|
||||||
|
model = SavedRateProfile
|
||||||
|
fields = ['id', 'name', 'type', 'rate_profile_data']
|
||||||
|
read_only_fields = ['id']
|
||||||
|
|
||||||
|
|
||||||
|
class PreditctionTemplateSerializer(serializers.ModelSerializer):
|
||||||
|
class Meta:
|
||||||
|
model = PreditctionTemplate
|
||||||
|
fields = ['id', 'name', 'is_default', 'description', 'prediction_mode', 'model', 'dataset', 'flight_parameters']
|
||||||
|
read_only_fields = ['id']
|
||||||
|
|
||||||
|
|
||||||
|
class UserSerializer(serializers.ModelSerializer):
|
||||||
|
class Meta:
|
||||||
|
model = User
|
||||||
|
fields = ['username', 'email', 'first_name', 'last_name']
|
||||||
|
extra_kwargs = {
|
||||||
|
'username': {'read_only': True}
|
||||||
|
}
|
||||||
|
|
||||||
|
def validate_email(self, value):
|
||||||
|
try:
|
||||||
|
validate_email(value)
|
||||||
|
except DjangoValidationError:
|
||||||
|
raise serializers.ValidationError("Invalid email format")
|
||||||
|
return value
|
||||||
|
|
||||||
|
class ChangePasswordSerializer(serializers.Serializer):
|
||||||
|
old_password = serializers.CharField(required=True)
|
||||||
|
new_password = serializers.CharField(required=True)
|
||||||
|
|
||||||
|
def validate_new_password(self, value):
|
||||||
|
validate_password(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
class DeleteAccountSerializer(serializers.Serializer):
|
||||||
|
password = serializers.CharField(required=True)
|
||||||
|
|
@ -1,59 +1,48 @@
|
||||||
from dmr.routing import Router, path
|
from django.urls import path
|
||||||
|
from rest_framework.routers import DefaultRouter
|
||||||
|
from rest_framework.authtoken.views import obtain_auth_token
|
||||||
from .views import (
|
from .views import (
|
||||||
PredictionCollectionController,
|
PredictionViewSet,
|
||||||
PredictionListUserController,
|
SavedPointViewset,
|
||||||
PredictionHistoryController,
|
PreditctionTemplateViewset,
|
||||||
PredictionDetailController,
|
TelemetryListCreateView,
|
||||||
PredictionDeleteController,
|
get_csrf,
|
||||||
SavedPointListController,
|
login_view,
|
||||||
SavedPointDetailController,
|
logout_view,
|
||||||
PredictionTemplateListController,
|
SessionView,
|
||||||
PredictionTemplateDetailController,
|
WhoAmIView,
|
||||||
TelemetryController,
|
UserProfileView,
|
||||||
CsrfController,
|
ChangePasswordView,
|
||||||
LoginController,
|
TokenManagementView,
|
||||||
LogoutController,
|
DeleteUserDataView,
|
||||||
SessionController,
|
DeleteAccountView
|
||||||
WhoAmIController,
|
|
||||||
UserProfileController,
|
|
||||||
ChangePasswordController,
|
|
||||||
ObtainTokenController,
|
|
||||||
TokenManagementController,
|
|
||||||
DeleteUserDataController,
|
|
||||||
DeleteAccountController,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# A Router (prefix + routes) is required so build_schema() can generate the
|
|
||||||
# OpenAPI document; see stratoflights/urls.py.
|
|
||||||
router = Router(
|
|
||||||
'api/',
|
|
||||||
[
|
|
||||||
path("csrf/", CsrfController.as_view(), name='api-csrf'),
|
|
||||||
path('token', ObtainTokenController.as_view(), name='get_token'),
|
|
||||||
path("login/", LoginController.as_view(), name='api-login'),
|
|
||||||
path("logout/", LogoutController.as_view(), name='api-logout'),
|
|
||||||
path("session/", SessionController.as_view(), name='api-session'),
|
|
||||||
path("whoami/", WhoAmIController.as_view(), name='api-whoami'),
|
|
||||||
path("<uuid:pk>/telemetry/", TelemetryController.as_view(), name="create_telemetry"),
|
|
||||||
path("profile/", UserProfileController.as_view(), name='api-profile'),
|
|
||||||
path("profile/change-password/", ChangePasswordController.as_view(), name='api-change-password'),
|
|
||||||
path("profile/token/", TokenManagementController.as_view(), name='api-token'),
|
|
||||||
path("profile/delete-data/", DeleteUserDataController.as_view(), name='api-delete-data'),
|
|
||||||
path("profile/delete-account/", DeleteAccountController.as_view(), name='api-delete-account'),
|
|
||||||
# Saved points (was SavedPointViewset via DefaultRouter).
|
|
||||||
path("saved-points/", SavedPointListController.as_view(), name='saved-points-list'),
|
|
||||||
path("saved-points/<int:pk>/", SavedPointDetailController.as_view(), name='saved-points-detail'),
|
|
||||||
# Prediction templates (was PreditctionTemplateViewset via DefaultRouter).
|
|
||||||
path("saved-templates/", PredictionTemplateListController.as_view(), name='saved-templates-list'),
|
|
||||||
path("saved-templates/<int:pk>/", PredictionTemplateDetailController.as_view(), name='saved-templates-detail'),
|
|
||||||
# Predictions (was PredictionViewSet via DefaultRouter).
|
|
||||||
path("predictions/", PredictionCollectionController.as_view(), name='predictions-list'),
|
|
||||||
path("predictions/list_user/", PredictionListUserController.as_view(), name='predictions-list-user'),
|
|
||||||
path("predictions/history/", PredictionHistoryController.as_view(), name='predictions-history'),
|
|
||||||
path("predictions/<uuid:pk>/detail/", PredictionDetailController.as_view(), name='predictions-detail'),
|
|
||||||
path("predictions/<uuid:pk>/delete/", PredictionDeleteController.as_view(), name='predictions-delete'),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
urlpatterns = router.urls
|
router = DefaultRouter()
|
||||||
|
router.register(r'predictions', PredictionViewSet, basename='predictions')
|
||||||
|
router.register(r'saved-points', SavedPointViewset, basename='saved-points')
|
||||||
|
router.register(r'saved-templates', PreditctionTemplateViewset, basename='saved-templates')
|
||||||
|
|
||||||
|
|
||||||
|
urlpatterns = [
|
||||||
|
path("csrf/", get_csrf, name='api-csrf'),
|
||||||
|
path('token', obtain_auth_token, name = 'get_token'),
|
||||||
|
path("login/", login_view, name='api-login'),
|
||||||
|
path("logout/", logout_view, name='api-logout'),
|
||||||
|
path("session/", SessionView.as_view(), name='api-session'),
|
||||||
|
path("whoami/", WhoAmIView.as_view(), name='api-whoami'),
|
||||||
|
path("<uuid:pk>/telemetry/", TelemetryListCreateView.as_view(), name="create_telemetry"),
|
||||||
|
path('csrf/', get_csrf, name='api-csrf'),
|
||||||
|
path('login/', login_view, name='api-login'),
|
||||||
|
path('logout/', logout_view, name='api-logout'),
|
||||||
|
path('session/', SessionView.as_view(), name='api-session'),
|
||||||
|
path('whoami/', WhoAmIView.as_view(), name='api-whoami'),
|
||||||
|
path("profile/", UserProfileView.as_view(), name='api-profile'),
|
||||||
|
path("profile/change-password/", ChangePasswordView.as_view(), name='api-change-password'),
|
||||||
|
path("profile/token/", TokenManagementView.as_view(), name='api-token'),
|
||||||
|
path("profile/delete-data/", DeleteUserDataView.as_view(), name='api-delete-data'),
|
||||||
|
path("profile/delete-account/", DeleteAccountView.as_view(), name='api-delete-account'),
|
||||||
|
]
|
||||||
|
urlpatterns += router.urls
|
||||||
|
|
|
||||||
|
|
@ -1,477 +1,305 @@
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time
|
||||||
import json
|
import json
|
||||||
from http import HTTPStatus
|
from rest_framework import status, generics, permissions
|
||||||
|
from rest_framework.response import Response
|
||||||
|
from rest_framework.views import APIView
|
||||||
|
from rest_framework.viewsets import ModelViewSet, ViewSet, GenericViewSet
|
||||||
|
from rest_framework.exceptions import APIException
|
||||||
|
from rest_framework.permissions import IsAuthenticated, AllowAny
|
||||||
|
from rest_framework.authentication import SessionAuthentication, BasicAuthentication, TokenAuthentication
|
||||||
|
from rest_framework.decorators import api_view, permission_classes, authentication_classes, action
|
||||||
|
from rest_framework.authtoken.models import Token
|
||||||
|
from django.utils import timezone
|
||||||
|
from django.views.decorators.csrf import csrf_exempt
|
||||||
|
from django.utils.decorators import method_decorator
|
||||||
|
from django.http import JsonResponse
|
||||||
from django.contrib.auth import authenticate, login, logout, get_user_model
|
from django.contrib.auth import authenticate, login, logout, get_user_model
|
||||||
from django.middleware.csrf import get_token
|
from django.middleware.csrf import get_token
|
||||||
|
from django.core.exceptions import ValidationError
|
||||||
from django.utils.dateparse import parse_datetime
|
from django.utils.dateparse import parse_datetime
|
||||||
from .models import Prediction, Satellite, SavedPoint, SavedRateProfile, PreditctionTemplate, TelemetryPacket
|
from .models import Prediction, User, Satellite, SavedPoint, SavedRateProfile, PreditctionTemplate, TelemetryPacket
|
||||||
|
from .serializers import PredictionSerializer, TelemetryPacketSerializer, PredictionRequestSerializer, PredictionListSerializer, PredictionDetailSerializer, SavedPointSerializer, SavedRateProfileSerializer, PreditctionTemplateSerializer, UserSerializer, ChangePasswordSerializer, DeleteAccountSerializer
|
||||||
from .services.tawhiri import TawhiriClient
|
from .services.tawhiri import TawhiriClient
|
||||||
from datetime import datetime, timedelta, timezone
|
from drf_spectacular.utils import extend_schema
|
||||||
|
from .permissions import ReadOnlyOrAuthenticated, IsOwner
|
||||||
from pydantic import ValidationError
|
from .custom_pagination import CustomLimitOffsetPagination
|
||||||
from dmr import Controller, modify
|
from datetime import datetime
|
||||||
from dmr.components import Body, Path, Query
|
|
||||||
from dmr.plugins.pydantic import PydanticSerializer
|
|
||||||
from dmr.response import APIError
|
|
||||||
from dmr.security.jwt.views import (
|
|
||||||
ObtainTokensPayload,
|
|
||||||
ObtainTokensResponse,
|
|
||||||
ObtainTokensSyncController,
|
|
||||||
)
|
|
||||||
from .dtos import (
|
|
||||||
DetailResponse,
|
|
||||||
SessionResponse,
|
|
||||||
WhoAmIResponse,
|
|
||||||
UserOut,
|
|
||||||
UserUpdateIn,
|
|
||||||
ChangePasswordIn,
|
|
||||||
DeleteAccountIn,
|
|
||||||
TokenResponse,
|
|
||||||
PkPath,
|
|
||||||
UuidPkPath,
|
|
||||||
SavedPointIn,
|
|
||||||
SavedPointPatchIn,
|
|
||||||
SavedPointOut,
|
|
||||||
PredictionTemplateIn,
|
|
||||||
PredictionTemplatePatchIn,
|
|
||||||
PredictionTemplateOut,
|
|
||||||
PredictionRequest,
|
|
||||||
PredictionOut,
|
|
||||||
PredictionListOut,
|
|
||||||
PredictionDetailOut,
|
|
||||||
PredictionCreateOut,
|
|
||||||
TelemetryIn,
|
|
||||||
TelemetryOut,
|
|
||||||
)
|
|
||||||
from .pagination import (
|
|
||||||
PredictionListUserQuery,
|
|
||||||
PredictionPage,
|
|
||||||
paginate_predictions,
|
|
||||||
TelemetryPage,
|
|
||||||
page_number_paginate,
|
|
||||||
)
|
|
||||||
from .validators import base64_to_curve
|
|
||||||
|
|
||||||
User = get_user_model()
|
User = get_user_model()
|
||||||
|
|
||||||
|
|
||||||
def _api_error(status_code, body, headers=None):
|
def get_prediction_from_tawhiri(params):
|
||||||
"""Adapter for DMR's APIError(raw_data, *, status_code, ...) signature.
|
|
||||||
|
|
||||||
Lets call sites keep the (status_code=, body=) style; the body is DMR's
|
base_url = "https://fly.stratonautica.ru/api/v2"
|
||||||
first positional ``raw_data`` argument.
|
response = requests.get(base_url, params=params)
|
||||||
"""
|
|
||||||
return APIError(body, status_code=status_code, headers=headers)
|
if response.status_code == 200:
|
||||||
|
return response.json() # получаем результат предсказания
|
||||||
|
else:
|
||||||
|
raise Exception(
|
||||||
|
f"Tawhiri error: {response.status_code} {response.text}")
|
||||||
|
|
||||||
|
|
||||||
def _resolve_related(model, pk, field_name):
|
class PredictionViewSet(GenericViewSet):
|
||||||
"""Resolve a related object by PK (was PrimaryKeyRelatedField(queryset=...)).
|
permission_classes = [IsAuthenticated]
|
||||||
|
pagination_class = CustomLimitOffsetPagination
|
||||||
|
|
||||||
The original queryset was ``.all()`` (not user-scoped); a missing PK yields a
|
def list(self, request):
|
||||||
400 matching DRF's "Invalid pk" error shape.
|
queryset = Prediction.objects.filter(user=request.user)
|
||||||
"""
|
return Response(PredictionSerializer(queryset, many=True).data)
|
||||||
if pk is None:
|
|
||||||
return None
|
|
||||||
obj = model.objects.filter(pk=pk).first()
|
|
||||||
if obj is None:
|
|
||||||
raise _api_error(
|
|
||||||
status_code=HTTPStatus.BAD_REQUEST,
|
|
||||||
body={field_name: [f'Invalid pk "{pk}" - object does not exist.']},
|
|
||||||
)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
|
def create(self, request):
|
||||||
|
serializer = PredictionRequestSerializer(data=request.data)
|
||||||
|
|
||||||
class PredictionCollectionController(Controller[PydanticSerializer]):
|
if not serializer.is_valid():
|
||||||
"""`predictions/` -- list the user's predictions and create a new one."""
|
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
def get(self) -> list[PredictionOut]:
|
validated_data = serializer.validated_data
|
||||||
queryset = Prediction.objects.filter(user=self.request.user)
|
|
||||||
return [PredictionOut.model_validate(obj) for obj in queryset]
|
|
||||||
|
|
||||||
def post(self, parsed_body: Body[PredictionRequest]) -> PredictionCreateOut:
|
|
||||||
user = self.request.user
|
|
||||||
|
|
||||||
# Resolve related objects before calling Tawhiri, so an invalid PK still
|
|
||||||
# fails with 400 first (as DRF validation did).
|
|
||||||
start_point = _resolve_related(SavedPoint, parsed_body.start_point, 'start_point')
|
|
||||||
template = _resolve_related(PreditctionTemplate, parsed_body.template, 'template')
|
|
||||||
rate_profile = _resolve_related(SavedRateProfile, parsed_body.rate_profile, 'rate_profile')
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
prediction_result = TawhiriClient.get_prediction(parsed_body.model_dump())
|
prediction_result = TawhiriClient.get_prediction(validated_data)
|
||||||
except requests.RequestException as exc:
|
|
||||||
raise _api_error(
|
|
||||||
status_code=HTTPStatus.BAD_GATEWAY,
|
|
||||||
body={'error': f'Tawhiri error: {str(exc)}'},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Carried over from the old serializer.create(): curves are decoded but
|
except requests.RequestException as e:
|
||||||
# never persisted (the model has no curve fields). Kept for parity --
|
return Response({"error": f"Tawhiri error: {str(e)}"}, status=status.HTTP_502_BAD_GATEWAY)
|
||||||
# including that a malformed curve here raises (then surfaces as 500).
|
|
||||||
if parsed_body.ascent_curve is not None:
|
|
||||||
base64_to_curve(parsed_body.ascent_curve)
|
|
||||||
if parsed_body.descent_curve is not None:
|
|
||||||
base64_to_curve(parsed_body.descent_curve)
|
|
||||||
|
|
||||||
prediction = Prediction(
|
# prediction = Prediction.objects.create(
|
||||||
user=user,
|
# result=prediction_result, user=request.user, request=request.data, validated_data=validated_data)
|
||||||
request=json.loads(self.request.body),
|
prediction = serializer.save(
|
||||||
result=prediction_result,
|
user=request.user,
|
||||||
start_point=start_point,
|
|
||||||
template=template,
|
|
||||||
rate_profile=rate_profile,
|
|
||||||
)
|
|
||||||
prediction.save()
|
|
||||||
|
|
||||||
return PredictionCreateOut(
|
|
||||||
id=prediction.id,
|
|
||||||
created_at=prediction.created_at,
|
|
||||||
result=prediction_result,
|
result=prediction_result,
|
||||||
|
request=request.data
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return Response({
|
||||||
|
"id": prediction.id,
|
||||||
|
"created_at": prediction.created_at,
|
||||||
|
"result": prediction_result
|
||||||
|
}, status=status.HTTP_201_CREATED)
|
||||||
|
|
||||||
class PredictionListUserController(Controller[PydanticSerializer]):
|
@action(detail=False, methods=['get'])
|
||||||
"""`predictions/list_user/` -- filtered, paginated list of the user's predictions."""
|
def list_user(self, request):
|
||||||
|
user = request.user
|
||||||
|
satellite_id = request.query_params.get('satellite_id')
|
||||||
|
created_from = request.query_params.get('created_from')
|
||||||
|
created_till = request.query_params.get('created_till')
|
||||||
|
|
||||||
def get(self, parsed_query: Query[PredictionListUserQuery]) -> PredictionPage:
|
filters = {
|
||||||
user = self.request.user
|
'user': user,
|
||||||
filters = {'user': user}
|
}
|
||||||
|
|
||||||
if parsed_query.created_from:
|
if created_from:
|
||||||
filters['created_at__gte'] = parse_datetime(parsed_query.created_from)
|
filters['created_at__gte'] = parse_datetime(created_from)
|
||||||
if parsed_query.created_till:
|
|
||||||
filters['created_at__lte'] = parse_datetime(parsed_query.created_till)
|
if created_till:
|
||||||
if parsed_query.satellite_id:
|
filters['created_at__lte'] = parse_datetime(created_till)
|
||||||
if not user.satellites.filter(id=parsed_query.satellite_id).exists():
|
|
||||||
raise _api_error(status_code=HTTPStatus.FORBIDDEN, body={'detail': 'Access denied'})
|
if satellite_id:
|
||||||
filters['satellite_id'] = parsed_query.satellite_id
|
if not user.satellites.filter(id=satellite_id).exists():
|
||||||
|
return Response({'detail': 'Access denied'}, status=403)
|
||||||
|
|
||||||
|
filters['satellite_id'] = satellite_id
|
||||||
|
|
||||||
queryset = Prediction.objects.filter(**filters)
|
queryset = Prediction.objects.filter(**filters)
|
||||||
return paginate_predictions(queryset, parsed_query)
|
queryset = self.filter_queryset(queryset)
|
||||||
|
|
||||||
|
page = self.paginate_queryset(queryset)
|
||||||
|
|
||||||
class PredictionHistoryController(Controller[PydanticSerializer]):
|
if page is not None:
|
||||||
"""`predictions/history/` -- compact list of the user's predictions."""
|
serializer = PredictionSerializer(page, many=True)
|
||||||
|
return self.get_paginated_response(serializer.data)
|
||||||
|
|
||||||
def get(self) -> list[PredictionListOut]:
|
serializer = PredictionSerializer(queryset, many=True)
|
||||||
queryset = Prediction.objects.filter(user=self.request.user)
|
return Response(serializer.data)
|
||||||
return [PredictionListOut.model_validate(obj) for obj in queryset]
|
|
||||||
|
|
||||||
|
@action(detail=False, methods=["get"])
|
||||||
|
def history(self, request):
|
||||||
|
queryset = Prediction.objects.filter(user=request.user)
|
||||||
|
return Response(PredictionListSerializer(queryset, many=True).data)
|
||||||
|
|
||||||
class PredictionDetailController(Controller[PydanticSerializer]):
|
@action(detail=True, methods=["get"])
|
||||||
"""`predictions/<pk>/detail/` -- retrieve a single prediction."""
|
def detail(self, request, pk=None):
|
||||||
|
|
||||||
def get(self, parsed_path: Path[UuidPkPath]) -> PredictionDetailOut:
|
|
||||||
prediction = Prediction.objects.filter(
|
prediction = Prediction.objects.filter(
|
||||||
user=self.request.user, pk=parsed_path.pk).first()
|
user=request.user, pk=pk).first()
|
||||||
if prediction is None:
|
if not prediction:
|
||||||
raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'})
|
return Response({'detail': 'Not found'}, status=404)
|
||||||
return PredictionDetailOut.model_validate(prediction)
|
return Response(PredictionDetailSerializer(prediction).data)
|
||||||
|
|
||||||
|
@action(detail=True, methods=["delete"])
|
||||||
class PredictionDeleteController(Controller[PydanticSerializer]):
|
def delete(self, request, pk=None):
|
||||||
"""`predictions/<pk>/delete/` -- delete a single prediction."""
|
|
||||||
|
|
||||||
@modify(status_code=HTTPStatus.NO_CONTENT)
|
|
||||||
def delete(self, parsed_path: Path[UuidPkPath]) -> None:
|
|
||||||
prediction = Prediction.objects.filter(
|
prediction = Prediction.objects.filter(
|
||||||
user=self.request.user, pk=parsed_path.pk).first()
|
user=request.user, pk=pk).first()
|
||||||
if prediction is None:
|
if not prediction:
|
||||||
raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'})
|
return Response({'detail': 'Not found'}, status=404)
|
||||||
prediction.delete()
|
prediction.delete()
|
||||||
|
return Response(status=204)
|
||||||
|
|
||||||
|
|
||||||
class TelemetryController(Controller[PydanticSerializer]):
|
class TelemetryListCreateView(generics.ListCreateAPIView):
|
||||||
"""`<uuid:pk>/telemetry/` -- list and ingest telemetry for a satellite.
|
serializer_class = TelemetryPacketSerializer
|
||||||
|
permission_classes = [permissions.AllowAny]
|
||||||
|
|
||||||
Public (was AllowAny). GET preserves the global PageNumberPagination envelope
|
def get_queryset(self):
|
||||||
{count, next, previous, results} with PAGE_SIZE=100.
|
qs = TelemetryPacket.objects.filter(satellite_id=self.kwargs["pk"])
|
||||||
"""
|
|
||||||
|
|
||||||
auth = None # public (was AllowAny)
|
from_ts = self.request.query_params.get("from")
|
||||||
csrf_exempt = True # DRF views bypass Django CSRF; ingestion is anonymous
|
till_ts = self.request.query_params.get("till")
|
||||||
|
|
||||||
def get(self, parsed_path: Path[UuidPkPath]) -> TelemetryPage:
|
|
||||||
qs = TelemetryPacket.objects.filter(satellite_id=parsed_path.pk)
|
|
||||||
|
|
||||||
from_ts = self.request.GET.get('from')
|
|
||||||
till_ts = self.request.GET.get('till')
|
|
||||||
|
|
||||||
if from_ts:
|
if from_ts:
|
||||||
qs = qs.filter(timestamp__gte=int(from_ts))
|
qs = qs.filter(timestamp__gte=int(from_ts))
|
||||||
if till_ts:
|
if till_ts:
|
||||||
qs = qs.filter(timestamp__lte=int(till_ts))
|
qs = qs.filter(timestamp__lte=int(till_ts))
|
||||||
|
|
||||||
qs = qs.order_by('-timestamp')
|
return qs.order_by("-timestamp")
|
||||||
|
|
||||||
count, next_link, previous_link, page_objects = page_number_paginate(self.request, qs)
|
def post(self, request, pk):
|
||||||
return TelemetryPage(
|
serializer = TelemetryPacketSerializer(data=request.data)
|
||||||
count=count,
|
if not serializer.is_valid():
|
||||||
next=next_link,
|
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||||
previous=previous_link,
|
|
||||||
results=[TelemetryOut.model_validate(obj) for obj in page_objects],
|
|
||||||
)
|
|
||||||
|
|
||||||
def post(self, parsed_path: Path[UuidPkPath], parsed_body: Body[TelemetryIn]) -> TelemetryOut:
|
validated_data = serializer.validated_data
|
||||||
# Bug fix (approved): the original returned serializer.errors on success;
|
|
||||||
# we return the created packet instead. timestamp is still server-set.
|
TelemetryPacket.objects.create(timestamp=time.time(),
|
||||||
packet = TelemetryPacket.objects.create(
|
satellite=Satellite.objects.get(id=pk),
|
||||||
timestamp=time.time(),
|
lat=validated_data["lat"],
|
||||||
satellite=Satellite.objects.get(id=parsed_path.pk),
|
lon=validated_data["lon"],
|
||||||
lat=parsed_body.lat,
|
alt=validated_data["alt"],
|
||||||
lon=parsed_body.lon,
|
payload=validated_data['payload'],
|
||||||
alt=parsed_body.alt,
|
)
|
||||||
payload=parsed_body.payload,
|
return Response(serializer.errors, status=status.HTTP_201_CREATED)
|
||||||
)
|
|
||||||
return TelemetryOut.model_validate(packet)
|
|
||||||
|
|
||||||
|
|
||||||
class SessionController(Controller[PydanticSerializer]):
|
class SessionView(APIView):
|
||||||
"""Report whether the current request is authenticated."""
|
permission_classes = [IsAuthenticated]
|
||||||
|
|
||||||
def get(self) -> SessionResponse:
|
@staticmethod
|
||||||
return SessionResponse(isAuthenticated=True)
|
def get(request, format=None):
|
||||||
|
return JsonResponse({'isAuthenticated': True})
|
||||||
|
|
||||||
|
|
||||||
class WhoAmIController(Controller[PydanticSerializer]):
|
class WhoAmIView(APIView):
|
||||||
"""Return the current user's username."""
|
permission_classes = [IsAuthenticated]
|
||||||
|
|
||||||
def get(self) -> WhoAmIResponse:
|
@staticmethod
|
||||||
return WhoAmIResponse(username=self.request.user.username)
|
def get(request, format=None):
|
||||||
|
return JsonResponse({'username': request.user.username})
|
||||||
|
|
||||||
|
|
||||||
class CsrfController(Controller[PydanticSerializer]):
|
@extend_schema(methods=["GET"], description="Get CSRF token")
|
||||||
"""Get CSRF token."""
|
@csrf_exempt
|
||||||
|
@api_view(["GET"])
|
||||||
auth = None # public endpoint (was AllowAny)
|
@permission_classes([AllowAny])
|
||||||
csrf_exempt = True
|
def get_csrf(request):
|
||||||
|
response = JsonResponse({'detail': 'CSRF cookie set'})
|
||||||
def get(self) -> DetailResponse:
|
response['X-CSRFToken'] = get_token(request)
|
||||||
token = get_token(self.request)
|
return response
|
||||||
return self.to_response(
|
|
||||||
DetailResponse(detail='CSRF cookie set'),
|
|
||||||
headers={'X-CSRFToken': token},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LoginController(Controller[PydanticSerializer]):
|
@extend_schema(methods=["POST"], description="Login user")
|
||||||
"""Login user."""
|
@csrf_exempt
|
||||||
|
@api_view(["POST"])
|
||||||
|
@authentication_classes([BasicAuthentication])
|
||||||
|
@permission_classes([AllowAny])
|
||||||
|
def login_view(request):
|
||||||
|
data = json.loads(request.body)
|
||||||
|
username = data.get('username')
|
||||||
|
password = data.get('password')
|
||||||
|
|
||||||
auth = None # public endpoint (was AllowAny)
|
if username is None or password is None:
|
||||||
csrf_exempt = True
|
return JsonResponse({'detail': 'Please provide username and password.'}, status=400)
|
||||||
|
|
||||||
# post() would default to 201; the original returned 200.
|
user = authenticate(username=username, password=password)
|
||||||
@modify(status_code=HTTPStatus.OK)
|
if user is None:
|
||||||
def post(self) -> DetailResponse:
|
return JsonResponse({'detail': 'Invalid credentials.'}, status=400)
|
||||||
data = json.loads(self.request.body)
|
|
||||||
username = data.get('username')
|
|
||||||
password = data.get('password')
|
|
||||||
|
|
||||||
if username is None or password is None:
|
login(request, user)
|
||||||
raise _api_error(
|
return JsonResponse({'detail': 'Successfully logged in.'})
|
||||||
status_code=HTTPStatus.BAD_REQUEST,
|
|
||||||
body={'detail': 'Please provide username and password.'},
|
|
||||||
)
|
|
||||||
|
|
||||||
user = authenticate(username=username, password=password)
|
|
||||||
if user is None:
|
|
||||||
raise _api_error(
|
|
||||||
status_code=HTTPStatus.BAD_REQUEST,
|
|
||||||
body={'detail': 'Invalid credentials.'},
|
|
||||||
)
|
|
||||||
|
|
||||||
login(self.request, user)
|
|
||||||
return DetailResponse(detail='Successfully logged in.')
|
|
||||||
|
|
||||||
|
|
||||||
class LogoutController(Controller[PydanticSerializer]):
|
@extend_schema(methods=["POST"], description="Logout user")
|
||||||
"""Logout user."""
|
@api_view(["POST"])
|
||||||
|
@permission_classes([AllowAny])
|
||||||
|
def logout_view(request):
|
||||||
|
if not request.user.is_authenticated:
|
||||||
|
return JsonResponse({'detail': 'You\'re not logged in.'}, status=400)
|
||||||
|
|
||||||
auth = None # public endpoint (was AllowAny); checks auth state manually
|
logout(request)
|
||||||
|
return JsonResponse({'detail': 'Successfully logged out.'})
|
||||||
@modify(status_code=HTTPStatus.OK)
|
|
||||||
def post(self) -> DetailResponse:
|
|
||||||
if not self.request.user.is_authenticated:
|
|
||||||
raise _api_error(
|
|
||||||
status_code=HTTPStatus.BAD_REQUEST,
|
|
||||||
body={'detail': "You're not logged in."},
|
|
||||||
)
|
|
||||||
|
|
||||||
logout(self.request)
|
|
||||||
return DetailResponse(detail='Successfully logged out.')
|
|
||||||
|
|
||||||
|
|
||||||
_SAVED_POINT_DUPLICATE = 'A saved point with this name already exists for the user.'
|
class SavedPointViewset(ModelViewSet):
|
||||||
|
permission_classes = [IsOwner]
|
||||||
|
serializer_class = SavedPointSerializer
|
||||||
|
pagination_class = None
|
||||||
|
|
||||||
|
def get_queryset(self):
|
||||||
|
return SavedPoint.objects.filter(user=self.request.user)
|
||||||
|
|
||||||
|
def perform_create(self, serializer):
|
||||||
|
serializer.save(user=self.request.user)
|
||||||
|
|
||||||
|
|
||||||
def _check_saved_point_unique(user, name, exclude_pk=None):
|
class PreditctionTemplateViewset(ModelViewSet):
|
||||||
"""Reproduce the former SavedPointSerializer UniqueTogetherValidator."""
|
permission_classes = [IsOwner]
|
||||||
qs = SavedPoint.objects.filter(user=user, name=name)
|
serializer_class = PreditctionTemplateSerializer
|
||||||
if exclude_pk is not None:
|
pagination_class = None
|
||||||
qs = qs.exclude(pk=exclude_pk)
|
|
||||||
if qs.exists():
|
def get_queryset(self):
|
||||||
raise _api_error(
|
return PreditctionTemplate.objects.filter(user=self.request.user)
|
||||||
status_code=HTTPStatus.BAD_REQUEST,
|
|
||||||
body={'non_field_errors': [_SAVED_POINT_DUPLICATE]},
|
def perform_create(self, serializer):
|
||||||
)
|
serializer.save(user=self.request.user)
|
||||||
|
|
||||||
|
|
||||||
class SavedPointListController(Controller[PydanticSerializer]):
|
class UserProfileView(APIView):
|
||||||
"""Collection endpoint for the current user's saved points (was SavedPointViewset)."""
|
permission_classes = [IsAuthenticated]
|
||||||
|
|
||||||
def get(self) -> list[SavedPointOut]:
|
def get(self, request):
|
||||||
qs = SavedPoint.objects.filter(user=self.request.user)
|
serializer = UserSerializer(request.user)
|
||||||
return [SavedPointOut.model_validate(obj) for obj in qs]
|
return Response(serializer.data)
|
||||||
|
|
||||||
def post(self, parsed_body: Body[SavedPointIn]) -> SavedPointOut:
|
def patch(self, request):
|
||||||
user = self.request.user
|
user = request.user
|
||||||
_check_saved_point_unique(user, parsed_body.name)
|
serializer = UserSerializer(user, data=request.data, partial=True)
|
||||||
obj = SavedPoint.objects.create(user=user, **parsed_body.model_dump())
|
|
||||||
return SavedPointOut.model_validate(obj)
|
if not serializer.is_valid():
|
||||||
|
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
serializer.save()
|
||||||
|
return Response(serializer.data)
|
||||||
|
|
||||||
|
|
||||||
class SavedPointDetailController(Controller[PydanticSerializer]):
|
class ChangePasswordView(APIView):
|
||||||
"""Detail endpoint for a single saved point owned by the current user."""
|
permission_classes = [IsAuthenticated]
|
||||||
|
|
||||||
def _get_object(self, pk: int):
|
def post(self, request):
|
||||||
# Filtering by user means another user's object reads as 404 (not 403),
|
user = request.user
|
||||||
# matching DRF's get_object() over a user-scoped queryset.
|
serializer = ChangePasswordSerializer(data=request.data)
|
||||||
obj = SavedPoint.objects.filter(user=self.request.user, pk=pk).first()
|
|
||||||
if obj is None:
|
|
||||||
raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'})
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def get(self, parsed_path: Path[PkPath]) -> SavedPointOut:
|
if not serializer.is_valid():
|
||||||
return SavedPointOut.model_validate(self._get_object(parsed_path.pk))
|
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
def put(self, parsed_path: Path[PkPath], parsed_body: Body[SavedPointIn]) -> SavedPointOut:
|
if not user.check_password(serializer.validated_data['old_password']):
|
||||||
obj = self._get_object(parsed_path.pk)
|
return Response({'detail': 'Old password is incorrect'},
|
||||||
_check_saved_point_unique(self.request.user, parsed_body.name, exclude_pk=obj.pk)
|
status=status.HTTP_400_BAD_REQUEST)
|
||||||
for field, value in parsed_body.model_dump().items():
|
|
||||||
setattr(obj, field, value)
|
|
||||||
obj.save()
|
|
||||||
return SavedPointOut.model_validate(obj)
|
|
||||||
|
|
||||||
def patch(self, parsed_path: Path[PkPath], parsed_body: Body[SavedPointPatchIn]) -> SavedPointOut:
|
user.set_password(serializer.validated_data['new_password'])
|
||||||
obj = self._get_object(parsed_path.pk)
|
|
||||||
updates = parsed_body.model_dump(exclude_unset=True)
|
|
||||||
if 'name' in updates:
|
|
||||||
_check_saved_point_unique(self.request.user, updates['name'], exclude_pk=obj.pk)
|
|
||||||
for field, value in updates.items():
|
|
||||||
setattr(obj, field, value)
|
|
||||||
obj.save()
|
|
||||||
return SavedPointOut.model_validate(obj)
|
|
||||||
|
|
||||||
@modify(status_code=HTTPStatus.NO_CONTENT)
|
|
||||||
def delete(self, parsed_path: Path[PkPath]) -> None:
|
|
||||||
self._get_object(parsed_path.pk).delete()
|
|
||||||
|
|
||||||
|
|
||||||
class PredictionTemplateListController(Controller[PydanticSerializer]):
|
|
||||||
"""Collection endpoint for the current user's templates (was PreditctionTemplateViewset).
|
|
||||||
|
|
||||||
NOTE: as before, there is no app-level uniqueness check; a duplicate
|
|
||||||
(user, name) hits the model's unique_together and surfaces as a DB error.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get(self) -> list[PredictionTemplateOut]:
|
|
||||||
qs = PreditctionTemplate.objects.filter(user=self.request.user)
|
|
||||||
return [PredictionTemplateOut.model_validate(obj) for obj in qs]
|
|
||||||
|
|
||||||
def post(self, parsed_body: Body[PredictionTemplateIn]) -> PredictionTemplateOut:
|
|
||||||
obj = PreditctionTemplate.objects.create(
|
|
||||||
user=self.request.user, **parsed_body.model_dump()
|
|
||||||
)
|
|
||||||
return PredictionTemplateOut.model_validate(obj)
|
|
||||||
|
|
||||||
|
|
||||||
class PredictionTemplateDetailController(Controller[PydanticSerializer]):
|
|
||||||
"""Detail endpoint for a single template owned by the current user."""
|
|
||||||
|
|
||||||
def _get_object(self, pk: int):
|
|
||||||
obj = PreditctionTemplate.objects.filter(user=self.request.user, pk=pk).first()
|
|
||||||
if obj is None:
|
|
||||||
raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'})
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def get(self, parsed_path: Path[PkPath]) -> PredictionTemplateOut:
|
|
||||||
return PredictionTemplateOut.model_validate(self._get_object(parsed_path.pk))
|
|
||||||
|
|
||||||
def put(self, parsed_path: Path[PkPath], parsed_body: Body[PredictionTemplateIn]) -> PredictionTemplateOut:
|
|
||||||
obj = self._get_object(parsed_path.pk)
|
|
||||||
for field, value in parsed_body.model_dump().items():
|
|
||||||
setattr(obj, field, value)
|
|
||||||
obj.save()
|
|
||||||
return PredictionTemplateOut.model_validate(obj)
|
|
||||||
|
|
||||||
def patch(self, parsed_path: Path[PkPath], parsed_body: Body[PredictionTemplatePatchIn]) -> PredictionTemplateOut:
|
|
||||||
obj = self._get_object(parsed_path.pk)
|
|
||||||
for field, value in parsed_body.model_dump(exclude_unset=True).items():
|
|
||||||
setattr(obj, field, value)
|
|
||||||
obj.save()
|
|
||||||
return PredictionTemplateOut.model_validate(obj)
|
|
||||||
|
|
||||||
@modify(status_code=HTTPStatus.NO_CONTENT)
|
|
||||||
def delete(self, parsed_path: Path[PkPath]) -> None:
|
|
||||||
self._get_object(parsed_path.pk).delete()
|
|
||||||
|
|
||||||
|
|
||||||
class UserProfileController(Controller[PydanticSerializer]):
|
|
||||||
"""Read and partially update the current user's profile."""
|
|
||||||
|
|
||||||
def get(self) -> UserOut:
|
|
||||||
return UserOut.model_validate(self.request.user)
|
|
||||||
|
|
||||||
def patch(self, parsed_body: Body[UserUpdateIn]) -> UserOut:
|
|
||||||
user = self.request.user
|
|
||||||
# partial update: apply only the fields the client actually sent.
|
|
||||||
for field, value in parsed_body.model_dump(exclude_unset=True).items():
|
|
||||||
setattr(user, field, value)
|
|
||||||
user.save()
|
user.save()
|
||||||
return UserOut.model_validate(user)
|
return Response({'detail': 'Password changed successfully'})
|
||||||
|
|
||||||
|
|
||||||
class ChangePasswordController(Controller[PydanticSerializer]):
|
class DeleteAccountView(APIView):
|
||||||
"""Change the current user's password."""
|
permission_classes = [IsAuthenticated]
|
||||||
|
|
||||||
# post() would default to 201; the original returned 200.
|
def delete(self, request):
|
||||||
@modify(status_code=HTTPStatus.OK)
|
user = request.user
|
||||||
def post(self, parsed_body: Body[ChangePasswordIn]) -> DetailResponse:
|
serializer = DeleteAccountSerializer(data=request.data)
|
||||||
user = self.request.user
|
|
||||||
|
|
||||||
if not user.check_password(parsed_body.old_password):
|
if not serializer.is_valid():
|
||||||
raise _api_error(
|
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||||
status_code=HTTPStatus.BAD_REQUEST,
|
|
||||||
body={'detail': 'Old password is incorrect'},
|
|
||||||
)
|
|
||||||
|
|
||||||
user.set_password(parsed_body.new_password)
|
if not user.check_password(serializer.validated_data['password']):
|
||||||
user.save()
|
return Response({'detail': 'Incorrect password'},
|
||||||
return DetailResponse(detail='Password changed successfully')
|
status=status.HTTP_400_BAD_REQUEST)
|
||||||
|
|
||||||
|
|
||||||
class DeleteAccountController(Controller[PydanticSerializer]):
|
|
||||||
"""Delete the current user's account and all their data."""
|
|
||||||
|
|
||||||
# DMR forbids a Body component on DELETE, but the original endpoint read the
|
|
||||||
# password from a DELETE request body. Preserve that contract by parsing the
|
|
||||||
# body manually instead of via Body[DeleteAccountIn].
|
|
||||||
def delete(self) -> DetailResponse:
|
|
||||||
user = self.request.user
|
|
||||||
|
|
||||||
try:
|
|
||||||
parsed_body = DeleteAccountIn(**json.loads(self.request.body or b'{}'))
|
|
||||||
except ValidationError as exc:
|
|
||||||
raise _api_error(status_code=HTTPStatus.BAD_REQUEST, body=json.loads(exc.json()))
|
|
||||||
except ValueError:
|
|
||||||
raise _api_error(status_code=HTTPStatus.BAD_REQUEST, body={'detail': 'Invalid request body.'})
|
|
||||||
|
|
||||||
if not user.check_password(parsed_body.password):
|
|
||||||
raise _api_error(
|
|
||||||
status_code=HTTPStatus.BAD_REQUEST,
|
|
||||||
body={'detail': 'Incorrect password'},
|
|
||||||
)
|
|
||||||
|
|
||||||
Prediction.objects.filter(user=user).delete()
|
Prediction.objects.filter(user=user).delete()
|
||||||
SavedPoint.objects.filter(user=user).delete()
|
SavedPoint.objects.filter(user=user).delete()
|
||||||
|
|
@ -479,83 +307,35 @@ class DeleteAccountController(Controller[PydanticSerializer]):
|
||||||
|
|
||||||
user.delete()
|
user.delete()
|
||||||
|
|
||||||
return DetailResponse(detail='Account deleted successfully')
|
return Response({'detail': 'Account deleted successfully'})
|
||||||
|
|
||||||
|
|
||||||
class DeleteUserDataController(Controller[PydanticSerializer]):
|
class DeleteUserDataView(APIView):
|
||||||
"""Delete all of the current user's data without deleting the account."""
|
permission_classes = [IsAuthenticated]
|
||||||
|
|
||||||
def delete(self) -> DetailResponse:
|
def delete(self, request):
|
||||||
user = self.request.user
|
user = request.user
|
||||||
|
|
||||||
Prediction.objects.filter(user=user).delete()
|
Prediction.objects.filter(user=user).delete()
|
||||||
SavedPoint.objects.filter(user=user).delete()
|
SavedPoint.objects.filter(user=user).delete()
|
||||||
PreditctionTemplate.objects.filter(user=user).delete()
|
PreditctionTemplate.objects.filter(user=user).delete()
|
||||||
|
|
||||||
return DetailResponse(detail='All user data deleted successfully')
|
return Response({'detail': 'All user data deleted successfully'})
|
||||||
|
|
||||||
|
|
||||||
class ObtainTokenController(
|
class TokenManagementView(APIView):
|
||||||
ObtainTokensSyncController[
|
permission_classes = [IsAuthenticated]
|
||||||
PydanticSerializer,
|
|
||||||
ObtainTokensPayload,
|
|
||||||
ObtainTokensResponse,
|
|
||||||
],
|
|
||||||
):
|
|
||||||
"""Exchange username/password for JWT access + refresh tokens.
|
|
||||||
|
|
||||||
Replaces DRF's obtain_auth_token. Approved drift: the token format and
|
def get(self, request):
|
||||||
semantics change from a single stored DRF token to stateless JWTs, so the
|
|
||||||
response is {access_token, refresh_token} instead of {token}.
|
|
||||||
"""
|
|
||||||
|
|
||||||
auth = None # public: credentials are supplied in the request body
|
token, created = Token.objects.get_or_create(user=request.user)
|
||||||
csrf_exempt = True
|
return Response({"token": token.key})
|
||||||
jwt_expiration = timedelta(hours=1)
|
|
||||||
jwt_refresh_expiration = timedelta(days=7)
|
|
||||||
|
|
||||||
def convert_auth_payload(self, payload):
|
def post(self, request):
|
||||||
return payload
|
|
||||||
|
|
||||||
def make_api_response(self):
|
Token.objects.filter(user=request.user).delete()
|
||||||
now = datetime.now(timezone.utc)
|
token = Token.objects.create(user=request.user)
|
||||||
return {
|
return Response({"token": token.key})
|
||||||
'access_token': self.create_jwt_token(
|
|
||||||
expiration=now + self.jwt_expiration,
|
|
||||||
token_type='access',
|
|
||||||
),
|
|
||||||
'refresh_token': self.create_jwt_token(
|
|
||||||
expiration=now + self.jwt_refresh_expiration,
|
|
||||||
token_type='refresh',
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TokenManagementController(Controller[PydanticSerializer]):
|
|
||||||
"""Issue a fresh JWT access token for the current user.
|
|
||||||
|
|
||||||
Was TokenManagementView (DRF stored-token get/regenerate). With stateless
|
|
||||||
JWTs there is nothing stored to fetch or delete, so both GET and POST mint a
|
|
||||||
new access token (approved drift). The {"token": ...} response shape is kept.
|
|
||||||
"""
|
|
||||||
|
|
||||||
jwt_expiration = timedelta(hours=1)
|
|
||||||
|
|
||||||
def get(self) -> TokenResponse:
|
|
||||||
return TokenResponse(token=self._issue_token())
|
|
||||||
|
|
||||||
# post() would default to 201; the original returned 200.
|
|
||||||
@modify(status_code=HTTPStatus.OK)
|
|
||||||
def post(self) -> TokenResponse:
|
|
||||||
return TokenResponse(token=self._issue_token())
|
|
||||||
|
|
||||||
def _issue_token(self) -> str:
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
return self.create_jwt_token(
|
|
||||||
subject=str(self.request.user.pk),
|
|
||||||
expiration=now + self.jwt_expiration,
|
|
||||||
token_type='access',
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue