migrated to modern-rest
This commit is contained in:
parent
d9a92569f0
commit
8e44c4501a
11 changed files with 1014 additions and 572 deletions
|
|
@ -1,341 +1,552 @@
|
|||
import requests
|
||||
import time
|
||||
import json
|
||||
from rest_framework import status, generics, permissions
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.views import APIView
|
||||
from rest_framework.viewsets import ModelViewSet, ViewSet, GenericViewSet
|
||||
from rest_framework.exceptions import APIException
|
||||
from rest_framework.permissions import IsAuthenticated, AllowAny
|
||||
from rest_framework.authentication import SessionAuthentication, BasicAuthentication, TokenAuthentication
|
||||
from rest_framework.decorators import api_view, permission_classes, authentication_classes, action
|
||||
from rest_framework.authtoken.models import Token
|
||||
from django.utils import timezone
|
||||
from django.views.decorators.csrf import csrf_exempt
|
||||
from django.utils.decorators import method_decorator
|
||||
from django.http import JsonResponse
|
||||
from http import HTTPStatus
|
||||
from django.contrib.auth import authenticate, login, logout, get_user_model
|
||||
from django.middleware.csrf import get_token
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.utils.dateparse import parse_datetime
|
||||
from .models import Prediction, User, Satellite, SavedPoint, SavedRateProfile, PreditctionTemplate, TelemetryPacket
|
||||
from .serializers import PredictionSerializer, TelemetryPacketSerializer, PredictionRequestSerializer, PredictionListSerializer, PredictionDetailSerializer, SavedPointSerializer, SavedRateProfileSerializer, PreditctionTemplateSerializer, UserSerializer, ChangePasswordSerializer, DeleteAccountSerializer
|
||||
from .models import Prediction, Satellite, SavedPoint, SavedRateProfile, PreditctionTemplate, TelemetryPacket
|
||||
from .services.tawhiri import TawhiriClient
|
||||
from drf_spectacular.utils import extend_schema
|
||||
from .permissions import ReadOnlyOrAuthenticated, IsOwner
|
||||
from .custom_pagination import CustomLimitOffsetPagination
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from pydantic import ValidationError
|
||||
from dmr import Controller, modify
|
||||
from dmr.components import Body, Path, Query
|
||||
from dmr.plugins.pydantic import PydanticSerializer
|
||||
from dmr.response import APIError
|
||||
from dmr.security.jwt.views import (
|
||||
ObtainTokensPayload,
|
||||
ObtainTokensResponse,
|
||||
ObtainTokensSyncController,
|
||||
)
|
||||
from .dtos import (
|
||||
DetailResponse,
|
||||
SessionResponse,
|
||||
WhoAmIResponse,
|
||||
UserOut,
|
||||
UserUpdateIn,
|
||||
ChangePasswordIn,
|
||||
DeleteAccountIn,
|
||||
TokenResponse,
|
||||
PkPath,
|
||||
UuidPkPath,
|
||||
SavedPointIn,
|
||||
SavedPointPatchIn,
|
||||
SavedPointOut,
|
||||
PredictionTemplateIn,
|
||||
PredictionTemplatePatchIn,
|
||||
PredictionTemplateOut,
|
||||
PredictionRequest,
|
||||
PredictionOut,
|
||||
PredictionListOut,
|
||||
PredictionDetailOut,
|
||||
PredictionCreateOut,
|
||||
TelemetryIn,
|
||||
TelemetryOut,
|
||||
)
|
||||
from .pagination import (
|
||||
PredictionListUserQuery,
|
||||
PredictionPage,
|
||||
paginate_predictions,
|
||||
TelemetryPage,
|
||||
page_number_paginate,
|
||||
)
|
||||
from .validators import base64_to_curve
|
||||
|
||||
User = get_user_model()
|
||||
|
||||
|
||||
def get_prediction_from_tawhiri(params):
|
||||
def _resolve_related(model, pk, field_name):
|
||||
"""Resolve a related object by PK (was PrimaryKeyRelatedField(queryset=...)).
|
||||
|
||||
base_url = "https://fly.stratonautica.ru/api/v2"
|
||||
response = requests.get(base_url, params=params)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json() # получаем результат предсказания
|
||||
else:
|
||||
raise Exception(
|
||||
f"Tawhiri error: {response.status_code} {response.text}")
|
||||
The original queryset was ``.all()`` (not user-scoped); a missing PK yields a
|
||||
400 matching DRF's "Invalid pk" error shape.
|
||||
"""
|
||||
if pk is None:
|
||||
return None
|
||||
obj = model.objects.filter(pk=pk).first()
|
||||
if obj is None:
|
||||
raise APIError(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
body={field_name: [f'Invalid pk "{pk}" - object does not exist.']},
|
||||
)
|
||||
return obj
|
||||
|
||||
|
||||
class PredictionViewSet(GenericViewSet):
|
||||
permission_classes = [IsAuthenticated]
|
||||
pagination_class = CustomLimitOffsetPagination
|
||||
class PredictionCollectionController(Controller[PydanticSerializer]):
|
||||
"""`predictions/` -- list the user's predictions and create a new one."""
|
||||
|
||||
def list(self, request):
|
||||
queryset = Prediction.objects.filter(user=request.user)
|
||||
return Response(PredictionSerializer(queryset, many=True).data)
|
||||
|
||||
def create(self, request):
|
||||
serializer = PredictionRequestSerializer(data=request.data)
|
||||
def get(self) -> list[PredictionOut]:
|
||||
queryset = Prediction.objects.filter(user=self.request.user)
|
||||
return [PredictionOut.model_validate(obj) for obj in queryset]
|
||||
|
||||
if not serializer.is_valid():
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
def post(self, parsed_body: Body[PredictionRequest]) -> PredictionCreateOut:
|
||||
user = self.request.user
|
||||
|
||||
validated_data = serializer.validated_data
|
||||
# Resolve related objects before calling Tawhiri, so an invalid PK still
|
||||
# fails with 400 first (as DRF validation did).
|
||||
start_point = _resolve_related(SavedPoint, parsed_body.start_point, 'start_point')
|
||||
template = _resolve_related(PreditctionTemplate, parsed_body.template, 'template')
|
||||
rate_profile = _resolve_related(SavedRateProfile, parsed_body.rate_profile, 'rate_profile')
|
||||
|
||||
try:
|
||||
prediction_result = TawhiriClient.get_prediction(validated_data)
|
||||
prediction_result = TawhiriClient.get_prediction(parsed_body.model_dump())
|
||||
except requests.RequestException as exc:
|
||||
raise APIError(
|
||||
status_code=HTTPStatus.BAD_GATEWAY,
|
||||
body={'error': f'Tawhiri error: {str(exc)}'},
|
||||
)
|
||||
|
||||
except requests.RequestException as e:
|
||||
return Response({"error": f"Tawhiri error: {str(e)}"}, status=status.HTTP_502_BAD_GATEWAY)
|
||||
# Carried over from the old serializer.create(): curves are decoded but
|
||||
# never persisted (the model has no curve fields). Kept for parity --
|
||||
# including that a malformed curve here raises (then surfaces as 500).
|
||||
if parsed_body.ascent_curve is not None:
|
||||
base64_to_curve(parsed_body.ascent_curve)
|
||||
if parsed_body.descent_curve is not None:
|
||||
base64_to_curve(parsed_body.descent_curve)
|
||||
|
||||
# prediction = Prediction.objects.create(
|
||||
# result=prediction_result, user=request.user, request=request.data, validated_data=validated_data)
|
||||
prediction = serializer.save(
|
||||
user=request.user,
|
||||
prediction = Prediction(
|
||||
user=user,
|
||||
request=json.loads(self.request.body),
|
||||
result=prediction_result,
|
||||
start_point=start_point,
|
||||
template=template,
|
||||
rate_profile=rate_profile,
|
||||
)
|
||||
prediction.save()
|
||||
|
||||
return PredictionCreateOut(
|
||||
id=prediction.id,
|
||||
created_at=prediction.created_at,
|
||||
result=prediction_result,
|
||||
request=request.data
|
||||
)
|
||||
|
||||
return Response({
|
||||
"id": prediction.id,
|
||||
"created_at": prediction.created_at,
|
||||
"result": prediction_result
|
||||
}, status=status.HTTP_201_CREATED)
|
||||
|
||||
@action(detail=False, methods=['get'])
|
||||
def list_user(self, request):
|
||||
user = request.user
|
||||
satellite_id = request.query_params.get('satellite_id')
|
||||
created_from = request.query_params.get('created_from')
|
||||
created_till = request.query_params.get('created_till')
|
||||
class PredictionListUserController(Controller[PydanticSerializer]):
|
||||
"""`predictions/list_user/` -- filtered, paginated list of the user's predictions."""
|
||||
|
||||
filters = {
|
||||
'user': user,
|
||||
}
|
||||
def get(self, parsed_query: Query[PredictionListUserQuery]) -> PredictionPage:
|
||||
user = self.request.user
|
||||
filters = {'user': user}
|
||||
|
||||
if created_from:
|
||||
filters['created_at__gte'] = parse_datetime(created_from)
|
||||
|
||||
if created_till:
|
||||
filters['created_at__lte'] = parse_datetime(created_till)
|
||||
|
||||
if satellite_id:
|
||||
if not user.satellites.filter(id=satellite_id).exists():
|
||||
return Response({'detail': 'Access denied'}, status=403)
|
||||
|
||||
filters['satellite_id'] = satellite_id
|
||||
if parsed_query.created_from:
|
||||
filters['created_at__gte'] = parse_datetime(parsed_query.created_from)
|
||||
if parsed_query.created_till:
|
||||
filters['created_at__lte'] = parse_datetime(parsed_query.created_till)
|
||||
if parsed_query.satellite_id:
|
||||
if not user.satellites.filter(id=parsed_query.satellite_id).exists():
|
||||
raise APIError(status_code=HTTPStatus.FORBIDDEN, body={'detail': 'Access denied'})
|
||||
filters['satellite_id'] = parsed_query.satellite_id
|
||||
|
||||
queryset = Prediction.objects.filter(**filters)
|
||||
queryset = self.filter_queryset(queryset)
|
||||
|
||||
page = self.paginate_queryset(queryset)
|
||||
return paginate_predictions(queryset, parsed_query)
|
||||
|
||||
if page is not None:
|
||||
serializer = PredictionSerializer(page, many=True)
|
||||
return self.get_paginated_response(serializer.data)
|
||||
|
||||
serializer = PredictionSerializer(queryset, many=True)
|
||||
return Response(serializer.data)
|
||||
|
||||
@action(detail=False, methods=["get"])
|
||||
def history(self, request):
|
||||
queryset = Prediction.objects.filter(user=request.user)
|
||||
return Response(PredictionListSerializer(queryset, many=True).data)
|
||||
class PredictionHistoryController(Controller[PydanticSerializer]):
|
||||
"""`predictions/history/` -- compact list of the user's predictions."""
|
||||
|
||||
@action(detail=True, methods=["get"])
|
||||
def detail(self, request, pk=None):
|
||||
def get(self) -> list[PredictionListOut]:
|
||||
queryset = Prediction.objects.filter(user=self.request.user)
|
||||
return [PredictionListOut.model_validate(obj) for obj in queryset]
|
||||
|
||||
|
||||
class PredictionDetailController(Controller[PydanticSerializer]):
|
||||
"""`predictions/<pk>/detail/` -- retrieve a single prediction."""
|
||||
|
||||
def get(self, parsed_path: Path[UuidPkPath]) -> PredictionDetailOut:
|
||||
prediction = Prediction.objects.filter(
|
||||
user=request.user, pk=pk).first()
|
||||
if not prediction:
|
||||
return Response({'detail': 'Not found'}, status=404)
|
||||
return Response(PredictionDetailSerializer(prediction).data)
|
||||
user=self.request.user, pk=parsed_path.pk).first()
|
||||
if prediction is None:
|
||||
raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'})
|
||||
return PredictionDetailOut.model_validate(prediction)
|
||||
|
||||
@action(detail=True, methods=["delete"])
|
||||
def delete(self, request, pk=None):
|
||||
|
||||
class PredictionDeleteController(Controller[PydanticSerializer]):
|
||||
"""`predictions/<pk>/delete/` -- delete a single prediction."""
|
||||
|
||||
@modify(status_code=HTTPStatus.NO_CONTENT)
|
||||
def delete(self, parsed_path: Path[UuidPkPath]) -> None:
|
||||
prediction = Prediction.objects.filter(
|
||||
user=request.user, pk=pk).first()
|
||||
if not prediction:
|
||||
return Response({'detail': 'Not found'}, status=404)
|
||||
user=self.request.user, pk=parsed_path.pk).first()
|
||||
if prediction is None:
|
||||
raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'})
|
||||
prediction.delete()
|
||||
return Response(status=204)
|
||||
|
||||
|
||||
class TelemetryListCreateView(generics.ListCreateAPIView):
|
||||
serializer_class = TelemetryPacketSerializer
|
||||
permission_classes = [permissions.AllowAny]
|
||||
class TelemetryController(Controller[PydanticSerializer]):
|
||||
"""`<uuid:pk>/telemetry/` -- list and ingest telemetry for a satellite.
|
||||
|
||||
def get_queryset(self):
|
||||
qs = TelemetryPacket.objects.filter(satellite_id=self.kwargs["pk"])
|
||||
Public (was AllowAny). GET preserves the global PageNumberPagination envelope
|
||||
{count, next, previous, results} with PAGE_SIZE=100.
|
||||
"""
|
||||
|
||||
from_ts = self.request.query_params.get("from")
|
||||
till_ts = self.request.query_params.get("till")
|
||||
auth = None # public (was AllowAny)
|
||||
csrf_exempt = True # DRF views bypass Django CSRF; ingestion is anonymous
|
||||
|
||||
def get(self, parsed_path: Path[UuidPkPath]) -> TelemetryPage:
|
||||
qs = TelemetryPacket.objects.filter(satellite_id=parsed_path.pk)
|
||||
|
||||
from_ts = self.request.GET.get('from')
|
||||
till_ts = self.request.GET.get('till')
|
||||
|
||||
if from_ts:
|
||||
qs = qs.filter(timestamp__gte=int(from_ts))
|
||||
if till_ts:
|
||||
qs = qs.filter(timestamp__lte=int(till_ts))
|
||||
|
||||
return qs.order_by("-timestamp")
|
||||
qs = qs.order_by('-timestamp')
|
||||
|
||||
def post(self, request, pk):
|
||||
serializer = TelemetryPacketSerializer(data=request.data)
|
||||
if not serializer.is_valid():
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
count, next_link, previous_link, page_objects = page_number_paginate(self.request, qs)
|
||||
return TelemetryPage(
|
||||
count=count,
|
||||
next=next_link,
|
||||
previous=previous_link,
|
||||
results=[TelemetryOut.model_validate(obj) for obj in page_objects],
|
||||
)
|
||||
|
||||
validated_data = serializer.validated_data
|
||||
|
||||
TelemetryPacket.objects.create(timestamp=time.time(),
|
||||
satellite=Satellite.objects.get(id=pk),
|
||||
lat=validated_data["lat"],
|
||||
lon=validated_data["lon"],
|
||||
alt=validated_data["alt"],
|
||||
payload=validated_data['payload'],
|
||||
)
|
||||
return Response(serializer.errors, status=status.HTTP_201_CREATED)
|
||||
def post(self, parsed_path: Path[UuidPkPath], parsed_body: Body[TelemetryIn]) -> TelemetryOut:
|
||||
# Bug fix (approved): the original returned serializer.errors on success;
|
||||
# we return the created packet instead. timestamp is still server-set.
|
||||
packet = TelemetryPacket.objects.create(
|
||||
timestamp=time.time(),
|
||||
satellite=Satellite.objects.get(id=parsed_path.pk),
|
||||
lat=parsed_body.lat,
|
||||
lon=parsed_body.lon,
|
||||
alt=parsed_body.alt,
|
||||
payload=parsed_body.payload,
|
||||
)
|
||||
return TelemetryOut.model_validate(packet)
|
||||
|
||||
|
||||
class SessionView(APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
class SessionController(Controller[PydanticSerializer]):
|
||||
"""Report whether the current request is authenticated."""
|
||||
|
||||
@staticmethod
|
||||
def get(request, format=None):
|
||||
return JsonResponse({'isAuthenticated': True})
|
||||
def get(self) -> SessionResponse:
|
||||
return SessionResponse(isAuthenticated=True)
|
||||
|
||||
|
||||
class WhoAmIView(APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
class WhoAmIController(Controller[PydanticSerializer]):
|
||||
"""Return the current user's username."""
|
||||
|
||||
@staticmethod
|
||||
def get(request, format=None):
|
||||
return JsonResponse({'username': request.user.username})
|
||||
def get(self) -> WhoAmIResponse:
|
||||
return WhoAmIResponse(username=self.request.user.username)
|
||||
|
||||
|
||||
@extend_schema(methods=["GET"], description="Get CSRF token")
|
||||
@csrf_exempt
|
||||
@api_view(["GET"])
|
||||
@permission_classes([AllowAny])
|
||||
def get_csrf(request):
|
||||
response = JsonResponse({'detail': 'CSRF cookie set'})
|
||||
response['X-CSRFToken'] = get_token(request)
|
||||
return response
|
||||
class CsrfController(Controller[PydanticSerializer]):
|
||||
"""Get CSRF token."""
|
||||
|
||||
auth = None # public endpoint (was AllowAny)
|
||||
csrf_exempt = True
|
||||
|
||||
def get(self) -> DetailResponse:
|
||||
token = get_token(self.request)
|
||||
return self.to_response(
|
||||
DetailResponse(detail='CSRF cookie set'),
|
||||
headers={'X-CSRFToken': token},
|
||||
)
|
||||
|
||||
|
||||
@extend_schema(methods=["POST"], description="Login user")
|
||||
@csrf_exempt
|
||||
@api_view(["POST"])
|
||||
@authentication_classes([BasicAuthentication])
|
||||
@permission_classes([AllowAny])
|
||||
def login_view(request):
|
||||
data = json.loads(request.body)
|
||||
username = data.get('username')
|
||||
password = data.get('password')
|
||||
class LoginController(Controller[PydanticSerializer]):
|
||||
"""Login user."""
|
||||
|
||||
if username is None or password is None:
|
||||
return JsonResponse({'detail': 'Please provide username and password.'}, status=400)
|
||||
auth = None # public endpoint (was AllowAny)
|
||||
csrf_exempt = True
|
||||
|
||||
user = authenticate(username=username, password=password)
|
||||
if user is None:
|
||||
return JsonResponse({'detail': 'Invalid credentials.'}, status=400)
|
||||
# post() would default to 201; the original returned 200.
|
||||
@modify(status_code=HTTPStatus.OK)
|
||||
def post(self) -> DetailResponse:
|
||||
data = json.loads(self.request.body)
|
||||
username = data.get('username')
|
||||
password = data.get('password')
|
||||
|
||||
login(request, user)
|
||||
return JsonResponse({'detail': 'Successfully logged in.'})
|
||||
if username is None or password is None:
|
||||
raise APIError(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
body={'detail': 'Please provide username and password.'},
|
||||
)
|
||||
|
||||
user = authenticate(username=username, password=password)
|
||||
if user is None:
|
||||
raise APIError(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
body={'detail': 'Invalid credentials.'},
|
||||
)
|
||||
|
||||
login(self.request, user)
|
||||
return DetailResponse(detail='Successfully logged in.')
|
||||
|
||||
|
||||
@extend_schema(methods=["POST"], description="Logout user")
|
||||
@api_view(["POST"])
|
||||
@permission_classes([AllowAny])
|
||||
def logout_view(request):
|
||||
if not request.user.is_authenticated:
|
||||
return JsonResponse({'detail': 'You\'re not logged in.'}, status=400)
|
||||
class LogoutController(Controller[PydanticSerializer]):
|
||||
"""Logout user."""
|
||||
|
||||
logout(request)
|
||||
return JsonResponse({'detail': 'Successfully logged out.'})
|
||||
auth = None # public endpoint (was AllowAny); checks auth state manually
|
||||
|
||||
@modify(status_code=HTTPStatus.OK)
|
||||
def post(self) -> DetailResponse:
|
||||
if not self.request.user.is_authenticated:
|
||||
raise APIError(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
body={'detail': "You're not logged in."},
|
||||
)
|
||||
|
||||
logout(self.request)
|
||||
return DetailResponse(detail='Successfully logged out.')
|
||||
|
||||
|
||||
class SavedPointViewset(ModelViewSet):
|
||||
permission_classes = [IsOwner]
|
||||
serializer_class = SavedPointSerializer
|
||||
pagination_class = None
|
||||
|
||||
def get_queryset(self):
|
||||
return SavedPoint.objects.filter(user=self.request.user)
|
||||
|
||||
def perform_create(self, serializer):
|
||||
serializer.save(user=self.request.user)
|
||||
_SAVED_POINT_DUPLICATE = 'A saved point with this name already exists for the user.'
|
||||
|
||||
|
||||
class PreditctionTemplateViewset(ModelViewSet):
|
||||
permission_classes = [IsOwner]
|
||||
serializer_class = PreditctionTemplateSerializer
|
||||
pagination_class = None
|
||||
|
||||
def get_queryset(self):
|
||||
return PreditctionTemplate.objects.filter(user=self.request.user)
|
||||
|
||||
def perform_create(self, serializer):
|
||||
serializer.save(user=self.request.user)
|
||||
def _check_saved_point_unique(user, name, exclude_pk=None):
|
||||
"""Reproduce the former SavedPointSerializer UniqueTogetherValidator."""
|
||||
qs = SavedPoint.objects.filter(user=user, name=name)
|
||||
if exclude_pk is not None:
|
||||
qs = qs.exclude(pk=exclude_pk)
|
||||
if qs.exists():
|
||||
raise APIError(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
body={'non_field_errors': [_SAVED_POINT_DUPLICATE]},
|
||||
)
|
||||
|
||||
|
||||
class UserProfileView(APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
class SavedPointListController(Controller[PydanticSerializer]):
|
||||
"""Collection endpoint for the current user's saved points (was SavedPointViewset)."""
|
||||
|
||||
def get(self, request):
|
||||
serializer = UserSerializer(request.user)
|
||||
return Response(serializer.data)
|
||||
def get(self) -> list[SavedPointOut]:
|
||||
qs = SavedPoint.objects.filter(user=self.request.user)
|
||||
return [SavedPointOut.model_validate(obj) for obj in qs]
|
||||
|
||||
def patch(self, request):
|
||||
user = request.user
|
||||
serializer = UserSerializer(user, data=request.data, partial=True)
|
||||
|
||||
if not serializer.is_valid():
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
serializer.save()
|
||||
return Response(serializer.data)
|
||||
def post(self, parsed_body: Body[SavedPointIn]) -> SavedPointOut:
|
||||
user = self.request.user
|
||||
_check_saved_point_unique(user, parsed_body.name)
|
||||
obj = SavedPoint.objects.create(user=user, **parsed_body.model_dump())
|
||||
return SavedPointOut.model_validate(obj)
|
||||
|
||||
|
||||
class ChangePasswordView(APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
class SavedPointDetailController(Controller[PydanticSerializer]):
|
||||
"""Detail endpoint for a single saved point owned by the current user."""
|
||||
|
||||
def post(self, request):
|
||||
user = request.user
|
||||
serializer = ChangePasswordSerializer(data=request.data)
|
||||
|
||||
if not serializer.is_valid():
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if not user.check_password(serializer.validated_data['old_password']):
|
||||
return Response({'detail': 'Old password is incorrect'},
|
||||
status=status.HTTP_400_BAD_REQUEST)
|
||||
def _get_object(self, pk: int):
|
||||
# Filtering by user means another user's object reads as 404 (not 403),
|
||||
# matching DRF's get_object() over a user-scoped queryset.
|
||||
obj = SavedPoint.objects.filter(user=self.request.user, pk=pk).first()
|
||||
if obj is None:
|
||||
raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'})
|
||||
return obj
|
||||
|
||||
user.set_password(serializer.validated_data['new_password'])
|
||||
def get(self, parsed_path: Path[PkPath]) -> SavedPointOut:
|
||||
return SavedPointOut.model_validate(self._get_object(parsed_path.pk))
|
||||
|
||||
def put(self, parsed_path: Path[PkPath], parsed_body: Body[SavedPointIn]) -> SavedPointOut:
|
||||
obj = self._get_object(parsed_path.pk)
|
||||
_check_saved_point_unique(self.request.user, parsed_body.name, exclude_pk=obj.pk)
|
||||
for field, value in parsed_body.model_dump().items():
|
||||
setattr(obj, field, value)
|
||||
obj.save()
|
||||
return SavedPointOut.model_validate(obj)
|
||||
|
||||
def patch(self, parsed_path: Path[PkPath], parsed_body: Body[SavedPointPatchIn]) -> SavedPointOut:
|
||||
obj = self._get_object(parsed_path.pk)
|
||||
updates = parsed_body.model_dump(exclude_unset=True)
|
||||
if 'name' in updates:
|
||||
_check_saved_point_unique(self.request.user, updates['name'], exclude_pk=obj.pk)
|
||||
for field, value in updates.items():
|
||||
setattr(obj, field, value)
|
||||
obj.save()
|
||||
return SavedPointOut.model_validate(obj)
|
||||
|
||||
@modify(status_code=HTTPStatus.NO_CONTENT)
|
||||
def delete(self, parsed_path: Path[PkPath]) -> None:
|
||||
self._get_object(parsed_path.pk).delete()
|
||||
|
||||
|
||||
class PredictionTemplateListController(Controller[PydanticSerializer]):
|
||||
"""Collection endpoint for the current user's templates (was PreditctionTemplateViewset).
|
||||
|
||||
NOTE: as before, there is no app-level uniqueness check; a duplicate
|
||||
(user, name) hits the model's unique_together and surfaces as a DB error.
|
||||
"""
|
||||
|
||||
def get(self) -> list[PredictionTemplateOut]:
|
||||
qs = PreditctionTemplate.objects.filter(user=self.request.user)
|
||||
return [PredictionTemplateOut.model_validate(obj) for obj in qs]
|
||||
|
||||
def post(self, parsed_body: Body[PredictionTemplateIn]) -> PredictionTemplateOut:
|
||||
obj = PreditctionTemplate.objects.create(
|
||||
user=self.request.user, **parsed_body.model_dump()
|
||||
)
|
||||
return PredictionTemplateOut.model_validate(obj)
|
||||
|
||||
|
||||
class PredictionTemplateDetailController(Controller[PydanticSerializer]):
|
||||
"""Detail endpoint for a single template owned by the current user."""
|
||||
|
||||
def _get_object(self, pk: int):
|
||||
obj = PreditctionTemplate.objects.filter(user=self.request.user, pk=pk).first()
|
||||
if obj is None:
|
||||
raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'})
|
||||
return obj
|
||||
|
||||
def get(self, parsed_path: Path[PkPath]) -> PredictionTemplateOut:
|
||||
return PredictionTemplateOut.model_validate(self._get_object(parsed_path.pk))
|
||||
|
||||
def put(self, parsed_path: Path[PkPath], parsed_body: Body[PredictionTemplateIn]) -> PredictionTemplateOut:
|
||||
obj = self._get_object(parsed_path.pk)
|
||||
for field, value in parsed_body.model_dump().items():
|
||||
setattr(obj, field, value)
|
||||
obj.save()
|
||||
return PredictionTemplateOut.model_validate(obj)
|
||||
|
||||
def patch(self, parsed_path: Path[PkPath], parsed_body: Body[PredictionTemplatePatchIn]) -> PredictionTemplateOut:
|
||||
obj = self._get_object(parsed_path.pk)
|
||||
for field, value in parsed_body.model_dump(exclude_unset=True).items():
|
||||
setattr(obj, field, value)
|
||||
obj.save()
|
||||
return PredictionTemplateOut.model_validate(obj)
|
||||
|
||||
@modify(status_code=HTTPStatus.NO_CONTENT)
|
||||
def delete(self, parsed_path: Path[PkPath]) -> None:
|
||||
self._get_object(parsed_path.pk).delete()
|
||||
|
||||
|
||||
class UserProfileController(Controller[PydanticSerializer]):
|
||||
"""Read and partially update the current user's profile."""
|
||||
|
||||
def get(self) -> UserOut:
|
||||
return UserOut.model_validate(self.request.user)
|
||||
|
||||
def patch(self, parsed_body: Body[UserUpdateIn]) -> UserOut:
|
||||
user = self.request.user
|
||||
# partial update: apply only the fields the client actually sent.
|
||||
for field, value in parsed_body.model_dump(exclude_unset=True).items():
|
||||
setattr(user, field, value)
|
||||
user.save()
|
||||
return Response({'detail': 'Password changed successfully'})
|
||||
return UserOut.model_validate(user)
|
||||
|
||||
|
||||
class DeleteAccountView(APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
class ChangePasswordController(Controller[PydanticSerializer]):
|
||||
"""Change the current user's password."""
|
||||
|
||||
# post() would default to 201; the original returned 200.
|
||||
@modify(status_code=HTTPStatus.OK)
|
||||
def post(self, parsed_body: Body[ChangePasswordIn]) -> DetailResponse:
|
||||
user = self.request.user
|
||||
|
||||
if not user.check_password(parsed_body.old_password):
|
||||
raise APIError(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
body={'detail': 'Old password is incorrect'},
|
||||
)
|
||||
|
||||
user.set_password(parsed_body.new_password)
|
||||
user.save()
|
||||
return DetailResponse(detail='Password changed successfully')
|
||||
|
||||
|
||||
class DeleteAccountController(Controller[PydanticSerializer]):
|
||||
"""Delete the current user's account and all their data."""
|
||||
|
||||
# DMR forbids a Body component on DELETE, but the original endpoint read the
|
||||
# password from a DELETE request body. Preserve that contract by parsing the
|
||||
# body manually instead of via Body[DeleteAccountIn].
|
||||
def delete(self) -> DetailResponse:
|
||||
user = self.request.user
|
||||
|
||||
try:
|
||||
parsed_body = DeleteAccountIn(**json.loads(self.request.body or b'{}'))
|
||||
except ValidationError as exc:
|
||||
raise APIError(status_code=HTTPStatus.BAD_REQUEST, body=json.loads(exc.json()))
|
||||
except ValueError:
|
||||
raise APIError(status_code=HTTPStatus.BAD_REQUEST, body={'detail': 'Invalid request body.'})
|
||||
|
||||
if not user.check_password(parsed_body.password):
|
||||
raise APIError(
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
body={'detail': 'Incorrect password'},
|
||||
)
|
||||
|
||||
def delete(self, request):
|
||||
user = request.user
|
||||
serializer = DeleteAccountSerializer(data=request.data)
|
||||
|
||||
if not serializer.is_valid():
|
||||
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
if not user.check_password(serializer.validated_data['password']):
|
||||
return Response({'detail': 'Incorrect password'},
|
||||
status=status.HTTP_400_BAD_REQUEST)
|
||||
|
||||
Prediction.objects.filter(user=user).delete()
|
||||
SavedPoint.objects.filter(user=user).delete()
|
||||
PreditctionTemplate.objects.filter(user=user).delete()
|
||||
|
||||
|
||||
user.delete()
|
||||
|
||||
return Response({'detail': 'Account deleted successfully'})
|
||||
|
||||
return DetailResponse(detail='Account deleted successfully')
|
||||
|
||||
|
||||
class DeleteUserDataView(APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
class DeleteUserDataController(Controller[PydanticSerializer]):
|
||||
"""Delete all of the current user's data without deleting the account."""
|
||||
|
||||
def delete(self) -> DetailResponse:
|
||||
user = self.request.user
|
||||
|
||||
def delete(self, request):
|
||||
user = request.user
|
||||
|
||||
Prediction.objects.filter(user=user).delete()
|
||||
SavedPoint.objects.filter(user=user).delete()
|
||||
PreditctionTemplate.objects.filter(user=user).delete()
|
||||
|
||||
return Response({'detail': 'All user data deleted successfully'})
|
||||
|
||||
return DetailResponse(detail='All user data deleted successfully')
|
||||
|
||||
|
||||
class TokenManagementView(APIView):
|
||||
permission_classes = [IsAuthenticated]
|
||||
class ObtainTokenController(
|
||||
ObtainTokensSyncController[
|
||||
PydanticSerializer,
|
||||
ObtainTokensPayload,
|
||||
ObtainTokensResponse,
|
||||
],
|
||||
):
|
||||
"""Exchange username/password for JWT access + refresh tokens.
|
||||
|
||||
def get(self, request):
|
||||
|
||||
token, created = Token.objects.get_or_create(user=request.user)
|
||||
return Response({"token": token.key})
|
||||
Replaces DRF's obtain_auth_token. Approved drift: the token format and
|
||||
semantics change from a single stored DRF token to stateless JWTs, so the
|
||||
response is {access_token, refresh_token} instead of {token}.
|
||||
"""
|
||||
|
||||
def post(self, request):
|
||||
|
||||
Token.objects.filter(user=request.user).delete()
|
||||
token = Token.objects.create(user=request.user)
|
||||
return Response({"token": token.key})
|
||||
auth = None # public: credentials are supplied in the request body
|
||||
csrf_exempt = True
|
||||
jwt_expiration = timedelta(hours=1)
|
||||
jwt_refresh_expiration = timedelta(days=7)
|
||||
|
||||
def convert_auth_payload(self, payload):
|
||||
return payload
|
||||
|
||||
def make_api_response(self):
|
||||
now = datetime.now(timezone.utc)
|
||||
return {
|
||||
'access_token': self.create_jwt_token(
|
||||
expiration=now + self.jwt_expiration,
|
||||
token_type='access',
|
||||
),
|
||||
'refresh_token': self.create_jwt_token(
|
||||
expiration=now + self.jwt_refresh_expiration,
|
||||
token_type='refresh',
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class TokenManagementController(Controller[PydanticSerializer]):
|
||||
"""Issue a fresh JWT access token for the current user.
|
||||
|
||||
Was TokenManagementView (DRF stored-token get/regenerate). With stateless
|
||||
JWTs there is nothing stored to fetch or delete, so both GET and POST mint a
|
||||
new access token (approved drift). The {"token": ...} response shape is kept.
|
||||
"""
|
||||
|
||||
jwt_expiration = timedelta(hours=1)
|
||||
|
||||
def get(self) -> TokenResponse:
|
||||
return TokenResponse(token=self._issue_token())
|
||||
|
||||
# post() would default to 201; the original returned 200.
|
||||
@modify(status_code=HTTPStatus.OK)
|
||||
def post(self) -> TokenResponse:
|
||||
return TokenResponse(token=self._issue_token())
|
||||
|
||||
def _issue_token(self) -> str:
|
||||
now = datetime.now(timezone.utc)
|
||||
return self.create_jwt_token(
|
||||
subject=str(self.request.user.pk),
|
||||
expiration=now + self.jwt_expiration,
|
||||
token_type='access',
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue