Compare commits
No commits in common. "modern-rest" and "master" have entirely different histories.
modern-res
...
master
11 changed files with 572 additions and 1023 deletions
|
|
@ -1,7 +1,8 @@
|
|||
# django-modern-rest 0.6.0 requires Django 5.x (upgraded from 4.2).
|
||||
Django>=5.0,<6.0
|
||||
django-modern-rest[jwt,pydantic]==0.6.0
|
||||
Django>=4.0,<5.0
|
||||
djangorestframework
|
||||
djangorestframework-simplejwt
|
||||
psycopg2-binary
|
||||
drf-spectacular
|
||||
requests
|
||||
django-cors-headers
|
||||
Pillow
|
||||
|
|
|
|||
|
|
@ -13,12 +13,6 @@ https://docs.djangoproject.com/en/4.2/ref/settings/
|
|||
from pathlib import Path
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from dmr.settings import Settings
|
||||
from dmr.security.django_session import DjangoSessionSyncAuth
|
||||
from dmr.security.jwt import JWTSyncAuth
|
||||
from dmr.openapi import OpenAPIConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
# Build paths inside the project like this: BASE_DIR / 'subdir'.
|
||||
|
|
@ -57,10 +51,12 @@ INSTALLED_APPS = [
|
|||
'django.contrib.sessions',
|
||||
'django.contrib.messages',
|
||||
'django.contrib.staticfiles',
|
||||
'rest_framework',
|
||||
'rest_framework.authtoken',
|
||||
'drf_spectacular',
|
||||
'corsheaders',
|
||||
'stratoflights_api.apps.StratoflightsApiConfig',
|
||||
'channels',
|
||||
'dmr', # required to serve OpenAPI docs static assets
|
||||
]
|
||||
|
||||
MIDDLEWARE = [
|
||||
|
|
@ -169,16 +165,25 @@ STATIC_URL = 'static/'
|
|||
DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField'
|
||||
AUTH_USER_MODEL = 'stratoflights_api.User'
|
||||
|
||||
# django-modern-rest configuration.
|
||||
# Global auth mirrors the previous DRF default (Token + Session); the DRF
|
||||
# TokenAuthentication is replaced by DMR's JWT (approved drift). Providing
|
||||
# several auth instances means at least one of them must succeed, so this
|
||||
# also enforces a 401 by default like DRF's global IsAuthenticated did.
|
||||
# Public endpoints opt out per-endpoint with @modify(auth=None).
|
||||
DMR_SETTINGS = {
|
||||
Settings.auth: [JWTSyncAuth(), DjangoSessionSyncAuth()],
|
||||
Settings.validate_responses: not PRODUCTION,
|
||||
Settings.openapi_config: OpenAPIConfig(title='Stratoflights API', version='1.0.0'),
|
||||
REST_FRAMEWORK = {
|
||||
'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination',
|
||||
'PAGE_SIZE': 100,
|
||||
|
||||
'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema',
|
||||
|
||||
'DEFAULT_AUTHENTICATION_CLASSES': [
|
||||
'rest_framework.authentication.TokenAuthentication',
|
||||
'rest_framework.authentication.SessionAuthentication',
|
||||
],
|
||||
|
||||
'DEFAULT_PERMISSION_CLASSES': [
|
||||
'rest_framework.permissions.IsAuthenticated',
|
||||
# 'rest_framework.permissions.AllowAny',
|
||||
],
|
||||
|
||||
'DEFAULT_RENDERER_CLASSES': [
|
||||
'rest_framework.renderers.JSONRenderer',
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,17 +1,27 @@
|
|||
"""URL configuration for the stratoflights project."""
|
||||
"""
|
||||
URL configuration for stratoflights project.
|
||||
|
||||
The `urlpatterns` list routes URLs to views. For more information please see:
|
||||
https://docs.djangoproject.com/en/4.2/topics/http/urls/
|
||||
Examples:
|
||||
Function views
|
||||
1. Add an import: from my_app import views
|
||||
2. Add a URL to urlpatterns: path('', views.home, name='home')
|
||||
Class-based views
|
||||
1. Add an import: from other_app.views import Home
|
||||
2. Add a URL to urlpatterns: path('', Home.as_view(), name='home')
|
||||
Including another URLconf
|
||||
1. Import the include() function: from django.urls import include, path
|
||||
2. Add a URL to urlpatterns: path('blog/', include('blog.urls'))
|
||||
"""
|
||||
from django.contrib import admin
|
||||
from django.urls import include
|
||||
from dmr.openapi import build_schema
|
||||
from dmr.openapi.views import OpenAPIJsonView, SwaggerView
|
||||
from dmr.routing import path
|
||||
|
||||
from stratoflights_api.urls import router
|
||||
|
||||
schema = build_schema(router)
|
||||
from django.urls import path, include
|
||||
from drf_spectacular.views import SpectacularAPIView
|
||||
from drf_spectacular.views import SpectacularSwaggerView
|
||||
|
||||
urlpatterns = [
|
||||
path('admin/', admin.site.urls),
|
||||
path(router.prefix, include(router.urls)),
|
||||
path('api/schema/', OpenAPIJsonView.as_view(schema), name='schema'),
|
||||
path('api/docs/', SwaggerView.as_view(schema), name='docs'),
|
||||
]
|
||||
path('api/', include('stratoflights_api.urls')),
|
||||
path('api/schema/', SpectacularAPIView.as_view(), name='schema'),
|
||||
path('api/docs/', SpectacularSwaggerView.as_view(url_name='schema'), name='docs'),
|
||||
]
|
||||
|
|
@ -18,19 +18,17 @@ class TelemetryConsumer(AsyncWebsocketConsumer):
|
|||
|
||||
async def receive(self, text_data):
|
||||
|
||||
from pydantic import ValidationError
|
||||
from .dtos import TelemetryIn
|
||||
from .serializers import TelemetryPacketSerializer
|
||||
|
||||
if not self.write_enabled:
|
||||
await self.send(text_data=json.dumps({"error": "Read-only mode"}))
|
||||
return
|
||||
|
||||
data = json.loads(text_data)
|
||||
serializer = TelemetryPacketSerializer(data=data)
|
||||
|
||||
try:
|
||||
TelemetryIn(**data)
|
||||
except ValidationError as exc:
|
||||
await self.send(text_data=json.dumps({"error": json.loads(exc.json())}))
|
||||
if not serializer.is_valid():
|
||||
await self.send(text_data=json.dumps({"error": serializer.errors}))
|
||||
return
|
||||
|
||||
saved_data = await self.save_telemetry(data)
|
||||
|
|
@ -53,6 +51,15 @@ class TelemetryConsumer(AsyncWebsocketConsumer):
|
|||
|
||||
async def telemetry_message(self, event):
|
||||
await self.send(text_data=json.dumps(event["data"]))
|
||||
@database_sync_to_async
|
||||
def get_user_from_token(self, token_key, Token):
|
||||
from rest_framework.authtoken.models import Token
|
||||
User = get_user_model()
|
||||
try:
|
||||
token = Token.objects.select_related("user").get(key=token_key)
|
||||
return token.user
|
||||
except Token.DoesNotExist:
|
||||
return None
|
||||
|
||||
@database_sync_to_async
|
||||
def save_telemetry(self, data):
|
||||
|
|
@ -92,20 +99,12 @@ class StationTelemetryConsumer(TelemetryConsumer):
|
|||
write_enabled = True
|
||||
|
||||
async def connect(self):
|
||||
from django.conf import settings
|
||||
from dmr.security.jwt.token import JWToken
|
||||
|
||||
from rest_framework.authtoken.models import Token
|
||||
token_key = self.scope["query_string"].decode().split("token=")[-1]
|
||||
try:
|
||||
decoded = JWToken.decode(
|
||||
encoded_token=token_key,
|
||||
secret=settings.SECRET_KEY,
|
||||
algorithm='HS256',
|
||||
)
|
||||
self.scope["user"] = await database_sync_to_async(
|
||||
get_user_model().objects.get
|
||||
)(pk=decoded.sub)
|
||||
except Exception:
|
||||
token = await database_sync_to_async(Token.objects.select_related("user").get)(key=token_key)
|
||||
self.scope["user"] = token.user
|
||||
except Token.DoesNotExist:
|
||||
await self.close()
|
||||
return
|
||||
|
||||
|
|
|
|||
17
stratoflights_api/custom_pagination.py
Normal file
17
stratoflights_api/custom_pagination.py
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
from rest_framework.pagination import LimitOffsetPagination
|
||||
from rest_framework.response import Response
|
||||
|
||||
class CustomLimitOffsetPagination(LimitOffsetPagination):
|
||||
limit_query_param = 'limit'
|
||||
offset_query_param = 'skip'
|
||||
max_limit = 100
|
||||
default_limit = 10
|
||||
|
||||
|
||||
def get_paginated_response(self, data):
|
||||
return Response({
|
||||
'total': self.count,
|
||||
'limit': self.limit,
|
||||
'skip': self.offset,
|
||||
'predictions': data
|
||||
})
|
||||
|
|
@ -1,333 +0,0 @@
|
|||
"""Pydantic DTOs replacing the DRF serializers from serializers.py.
|
||||
|
||||
Mapping is 1:1 with the former serializers. Output DTOs enable ``from_attributes``
|
||||
so an endpoint can build them straight from a Django model instance via
|
||||
``SomeOut.model_validate(obj)`` (the equivalent of ``Serializer(obj).data``).
|
||||
|
||||
Field-level and cross-field validation that DRF kept in ``validate_<field>()`` /
|
||||
``validate()`` is reproduced with pydantic ``field_validator`` / ``model_validator``.
|
||||
Business logic DRF kept in ``create()`` (base64 curve decoding, ORM writes, and
|
||||
resolving related objects from their PKs) is intentionally NOT here -- per DMR's
|
||||
split it moves into the endpoints (Step 2.9).
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
|
||||
from django.contrib.auth.password_validation import validate_password
|
||||
from django.core.validators import validate_email
|
||||
from django.core.exceptions import ValidationError as DjangoValidationError
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
|
||||
from .validators import validate_custom_curve, rate_clip
|
||||
|
||||
|
||||
# Profile constants (moved verbatim from serializers.py).
|
||||
PROFILE_STANDARD = "standard_profile"
|
||||
PROFILE_FLOAT = "float_profile"
|
||||
PROFILE_REVERSE = "reverse_profile"
|
||||
PROFILE_CUSTOM = "custom_profile"
|
||||
LATEST_DATASET_KEYWORD = "latest"
|
||||
SUPPORTED_PROFILES = [PROFILE_STANDARD, PROFILE_FLOAT, PROFILE_REVERSE, PROFILE_CUSTOM]
|
||||
|
||||
|
||||
# --- Prediction request (was PredictionRequestSerializer) ---------------------
|
||||
class PredictionRequest(BaseModel):
|
||||
launch_latitude: float = Field(ge=-90, le=90)
|
||||
launch_longitude: float = Field(ge=0, le=360)
|
||||
launch_datetime: datetime
|
||||
launch_altitude: Optional[float] = None
|
||||
format: str = "json"
|
||||
profile: str = PROFILE_STANDARD
|
||||
dataset: str = LATEST_DATASET_KEYWORD
|
||||
|
||||
# profile-dependent fields
|
||||
ascent_rate: Optional[float] = Field(default=None, ge=0.01)
|
||||
descent_rate: Optional[float] = Field(default=None, ge=0.01)
|
||||
burst_altitude: Optional[float] = None
|
||||
float_altitude: Optional[float] = None
|
||||
stop_datetime: Optional[datetime] = None
|
||||
ascent_curve: Optional[str] = None
|
||||
descent_curve: Optional[str] = None
|
||||
interpolate: bool = False
|
||||
# Related objects are accepted as PKs; existence is resolved in the endpoint
|
||||
# (was PrimaryKeyRelatedField(queryset=...)).
|
||||
start_point: Optional[int] = None
|
||||
rate_profile: Optional[int] = None
|
||||
template: Optional[int] = None
|
||||
|
||||
@field_validator("profile")
|
||||
@classmethod
|
||||
def _validate_profile(cls, value: str) -> str:
|
||||
if value not in SUPPORTED_PROFILES:
|
||||
raise ValueError(f'"{value}" is not a valid choice.')
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_cross_fields(self) -> "PredictionRequest":
|
||||
launch_alt = self.launch_altitude if self.launch_altitude is not None else 0
|
||||
|
||||
if self.profile == PROFILE_STANDARD:
|
||||
if self.burst_altitude is None:
|
||||
raise ValueError("burst_altitude is required for standard profile.")
|
||||
if self.burst_altitude <= launch_alt:
|
||||
raise ValueError("burst_altitude must be greater than launch_altitude.")
|
||||
|
||||
elif self.profile == PROFILE_FLOAT:
|
||||
if self.float_altitude is None or self.float_altitude <= launch_alt:
|
||||
raise ValueError("float_altitude must be greater than launch_altitude.")
|
||||
if self.stop_datetime is None or self.stop_datetime <= self.launch_datetime:
|
||||
raise ValueError("stop_datetime must be later than launch_datetime.")
|
||||
|
||||
elif self.profile == PROFILE_CUSTOM:
|
||||
if self.ascent_curve is None or not validate_custom_curve(self.ascent_curve):
|
||||
raise ValueError("Invalid ascent_curve.")
|
||||
if self.descent_curve is None or not validate_custom_curve(self.descent_curve):
|
||||
raise ValueError("Invalid descent_curve.")
|
||||
if self.burst_altitude is None or self.burst_altitude <= launch_alt:
|
||||
raise ValueError("burst_altitude must be greater than launch_altitude.")
|
||||
|
||||
# custom clipping logic (was in validate())
|
||||
if self.ascent_rate is not None:
|
||||
self.ascent_rate = rate_clip(self.ascent_rate)
|
||||
if self.descent_rate is not None:
|
||||
self.descent_rate = rate_clip(self.descent_rate)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
# --- Prediction outputs -------------------------------------------------------
|
||||
class PredictionOut(BaseModel):
|
||||
"""Was PredictionSerializer."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
result: Any
|
||||
|
||||
|
||||
class PredictionListOut(BaseModel):
|
||||
"""Was PredictionListSerializer."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
start_point: Optional[int] = Field(default=None, validation_alias="start_point_id")
|
||||
template: Optional[int] = Field(default=None, validation_alias="template_id")
|
||||
rate_profile: Optional[int] = Field(default=None, validation_alias="rate_profile_id")
|
||||
|
||||
|
||||
class PredictionDetailOut(BaseModel):
|
||||
"""Was PredictionDetailSerializer."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
result: Any
|
||||
start_point: Optional[int] = Field(default=None, validation_alias="start_point_id")
|
||||
template: Optional[int] = Field(default=None, validation_alias="template_id")
|
||||
rate_profile: Optional[int] = Field(default=None, validation_alias="rate_profile_id")
|
||||
|
||||
|
||||
class PredictionCreateOut(BaseModel):
|
||||
"""Custom create() response shape: {id, created_at, result} (not the full model)."""
|
||||
|
||||
id: uuid.UUID
|
||||
created_at: datetime
|
||||
result: Any
|
||||
|
||||
|
||||
# --- Telemetry (was TelemetryPacketSerializer) --------------------------------
|
||||
class TelemetryIn(BaseModel):
|
||||
# timestamp is required (model BigIntegerField has no default), matching the
|
||||
# former ModelSerializer; the endpoint still overwrites it with server time.
|
||||
timestamp: int
|
||||
lat: float = 0.0
|
||||
lon: float = 0.0
|
||||
alt: float = 0.0
|
||||
payload: Any = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TelemetryOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: uuid.UUID
|
||||
timestamp: int
|
||||
lat: float
|
||||
lon: float
|
||||
alt: float
|
||||
payload: Any
|
||||
|
||||
|
||||
# --- SavedPoint (was SavedPointSerializer) ------------------------------------
|
||||
# `user` was a HiddenField(CurrentUserDefault) -> set from request in the endpoint,
|
||||
# not part of the wire input/output. The unique_together (user, name) check that
|
||||
# DRF did via UniqueTogetherValidator is reproduced in the endpoint (Step 2.8).
|
||||
class SavedPointIn(BaseModel):
|
||||
name: str
|
||||
lat: float = 0.0
|
||||
lon: float = 0.0
|
||||
alt: float = 0.0
|
||||
|
||||
|
||||
class SavedPointPatchIn(BaseModel):
|
||||
# All-optional variant for PATCH (partial_update). Only set fields apply.
|
||||
name: Optional[str] = None
|
||||
lat: Optional[float] = None
|
||||
lon: Optional[float] = None
|
||||
alt: Optional[float] = None
|
||||
|
||||
|
||||
class SavedPointOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
name: str
|
||||
lat: float
|
||||
lon: float
|
||||
alt: float
|
||||
|
||||
|
||||
# --- SavedRateProfile (was SavedRateProfileSerializer) ------------------------
|
||||
# NOTE: no view referenced this serializer in routing; kept for parity.
|
||||
class SavedRateProfileIn(BaseModel):
|
||||
name: str
|
||||
type: str = "ascent"
|
||||
rate_profile_data: Any = Field(default_factory=dict)
|
||||
|
||||
|
||||
class SavedRateProfileOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: int
|
||||
name: str
|
||||
type: str
|
||||
rate_profile_data: Any
|
||||
|
||||
|
||||
# --- PredictionTemplate (was PreditctionTemplateSerializer) -------------------
|
||||
class PredictionTemplateIn(BaseModel):
|
||||
# protected_namespaces=() avoids pydantic's warning about the `model` field.
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
name: str
|
||||
is_default: bool = False
|
||||
description: Optional[str] = None
|
||||
prediction_mode: str = ""
|
||||
model: str = ""
|
||||
dataset: str = ""
|
||||
flight_parameters: Any = Field(default_factory=dict)
|
||||
|
||||
|
||||
class PredictionTemplatePatchIn(BaseModel):
|
||||
# All-optional variant for PATCH (partial_update). Only set fields apply.
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
name: Optional[str] = None
|
||||
is_default: Optional[bool] = None
|
||||
description: Optional[str] = None
|
||||
prediction_mode: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
dataset: Optional[str] = None
|
||||
flight_parameters: Optional[Any] = None
|
||||
|
||||
|
||||
class PredictionTemplateOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True, protected_namespaces=())
|
||||
|
||||
id: int
|
||||
name: str
|
||||
is_default: bool
|
||||
description: Optional[str] = None
|
||||
prediction_mode: str
|
||||
model: str
|
||||
dataset: str
|
||||
flight_parameters: Any
|
||||
|
||||
|
||||
# --- User (was UserSerializer) ------------------------------------------------
|
||||
class UserOut(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
username: str
|
||||
email: str
|
||||
first_name: str
|
||||
last_name: str
|
||||
|
||||
|
||||
class UserUpdateIn(BaseModel):
|
||||
# `username` was read_only -> excluded from input. PATCH is partial, so all
|
||||
# fields are optional; the endpoint applies only fields that were set.
|
||||
email: Optional[str] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
|
||||
@field_validator("email")
|
||||
@classmethod
|
||||
def _validate_email(cls, value: Optional[str]) -> Optional[str]:
|
||||
if value is None:
|
||||
return value
|
||||
try:
|
||||
validate_email(value)
|
||||
except DjangoValidationError:
|
||||
raise ValueError("Invalid email format")
|
||||
return value
|
||||
|
||||
|
||||
# --- Password / account (were ChangePasswordSerializer / DeleteAccountSerializer)
|
||||
class ChangePasswordIn(BaseModel):
|
||||
old_password: str
|
||||
new_password: str
|
||||
|
||||
@field_validator("new_password")
|
||||
@classmethod
|
||||
def _validate_new_password(cls, value: str) -> str:
|
||||
# validate_password raises Django's ValidationError (not a ValueError),
|
||||
# which pydantic would not convert into a 4xx -- re-raise as ValueError.
|
||||
try:
|
||||
validate_password(value)
|
||||
except DjangoValidationError as exc:
|
||||
raise ValueError("; ".join(exc.messages))
|
||||
return value
|
||||
|
||||
|
||||
class DeleteAccountIn(BaseModel):
|
||||
password: str
|
||||
|
||||
|
||||
# --- Auth / session response shapes (were inline JsonResponse dicts) ----------
|
||||
class DetailResponse(BaseModel):
|
||||
detail: str
|
||||
|
||||
|
||||
class SessionResponse(BaseModel):
|
||||
isAuthenticated: bool
|
||||
|
||||
|
||||
class WhoAmIResponse(BaseModel):
|
||||
username: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
# --- Shared path parameters ---------------------------------------------------
|
||||
class PkPath(BaseModel):
|
||||
pk: int
|
||||
|
||||
|
||||
class UuidPkPath(BaseModel):
|
||||
pk: uuid.UUID
|
||||
|
|
@ -1,120 +0,0 @@
|
|||
"""Limit/offset pagination for the prediction list endpoints.
|
||||
|
||||
DMR has no built-in limit/offset paginator, so (as the docs suggest) we keep our
|
||||
own envelope model -- matching the former ``CustomLimitOffsetPagination`` output
|
||||
exactly -- plus a helper that slices the queryset. The query params and envelope
|
||||
keys are preserved verbatim: ``limit`` / ``skip`` in, and
|
||||
``{total, limit, skip, predictions}`` out.
|
||||
"""
|
||||
from http import HTTPStatus
|
||||
from typing import Optional
|
||||
from urllib.parse import urlsplit, urlunsplit
|
||||
|
||||
from django.core.paginator import InvalidPage, Paginator
|
||||
from django.http import QueryDict
|
||||
from pydantic import BaseModel
|
||||
|
||||
from dmr.response import APIError
|
||||
|
||||
from .dtos import PredictionOut, TelemetryOut
|
||||
|
||||
DEFAULT_LIMIT = 10 # was CustomLimitOffsetPagination.default_limit
|
||||
MAX_LIMIT = 100 # was CustomLimitOffsetPagination.max_limit
|
||||
PAGE_SIZE = 100 # global REST_FRAMEWORK PAGE_SIZE (PageNumberPagination)
|
||||
|
||||
|
||||
class PaginationQuery(BaseModel):
|
||||
# Param names preserved: limit_query_param='limit', offset_query_param='skip'.
|
||||
limit: int = DEFAULT_LIMIT
|
||||
skip: int = 0
|
||||
|
||||
|
||||
class PredictionListUserQuery(BaseModel):
|
||||
"""Query params for PredictionViewSet.list_user: filters + limit/offset."""
|
||||
|
||||
satellite_id: Optional[str] = None
|
||||
created_from: Optional[str] = None
|
||||
created_till: Optional[str] = None
|
||||
limit: int = DEFAULT_LIMIT
|
||||
skip: int = 0
|
||||
|
||||
|
||||
class PredictionPage(BaseModel):
|
||||
total: int
|
||||
limit: int
|
||||
skip: int
|
||||
predictions: list[PredictionOut]
|
||||
|
||||
|
||||
class TelemetryPage(BaseModel):
|
||||
count: int
|
||||
next: Optional[str] = None
|
||||
previous: Optional[str] = None
|
||||
results: list[TelemetryOut]
|
||||
|
||||
|
||||
def _replace_query_param(url: str, key: str, value) -> str:
|
||||
scheme, netloc, path, query, fragment = urlsplit(url)
|
||||
query_dict = QueryDict(query, mutable=True)
|
||||
query_dict[key] = value
|
||||
return urlunsplit((scheme, netloc, path, query_dict.urlencode(), fragment))
|
||||
|
||||
|
||||
def _remove_query_param(url: str, key: str) -> str:
|
||||
scheme, netloc, path, query, fragment = urlsplit(url)
|
||||
query_dict = QueryDict(query, mutable=True)
|
||||
query_dict.pop(key, None)
|
||||
return urlunsplit((scheme, netloc, path, query_dict.urlencode(), fragment))
|
||||
|
||||
|
||||
def page_number_paginate(request, queryset, page_size: int = PAGE_SIZE):
|
||||
"""Reproduce DRF PageNumberPagination.
|
||||
|
||||
Returns ``(count, next_link, previous_link, object_list)``. Raises a 404
|
||||
"Invalid page." like DRF on an out-of-range / non-integer page. Honours
|
||||
``page=last`` and builds absolute next/previous links off the current URL.
|
||||
"""
|
||||
paginator = Paginator(queryset, page_size)
|
||||
page_number = request.GET.get('page', 1)
|
||||
if page_number in ('last',):
|
||||
page_number = paginator.num_pages
|
||||
try:
|
||||
page = paginator.page(page_number)
|
||||
except InvalidPage:
|
||||
raise APIError({'detail': 'Invalid page.'}, status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
url = request.build_absolute_uri()
|
||||
|
||||
next_link = None
|
||||
if page.has_next():
|
||||
next_link = _replace_query_param(url, 'page', page.next_page_number())
|
||||
|
||||
previous_link = None
|
||||
if page.has_previous():
|
||||
previous_number = page.previous_page_number()
|
||||
if previous_number == 1:
|
||||
previous_link = _remove_query_param(url, 'page')
|
||||
else:
|
||||
previous_link = _replace_query_param(url, 'page', previous_number)
|
||||
|
||||
return paginator.count, next_link, previous_link, list(page.object_list)
|
||||
|
||||
|
||||
def paginate_predictions(queryset, query: PaginationQuery) -> PredictionPage:
|
||||
"""Slice ``queryset`` and build the prediction page envelope.
|
||||
|
||||
Mirrors the bounds DRF's LimitOffsetPagination enforced: limit falls back to
|
||||
the default when non-positive and is capped at MAX_LIMIT; skip floors at 0.
|
||||
"""
|
||||
limit = query.limit if query.limit > 0 else DEFAULT_LIMIT
|
||||
limit = min(limit, MAX_LIMIT)
|
||||
skip = query.skip if query.skip >= 0 else 0
|
||||
|
||||
total = queryset.count()
|
||||
page = list(queryset[skip:skip + limit])
|
||||
return PredictionPage(
|
||||
total=total,
|
||||
limit=limit,
|
||||
skip=skip,
|
||||
predictions=[PredictionOut.model_validate(obj) for obj in page],
|
||||
)
|
||||
16
stratoflights_api/permissions.py
Normal file
16
stratoflights_api/permissions.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
from rest_framework.permissions import BasePermission, SAFE_METHODS
|
||||
|
||||
class ReadOnlyOrAuthenticated(BasePermission):
|
||||
def has_permission(self, request, view):
|
||||
return (
|
||||
request.method in SAFE_METHODS or
|
||||
request.user and request.user.is_authenticated
|
||||
)
|
||||
|
||||
|
||||
class IsOwner(BasePermission):
|
||||
def has_object_permission(self, request, view, obj):
|
||||
return obj.user == request.user
|
||||
|
||||
def has_permission(self, request, view):
|
||||
return request.user and request.user.is_authenticated
|
||||
185
stratoflights_api/serializers.py
Normal file
185
stratoflights_api/serializers.py
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
from rest_framework import serializers
|
||||
from .models import Prediction, SavedPoint, SavedRateProfile, PreditctionTemplate
|
||||
from datetime import datetime
|
||||
from django.contrib.auth.password_validation import validate_password
|
||||
from django.core.validators import validate_email
|
||||
from django.core.exceptions import ValidationError as DjangoValidationError
|
||||
from django.contrib.auth import get_user_model
|
||||
from .validators import (
|
||||
validate_custom_curve, rate_clip,
|
||||
_rfc3339_to_timestamp, base64_to_curve
|
||||
)
|
||||
|
||||
class PredictionSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Prediction
|
||||
fields = ['id', 'created_at', 'updated_at', 'result']
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
PROFILE_STANDARD = "standard_profile"
|
||||
PROFILE_FLOAT = "float_profile"
|
||||
PROFILE_REVERSE = "reverse_profile"
|
||||
PROFILE_CUSTOM = "custom_profile"
|
||||
LATEST_DATASET_KEYWORD = "latest"
|
||||
SUPPORTED_PROFILES = [PROFILE_STANDARD, PROFILE_FLOAT, PROFILE_REVERSE, PROFILE_CUSTOM]
|
||||
|
||||
|
||||
class PredictionRequestSerializer(serializers.Serializer):
|
||||
launch_latitude = serializers.FloatField(min_value=-90, max_value=90)
|
||||
launch_longitude = serializers.FloatField(min_value=0, max_value=360)
|
||||
launch_datetime = serializers.DateTimeField()
|
||||
launch_altitude = serializers.FloatField(required=False)
|
||||
format = serializers.CharField(default="json")
|
||||
profile = serializers.ChoiceField(choices=SUPPORTED_PROFILES, default=PROFILE_STANDARD)
|
||||
dataset = serializers.CharField(default=LATEST_DATASET_KEYWORD)
|
||||
|
||||
# --- профиль-dependent поля ---
|
||||
ascent_rate = serializers.FloatField(required=False, min_value=0.01)
|
||||
descent_rate = serializers.FloatField(required=False, min_value=0.01)
|
||||
burst_altitude = serializers.FloatField(required=False)
|
||||
float_altitude = serializers.FloatField(required=False)
|
||||
stop_datetime = serializers.DateTimeField(required=False)
|
||||
ascent_curve = serializers.CharField(required=False)
|
||||
descent_curve = serializers.CharField(required=False)
|
||||
interpolate = serializers.BooleanField(required=False, default=False)
|
||||
start_point = serializers.PrimaryKeyRelatedField(
|
||||
queryset=SavedPoint.objects.all(), required=False, allow_null=True
|
||||
)
|
||||
rate_profile = serializers.PrimaryKeyRelatedField(
|
||||
queryset=SavedRateProfile.objects.all(), required=False, allow_null=True
|
||||
)
|
||||
template = serializers.PrimaryKeyRelatedField(
|
||||
queryset=PreditctionTemplate.objects.all(), required=False, allow_null=True
|
||||
)
|
||||
|
||||
def validate(self, data):
|
||||
profile = data.get("profile", PROFILE_STANDARD)
|
||||
launch_alt = data.get("launch_altitude", 0)
|
||||
|
||||
if profile == PROFILE_STANDARD:
|
||||
if 'burst_altitude' not in data:
|
||||
raise serializers.ValidationError("burst_altitude is required for standard profile.")
|
||||
if data['burst_altitude'] <= launch_alt:
|
||||
raise serializers.ValidationError("burst_altitude must be greater than launch_altitude.")
|
||||
|
||||
elif profile == PROFILE_FLOAT:
|
||||
if 'float_altitude' not in data or data['float_altitude'] <= launch_alt:
|
||||
raise serializers.ValidationError("float_altitude must be greater than launch_altitude.")
|
||||
if 'stop_datetime' not in data or data['stop_datetime'] <= data['launch_datetime']:
|
||||
raise serializers.ValidationError("stop_datetime must be later than launch_datetime.")
|
||||
|
||||
elif profile == PROFILE_CUSTOM:
|
||||
if 'ascent_curve' not in data or not validate_custom_curve(data['ascent_curve']):
|
||||
raise serializers.ValidationError("Invalid ascent_curve.")
|
||||
if 'descent_curve' not in data or not validate_custom_curve(data['descent_curve']):
|
||||
raise serializers.ValidationError("Invalid descent_curve.")
|
||||
if 'burst_altitude' not in data or data['burst_altitude'] <= launch_alt:
|
||||
raise serializers.ValidationError("burst_altitude must be greater than launch_altitude.")
|
||||
|
||||
# кастомная логика clipping'а
|
||||
if 'ascent_rate' in data:
|
||||
data['ascent_rate'] = rate_clip(data['ascent_rate'])
|
||||
if 'descent_rate' in data:
|
||||
data['descent_rate'] = rate_clip(data['descent_rate'])
|
||||
|
||||
return data
|
||||
|
||||
def create(self, validated_data):
|
||||
if 'ascent_curve' in validated_data:
|
||||
validated_data['ascent_curve'] = base64_to_curve(validated_data['ascent_curve'])
|
||||
if 'descent_curve' in validated_data:
|
||||
validated_data['descent_curve'] = base64_to_curve(validated_data['descent_curve'])
|
||||
|
||||
prediction = Prediction(
|
||||
user=validated_data.get('user'),
|
||||
request=validated_data.get('request', {}),
|
||||
result=validated_data.get('result', {}),
|
||||
start_point=validated_data.get('start_point'),
|
||||
template=validated_data.get('template'),
|
||||
rate_profile=validated_data.get('rate_profile')
|
||||
)
|
||||
prediction.save()
|
||||
|
||||
return prediction
|
||||
|
||||
|
||||
class PredictionListSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Prediction
|
||||
fields = ["id", "created_at", "updated_at", "start_point", "template", "rate_profile"]
|
||||
|
||||
|
||||
class PredictionDetailSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = Prediction
|
||||
fields = ["id", "created_at", "updated_at", "result", "start_point", "template", "rate_profile"]
|
||||
|
||||
|
||||
from rest_framework import serializers
|
||||
from .models import TelemetryPacket
|
||||
|
||||
class TelemetryPacketSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = TelemetryPacket
|
||||
fields = ['id', 'timestamp', 'lat', 'lon', 'alt', 'payload']
|
||||
read_only_fields = ['id']
|
||||
|
||||
|
||||
class SavedPointSerializer(serializers.ModelSerializer):
|
||||
user = serializers.HiddenField(
|
||||
default=serializers.CurrentUserDefault()
|
||||
)
|
||||
class Meta:
|
||||
model = SavedPoint
|
||||
fields = ['user', 'id', 'name', 'lat', 'lon', 'alt']
|
||||
read_only_fields = ['id']
|
||||
|
||||
validators = [
|
||||
serializers.UniqueTogetherValidator(
|
||||
queryset=SavedPoint.objects.all(),
|
||||
fields=['user', 'name'],
|
||||
message="A saved point with this name already exists for the user."
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class SavedRateProfileSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = SavedRateProfile
|
||||
fields = ['id', 'name', 'type', 'rate_profile_data']
|
||||
read_only_fields = ['id']
|
||||
|
||||
|
||||
class PreditctionTemplateSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = PreditctionTemplate
|
||||
fields = ['id', 'name', 'is_default', 'description', 'prediction_mode', 'model', 'dataset', 'flight_parameters']
|
||||
read_only_fields = ['id']
|
||||
|
||||
|
||||
class UserSerializer(serializers.ModelSerializer):
|
||||
class Meta:
|
||||
model = User
|
||||
fields = ['username', 'email', 'first_name', 'last_name']
|
||||
extra_kwargs = {
|
||||
'username': {'read_only': True}
|
||||
}
|
||||
|
||||
def validate_email(self, value):
|
||||
try:
|
||||
validate_email(value)
|
||||
except DjangoValidationError:
|
||||
raise serializers.ValidationError("Invalid email format")
|
||||
return value
|
||||
|
||||
class ChangePasswordSerializer(serializers.Serializer):
|
||||
old_password = serializers.CharField(required=True)
|
||||
new_password = serializers.CharField(required=True)
|
||||
|
||||
def validate_new_password(self, value):
|
||||
validate_password(value)
|
||||
return value
|
||||
|
||||
class DeleteAccountSerializer(serializers.Serializer):
|
||||
password = serializers.CharField(required=True)
|
||||
|
|
@ -1,59 +1,48 @@
|
|||
from dmr.routing import Router, path
|
||||
from django.urls import path
|
||||
from rest_framework.routers import DefaultRouter
|
||||
from rest_framework.authtoken.views import obtain_auth_token
|
||||
from .views import (
|
||||
PredictionCollectionController,
|
||||
PredictionListUserController,
|
||||
PredictionHistoryController,
|
||||
PredictionDetailController,
|
||||
PredictionDeleteController,
|
||||
SavedPointListController,
|
||||
SavedPointDetailController,
|
||||
PredictionTemplateListController,
|
||||
PredictionTemplateDetailController,
|
||||
TelemetryController,
|
||||
CsrfController,
|
||||
LoginController,
|
||||
LogoutController,
|
||||
SessionController,
|
||||
WhoAmIController,
|
||||
UserProfileController,
|
||||
ChangePasswordController,
|
||||
ObtainTokenController,
|
||||
TokenManagementController,
|
||||
DeleteUserDataController,
|
||||
DeleteAccountController,
|
||||
PredictionViewSet,
|
||||
SavedPointViewset,
|
||||
PreditctionTemplateViewset,
|
||||
TelemetryListCreateView,
|
||||
get_csrf,
|
||||
login_view,
|
||||
logout_view,
|
||||
SessionView,
|
||||
WhoAmIView,
|
||||
UserProfileView,
|
||||
ChangePasswordView,
|
||||
TokenManagementView,
|
||||
DeleteUserDataView,
|
||||
DeleteAccountView
|
||||
)
|
||||
|
||||
|
||||
# A Router (prefix + routes) is required so build_schema() can generate the
|
||||
# OpenAPI document; see stratoflights/urls.py.
|
||||
router = Router(
|
||||
'api/',
|
||||
[
|
||||
path("csrf/", CsrfController.as_view(), name='api-csrf'),
|
||||
path('token', ObtainTokenController.as_view(), name='get_token'),
|
||||
path("login/", LoginController.as_view(), name='api-login'),
|
||||
path("logout/", LogoutController.as_view(), name='api-logout'),
|
||||
path("session/", SessionController.as_view(), name='api-session'),
|
||||
path("whoami/", WhoAmIController.as_view(), name='api-whoami'),
|
||||
path("<uuid:pk>/telemetry/", TelemetryController.as_view(), name="create_telemetry"),
|
||||
path("profile/", UserProfileController.as_view(), name='api-profile'),
|
||||
path("profile/change-password/", ChangePasswordController.as_view(), name='api-change-password'),
|
||||
path("profile/token/", TokenManagementController.as_view(), name='api-token'),
|
||||
path("profile/delete-data/", DeleteUserDataController.as_view(), name='api-delete-data'),
|
||||
path("profile/delete-account/", DeleteAccountController.as_view(), name='api-delete-account'),
|
||||
# Saved points (was SavedPointViewset via DefaultRouter).
|
||||
path("saved-points/", SavedPointListController.as_view(), name='saved-points-list'),
|
||||
path("saved-points/<int:pk>/", SavedPointDetailController.as_view(), name='saved-points-detail'),
|
||||
# Prediction templates (was PreditctionTemplateViewset via DefaultRouter).
|
||||
path("saved-templates/", PredictionTemplateListController.as_view(), name='saved-templates-list'),
|
||||
path("saved-templates/<int:pk>/", PredictionTemplateDetailController.as_view(), name='saved-templates-detail'),
|
||||
# Predictions (was PredictionViewSet via DefaultRouter).
|
||||
path("predictions/", PredictionCollectionController.as_view(), name='predictions-list'),
|
||||
path("predictions/list_user/", PredictionListUserController.as_view(), name='predictions-list-user'),
|
||||
path("predictions/history/", PredictionHistoryController.as_view(), name='predictions-history'),
|
||||
path("predictions/<uuid:pk>/detail/", PredictionDetailController.as_view(), name='predictions-detail'),
|
||||
path("predictions/<uuid:pk>/delete/", PredictionDeleteController.as_view(), name='predictions-delete'),
|
||||
],
|
||||
)
|
||||
|
||||
urlpatterns = router.urls
|
||||
router = DefaultRouter()
|
||||
router.register(r'predictions', PredictionViewSet, basename='predictions')
|
||||
router.register(r'saved-points', SavedPointViewset, basename='saved-points')
|
||||
router.register(r'saved-templates', PreditctionTemplateViewset, basename='saved-templates')
|
||||
|
||||
|
||||
urlpatterns = [
|
||||
path("csrf/", get_csrf, name='api-csrf'),
|
||||
path('token', obtain_auth_token, name = 'get_token'),
|
||||
path("login/", login_view, name='api-login'),
|
||||
path("logout/", logout_view, name='api-logout'),
|
||||
path("session/", SessionView.as_view(), name='api-session'),
|
||||
path("whoami/", WhoAmIView.as_view(), name='api-whoami'),
|
||||
path("<uuid:pk>/telemetry/", TelemetryListCreateView.as_view(), name="create_telemetry"),
|
||||
path('csrf/', get_csrf, name='api-csrf'),
|
||||
path('login/', login_view, name='api-login'),
|
||||
path('logout/', logout_view, name='api-logout'),
|
||||
path('session/', SessionView.as_view(), name='api-session'),
|
||||
path('whoami/', WhoAmIView.as_view(), name='api-whoami'),
|
||||
path("profile/", UserProfileView.as_view(), name='api-profile'),
|
||||
path("profile/change-password/", ChangePasswordView.as_view(), name='api-change-password'),
|
||||
path("profile/token/", TokenManagementView.as_view(), name='api-token'),
|
||||
path("profile/delete-data/", DeleteUserDataView.as_view(), name='api-delete-data'),
|
||||
path("profile/delete-account/", DeleteAccountView.as_view(), name='api-delete-account'),
|
||||
]
|
||||
urlpatterns += router.urls
|
||||
|
|
|
|||
|
|
@ -1,561 +1,341 @@
|
|||
import requests
|
||||
import time
|
||||
import json
|
||||
from http import HTTPStatus
|
||||
from rest_framework import status, generics, permissions
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.viewsets import ModelViewSet, ViewSet, GenericViewSet
|
||||
from rest_framework.exceptions import APIException
|
||||
from rest_framework.permissions import IsAuthenticated, AllowAny
|
||||
from rest_framework.authentication import SessionAuthentication, BasicAuthentication, TokenAuthentication
|
||||
from rest_framework.decorators import api_view, permission_classes, authentication_classes, action
|
||||
from rest_framework.authtoken.models import Token
|
||||
from django.utils import timezone
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.http import JsonResponse
|
||||
from django.contrib.auth import authenticate, login, logout, get_user_model
|
||||
from django.middleware.csrf import get_token
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.utils.dateparse import parse_datetime
|
||||
from .models import Prediction, Satellite, SavedPoint, SavedRateProfile, PreditctionTemplate, TelemetryPacket
|
||||
from .models import Prediction, User, Satellite, SavedPoint, SavedRateProfile, PreditctionTemplate, TelemetryPacket
|
||||
from .serializers import PredictionSerializer, TelemetryPacketSerializer, PredictionRequestSerializer, PredictionListSerializer, PredictionDetailSerializer, SavedPointSerializer, SavedRateProfileSerializer, PreditctionTemplateSerializer, UserSerializer, ChangePasswordSerializer, DeleteAccountSerializer
|
||||
from .services.tawhiri import TawhiriClient
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from pydantic import ValidationError
|
||||
from dmr import Controller, modify
|
||||
from dmr.components import Body, Path, Query
|
||||
from dmr.plugins.pydantic import PydanticSerializer
|
||||
from dmr.response import APIError
|
||||
from dmr.security.jwt.views import (
|
||||
ObtainTokensPayload,
|
||||
ObtainTokensResponse,
|
||||
ObtainTokensSyncController,
|
||||
)
|
||||
from .dtos import (
|
||||
DetailResponse,
|
||||
SessionResponse,
|
||||
WhoAmIResponse,
|
||||
UserOut,
|
||||
UserUpdateIn,
|
||||
ChangePasswordIn,
|
||||
DeleteAccountIn,
|
||||
TokenResponse,
|
||||
PkPath,
|
||||
UuidPkPath,
|
||||
SavedPointIn,
|
||||
SavedPointPatchIn,
|
||||
SavedPointOut,
|
||||
PredictionTemplateIn,
|
||||
PredictionTemplatePatchIn,
|
||||
PredictionTemplateOut,
|
||||
PredictionRequest,
|
||||
PredictionOut,
|
||||
PredictionListOut,
|
||||
PredictionDetailOut,
|
||||
PredictionCreateOut,
|
||||
TelemetryIn,
|
||||
TelemetryOut,
|
||||
)
|
||||
from .pagination import (
|
||||
PredictionListUserQuery,
|
||||
PredictionPage,
|
||||
paginate_predictions,
|
||||
TelemetryPage,
|
||||
page_number_paginate,
|
||||
)
|
||||
from .validators import base64_to_curve
|
||||
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from .permissions import ReadOnlyOrAuthenticated, IsOwner
|
||||
from .custom_pagination import CustomLimitOffsetPagination
|
||||
from datetime import datetime
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
def _api_error(status_code, body, headers=None):
|
||||
"""Adapter for DMR's APIError(raw_data, *, status_code, ...) signature.
|
||||
def get_prediction_from_tawhiri(params):
|
||||
|
||||
Lets call sites keep the (status_code=, body=) style; the body is DMR's
|
||||
first positional ``raw_data`` argument.
|
||||
"""
|
||||
return APIError(body, status_code=status_code, headers=headers)
|
||||
base_url = "https://fly.stratonautica.ru/api/v2"
|
||||
response = requests.get(base_url, params=params)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json() # получаем результат предсказания
|
||||
else:
|
||||
raise Exception(
|
||||
f"Tawhiri error: {response.status_code} {response.text}")
|
||||
|
||||
|
||||
def _resolve_related(model, pk, field_name):
|
||||
"""Resolve a related object by PK (was PrimaryKeyRelatedField(queryset=...)).
|
||||
class PredictionViewSet(GenericViewSet):
|
||||
permission_classes = [IsAuthenticated]
|
||||
pagination_class = CustomLimitOffsetPagination
|
||||
|
||||
The original queryset was ``.all()`` (not user-scoped); a missing PK yields a
|
||||
400 matching DRF's "Invalid pk" error shape.
|
||||
"""
|
||||
if pk is None:
|
||||
return None
|
||||
obj = model.objects.filter(pk=pk).first()
|
||||
if obj is None:
|
||||
raise _api_error(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
body={field_name: [f'Invalid pk "{pk}" - object does not exist.']},
|
||||
)
|
||||
return obj
|
||||
def list(self, request):
|
||||
queryset = Prediction.objects.filter(user=request.user)
|
||||
return Response(PredictionSerializer(queryset, many=True).data)
|
||||
|
||||
def create(self, request):
|
||||
serializer = PredictionRequestSerializer(data=request.data)
|
||||
|
||||
if not serializer.is_valid():
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
class PredictionCollectionController(Controller[PydanticSerializer]):
|
||||
"""`predictions/` -- list the user's predictions and create a new one."""
|
||||
|
||||
def get(self) -> list[PredictionOut]:
|
||||
queryset = Prediction.objects.filter(user=self.request.user)
|
||||
return [PredictionOut.model_validate(obj) for obj in queryset]
|
||||
|
||||
def post(self, parsed_body: Body[PredictionRequest]) -> PredictionCreateOut:
|
||||
user = self.request.user
|
||||
|
||||
# Resolve related objects before calling Tawhiri, so an invalid PK still
|
||||
# fails with 400 first (as DRF validation did).
|
||||
start_point = _resolve_related(SavedPoint, parsed_body.start_point, 'start_point')
|
||||
template = _resolve_related(PreditctionTemplate, parsed_body.template, 'template')
|
||||
rate_profile = _resolve_related(SavedRateProfile, parsed_body.rate_profile, 'rate_profile')
|
||||
validated_data = serializer.validated_data
|
||||
|
||||
try:
|
||||
prediction_result = TawhiriClient.get_prediction(parsed_body.model_dump())
|
||||
except requests.RequestException as exc:
|
||||
raise _api_error(
|
||||
status_code=HTTPStatus.BAD_GATEWAY,
|
||||
body={'error': f'Tawhiri error: {str(exc)}'},
|
||||
)
|
||||
prediction_result = TawhiriClient.get_prediction(validated_data)
|
||||
|
||||
# Carried over from the old serializer.create(): curves are decoded but
|
||||
# never persisted (the model has no curve fields). Kept for parity --
|
||||
# including that a malformed curve here raises (then surfaces as 500).
|
||||
if parsed_body.ascent_curve is not None:
|
||||
base64_to_curve(parsed_body.ascent_curve)
|
||||
if parsed_body.descent_curve is not None:
|
||||
base64_to_curve(parsed_body.descent_curve)
|
||||
except requests.RequestException as e:
|
||||
return Response({"error": f"Tawhiri error: {str(e)}"}, status=status.HTTP_502_BAD_GATEWAY)
|
||||
|
||||
prediction = Prediction(
|
||||
user=user,
|
||||
request=json.loads(self.request.body),
|
||||
result=prediction_result,
|
||||
start_point=start_point,
|
||||
template=template,
|
||||
rate_profile=rate_profile,
|
||||
)
|
||||
prediction.save()
|
||||
|
||||
return PredictionCreateOut(
|
||||
id=prediction.id,
|
||||
created_at=prediction.created_at,
|
||||
# prediction = Prediction.objects.create(
|
||||
# result=prediction_result, user=request.user, request=request.data, validated_data=validated_data)
|
||||
prediction = serializer.save(
|
||||
user=request.user,
|
||||
result=prediction_result,
|
||||
request=request.data
|
||||
)
|
||||
|
||||
return Response({
|
||||
"id": prediction.id,
|
||||
"created_at": prediction.created_at,
|
||||
"result": prediction_result
|
||||
}, status=status.HTTP_201_CREATED)
|
||||
|
||||
class PredictionListUserController(Controller[PydanticSerializer]):
|
||||
"""`predictions/list_user/` -- filtered, paginated list of the user's predictions."""
|
||||
@action(detail=False, methods=['get'])
|
||||
def list_user(self, request):
|
||||
user = request.user
|
||||
satellite_id = request.query_params.get('satellite_id')
|
||||
created_from = request.query_params.get('created_from')
|
||||
created_till = request.query_params.get('created_till')
|
||||
|
||||
def get(self, parsed_query: Query[PredictionListUserQuery]) -> PredictionPage:
|
||||
user = self.request.user
|
||||
filters = {'user': user}
|
||||
filters = {
|
||||
'user': user,
|
||||
}
|
||||
|
||||
if parsed_query.created_from:
|
||||
filters['created_at__gte'] = parse_datetime(parsed_query.created_from)
|
||||
if parsed_query.created_till:
|
||||
filters['created_at__lte'] = parse_datetime(parsed_query.created_till)
|
||||
if parsed_query.satellite_id:
|
||||
if not user.satellites.filter(id=parsed_query.satellite_id).exists():
|
||||
raise _api_error(status_code=HTTPStatus.FORBIDDEN, body={'detail': 'Access denied'})
|
||||
filters['satellite_id'] = parsed_query.satellite_id
|
||||
if created_from:
|
||||
filters['created_at__gte'] = parse_datetime(created_from)
|
||||
|
||||
if created_till:
|
||||
filters['created_at__lte'] = parse_datetime(created_till)
|
||||
|
||||
if satellite_id:
|
||||
if not user.satellites.filter(id=satellite_id).exists():
|
||||
return Response({'detail': 'Access denied'}, status=403)
|
||||
|
||||
filters['satellite_id'] = satellite_id
|
||||
|
||||
queryset = Prediction.objects.filter(**filters)
|
||||
return paginate_predictions(queryset, parsed_query)
|
||||
queryset = self.filter_queryset(queryset)
|
||||
|
||||
page = self.paginate_queryset(queryset)
|
||||
|
||||
if page is not None:
|
||||
serializer = PredictionSerializer(page, many=True)
|
||||
return self.get_paginated_response(serializer.data)
|
||||
|
||||
serializer = PredictionSerializer(queryset, many=True)
|
||||
return Response(serializer.data)
|
||||
|
||||
class PredictionHistoryController(Controller[PydanticSerializer]):
|
||||
"""`predictions/history/` -- compact list of the user's predictions."""
|
||||
@action(detail=False, methods=["get"])
|
||||
def history(self, request):
|
||||
queryset = Prediction.objects.filter(user=request.user)
|
||||
return Response(PredictionListSerializer(queryset, many=True).data)
|
||||
|
||||
def get(self) -> list[PredictionListOut]:
|
||||
queryset = Prediction.objects.filter(user=self.request.user)
|
||||
return [PredictionListOut.model_validate(obj) for obj in queryset]
|
||||
|
||||
|
||||
class PredictionDetailController(Controller[PydanticSerializer]):
|
||||
"""`predictions/<pk>/detail/` -- retrieve a single prediction."""
|
||||
|
||||
def get(self, parsed_path: Path[UuidPkPath]) -> PredictionDetailOut:
|
||||
@action(detail=True, methods=["get"])
|
||||
def detail(self, request, pk=None):
|
||||
prediction = Prediction.objects.filter(
|
||||
user=self.request.user, pk=parsed_path.pk).first()
|
||||
if prediction is None:
|
||||
raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'})
|
||||
return PredictionDetailOut.model_validate(prediction)
|
||||
user=request.user, pk=pk).first()
|
||||
if not prediction:
|
||||
return Response({'detail': 'Not found'}, status=404)
|
||||
return Response(PredictionDetailSerializer(prediction).data)
|
||||
|
||||
|
||||
class PredictionDeleteController(Controller[PydanticSerializer]):
|
||||
"""`predictions/<pk>/delete/` -- delete a single prediction."""
|
||||
|
||||
@modify(status_code=HTTPStatus.NO_CONTENT)
|
||||
def delete(self, parsed_path: Path[UuidPkPath]) -> None:
|
||||
@action(detail=True, methods=["delete"])
|
||||
def delete(self, request, pk=None):
|
||||
prediction = Prediction.objects.filter(
|
||||
user=self.request.user, pk=parsed_path.pk).first()
|
||||
if prediction is None:
|
||||
raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'})
|
||||
user=request.user, pk=pk).first()
|
||||
if not prediction:
|
||||
return Response({'detail': 'Not found'}, status=404)
|
||||
prediction.delete()
|
||||
return Response(status=204)
|
||||
|
||||
|
||||
class TelemetryController(Controller[PydanticSerializer]):
|
||||
"""`<uuid:pk>/telemetry/` -- list and ingest telemetry for a satellite.
|
||||
class TelemetryListCreateView(generics.ListCreateAPIView):
|
||||
serializer_class = TelemetryPacketSerializer
|
||||
permission_classes = [permissions.AllowAny]
|
||||
|
||||
Public (was AllowAny). GET preserves the global PageNumberPagination envelope
|
||||
{count, next, previous, results} with PAGE_SIZE=100.
|
||||
"""
|
||||
def get_queryset(self):
|
||||
qs = TelemetryPacket.objects.filter(satellite_id=self.kwargs["pk"])
|
||||
|
||||
auth = None # public (was AllowAny)
|
||||
csrf_exempt = True # DRF views bypass Django CSRF; ingestion is anonymous
|
||||
|
||||
def get(self, parsed_path: Path[UuidPkPath]) -> TelemetryPage:
|
||||
qs = TelemetryPacket.objects.filter(satellite_id=parsed_path.pk)
|
||||
|
||||
from_ts = self.request.GET.get('from')
|
||||
till_ts = self.request.GET.get('till')
|
||||
from_ts = self.request.query_params.get("from")
|
||||
till_ts = self.request.query_params.get("till")
|
||||
|
||||
if from_ts:
|
||||
qs = qs.filter(timestamp__gte=int(from_ts))
|
||||
if till_ts:
|
||||
qs = qs.filter(timestamp__lte=int(till_ts))
|
||||
|
||||
qs = qs.order_by('-timestamp')
|
||||
return qs.order_by("-timestamp")
|
||||
|
||||
count, next_link, previous_link, page_objects = page_number_paginate(self.request, qs)
|
||||
return TelemetryPage(
|
||||
count=count,
|
||||
next=next_link,
|
||||
previous=previous_link,
|
||||
results=[TelemetryOut.model_validate(obj) for obj in page_objects],
|
||||
)
|
||||
def post(self, request, pk):
|
||||
serializer = TelemetryPacketSerializer(data=request.data)
|
||||
if not serializer.is_valid():
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
def post(self, parsed_path: Path[UuidPkPath], parsed_body: Body[TelemetryIn]) -> TelemetryOut:
|
||||
# Bug fix (approved): the original returned serializer.errors on success;
|
||||
# we return the created packet instead. timestamp is still server-set.
|
||||
packet = TelemetryPacket.objects.create(
|
||||
timestamp=time.time(),
|
||||
satellite=Satellite.objects.get(id=parsed_path.pk),
|
||||
lat=parsed_body.lat,
|
||||
lon=parsed_body.lon,
|
||||
alt=parsed_body.alt,
|
||||
payload=parsed_body.payload,
|
||||
)
|
||||
return TelemetryOut.model_validate(packet)
|
||||
validated_data = serializer.validated_data
|
||||
|
||||
TelemetryPacket.objects.create(timestamp=time.time(),
|
||||
satellite=Satellite.objects.get(id=pk),
|
||||
lat=validated_data["lat"],
|
||||
lon=validated_data["lon"],
|
||||
alt=validated_data["alt"],
|
||||
payload=validated_data['payload'],
|
||||
)
|
||||
return Response(serializer.errors, status=status.HTTP_201_CREATED)
|
||||
|
||||
|
||||
class SessionController(Controller[PydanticSerializer]):
|
||||
"""Report whether the current request is authenticated."""
|
||||
class SessionView(APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def get(self) -> SessionResponse:
|
||||
return SessionResponse(isAuthenticated=True)
|
||||
@staticmethod
|
||||
def get(request, format=None):
|
||||
return JsonResponse({'isAuthenticated': True})
|
||||
|
||||
|
||||
class WhoAmIController(Controller[PydanticSerializer]):
|
||||
"""Return the current user's username."""
|
||||
class WhoAmIView(APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def get(self) -> WhoAmIResponse:
|
||||
return WhoAmIResponse(username=self.request.user.username)
|
||||
@staticmethod
|
||||
def get(request, format=None):
|
||||
return JsonResponse({'username': request.user.username})
|
||||
|
||||
|
||||
class CsrfController(Controller[PydanticSerializer]):
|
||||
"""Get CSRF token."""
|
||||
|
||||
auth = None # public endpoint (was AllowAny)
|
||||
csrf_exempt = True
|
||||
|
||||
def get(self) -> DetailResponse:
|
||||
token = get_token(self.request)
|
||||
return self.to_response(
|
||||
DetailResponse(detail='CSRF cookie set'),
|
||||
headers={'X-CSRFToken': token},
|
||||
)
|
||||
@extend_schema(methods=["GET"], description="Get CSRF token")
|
||||
@csrf_exempt
|
||||
@api_view(["GET"])
|
||||
@permission_classes([AllowAny])
|
||||
def get_csrf(request):
|
||||
response = JsonResponse({'detail': 'CSRF cookie set'})
|
||||
response['X-CSRFToken'] = get_token(request)
|
||||
return response
|
||||
|
||||
|
||||
class LoginController(Controller[PydanticSerializer]):
|
||||
"""Login user."""
|
||||
@extend_schema(methods=["POST"], description="Login user")
|
||||
@csrf_exempt
|
||||
@api_view(["POST"])
|
||||
@authentication_classes([BasicAuthentication])
|
||||
@permission_classes([AllowAny])
|
||||
def login_view(request):
|
||||
data = json.loads(request.body)
|
||||
username = data.get('username')
|
||||
password = data.get('password')
|
||||
|
||||
auth = None # public endpoint (was AllowAny)
|
||||
csrf_exempt = True
|
||||
if username is None or password is None:
|
||||
return JsonResponse({'detail': 'Please provide username and password.'}, status=400)
|
||||
|
||||
# post() would default to 201; the original returned 200.
|
||||
@modify(status_code=HTTPStatus.OK)
|
||||
def post(self) -> DetailResponse:
|
||||
data = json.loads(self.request.body)
|
||||
username = data.get('username')
|
||||
password = data.get('password')
|
||||
user = authenticate(username=username, password=password)
|
||||
if user is None:
|
||||
return JsonResponse({'detail': 'Invalid credentials.'}, status=400)
|
||||
|
||||
if username is None or password is None:
|
||||
raise _api_error(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
body={'detail': 'Please provide username and password.'},
|
||||
)
|
||||
|
||||
user = authenticate(username=username, password=password)
|
||||
if user is None:
|
||||
raise _api_error(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
body={'detail': 'Invalid credentials.'},
|
||||
)
|
||||
|
||||
login(self.request, user)
|
||||
return DetailResponse(detail='Successfully logged in.')
|
||||
login(request, user)
|
||||
return JsonResponse({'detail': 'Successfully logged in.'})
|
||||
|
||||
|
||||
class LogoutController(Controller[PydanticSerializer]):
|
||||
"""Logout user."""
|
||||
@extend_schema(methods=["POST"], description="Logout user")
|
||||
@api_view(["POST"])
|
||||
@permission_classes([AllowAny])
|
||||
def logout_view(request):
|
||||
if not request.user.is_authenticated:
|
||||
return JsonResponse({'detail': 'You\'re not logged in.'}, status=400)
|
||||
|
||||
auth = None # public endpoint (was AllowAny); checks auth state manually
|
||||
|
||||
@modify(status_code=HTTPStatus.OK)
|
||||
def post(self) -> DetailResponse:
|
||||
if not self.request.user.is_authenticated:
|
||||
raise _api_error(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
body={'detail': "You're not logged in."},
|
||||
)
|
||||
|
||||
logout(self.request)
|
||||
return DetailResponse(detail='Successfully logged out.')
|
||||
logout(request)
|
||||
return JsonResponse({'detail': 'Successfully logged out.'})
|
||||
|
||||
|
||||
_SAVED_POINT_DUPLICATE = 'A saved point with this name already exists for the user.'
|
||||
class SavedPointViewset(ModelViewSet):
|
||||
permission_classes = [IsOwner]
|
||||
serializer_class = SavedPointSerializer
|
||||
pagination_class = None
|
||||
|
||||
def get_queryset(self):
|
||||
return SavedPoint.objects.filter(user=self.request.user)
|
||||
|
||||
def perform_create(self, serializer):
|
||||
serializer.save(user=self.request.user)
|
||||
|
||||
|
||||
def _check_saved_point_unique(user, name, exclude_pk=None):
|
||||
"""Reproduce the former SavedPointSerializer UniqueTogetherValidator."""
|
||||
qs = SavedPoint.objects.filter(user=user, name=name)
|
||||
if exclude_pk is not None:
|
||||
qs = qs.exclude(pk=exclude_pk)
|
||||
if qs.exists():
|
||||
raise _api_error(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
body={'non_field_errors': [_SAVED_POINT_DUPLICATE]},
|
||||
)
|
||||
class PreditctionTemplateViewset(ModelViewSet):
|
||||
permission_classes = [IsOwner]
|
||||
serializer_class = PreditctionTemplateSerializer
|
||||
pagination_class = None
|
||||
|
||||
def get_queryset(self):
|
||||
return PreditctionTemplate.objects.filter(user=self.request.user)
|
||||
|
||||
def perform_create(self, serializer):
|
||||
serializer.save(user=self.request.user)
|
||||
|
||||
|
||||
class SavedPointListController(Controller[PydanticSerializer]):
|
||||
"""Collection endpoint for the current user's saved points (was SavedPointViewset)."""
|
||||
class UserProfileView(APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def get(self) -> list[SavedPointOut]:
|
||||
qs = SavedPoint.objects.filter(user=self.request.user)
|
||||
return [SavedPointOut.model_validate(obj) for obj in qs]
|
||||
def get(self, request):
|
||||
serializer = UserSerializer(request.user)
|
||||
return Response(serializer.data)
|
||||
|
||||
def post(self, parsed_body: Body[SavedPointIn]) -> SavedPointOut:
|
||||
user = self.request.user
|
||||
_check_saved_point_unique(user, parsed_body.name)
|
||||
obj = SavedPoint.objects.create(user=user, **parsed_body.model_dump())
|
||||
return SavedPointOut.model_validate(obj)
|
||||
def patch(self, request):
|
||||
user = request.user
|
||||
serializer = UserSerializer(user, data=request.data, partial=True)
|
||||
|
||||
if not serializer.is_valid():
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
serializer.save()
|
||||
return Response(serializer.data)
|
||||
|
||||
|
||||
class SavedPointDetailController(Controller[PydanticSerializer]):
|
||||
"""Detail endpoint for a single saved point owned by the current user."""
|
||||
class ChangePasswordView(APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def _get_object(self, pk: int):
|
||||
# Filtering by user means another user's object reads as 404 (not 403),
|
||||
# matching DRF's get_object() over a user-scoped queryset.
|
||||
obj = SavedPoint.objects.filter(user=self.request.user, pk=pk).first()
|
||||
if obj is None:
|
||||
raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'})
|
||||
return obj
|
||||
def post(self, request):
|
||||
user = request.user
|
||||
serializer = ChangePasswordSerializer(data=request.data)
|
||||
|
||||
if not serializer.is_valid():
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if not user.check_password(serializer.validated_data['old_password']):
|
||||
return Response({'detail': 'Old password is incorrect'},
|
||||
status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
def get(self, parsed_path: Path[PkPath]) -> SavedPointOut:
|
||||
return SavedPointOut.model_validate(self._get_object(parsed_path.pk))
|
||||
|
||||
def put(self, parsed_path: Path[PkPath], parsed_body: Body[SavedPointIn]) -> SavedPointOut:
|
||||
obj = self._get_object(parsed_path.pk)
|
||||
_check_saved_point_unique(self.request.user, parsed_body.name, exclude_pk=obj.pk)
|
||||
for field, value in parsed_body.model_dump().items():
|
||||
setattr(obj, field, value)
|
||||
obj.save()
|
||||
return SavedPointOut.model_validate(obj)
|
||||
|
||||
def patch(self, parsed_path: Path[PkPath], parsed_body: Body[SavedPointPatchIn]) -> SavedPointOut:
|
||||
obj = self._get_object(parsed_path.pk)
|
||||
updates = parsed_body.model_dump(exclude_unset=True)
|
||||
if 'name' in updates:
|
||||
_check_saved_point_unique(self.request.user, updates['name'], exclude_pk=obj.pk)
|
||||
for field, value in updates.items():
|
||||
setattr(obj, field, value)
|
||||
obj.save()
|
||||
return SavedPointOut.model_validate(obj)
|
||||
|
||||
@modify(status_code=HTTPStatus.NO_CONTENT)
|
||||
def delete(self, parsed_path: Path[PkPath]) -> None:
|
||||
self._get_object(parsed_path.pk).delete()
|
||||
|
||||
|
||||
class PredictionTemplateListController(Controller[PydanticSerializer]):
|
||||
"""Collection endpoint for the current user's templates (was PreditctionTemplateViewset).
|
||||
|
||||
NOTE: as before, there is no app-level uniqueness check; a duplicate
|
||||
(user, name) hits the model's unique_together and surfaces as a DB error.
|
||||
"""
|
||||
|
||||
def get(self) -> list[PredictionTemplateOut]:
|
||||
qs = PreditctionTemplate.objects.filter(user=self.request.user)
|
||||
return [PredictionTemplateOut.model_validate(obj) for obj in qs]
|
||||
|
||||
def post(self, parsed_body: Body[PredictionTemplateIn]) -> PredictionTemplateOut:
|
||||
obj = PreditctionTemplate.objects.create(
|
||||
user=self.request.user, **parsed_body.model_dump()
|
||||
)
|
||||
return PredictionTemplateOut.model_validate(obj)
|
||||
|
||||
|
||||
class PredictionTemplateDetailController(Controller[PydanticSerializer]):
|
||||
"""Detail endpoint for a single template owned by the current user."""
|
||||
|
||||
def _get_object(self, pk: int):
|
||||
obj = PreditctionTemplate.objects.filter(user=self.request.user, pk=pk).first()
|
||||
if obj is None:
|
||||
raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'})
|
||||
return obj
|
||||
|
||||
def get(self, parsed_path: Path[PkPath]) -> PredictionTemplateOut:
|
||||
return PredictionTemplateOut.model_validate(self._get_object(parsed_path.pk))
|
||||
|
||||
def put(self, parsed_path: Path[PkPath], parsed_body: Body[PredictionTemplateIn]) -> PredictionTemplateOut:
|
||||
obj = self._get_object(parsed_path.pk)
|
||||
for field, value in parsed_body.model_dump().items():
|
||||
setattr(obj, field, value)
|
||||
obj.save()
|
||||
return PredictionTemplateOut.model_validate(obj)
|
||||
|
||||
def patch(self, parsed_path: Path[PkPath], parsed_body: Body[PredictionTemplatePatchIn]) -> PredictionTemplateOut:
|
||||
obj = self._get_object(parsed_path.pk)
|
||||
for field, value in parsed_body.model_dump(exclude_unset=True).items():
|
||||
setattr(obj, field, value)
|
||||
obj.save()
|
||||
return PredictionTemplateOut.model_validate(obj)
|
||||
|
||||
@modify(status_code=HTTPStatus.NO_CONTENT)
|
||||
def delete(self, parsed_path: Path[PkPath]) -> None:
|
||||
self._get_object(parsed_path.pk).delete()
|
||||
|
||||
|
||||
class UserProfileController(Controller[PydanticSerializer]):
|
||||
"""Read and partially update the current user's profile."""
|
||||
|
||||
def get(self) -> UserOut:
|
||||
return UserOut.model_validate(self.request.user)
|
||||
|
||||
def patch(self, parsed_body: Body[UserUpdateIn]) -> UserOut:
|
||||
user = self.request.user
|
||||
# partial update: apply only the fields the client actually sent.
|
||||
for field, value in parsed_body.model_dump(exclude_unset=True).items():
|
||||
setattr(user, field, value)
|
||||
user.set_password(serializer.validated_data['new_password'])
|
||||
user.save()
|
||||
return UserOut.model_validate(user)
|
||||
return Response({'detail': 'Password changed successfully'})
|
||||
|
||||
|
||||
class ChangePasswordController(Controller[PydanticSerializer]):
|
||||
"""Change the current user's password."""
|
||||
|
||||
# post() would default to 201; the original returned 200.
|
||||
@modify(status_code=HTTPStatus.OK)
|
||||
def post(self, parsed_body: Body[ChangePasswordIn]) -> DetailResponse:
|
||||
user = self.request.user
|
||||
|
||||
if not user.check_password(parsed_body.old_password):
|
||||
raise _api_error(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
body={'detail': 'Old password is incorrect'},
|
||||
)
|
||||
|
||||
user.set_password(parsed_body.new_password)
|
||||
user.save()
|
||||
return DetailResponse(detail='Password changed successfully')
|
||||
|
||||
|
||||
class DeleteAccountController(Controller[PydanticSerializer]):
|
||||
"""Delete the current user's account and all their data."""
|
||||
|
||||
# DMR forbids a Body component on DELETE, but the original endpoint read the
|
||||
# password from a DELETE request body. Preserve that contract by parsing the
|
||||
# body manually instead of via Body[DeleteAccountIn].
|
||||
def delete(self) -> DetailResponse:
|
||||
user = self.request.user
|
||||
|
||||
try:
|
||||
parsed_body = DeleteAccountIn(**json.loads(self.request.body or b'{}'))
|
||||
except ValidationError as exc:
|
||||
raise _api_error(status_code=HTTPStatus.BAD_REQUEST, body=json.loads(exc.json()))
|
||||
except ValueError:
|
||||
raise _api_error(status_code=HTTPStatus.BAD_REQUEST, body={'detail': 'Invalid request body.'})
|
||||
|
||||
if not user.check_password(parsed_body.password):
|
||||
raise _api_error(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
body={'detail': 'Incorrect password'},
|
||||
)
|
||||
class DeleteAccountView(APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def delete(self, request):
|
||||
user = request.user
|
||||
serializer = DeleteAccountSerializer(data=request.data)
|
||||
|
||||
if not serializer.is_valid():
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if not user.check_password(serializer.validated_data['password']):
|
||||
return Response({'detail': 'Incorrect password'},
|
||||
status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
Prediction.objects.filter(user=user).delete()
|
||||
SavedPoint.objects.filter(user=user).delete()
|
||||
PreditctionTemplate.objects.filter(user=user).delete()
|
||||
|
||||
|
||||
user.delete()
|
||||
|
||||
return DetailResponse(detail='Account deleted successfully')
|
||||
|
||||
return Response({'detail': 'Account deleted successfully'})
|
||||
|
||||
|
||||
class DeleteUserDataController(Controller[PydanticSerializer]):
|
||||
"""Delete all of the current user's data without deleting the account."""
|
||||
|
||||
def delete(self) -> DetailResponse:
|
||||
user = self.request.user
|
||||
class DeleteUserDataView(APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
def delete(self, request):
|
||||
user = request.user
|
||||
|
||||
Prediction.objects.filter(user=user).delete()
|
||||
SavedPoint.objects.filter(user=user).delete()
|
||||
PreditctionTemplate.objects.filter(user=user).delete()
|
||||
|
||||
return DetailResponse(detail='All user data deleted successfully')
|
||||
|
||||
return Response({'detail': 'All user data deleted successfully'})
|
||||
|
||||
|
||||
class ObtainTokenController(
|
||||
ObtainTokensSyncController[
|
||||
PydanticSerializer,
|
||||
ObtainTokensPayload,
|
||||
ObtainTokensResponse,
|
||||
],
|
||||
):
|
||||
"""Exchange username/password for JWT access + refresh tokens.
|
||||
class TokenManagementView(APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
|
||||
Replaces DRF's obtain_auth_token. Approved drift: the token format and
|
||||
semantics change from a single stored DRF token to stateless JWTs, so the
|
||||
response is {access_token, refresh_token} instead of {token}.
|
||||
"""
|
||||
def get(self, request):
|
||||
|
||||
token, created = Token.objects.get_or_create(user=request.user)
|
||||
return Response({"token": token.key})
|
||||
|
||||
auth = None # public: credentials are supplied in the request body
|
||||
csrf_exempt = True
|
||||
jwt_expiration = timedelta(hours=1)
|
||||
jwt_refresh_expiration = timedelta(days=7)
|
||||
|
||||
def convert_auth_payload(self, payload):
|
||||
return payload
|
||||
|
||||
def make_api_response(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
return {
|
||||
'access_token': self.create_jwt_token(
|
||||
expiration=now + self.jwt_expiration,
|
||||
token_type='access',
|
||||
),
|
||||
'refresh_token': self.create_jwt_token(
|
||||
expiration=now + self.jwt_refresh_expiration,
|
||||
token_type='refresh',
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class TokenManagementController(Controller[PydanticSerializer]):
|
||||
"""Issue a fresh JWT access token for the current user.
|
||||
|
||||
Was TokenManagementView (DRF stored-token get/regenerate). With stateless
|
||||
JWTs there is nothing stored to fetch or delete, so both GET and POST mint a
|
||||
new access token (approved drift). The {"token": ...} response shape is kept.
|
||||
"""
|
||||
|
||||
jwt_expiration = timedelta(hours=1)
|
||||
|
||||
def get(self) -> TokenResponse:
|
||||
return TokenResponse(token=self._issue_token())
|
||||
|
||||
# post() would default to 201; the original returned 200.
|
||||
@modify(status_code=HTTPStatus.OK)
|
||||
def post(self) -> TokenResponse:
|
||||
return TokenResponse(token=self._issue_token())
|
||||
|
||||
def _issue_token(self) -> str:
|
||||
now = datetime.now(timezone.utc)
|
||||
return self.create_jwt_token(
|
||||
subject=str(self.request.user.pk),
|
||||
expiration=now + self.jwt_expiration,
|
||||
token_type='access',
|
||||
)
|
||||
def post(self, request):
|
||||
|
||||
Token.objects.filter(user=request.user).delete()
|
||||
token = Token.objects.create(user=request.user)
|
||||
return Response({"token": token.key})
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue