diff --git a/requirements.txt b/requirements.txt index f91db5a..2672235 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,8 @@ -# django-modern-rest 0.6.0 requires Django 5.x (upgraded from 4.2). -Django>=5.0,<6.0 -django-modern-rest[jwt,pydantic]==0.6.0 +Django>=4.0,<5.0 +djangorestframework +djangorestframework-simplejwt psycopg2-binary +drf-spectacular requests django-cors-headers Pillow diff --git a/stratoflights/settings.py b/stratoflights/settings.py index decb5ef..bb5b63e 100644 --- a/stratoflights/settings.py +++ b/stratoflights/settings.py @@ -13,12 +13,6 @@ https://docs.djangoproject.com/en/4.2/ref/settings/ from pathlib import Path import os from dotenv import load_dotenv - -from dmr.settings import Settings -from dmr.security.django_session import DjangoSessionSyncAuth -from dmr.security.jwt import JWTSyncAuth -from dmr.openapi import OpenAPIConfig - load_dotenv() # Build paths inside the project like this: BASE_DIR / 'subdir'. @@ -57,10 +51,12 @@ INSTALLED_APPS = [ 'django.contrib.sessions', 'django.contrib.messages', 'django.contrib.staticfiles', + 'rest_framework', + 'rest_framework.authtoken', + 'drf_spectacular', 'corsheaders', 'stratoflights_api.apps.StratoflightsApiConfig', 'channels', - 'dmr', # required to serve OpenAPI docs static assets ] MIDDLEWARE = [ @@ -169,16 +165,25 @@ STATIC_URL = 'static/' DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' AUTH_USER_MODEL = 'stratoflights_api.User' -# django-modern-rest configuration. -# Global auth mirrors the previous DRF default (Token + Session); the DRF -# TokenAuthentication is replaced by DMR's JWT (approved drift). Providing -# several auth instances means at least one of them must succeed, so this -# also enforces a 401 by default like DRF's global IsAuthenticated did. -# Public endpoints opt out per-endpoint with @modify(auth=None). -DMR_SETTINGS = { - Settings.auth: [JWTSyncAuth(), DjangoSessionSyncAuth()], - Settings.validate_responses: not PRODUCTION, - Settings.openapi_config: OpenAPIConfig(title='Stratoflights API', version='1.0.0'), +REST_FRAMEWORK = { + 'DEFAULT_PAGINATION_CLASS': 'rest_framework.pagination.PageNumberPagination', + 'PAGE_SIZE': 100, + + 'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema', + + 'DEFAULT_AUTHENTICATION_CLASSES': [ + 'rest_framework.authentication.TokenAuthentication', + 'rest_framework.authentication.SessionAuthentication', + ], + + 'DEFAULT_PERMISSION_CLASSES': [ + 'rest_framework.permissions.IsAuthenticated', + # 'rest_framework.permissions.AllowAny', + ], + + 'DEFAULT_RENDERER_CLASSES': [ + 'rest_framework.renderers.JSONRenderer', + ], } diff --git a/stratoflights/urls.py b/stratoflights/urls.py index 0541cd2..d97bf9b 100644 --- a/stratoflights/urls.py +++ b/stratoflights/urls.py @@ -1,17 +1,27 @@ -"""URL configuration for the stratoflights project.""" +""" +URL configuration for stratoflights project. + +The `urlpatterns` list routes URLs to views. For more information please see: + https://docs.djangoproject.com/en/4.2/topics/http/urls/ +Examples: +Function views + 1. Add an import: from my_app import views + 2. Add a URL to urlpatterns: path('', views.home, name='home') +Class-based views + 1. Add an import: from other_app.views import Home + 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') +Including another URLconf + 1. Import the include() function: from django.urls import include, path + 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) +""" from django.contrib import admin -from django.urls import include -from dmr.openapi import build_schema -from dmr.openapi.views import OpenAPIJsonView, SwaggerView -from dmr.routing import path - -from stratoflights_api.urls import router - -schema = build_schema(router) +from django.urls import path, include +from drf_spectacular.views import SpectacularAPIView +from drf_spectacular.views import SpectacularSwaggerView urlpatterns = [ path('admin/', admin.site.urls), - path(router.prefix, include(router.urls)), - path('api/schema/', OpenAPIJsonView.as_view(schema), name='schema'), - path('api/docs/', SwaggerView.as_view(schema), name='docs'), -] + path('api/', include('stratoflights_api.urls')), + path('api/schema/', SpectacularAPIView.as_view(), name='schema'), + path('api/docs/', SpectacularSwaggerView.as_view(url_name='schema'), name='docs'), +] \ No newline at end of file diff --git a/stratoflights_api/consumers.py b/stratoflights_api/consumers.py index 3c44e16..66dce9b 100644 --- a/stratoflights_api/consumers.py +++ b/stratoflights_api/consumers.py @@ -18,19 +18,17 @@ class TelemetryConsumer(AsyncWebsocketConsumer): async def receive(self, text_data): - from pydantic import ValidationError - from .dtos import TelemetryIn + from .serializers import TelemetryPacketSerializer if not self.write_enabled: await self.send(text_data=json.dumps({"error": "Read-only mode"})) return data = json.loads(text_data) + serializer = TelemetryPacketSerializer(data=data) - try: - TelemetryIn(**data) - except ValidationError as exc: - await self.send(text_data=json.dumps({"error": json.loads(exc.json())})) + if not serializer.is_valid(): + await self.send(text_data=json.dumps({"error": serializer.errors})) return saved_data = await self.save_telemetry(data) @@ -53,6 +51,15 @@ class TelemetryConsumer(AsyncWebsocketConsumer): async def telemetry_message(self, event): await self.send(text_data=json.dumps(event["data"])) + @database_sync_to_async + def get_user_from_token(self, token_key, Token): + from rest_framework.authtoken.models import Token + User = get_user_model() + try: + token = Token.objects.select_related("user").get(key=token_key) + return token.user + except Token.DoesNotExist: + return None @database_sync_to_async def save_telemetry(self, data): @@ -92,20 +99,12 @@ class StationTelemetryConsumer(TelemetryConsumer): write_enabled = True async def connect(self): - from django.conf import settings - from dmr.security.jwt.token import JWToken - + from rest_framework.authtoken.models import Token token_key = self.scope["query_string"].decode().split("token=")[-1] try: - decoded = JWToken.decode( - encoded_token=token_key, - secret=settings.SECRET_KEY, - algorithm='HS256', - ) - self.scope["user"] = await database_sync_to_async( - get_user_model().objects.get - )(pk=decoded.sub) - except Exception: + token = await database_sync_to_async(Token.objects.select_related("user").get)(key=token_key) + self.scope["user"] = token.user + except Token.DoesNotExist: await self.close() return diff --git a/stratoflights_api/custom_pagination.py b/stratoflights_api/custom_pagination.py new file mode 100644 index 0000000..6d37de8 --- /dev/null +++ b/stratoflights_api/custom_pagination.py @@ -0,0 +1,17 @@ +from rest_framework.pagination import LimitOffsetPagination +from rest_framework.response import Response + +class CustomLimitOffsetPagination(LimitOffsetPagination): + limit_query_param = 'limit' + offset_query_param = 'skip' + max_limit = 100 + default_limit = 10 + + + def get_paginated_response(self, data): + return Response({ + 'total': self.count, + 'limit': self.limit, + 'skip': self.offset, + 'predictions': data + }) \ No newline at end of file diff --git a/stratoflights_api/dtos.py b/stratoflights_api/dtos.py deleted file mode 100644 index 3178f93..0000000 --- a/stratoflights_api/dtos.py +++ /dev/null @@ -1,333 +0,0 @@ -"""Pydantic DTOs replacing the DRF serializers from serializers.py. - -Mapping is 1:1 with the former serializers. Output DTOs enable ``from_attributes`` -so an endpoint can build them straight from a Django model instance via -``SomeOut.model_validate(obj)`` (the equivalent of ``Serializer(obj).data``). - -Field-level and cross-field validation that DRF kept in ``validate_()`` / -``validate()`` is reproduced with pydantic ``field_validator`` / ``model_validator``. -Business logic DRF kept in ``create()`` (base64 curve decoding, ORM writes, and -resolving related objects from their PKs) is intentionally NOT here -- per DMR's -split it moves into the endpoints (Step 2.9). -""" -import uuid -from datetime import datetime -from typing import Any, Optional - -from django.contrib.auth.password_validation import validate_password -from django.core.validators import validate_email -from django.core.exceptions import ValidationError as DjangoValidationError -from pydantic import ( - BaseModel, - ConfigDict, - Field, - field_validator, - model_validator, -) - -from .validators import validate_custom_curve, rate_clip - - -# Profile constants (moved verbatim from serializers.py). -PROFILE_STANDARD = "standard_profile" -PROFILE_FLOAT = "float_profile" -PROFILE_REVERSE = "reverse_profile" -PROFILE_CUSTOM = "custom_profile" -LATEST_DATASET_KEYWORD = "latest" -SUPPORTED_PROFILES = [PROFILE_STANDARD, PROFILE_FLOAT, PROFILE_REVERSE, PROFILE_CUSTOM] - - -# --- Prediction request (was PredictionRequestSerializer) --------------------- -class PredictionRequest(BaseModel): - launch_latitude: float = Field(ge=-90, le=90) - launch_longitude: float = Field(ge=0, le=360) - launch_datetime: datetime - launch_altitude: Optional[float] = None - format: str = "json" - profile: str = PROFILE_STANDARD - dataset: str = LATEST_DATASET_KEYWORD - - # profile-dependent fields - ascent_rate: Optional[float] = Field(default=None, ge=0.01) - descent_rate: Optional[float] = Field(default=None, ge=0.01) - burst_altitude: Optional[float] = None - float_altitude: Optional[float] = None - stop_datetime: Optional[datetime] = None - ascent_curve: Optional[str] = None - descent_curve: Optional[str] = None - interpolate: bool = False - # Related objects are accepted as PKs; existence is resolved in the endpoint - # (was PrimaryKeyRelatedField(queryset=...)). - start_point: Optional[int] = None - rate_profile: Optional[int] = None - template: Optional[int] = None - - @field_validator("profile") - @classmethod - def _validate_profile(cls, value: str) -> str: - if value not in SUPPORTED_PROFILES: - raise ValueError(f'"{value}" is not a valid choice.') - return value - - @model_validator(mode="after") - def _validate_cross_fields(self) -> "PredictionRequest": - launch_alt = self.launch_altitude if self.launch_altitude is not None else 0 - - if self.profile == PROFILE_STANDARD: - if self.burst_altitude is None: - raise ValueError("burst_altitude is required for standard profile.") - if self.burst_altitude <= launch_alt: - raise ValueError("burst_altitude must be greater than launch_altitude.") - - elif self.profile == PROFILE_FLOAT: - if self.float_altitude is None or self.float_altitude <= launch_alt: - raise ValueError("float_altitude must be greater than launch_altitude.") - if self.stop_datetime is None or self.stop_datetime <= self.launch_datetime: - raise ValueError("stop_datetime must be later than launch_datetime.") - - elif self.profile == PROFILE_CUSTOM: - if self.ascent_curve is None or not validate_custom_curve(self.ascent_curve): - raise ValueError("Invalid ascent_curve.") - if self.descent_curve is None or not validate_custom_curve(self.descent_curve): - raise ValueError("Invalid descent_curve.") - if self.burst_altitude is None or self.burst_altitude <= launch_alt: - raise ValueError("burst_altitude must be greater than launch_altitude.") - - # custom clipping logic (was in validate()) - if self.ascent_rate is not None: - self.ascent_rate = rate_clip(self.ascent_rate) - if self.descent_rate is not None: - self.descent_rate = rate_clip(self.descent_rate) - - return self - - -# --- Prediction outputs ------------------------------------------------------- -class PredictionOut(BaseModel): - """Was PredictionSerializer.""" - - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - created_at: datetime - updated_at: datetime - result: Any - - -class PredictionListOut(BaseModel): - """Was PredictionListSerializer.""" - - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - created_at: datetime - updated_at: datetime - start_point: Optional[int] = Field(default=None, validation_alias="start_point_id") - template: Optional[int] = Field(default=None, validation_alias="template_id") - rate_profile: Optional[int] = Field(default=None, validation_alias="rate_profile_id") - - -class PredictionDetailOut(BaseModel): - """Was PredictionDetailSerializer.""" - - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - created_at: datetime - updated_at: datetime - result: Any - start_point: Optional[int] = Field(default=None, validation_alias="start_point_id") - template: Optional[int] = Field(default=None, validation_alias="template_id") - rate_profile: Optional[int] = Field(default=None, validation_alias="rate_profile_id") - - -class PredictionCreateOut(BaseModel): - """Custom create() response shape: {id, created_at, result} (not the full model).""" - - id: uuid.UUID - created_at: datetime - result: Any - - -# --- Telemetry (was TelemetryPacketSerializer) -------------------------------- -class TelemetryIn(BaseModel): - # timestamp is required (model BigIntegerField has no default), matching the - # former ModelSerializer; the endpoint still overwrites it with server time. - timestamp: int - lat: float = 0.0 - lon: float = 0.0 - alt: float = 0.0 - payload: Any = Field(default_factory=dict) - - -class TelemetryOut(BaseModel): - model_config = ConfigDict(from_attributes=True) - - id: uuid.UUID - timestamp: int - lat: float - lon: float - alt: float - payload: Any - - -# --- SavedPoint (was SavedPointSerializer) ------------------------------------ -# `user` was a HiddenField(CurrentUserDefault) -> set from request in the endpoint, -# not part of the wire input/output. The unique_together (user, name) check that -# DRF did via UniqueTogetherValidator is reproduced in the endpoint (Step 2.8). -class SavedPointIn(BaseModel): - name: str - lat: float = 0.0 - lon: float = 0.0 - alt: float = 0.0 - - -class SavedPointPatchIn(BaseModel): - # All-optional variant for PATCH (partial_update). Only set fields apply. - name: Optional[str] = None - lat: Optional[float] = None - lon: Optional[float] = None - alt: Optional[float] = None - - -class SavedPointOut(BaseModel): - model_config = ConfigDict(from_attributes=True) - - id: int - name: str - lat: float - lon: float - alt: float - - -# --- SavedRateProfile (was SavedRateProfileSerializer) ------------------------ -# NOTE: no view referenced this serializer in routing; kept for parity. -class SavedRateProfileIn(BaseModel): - name: str - type: str = "ascent" - rate_profile_data: Any = Field(default_factory=dict) - - -class SavedRateProfileOut(BaseModel): - model_config = ConfigDict(from_attributes=True) - - id: int - name: str - type: str - rate_profile_data: Any - - -# --- PredictionTemplate (was PreditctionTemplateSerializer) ------------------- -class PredictionTemplateIn(BaseModel): - # protected_namespaces=() avoids pydantic's warning about the `model` field. - model_config = ConfigDict(protected_namespaces=()) - - name: str - is_default: bool = False - description: Optional[str] = None - prediction_mode: str = "" - model: str = "" - dataset: str = "" - flight_parameters: Any = Field(default_factory=dict) - - -class PredictionTemplatePatchIn(BaseModel): - # All-optional variant for PATCH (partial_update). Only set fields apply. - model_config = ConfigDict(protected_namespaces=()) - - name: Optional[str] = None - is_default: Optional[bool] = None - description: Optional[str] = None - prediction_mode: Optional[str] = None - model: Optional[str] = None - dataset: Optional[str] = None - flight_parameters: Optional[Any] = None - - -class PredictionTemplateOut(BaseModel): - model_config = ConfigDict(from_attributes=True, protected_namespaces=()) - - id: int - name: str - is_default: bool - description: Optional[str] = None - prediction_mode: str - model: str - dataset: str - flight_parameters: Any - - -# --- User (was UserSerializer) ------------------------------------------------ -class UserOut(BaseModel): - model_config = ConfigDict(from_attributes=True) - - username: str - email: str - first_name: str - last_name: str - - -class UserUpdateIn(BaseModel): - # `username` was read_only -> excluded from input. PATCH is partial, so all - # fields are optional; the endpoint applies only fields that were set. - email: Optional[str] = None - first_name: Optional[str] = None - last_name: Optional[str] = None - - @field_validator("email") - @classmethod - def _validate_email(cls, value: Optional[str]) -> Optional[str]: - if value is None: - return value - try: - validate_email(value) - except DjangoValidationError: - raise ValueError("Invalid email format") - return value - - -# --- Password / account (were ChangePasswordSerializer / DeleteAccountSerializer) -class ChangePasswordIn(BaseModel): - old_password: str - new_password: str - - @field_validator("new_password") - @classmethod - def _validate_new_password(cls, value: str) -> str: - # validate_password raises Django's ValidationError (not a ValueError), - # which pydantic would not convert into a 4xx -- re-raise as ValueError. - try: - validate_password(value) - except DjangoValidationError as exc: - raise ValueError("; ".join(exc.messages)) - return value - - -class DeleteAccountIn(BaseModel): - password: str - - -# --- Auth / session response shapes (were inline JsonResponse dicts) ---------- -class DetailResponse(BaseModel): - detail: str - - -class SessionResponse(BaseModel): - isAuthenticated: bool - - -class WhoAmIResponse(BaseModel): - username: str - - -class TokenResponse(BaseModel): - token: str - - -# --- Shared path parameters --------------------------------------------------- -class PkPath(BaseModel): - pk: int - - -class UuidPkPath(BaseModel): - pk: uuid.UUID diff --git a/stratoflights_api/pagination.py b/stratoflights_api/pagination.py deleted file mode 100644 index 6f1fbaf..0000000 --- a/stratoflights_api/pagination.py +++ /dev/null @@ -1,120 +0,0 @@ -"""Limit/offset pagination for the prediction list endpoints. - -DMR has no built-in limit/offset paginator, so (as the docs suggest) we keep our -own envelope model -- matching the former ``CustomLimitOffsetPagination`` output -exactly -- plus a helper that slices the queryset. The query params and envelope -keys are preserved verbatim: ``limit`` / ``skip`` in, and -``{total, limit, skip, predictions}`` out. -""" -from http import HTTPStatus -from typing import Optional -from urllib.parse import urlsplit, urlunsplit - -from django.core.paginator import InvalidPage, Paginator -from django.http import QueryDict -from pydantic import BaseModel - -from dmr.response import APIError - -from .dtos import PredictionOut, TelemetryOut - -DEFAULT_LIMIT = 10 # was CustomLimitOffsetPagination.default_limit -MAX_LIMIT = 100 # was CustomLimitOffsetPagination.max_limit -PAGE_SIZE = 100 # global REST_FRAMEWORK PAGE_SIZE (PageNumberPagination) - - -class PaginationQuery(BaseModel): - # Param names preserved: limit_query_param='limit', offset_query_param='skip'. - limit: int = DEFAULT_LIMIT - skip: int = 0 - - -class PredictionListUserQuery(BaseModel): - """Query params for PredictionViewSet.list_user: filters + limit/offset.""" - - satellite_id: Optional[str] = None - created_from: Optional[str] = None - created_till: Optional[str] = None - limit: int = DEFAULT_LIMIT - skip: int = 0 - - -class PredictionPage(BaseModel): - total: int - limit: int - skip: int - predictions: list[PredictionOut] - - -class TelemetryPage(BaseModel): - count: int - next: Optional[str] = None - previous: Optional[str] = None - results: list[TelemetryOut] - - -def _replace_query_param(url: str, key: str, value) -> str: - scheme, netloc, path, query, fragment = urlsplit(url) - query_dict = QueryDict(query, mutable=True) - query_dict[key] = value - return urlunsplit((scheme, netloc, path, query_dict.urlencode(), fragment)) - - -def _remove_query_param(url: str, key: str) -> str: - scheme, netloc, path, query, fragment = urlsplit(url) - query_dict = QueryDict(query, mutable=True) - query_dict.pop(key, None) - return urlunsplit((scheme, netloc, path, query_dict.urlencode(), fragment)) - - -def page_number_paginate(request, queryset, page_size: int = PAGE_SIZE): - """Reproduce DRF PageNumberPagination. - - Returns ``(count, next_link, previous_link, object_list)``. Raises a 404 - "Invalid page." like DRF on an out-of-range / non-integer page. Honours - ``page=last`` and builds absolute next/previous links off the current URL. - """ - paginator = Paginator(queryset, page_size) - page_number = request.GET.get('page', 1) - if page_number in ('last',): - page_number = paginator.num_pages - try: - page = paginator.page(page_number) - except InvalidPage: - raise APIError({'detail': 'Invalid page.'}, status_code=HTTPStatus.NOT_FOUND) - - url = request.build_absolute_uri() - - next_link = None - if page.has_next(): - next_link = _replace_query_param(url, 'page', page.next_page_number()) - - previous_link = None - if page.has_previous(): - previous_number = page.previous_page_number() - if previous_number == 1: - previous_link = _remove_query_param(url, 'page') - else: - previous_link = _replace_query_param(url, 'page', previous_number) - - return paginator.count, next_link, previous_link, list(page.object_list) - - -def paginate_predictions(queryset, query: PaginationQuery) -> PredictionPage: - """Slice ``queryset`` and build the prediction page envelope. - - Mirrors the bounds DRF's LimitOffsetPagination enforced: limit falls back to - the default when non-positive and is capped at MAX_LIMIT; skip floors at 0. - """ - limit = query.limit if query.limit > 0 else DEFAULT_LIMIT - limit = min(limit, MAX_LIMIT) - skip = query.skip if query.skip >= 0 else 0 - - total = queryset.count() - page = list(queryset[skip:skip + limit]) - return PredictionPage( - total=total, - limit=limit, - skip=skip, - predictions=[PredictionOut.model_validate(obj) for obj in page], - ) diff --git a/stratoflights_api/permissions.py b/stratoflights_api/permissions.py new file mode 100644 index 0000000..84c2ed4 --- /dev/null +++ b/stratoflights_api/permissions.py @@ -0,0 +1,16 @@ +from rest_framework.permissions import BasePermission, SAFE_METHODS + +class ReadOnlyOrAuthenticated(BasePermission): + def has_permission(self, request, view): + return ( + request.method in SAFE_METHODS or + request.user and request.user.is_authenticated + ) + + +class IsOwner(BasePermission): + def has_object_permission(self, request, view, obj): + return obj.user == request.user + + def has_permission(self, request, view): + return request.user and request.user.is_authenticated \ No newline at end of file diff --git a/stratoflights_api/serializers.py b/stratoflights_api/serializers.py new file mode 100644 index 0000000..dacdf57 --- /dev/null +++ b/stratoflights_api/serializers.py @@ -0,0 +1,185 @@ +from rest_framework import serializers +from .models import Prediction, SavedPoint, SavedRateProfile, PreditctionTemplate +from datetime import datetime +from django.contrib.auth.password_validation import validate_password +from django.core.validators import validate_email +from django.core.exceptions import ValidationError as DjangoValidationError +from django.contrib.auth import get_user_model +from .validators import ( + validate_custom_curve, rate_clip, + _rfc3339_to_timestamp, base64_to_curve +) + +class PredictionSerializer(serializers.ModelSerializer): + class Meta: + model = Prediction + fields = ['id', 'created_at', 'updated_at', 'result'] + +User = get_user_model() + +PROFILE_STANDARD = "standard_profile" +PROFILE_FLOAT = "float_profile" +PROFILE_REVERSE = "reverse_profile" +PROFILE_CUSTOM = "custom_profile" +LATEST_DATASET_KEYWORD = "latest" +SUPPORTED_PROFILES = [PROFILE_STANDARD, PROFILE_FLOAT, PROFILE_REVERSE, PROFILE_CUSTOM] + + +class PredictionRequestSerializer(serializers.Serializer): + launch_latitude = serializers.FloatField(min_value=-90, max_value=90) + launch_longitude = serializers.FloatField(min_value=0, max_value=360) + launch_datetime = serializers.DateTimeField() + launch_altitude = serializers.FloatField(required=False) + format = serializers.CharField(default="json") + profile = serializers.ChoiceField(choices=SUPPORTED_PROFILES, default=PROFILE_STANDARD) + dataset = serializers.CharField(default=LATEST_DATASET_KEYWORD) + + # --- профиль-dependent поля --- + ascent_rate = serializers.FloatField(required=False, min_value=0.01) + descent_rate = serializers.FloatField(required=False, min_value=0.01) + burst_altitude = serializers.FloatField(required=False) + float_altitude = serializers.FloatField(required=False) + stop_datetime = serializers.DateTimeField(required=False) + ascent_curve = serializers.CharField(required=False) + descent_curve = serializers.CharField(required=False) + interpolate = serializers.BooleanField(required=False, default=False) + start_point = serializers.PrimaryKeyRelatedField( + queryset=SavedPoint.objects.all(), required=False, allow_null=True + ) + rate_profile = serializers.PrimaryKeyRelatedField( + queryset=SavedRateProfile.objects.all(), required=False, allow_null=True + ) + template = serializers.PrimaryKeyRelatedField( + queryset=PreditctionTemplate.objects.all(), required=False, allow_null=True + ) + + def validate(self, data): + profile = data.get("profile", PROFILE_STANDARD) + launch_alt = data.get("launch_altitude", 0) + + if profile == PROFILE_STANDARD: + if 'burst_altitude' not in data: + raise serializers.ValidationError("burst_altitude is required for standard profile.") + if data['burst_altitude'] <= launch_alt: + raise serializers.ValidationError("burst_altitude must be greater than launch_altitude.") + + elif profile == PROFILE_FLOAT: + if 'float_altitude' not in data or data['float_altitude'] <= launch_alt: + raise serializers.ValidationError("float_altitude must be greater than launch_altitude.") + if 'stop_datetime' not in data or data['stop_datetime'] <= data['launch_datetime']: + raise serializers.ValidationError("stop_datetime must be later than launch_datetime.") + + elif profile == PROFILE_CUSTOM: + if 'ascent_curve' not in data or not validate_custom_curve(data['ascent_curve']): + raise serializers.ValidationError("Invalid ascent_curve.") + if 'descent_curve' not in data or not validate_custom_curve(data['descent_curve']): + raise serializers.ValidationError("Invalid descent_curve.") + if 'burst_altitude' not in data or data['burst_altitude'] <= launch_alt: + raise serializers.ValidationError("burst_altitude must be greater than launch_altitude.") + + # кастомная логика clipping'а + if 'ascent_rate' in data: + data['ascent_rate'] = rate_clip(data['ascent_rate']) + if 'descent_rate' in data: + data['descent_rate'] = rate_clip(data['descent_rate']) + + return data + + def create(self, validated_data): + if 'ascent_curve' in validated_data: + validated_data['ascent_curve'] = base64_to_curve(validated_data['ascent_curve']) + if 'descent_curve' in validated_data: + validated_data['descent_curve'] = base64_to_curve(validated_data['descent_curve']) + + prediction = Prediction( + user=validated_data.get('user'), + request=validated_data.get('request', {}), + result=validated_data.get('result', {}), + start_point=validated_data.get('start_point'), + template=validated_data.get('template'), + rate_profile=validated_data.get('rate_profile') + ) + prediction.save() + + return prediction + + +class PredictionListSerializer(serializers.ModelSerializer): + class Meta: + model = Prediction + fields = ["id", "created_at", "updated_at", "start_point", "template", "rate_profile"] + + +class PredictionDetailSerializer(serializers.ModelSerializer): + class Meta: + model = Prediction + fields = ["id", "created_at", "updated_at", "result", "start_point", "template", "rate_profile"] + + +from rest_framework import serializers +from .models import TelemetryPacket + +class TelemetryPacketSerializer(serializers.ModelSerializer): + class Meta: + model = TelemetryPacket + fields = ['id', 'timestamp', 'lat', 'lon', 'alt', 'payload'] + read_only_fields = ['id'] + + +class SavedPointSerializer(serializers.ModelSerializer): + user = serializers.HiddenField( + default=serializers.CurrentUserDefault() + ) + class Meta: + model = SavedPoint + fields = ['user', 'id', 'name', 'lat', 'lon', 'alt'] + read_only_fields = ['id'] + + validators = [ + serializers.UniqueTogetherValidator( + queryset=SavedPoint.objects.all(), + fields=['user', 'name'], + message="A saved point with this name already exists for the user." + ) + ] + + +class SavedRateProfileSerializer(serializers.ModelSerializer): + class Meta: + model = SavedRateProfile + fields = ['id', 'name', 'type', 'rate_profile_data'] + read_only_fields = ['id'] + + +class PreditctionTemplateSerializer(serializers.ModelSerializer): + class Meta: + model = PreditctionTemplate + fields = ['id', 'name', 'is_default', 'description', 'prediction_mode', 'model', 'dataset', 'flight_parameters'] + read_only_fields = ['id'] + + +class UserSerializer(serializers.ModelSerializer): + class Meta: + model = User + fields = ['username', 'email', 'first_name', 'last_name'] + extra_kwargs = { + 'username': {'read_only': True} + } + + def validate_email(self, value): + try: + validate_email(value) + except DjangoValidationError: + raise serializers.ValidationError("Invalid email format") + return value + +class ChangePasswordSerializer(serializers.Serializer): + old_password = serializers.CharField(required=True) + new_password = serializers.CharField(required=True) + + def validate_new_password(self, value): + validate_password(value) + return value + +class DeleteAccountSerializer(serializers.Serializer): + password = serializers.CharField(required=True) \ No newline at end of file diff --git a/stratoflights_api/urls.py b/stratoflights_api/urls.py index 9eb84dc..3a438c5 100644 --- a/stratoflights_api/urls.py +++ b/stratoflights_api/urls.py @@ -1,59 +1,48 @@ -from dmr.routing import Router, path +from django.urls import path +from rest_framework.routers import DefaultRouter +from rest_framework.authtoken.views import obtain_auth_token from .views import ( - PredictionCollectionController, - PredictionListUserController, - PredictionHistoryController, - PredictionDetailController, - PredictionDeleteController, - SavedPointListController, - SavedPointDetailController, - PredictionTemplateListController, - PredictionTemplateDetailController, - TelemetryController, - CsrfController, - LoginController, - LogoutController, - SessionController, - WhoAmIController, - UserProfileController, - ChangePasswordController, - ObtainTokenController, - TokenManagementController, - DeleteUserDataController, - DeleteAccountController, + PredictionViewSet, + SavedPointViewset, + PreditctionTemplateViewset, + TelemetryListCreateView, + get_csrf, + login_view, + logout_view, + SessionView, + WhoAmIView, + UserProfileView, + ChangePasswordView, + TokenManagementView, + DeleteUserDataView, + DeleteAccountView ) -# A Router (prefix + routes) is required so build_schema() can generate the -# OpenAPI document; see stratoflights/urls.py. -router = Router( - 'api/', - [ - path("csrf/", CsrfController.as_view(), name='api-csrf'), - path('token', ObtainTokenController.as_view(), name='get_token'), - path("login/", LoginController.as_view(), name='api-login'), - path("logout/", LogoutController.as_view(), name='api-logout'), - path("session/", SessionController.as_view(), name='api-session'), - path("whoami/", WhoAmIController.as_view(), name='api-whoami'), - path("/telemetry/", TelemetryController.as_view(), name="create_telemetry"), - path("profile/", UserProfileController.as_view(), name='api-profile'), - path("profile/change-password/", ChangePasswordController.as_view(), name='api-change-password'), - path("profile/token/", TokenManagementController.as_view(), name='api-token'), - path("profile/delete-data/", DeleteUserDataController.as_view(), name='api-delete-data'), - path("profile/delete-account/", DeleteAccountController.as_view(), name='api-delete-account'), - # Saved points (was SavedPointViewset via DefaultRouter). - path("saved-points/", SavedPointListController.as_view(), name='saved-points-list'), - path("saved-points//", SavedPointDetailController.as_view(), name='saved-points-detail'), - # Prediction templates (was PreditctionTemplateViewset via DefaultRouter). - path("saved-templates/", PredictionTemplateListController.as_view(), name='saved-templates-list'), - path("saved-templates//", PredictionTemplateDetailController.as_view(), name='saved-templates-detail'), - # Predictions (was PredictionViewSet via DefaultRouter). - path("predictions/", PredictionCollectionController.as_view(), name='predictions-list'), - path("predictions/list_user/", PredictionListUserController.as_view(), name='predictions-list-user'), - path("predictions/history/", PredictionHistoryController.as_view(), name='predictions-history'), - path("predictions//detail/", PredictionDetailController.as_view(), name='predictions-detail'), - path("predictions//delete/", PredictionDeleteController.as_view(), name='predictions-delete'), - ], -) -urlpatterns = router.urls +router = DefaultRouter() +router.register(r'predictions', PredictionViewSet, basename='predictions') +router.register(r'saved-points', SavedPointViewset, basename='saved-points') +router.register(r'saved-templates', PreditctionTemplateViewset, basename='saved-templates') + + +urlpatterns = [ + path("csrf/", get_csrf, name='api-csrf'), + path('token', obtain_auth_token, name = 'get_token'), + path("login/", login_view, name='api-login'), + path("logout/", logout_view, name='api-logout'), + path("session/", SessionView.as_view(), name='api-session'), + path("whoami/", WhoAmIView.as_view(), name='api-whoami'), + path("/telemetry/", TelemetryListCreateView.as_view(), name="create_telemetry"), + path('csrf/', get_csrf, name='api-csrf'), + path('login/', login_view, name='api-login'), + path('logout/', logout_view, name='api-logout'), + path('session/', SessionView.as_view(), name='api-session'), + path('whoami/', WhoAmIView.as_view(), name='api-whoami'), + path("profile/", UserProfileView.as_view(), name='api-profile'), + path("profile/change-password/", ChangePasswordView.as_view(), name='api-change-password'), + path("profile/token/", TokenManagementView.as_view(), name='api-token'), + path("profile/delete-data/", DeleteUserDataView.as_view(), name='api-delete-data'), + path("profile/delete-account/", DeleteAccountView.as_view(), name='api-delete-account'), +] +urlpatterns += router.urls diff --git a/stratoflights_api/views.py b/stratoflights_api/views.py index 7de0cf3..319f2a1 100644 --- a/stratoflights_api/views.py +++ b/stratoflights_api/views.py @@ -1,561 +1,341 @@ import requests import time import json -from http import HTTPStatus +from rest_framework import status, generics, permissions +from rest_framework.response import Response +from rest_framework.views import APIView +from rest_framework.viewsets import ModelViewSet, ViewSet, GenericViewSet +from rest_framework.exceptions import APIException +from rest_framework.permissions import IsAuthenticated, AllowAny +from rest_framework.authentication import SessionAuthentication, BasicAuthentication, TokenAuthentication +from rest_framework.decorators import api_view, permission_classes, authentication_classes, action +from rest_framework.authtoken.models import Token +from django.utils import timezone +from django.views.decorators.csrf import csrf_exempt +from django.utils.decorators import method_decorator +from django.http import JsonResponse from django.contrib.auth import authenticate, login, logout, get_user_model from django.middleware.csrf import get_token +from django.core.exceptions import ValidationError from django.utils.dateparse import parse_datetime -from .models import Prediction, Satellite, SavedPoint, SavedRateProfile, PreditctionTemplate, TelemetryPacket +from .models import Prediction, User, Satellite, SavedPoint, SavedRateProfile, PreditctionTemplate, TelemetryPacket +from .serializers import PredictionSerializer, TelemetryPacketSerializer, PredictionRequestSerializer, PredictionListSerializer, PredictionDetailSerializer, SavedPointSerializer, SavedRateProfileSerializer, PreditctionTemplateSerializer, UserSerializer, ChangePasswordSerializer, DeleteAccountSerializer from .services.tawhiri import TawhiriClient -from datetime import datetime, timedelta, timezone - -from pydantic import ValidationError -from dmr import Controller, modify -from dmr.components import Body, Path, Query -from dmr.plugins.pydantic import PydanticSerializer -from dmr.response import APIError -from dmr.security.jwt.views import ( - ObtainTokensPayload, - ObtainTokensResponse, - ObtainTokensSyncController, -) -from .dtos import ( - DetailResponse, - SessionResponse, - WhoAmIResponse, - UserOut, - UserUpdateIn, - ChangePasswordIn, - DeleteAccountIn, - TokenResponse, - PkPath, - UuidPkPath, - SavedPointIn, - SavedPointPatchIn, - SavedPointOut, - PredictionTemplateIn, - PredictionTemplatePatchIn, - PredictionTemplateOut, - PredictionRequest, - PredictionOut, - PredictionListOut, - PredictionDetailOut, - PredictionCreateOut, - TelemetryIn, - TelemetryOut, -) -from .pagination import ( - PredictionListUserQuery, - PredictionPage, - paginate_predictions, - TelemetryPage, - page_number_paginate, -) -from .validators import base64_to_curve - +from drf_spectacular.utils import extend_schema +from .permissions import ReadOnlyOrAuthenticated, IsOwner +from .custom_pagination import CustomLimitOffsetPagination +from datetime import datetime User = get_user_model() -def _api_error(status_code, body, headers=None): - """Adapter for DMR's APIError(raw_data, *, status_code, ...) signature. +def get_prediction_from_tawhiri(params): - Lets call sites keep the (status_code=, body=) style; the body is DMR's - first positional ``raw_data`` argument. - """ - return APIError(body, status_code=status_code, headers=headers) + base_url = "https://fly.stratonautica.ru/api/v2" + response = requests.get(base_url, params=params) + + if response.status_code == 200: + return response.json() # получаем результат предсказания + else: + raise Exception( + f"Tawhiri error: {response.status_code} {response.text}") -def _resolve_related(model, pk, field_name): - """Resolve a related object by PK (was PrimaryKeyRelatedField(queryset=...)). +class PredictionViewSet(GenericViewSet): + permission_classes = [IsAuthenticated] + pagination_class = CustomLimitOffsetPagination - The original queryset was ``.all()`` (not user-scoped); a missing PK yields a - 400 matching DRF's "Invalid pk" error shape. - """ - if pk is None: - return None - obj = model.objects.filter(pk=pk).first() - if obj is None: - raise _api_error( - status_code=HTTPStatus.BAD_REQUEST, - body={field_name: [f'Invalid pk "{pk}" - object does not exist.']}, - ) - return obj + def list(self, request): + queryset = Prediction.objects.filter(user=request.user) + return Response(PredictionSerializer(queryset, many=True).data) + + def create(self, request): + serializer = PredictionRequestSerializer(data=request.data) + if not serializer.is_valid(): + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) -class PredictionCollectionController(Controller[PydanticSerializer]): - """`predictions/` -- list the user's predictions and create a new one.""" - - def get(self) -> list[PredictionOut]: - queryset = Prediction.objects.filter(user=self.request.user) - return [PredictionOut.model_validate(obj) for obj in queryset] - - def post(self, parsed_body: Body[PredictionRequest]) -> PredictionCreateOut: - user = self.request.user - - # Resolve related objects before calling Tawhiri, so an invalid PK still - # fails with 400 first (as DRF validation did). - start_point = _resolve_related(SavedPoint, parsed_body.start_point, 'start_point') - template = _resolve_related(PreditctionTemplate, parsed_body.template, 'template') - rate_profile = _resolve_related(SavedRateProfile, parsed_body.rate_profile, 'rate_profile') + validated_data = serializer.validated_data try: - prediction_result = TawhiriClient.get_prediction(parsed_body.model_dump()) - except requests.RequestException as exc: - raise _api_error( - status_code=HTTPStatus.BAD_GATEWAY, - body={'error': f'Tawhiri error: {str(exc)}'}, - ) + prediction_result = TawhiriClient.get_prediction(validated_data) - # Carried over from the old serializer.create(): curves are decoded but - # never persisted (the model has no curve fields). Kept for parity -- - # including that a malformed curve here raises (then surfaces as 500). - if parsed_body.ascent_curve is not None: - base64_to_curve(parsed_body.ascent_curve) - if parsed_body.descent_curve is not None: - base64_to_curve(parsed_body.descent_curve) + except requests.RequestException as e: + return Response({"error": f"Tawhiri error: {str(e)}"}, status=status.HTTP_502_BAD_GATEWAY) - prediction = Prediction( - user=user, - request=json.loads(self.request.body), - result=prediction_result, - start_point=start_point, - template=template, - rate_profile=rate_profile, - ) - prediction.save() - - return PredictionCreateOut( - id=prediction.id, - created_at=prediction.created_at, + # prediction = Prediction.objects.create( + # result=prediction_result, user=request.user, request=request.data, validated_data=validated_data) + prediction = serializer.save( + user=request.user, result=prediction_result, + request=request.data ) + return Response({ + "id": prediction.id, + "created_at": prediction.created_at, + "result": prediction_result + }, status=status.HTTP_201_CREATED) -class PredictionListUserController(Controller[PydanticSerializer]): - """`predictions/list_user/` -- filtered, paginated list of the user's predictions.""" + @action(detail=False, methods=['get']) + def list_user(self, request): + user = request.user + satellite_id = request.query_params.get('satellite_id') + created_from = request.query_params.get('created_from') + created_till = request.query_params.get('created_till') - def get(self, parsed_query: Query[PredictionListUserQuery]) -> PredictionPage: - user = self.request.user - filters = {'user': user} + filters = { + 'user': user, + } - if parsed_query.created_from: - filters['created_at__gte'] = parse_datetime(parsed_query.created_from) - if parsed_query.created_till: - filters['created_at__lte'] = parse_datetime(parsed_query.created_till) - if parsed_query.satellite_id: - if not user.satellites.filter(id=parsed_query.satellite_id).exists(): - raise _api_error(status_code=HTTPStatus.FORBIDDEN, body={'detail': 'Access denied'}) - filters['satellite_id'] = parsed_query.satellite_id + if created_from: + filters['created_at__gte'] = parse_datetime(created_from) + + if created_till: + filters['created_at__lte'] = parse_datetime(created_till) + + if satellite_id: + if not user.satellites.filter(id=satellite_id).exists(): + return Response({'detail': 'Access denied'}, status=403) + + filters['satellite_id'] = satellite_id queryset = Prediction.objects.filter(**filters) - return paginate_predictions(queryset, parsed_query) + queryset = self.filter_queryset(queryset) + + page = self.paginate_queryset(queryset) + if page is not None: + serializer = PredictionSerializer(page, many=True) + return self.get_paginated_response(serializer.data) + + serializer = PredictionSerializer(queryset, many=True) + return Response(serializer.data) -class PredictionHistoryController(Controller[PydanticSerializer]): - """`predictions/history/` -- compact list of the user's predictions.""" + @action(detail=False, methods=["get"]) + def history(self, request): + queryset = Prediction.objects.filter(user=request.user) + return Response(PredictionListSerializer(queryset, many=True).data) - def get(self) -> list[PredictionListOut]: - queryset = Prediction.objects.filter(user=self.request.user) - return [PredictionListOut.model_validate(obj) for obj in queryset] - - -class PredictionDetailController(Controller[PydanticSerializer]): - """`predictions//detail/` -- retrieve a single prediction.""" - - def get(self, parsed_path: Path[UuidPkPath]) -> PredictionDetailOut: + @action(detail=True, methods=["get"]) + def detail(self, request, pk=None): prediction = Prediction.objects.filter( - user=self.request.user, pk=parsed_path.pk).first() - if prediction is None: - raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'}) - return PredictionDetailOut.model_validate(prediction) + user=request.user, pk=pk).first() + if not prediction: + return Response({'detail': 'Not found'}, status=404) + return Response(PredictionDetailSerializer(prediction).data) - -class PredictionDeleteController(Controller[PydanticSerializer]): - """`predictions//delete/` -- delete a single prediction.""" - - @modify(status_code=HTTPStatus.NO_CONTENT) - def delete(self, parsed_path: Path[UuidPkPath]) -> None: + @action(detail=True, methods=["delete"]) + def delete(self, request, pk=None): prediction = Prediction.objects.filter( - user=self.request.user, pk=parsed_path.pk).first() - if prediction is None: - raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'}) + user=request.user, pk=pk).first() + if not prediction: + return Response({'detail': 'Not found'}, status=404) prediction.delete() + return Response(status=204) -class TelemetryController(Controller[PydanticSerializer]): - """`/telemetry/` -- list and ingest telemetry for a satellite. +class TelemetryListCreateView(generics.ListCreateAPIView): + serializer_class = TelemetryPacketSerializer + permission_classes = [permissions.AllowAny] - Public (was AllowAny). GET preserves the global PageNumberPagination envelope - {count, next, previous, results} with PAGE_SIZE=100. - """ + def get_queryset(self): + qs = TelemetryPacket.objects.filter(satellite_id=self.kwargs["pk"]) - auth = None # public (was AllowAny) - csrf_exempt = True # DRF views bypass Django CSRF; ingestion is anonymous - - def get(self, parsed_path: Path[UuidPkPath]) -> TelemetryPage: - qs = TelemetryPacket.objects.filter(satellite_id=parsed_path.pk) - - from_ts = self.request.GET.get('from') - till_ts = self.request.GET.get('till') + from_ts = self.request.query_params.get("from") + till_ts = self.request.query_params.get("till") if from_ts: qs = qs.filter(timestamp__gte=int(from_ts)) if till_ts: qs = qs.filter(timestamp__lte=int(till_ts)) - qs = qs.order_by('-timestamp') + return qs.order_by("-timestamp") - count, next_link, previous_link, page_objects = page_number_paginate(self.request, qs) - return TelemetryPage( - count=count, - next=next_link, - previous=previous_link, - results=[TelemetryOut.model_validate(obj) for obj in page_objects], - ) + def post(self, request, pk): + serializer = TelemetryPacketSerializer(data=request.data) + if not serializer.is_valid(): + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) - def post(self, parsed_path: Path[UuidPkPath], parsed_body: Body[TelemetryIn]) -> TelemetryOut: - # Bug fix (approved): the original returned serializer.errors on success; - # we return the created packet instead. timestamp is still server-set. - packet = TelemetryPacket.objects.create( - timestamp=time.time(), - satellite=Satellite.objects.get(id=parsed_path.pk), - lat=parsed_body.lat, - lon=parsed_body.lon, - alt=parsed_body.alt, - payload=parsed_body.payload, - ) - return TelemetryOut.model_validate(packet) + validated_data = serializer.validated_data + + TelemetryPacket.objects.create(timestamp=time.time(), + satellite=Satellite.objects.get(id=pk), + lat=validated_data["lat"], + lon=validated_data["lon"], + alt=validated_data["alt"], + payload=validated_data['payload'], + ) + return Response(serializer.errors, status=status.HTTP_201_CREATED) -class SessionController(Controller[PydanticSerializer]): - """Report whether the current request is authenticated.""" +class SessionView(APIView): + permission_classes = [IsAuthenticated] - def get(self) -> SessionResponse: - return SessionResponse(isAuthenticated=True) + @staticmethod + def get(request, format=None): + return JsonResponse({'isAuthenticated': True}) -class WhoAmIController(Controller[PydanticSerializer]): - """Return the current user's username.""" +class WhoAmIView(APIView): + permission_classes = [IsAuthenticated] - def get(self) -> WhoAmIResponse: - return WhoAmIResponse(username=self.request.user.username) + @staticmethod + def get(request, format=None): + return JsonResponse({'username': request.user.username}) -class CsrfController(Controller[PydanticSerializer]): - """Get CSRF token.""" - - auth = None # public endpoint (was AllowAny) - csrf_exempt = True - - def get(self) -> DetailResponse: - token = get_token(self.request) - return self.to_response( - DetailResponse(detail='CSRF cookie set'), - headers={'X-CSRFToken': token}, - ) +@extend_schema(methods=["GET"], description="Get CSRF token") +@csrf_exempt +@api_view(["GET"]) +@permission_classes([AllowAny]) +def get_csrf(request): + response = JsonResponse({'detail': 'CSRF cookie set'}) + response['X-CSRFToken'] = get_token(request) + return response -class LoginController(Controller[PydanticSerializer]): - """Login user.""" +@extend_schema(methods=["POST"], description="Login user") +@csrf_exempt +@api_view(["POST"]) +@authentication_classes([BasicAuthentication]) +@permission_classes([AllowAny]) +def login_view(request): + data = json.loads(request.body) + username = data.get('username') + password = data.get('password') - auth = None # public endpoint (was AllowAny) - csrf_exempt = True + if username is None or password is None: + return JsonResponse({'detail': 'Please provide username and password.'}, status=400) - # post() would default to 201; the original returned 200. - @modify(status_code=HTTPStatus.OK) - def post(self) -> DetailResponse: - data = json.loads(self.request.body) - username = data.get('username') - password = data.get('password') + user = authenticate(username=username, password=password) + if user is None: + return JsonResponse({'detail': 'Invalid credentials.'}, status=400) - if username is None or password is None: - raise _api_error( - status_code=HTTPStatus.BAD_REQUEST, - body={'detail': 'Please provide username and password.'}, - ) - - user = authenticate(username=username, password=password) - if user is None: - raise _api_error( - status_code=HTTPStatus.BAD_REQUEST, - body={'detail': 'Invalid credentials.'}, - ) - - login(self.request, user) - return DetailResponse(detail='Successfully logged in.') + login(request, user) + return JsonResponse({'detail': 'Successfully logged in.'}) -class LogoutController(Controller[PydanticSerializer]): - """Logout user.""" +@extend_schema(methods=["POST"], description="Logout user") +@api_view(["POST"]) +@permission_classes([AllowAny]) +def logout_view(request): + if not request.user.is_authenticated: + return JsonResponse({'detail': 'You\'re not logged in.'}, status=400) - auth = None # public endpoint (was AllowAny); checks auth state manually - - @modify(status_code=HTTPStatus.OK) - def post(self) -> DetailResponse: - if not self.request.user.is_authenticated: - raise _api_error( - status_code=HTTPStatus.BAD_REQUEST, - body={'detail': "You're not logged in."}, - ) - - logout(self.request) - return DetailResponse(detail='Successfully logged out.') + logout(request) + return JsonResponse({'detail': 'Successfully logged out.'}) -_SAVED_POINT_DUPLICATE = 'A saved point with this name already exists for the user.' +class SavedPointViewset(ModelViewSet): + permission_classes = [IsOwner] + serializer_class = SavedPointSerializer + pagination_class = None + + def get_queryset(self): + return SavedPoint.objects.filter(user=self.request.user) + + def perform_create(self, serializer): + serializer.save(user=self.request.user) -def _check_saved_point_unique(user, name, exclude_pk=None): - """Reproduce the former SavedPointSerializer UniqueTogetherValidator.""" - qs = SavedPoint.objects.filter(user=user, name=name) - if exclude_pk is not None: - qs = qs.exclude(pk=exclude_pk) - if qs.exists(): - raise _api_error( - status_code=HTTPStatus.BAD_REQUEST, - body={'non_field_errors': [_SAVED_POINT_DUPLICATE]}, - ) +class PreditctionTemplateViewset(ModelViewSet): + permission_classes = [IsOwner] + serializer_class = PreditctionTemplateSerializer + pagination_class = None + + def get_queryset(self): + return PreditctionTemplate.objects.filter(user=self.request.user) + + def perform_create(self, serializer): + serializer.save(user=self.request.user) -class SavedPointListController(Controller[PydanticSerializer]): - """Collection endpoint for the current user's saved points (was SavedPointViewset).""" +class UserProfileView(APIView): + permission_classes = [IsAuthenticated] - def get(self) -> list[SavedPointOut]: - qs = SavedPoint.objects.filter(user=self.request.user) - return [SavedPointOut.model_validate(obj) for obj in qs] + def get(self, request): + serializer = UserSerializer(request.user) + return Response(serializer.data) - def post(self, parsed_body: Body[SavedPointIn]) -> SavedPointOut: - user = self.request.user - _check_saved_point_unique(user, parsed_body.name) - obj = SavedPoint.objects.create(user=user, **parsed_body.model_dump()) - return SavedPointOut.model_validate(obj) + def patch(self, request): + user = request.user + serializer = UserSerializer(user, data=request.data, partial=True) + + if not serializer.is_valid(): + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + serializer.save() + return Response(serializer.data) -class SavedPointDetailController(Controller[PydanticSerializer]): - """Detail endpoint for a single saved point owned by the current user.""" +class ChangePasswordView(APIView): + permission_classes = [IsAuthenticated] - def _get_object(self, pk: int): - # Filtering by user means another user's object reads as 404 (not 403), - # matching DRF's get_object() over a user-scoped queryset. - obj = SavedPoint.objects.filter(user=self.request.user, pk=pk).first() - if obj is None: - raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'}) - return obj + def post(self, request): + user = request.user + serializer = ChangePasswordSerializer(data=request.data) + + if not serializer.is_valid(): + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + if not user.check_password(serializer.validated_data['old_password']): + return Response({'detail': 'Old password is incorrect'}, + status=status.HTTP_400_BAD_REQUEST) - def get(self, parsed_path: Path[PkPath]) -> SavedPointOut: - return SavedPointOut.model_validate(self._get_object(parsed_path.pk)) - - def put(self, parsed_path: Path[PkPath], parsed_body: Body[SavedPointIn]) -> SavedPointOut: - obj = self._get_object(parsed_path.pk) - _check_saved_point_unique(self.request.user, parsed_body.name, exclude_pk=obj.pk) - for field, value in parsed_body.model_dump().items(): - setattr(obj, field, value) - obj.save() - return SavedPointOut.model_validate(obj) - - def patch(self, parsed_path: Path[PkPath], parsed_body: Body[SavedPointPatchIn]) -> SavedPointOut: - obj = self._get_object(parsed_path.pk) - updates = parsed_body.model_dump(exclude_unset=True) - if 'name' in updates: - _check_saved_point_unique(self.request.user, updates['name'], exclude_pk=obj.pk) - for field, value in updates.items(): - setattr(obj, field, value) - obj.save() - return SavedPointOut.model_validate(obj) - - @modify(status_code=HTTPStatus.NO_CONTENT) - def delete(self, parsed_path: Path[PkPath]) -> None: - self._get_object(parsed_path.pk).delete() - - -class PredictionTemplateListController(Controller[PydanticSerializer]): - """Collection endpoint for the current user's templates (was PreditctionTemplateViewset). - - NOTE: as before, there is no app-level uniqueness check; a duplicate - (user, name) hits the model's unique_together and surfaces as a DB error. - """ - - def get(self) -> list[PredictionTemplateOut]: - qs = PreditctionTemplate.objects.filter(user=self.request.user) - return [PredictionTemplateOut.model_validate(obj) for obj in qs] - - def post(self, parsed_body: Body[PredictionTemplateIn]) -> PredictionTemplateOut: - obj = PreditctionTemplate.objects.create( - user=self.request.user, **parsed_body.model_dump() - ) - return PredictionTemplateOut.model_validate(obj) - - -class PredictionTemplateDetailController(Controller[PydanticSerializer]): - """Detail endpoint for a single template owned by the current user.""" - - def _get_object(self, pk: int): - obj = PreditctionTemplate.objects.filter(user=self.request.user, pk=pk).first() - if obj is None: - raise _api_error(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'}) - return obj - - def get(self, parsed_path: Path[PkPath]) -> PredictionTemplateOut: - return PredictionTemplateOut.model_validate(self._get_object(parsed_path.pk)) - - def put(self, parsed_path: Path[PkPath], parsed_body: Body[PredictionTemplateIn]) -> PredictionTemplateOut: - obj = self._get_object(parsed_path.pk) - for field, value in parsed_body.model_dump().items(): - setattr(obj, field, value) - obj.save() - return PredictionTemplateOut.model_validate(obj) - - def patch(self, parsed_path: Path[PkPath], parsed_body: Body[PredictionTemplatePatchIn]) -> PredictionTemplateOut: - obj = self._get_object(parsed_path.pk) - for field, value in parsed_body.model_dump(exclude_unset=True).items(): - setattr(obj, field, value) - obj.save() - return PredictionTemplateOut.model_validate(obj) - - @modify(status_code=HTTPStatus.NO_CONTENT) - def delete(self, parsed_path: Path[PkPath]) -> None: - self._get_object(parsed_path.pk).delete() - - -class UserProfileController(Controller[PydanticSerializer]): - """Read and partially update the current user's profile.""" - - def get(self) -> UserOut: - return UserOut.model_validate(self.request.user) - - def patch(self, parsed_body: Body[UserUpdateIn]) -> UserOut: - user = self.request.user - # partial update: apply only the fields the client actually sent. - for field, value in parsed_body.model_dump(exclude_unset=True).items(): - setattr(user, field, value) + user.set_password(serializer.validated_data['new_password']) user.save() - return UserOut.model_validate(user) + return Response({'detail': 'Password changed successfully'}) -class ChangePasswordController(Controller[PydanticSerializer]): - """Change the current user's password.""" - - # post() would default to 201; the original returned 200. - @modify(status_code=HTTPStatus.OK) - def post(self, parsed_body: Body[ChangePasswordIn]) -> DetailResponse: - user = self.request.user - - if not user.check_password(parsed_body.old_password): - raise _api_error( - status_code=HTTPStatus.BAD_REQUEST, - body={'detail': 'Old password is incorrect'}, - ) - - user.set_password(parsed_body.new_password) - user.save() - return DetailResponse(detail='Password changed successfully') - - -class DeleteAccountController(Controller[PydanticSerializer]): - """Delete the current user's account and all their data.""" - - # DMR forbids a Body component on DELETE, but the original endpoint read the - # password from a DELETE request body. Preserve that contract by parsing the - # body manually instead of via Body[DeleteAccountIn]. - def delete(self) -> DetailResponse: - user = self.request.user - - try: - parsed_body = DeleteAccountIn(**json.loads(self.request.body or b'{}')) - except ValidationError as exc: - raise _api_error(status_code=HTTPStatus.BAD_REQUEST, body=json.loads(exc.json())) - except ValueError: - raise _api_error(status_code=HTTPStatus.BAD_REQUEST, body={'detail': 'Invalid request body.'}) - - if not user.check_password(parsed_body.password): - raise _api_error( - status_code=HTTPStatus.BAD_REQUEST, - body={'detail': 'Incorrect password'}, - ) +class DeleteAccountView(APIView): + permission_classes = [IsAuthenticated] + def delete(self, request): + user = request.user + serializer = DeleteAccountSerializer(data=request.data) + + if not serializer.is_valid(): + return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST) + + if not user.check_password(serializer.validated_data['password']): + return Response({'detail': 'Incorrect password'}, + status=status.HTTP_400_BAD_REQUEST) + Prediction.objects.filter(user=user).delete() SavedPoint.objects.filter(user=user).delete() PreditctionTemplate.objects.filter(user=user).delete() - + user.delete() - - return DetailResponse(detail='Account deleted successfully') + + return Response({'detail': 'Account deleted successfully'}) -class DeleteUserDataController(Controller[PydanticSerializer]): - """Delete all of the current user's data without deleting the account.""" - - def delete(self) -> DetailResponse: - user = self.request.user +class DeleteUserDataView(APIView): + permission_classes = [IsAuthenticated] + def delete(self, request): + user = request.user + Prediction.objects.filter(user=user).delete() SavedPoint.objects.filter(user=user).delete() PreditctionTemplate.objects.filter(user=user).delete() - - return DetailResponse(detail='All user data deleted successfully') + + return Response({'detail': 'All user data deleted successfully'}) -class ObtainTokenController( - ObtainTokensSyncController[ - PydanticSerializer, - ObtainTokensPayload, - ObtainTokensResponse, - ], -): - """Exchange username/password for JWT access + refresh tokens. +class TokenManagementView(APIView): + permission_classes = [IsAuthenticated] - Replaces DRF's obtain_auth_token. Approved drift: the token format and - semantics change from a single stored DRF token to stateless JWTs, so the - response is {access_token, refresh_token} instead of {token}. - """ + def get(self, request): + + token, created = Token.objects.get_or_create(user=request.user) + return Response({"token": token.key}) - auth = None # public: credentials are supplied in the request body - csrf_exempt = True - jwt_expiration = timedelta(hours=1) - jwt_refresh_expiration = timedelta(days=7) - - def convert_auth_payload(self, payload): - return payload - - def make_api_response(self): - now = datetime.now(timezone.utc) - return { - 'access_token': self.create_jwt_token( - expiration=now + self.jwt_expiration, - token_type='access', - ), - 'refresh_token': self.create_jwt_token( - expiration=now + self.jwt_refresh_expiration, - token_type='refresh', - ), - } - - -class TokenManagementController(Controller[PydanticSerializer]): - """Issue a fresh JWT access token for the current user. - - Was TokenManagementView (DRF stored-token get/regenerate). With stateless - JWTs there is nothing stored to fetch or delete, so both GET and POST mint a - new access token (approved drift). The {"token": ...} response shape is kept. - """ - - jwt_expiration = timedelta(hours=1) - - def get(self) -> TokenResponse: - return TokenResponse(token=self._issue_token()) - - # post() would default to 201; the original returned 200. - @modify(status_code=HTTPStatus.OK) - def post(self) -> TokenResponse: - return TokenResponse(token=self._issue_token()) - - def _issue_token(self) -> str: - now = datetime.now(timezone.utc) - return self.create_jwt_token( - subject=str(self.request.user.pk), - expiration=now + self.jwt_expiration, - token_type='access', - ) + def post(self, request): + + Token.objects.filter(user=request.user).delete() + token = Token.objects.create(user=request.user) + return Response({"token": token.key})