565 lines
21 KiB
Python
565 lines
21 KiB
Python
import requests
|
|
import time
|
|
import json
|
|
from http import HTTPStatus
|
|
from django.contrib.auth import authenticate, login, logout, get_user_model
|
|
from django.middleware.csrf import get_token
|
|
from django.utils.dateparse import parse_datetime
|
|
from .models import Prediction, Satellite, SavedPoint, SavedRateProfile, PreditctionTemplate, TelemetryPacket
|
|
from .services.tawhiri import TawhiriClient
|
|
from datetime import datetime, timedelta, timezone
|
|
|
|
from pydantic import ValidationError
|
|
from dmr import Controller, modify
|
|
from dmr.components import Body, Path, Query
|
|
from dmr.plugins.pydantic import PydanticSerializer
|
|
from dmr.response import APIError
|
|
from dmr.security.jwt.views import (
|
|
ObtainTokensPayload,
|
|
ObtainTokensResponse,
|
|
ObtainTokensSyncController,
|
|
)
|
|
from .dtos import (
|
|
DetailResponse,
|
|
SessionResponse,
|
|
WhoAmIResponse,
|
|
UserOut,
|
|
UserUpdateIn,
|
|
ChangePasswordIn,
|
|
DeleteAccountIn,
|
|
TokenResponse,
|
|
PkPath,
|
|
UuidPkPath,
|
|
SavedPointIn,
|
|
SavedPointPatchIn,
|
|
SavedPointOut,
|
|
PredictionTemplateIn,
|
|
PredictionTemplatePatchIn,
|
|
PredictionTemplateOut,
|
|
PredictionRequest,
|
|
PredictionOut,
|
|
PredictionListOut,
|
|
PredictionDetailOut,
|
|
PredictionCreateOut,
|
|
TelemetryIn,
|
|
TelemetryOut,
|
|
)
|
|
from .pagination import (
|
|
PredictionListUserQuery,
|
|
PredictionPage,
|
|
paginate_predictions,
|
|
TelemetryPage,
|
|
page_number_paginate,
|
|
)
|
|
from .validators import base64_to_curve
|
|
|
|
User = get_user_model()
|
|
|
|
|
|
def _resolve_related(model, pk, field_name):
|
|
"""Resolve a related object by PK (was PrimaryKeyRelatedField(queryset=...)).
|
|
|
|
The original queryset was ``.all()`` (not user-scoped); a missing PK yields a
|
|
400 matching DRF's "Invalid pk" error shape.
|
|
"""
|
|
if pk is None:
|
|
return None
|
|
obj = model.objects.filter(pk=pk).first()
|
|
if obj is None:
|
|
raise APIError(
|
|
status_code=HTTPStatus.BAD_REQUEST,
|
|
body={field_name: [f'Invalid pk "{pk}" - object does not exist.']},
|
|
)
|
|
return obj
|
|
|
|
|
|
class PredictionCollectionController(Controller[PydanticSerializer]):
|
|
"""`predictions/` -- list the user's predictions and create a new one."""
|
|
|
|
def get(self) -> list[PredictionOut]:
|
|
queryset = Prediction.objects.filter(user=self.request.user)
|
|
return [PredictionOut.model_validate(obj) for obj in queryset]
|
|
|
|
def post(self, parsed_body: Body[PredictionRequest]) -> PredictionCreateOut:
|
|
user = self.request.user
|
|
|
|
# Resolve related objects before calling Tawhiri, so an invalid PK still
|
|
# fails with 400 first (as DRF validation did).
|
|
start_point = _resolve_related(SavedPoint, parsed_body.start_point, 'start_point')
|
|
template = _resolve_related(PreditctionTemplate, parsed_body.template, 'template')
|
|
rate_profile = _resolve_related(SavedRateProfile, parsed_body.rate_profile, 'rate_profile')
|
|
|
|
try:
|
|
prediction_result = TawhiriClient.get_prediction(parsed_body.model_dump())
|
|
except requests.RequestException as exc:
|
|
raise APIError(
|
|
status_code=HTTPStatus.BAD_GATEWAY,
|
|
body={'error': f'Tawhiri error: {str(exc)}'},
|
|
)
|
|
|
|
# Carried over from the old serializer.create(): curves are decoded but
|
|
# never persisted (the model has no curve fields). Kept for parity --
|
|
# including that a malformed curve here raises (then surfaces as 500).
|
|
if parsed_body.ascent_curve is not None:
|
|
base64_to_curve(parsed_body.ascent_curve)
|
|
if parsed_body.descent_curve is not None:
|
|
base64_to_curve(parsed_body.descent_curve)
|
|
|
|
prediction = Prediction(
|
|
user=user,
|
|
request=json.loads(self.request.body),
|
|
result=prediction_result,
|
|
start_point=start_point,
|
|
template=template,
|
|
rate_profile=rate_profile,
|
|
)
|
|
prediction.save()
|
|
|
|
return PredictionCreateOut(
|
|
id=prediction.id,
|
|
created_at=prediction.created_at,
|
|
result=prediction_result,
|
|
)
|
|
|
|
|
|
class PredictionListUserController(Controller[PydanticSerializer]):
|
|
"""`predictions/list_user/` -- filtered, paginated list of the user's predictions."""
|
|
|
|
def get(self, parsed_query: Query[PredictionListUserQuery]) -> PredictionPage:
|
|
user = self.request.user
|
|
filters = {'user': user}
|
|
|
|
if parsed_query.created_from:
|
|
filters['created_at__gte'] = parse_datetime(parsed_query.created_from)
|
|
if parsed_query.created_till:
|
|
filters['created_at__lte'] = parse_datetime(parsed_query.created_till)
|
|
if parsed_query.satellite_id:
|
|
if not user.satellites.filter(id=parsed_query.satellite_id).exists():
|
|
raise APIError(status_code=HTTPStatus.FORBIDDEN, body={'detail': 'Access denied'})
|
|
filters['satellite_id'] = parsed_query.satellite_id
|
|
|
|
queryset = Prediction.objects.filter(**filters)
|
|
return paginate_predictions(queryset, parsed_query)
|
|
|
|
|
|
class PredictionHistoryController(Controller[PydanticSerializer]):
|
|
"""`predictions/history/` -- compact list of the user's predictions."""
|
|
|
|
def get(self) -> list[PredictionListOut]:
|
|
queryset = Prediction.objects.filter(user=self.request.user)
|
|
return [PredictionListOut.model_validate(obj) for obj in queryset]
|
|
|
|
|
|
class PredictionDetailController(Controller[PydanticSerializer]):
|
|
"""`predictions/<pk>/detail/` -- retrieve a single prediction."""
|
|
|
|
def get(self, parsed_path: Path[UuidPkPath]) -> PredictionDetailOut:
|
|
prediction = Prediction.objects.filter(
|
|
user=self.request.user, pk=parsed_path.pk).first()
|
|
if prediction is None:
|
|
raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'})
|
|
return PredictionDetailOut.model_validate(prediction)
|
|
|
|
|
|
class PredictionDeleteController(Controller[PydanticSerializer]):
|
|
"""`predictions/<pk>/delete/` -- delete a single prediction."""
|
|
|
|
@modify(status_code=HTTPStatus.NO_CONTENT)
|
|
def delete(self, parsed_path: Path[UuidPkPath]) -> None:
|
|
prediction = Prediction.objects.filter(
|
|
user=self.request.user, pk=parsed_path.pk).first()
|
|
if prediction is None:
|
|
raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found'})
|
|
prediction.delete()
|
|
|
|
|
|
class TelemetryController(Controller[PydanticSerializer]):
|
|
"""`<uuid:pk>/telemetry/` -- list and ingest telemetry for a satellite.
|
|
|
|
Public (was AllowAny). GET preserves the global PageNumberPagination envelope
|
|
{count, next, previous, results} with PAGE_SIZE=100.
|
|
"""
|
|
|
|
auth = None # public (was AllowAny)
|
|
csrf_exempt = True # DRF views bypass Django CSRF; ingestion is anonymous
|
|
|
|
def get(self, parsed_path: Path[UuidPkPath]) -> TelemetryPage:
|
|
qs = TelemetryPacket.objects.filter(satellite_id=parsed_path.pk)
|
|
|
|
from_ts = self.request.GET.get('from')
|
|
till_ts = self.request.GET.get('till')
|
|
|
|
if from_ts:
|
|
qs = qs.filter(timestamp__gte=int(from_ts))
|
|
if till_ts:
|
|
qs = qs.filter(timestamp__lte=int(till_ts))
|
|
|
|
qs = qs.order_by('-timestamp')
|
|
|
|
count, next_link, previous_link, page_objects = page_number_paginate(self.request, qs)
|
|
return TelemetryPage(
|
|
count=count,
|
|
next=next_link,
|
|
previous=previous_link,
|
|
results=[TelemetryOut.model_validate(obj) for obj in page_objects],
|
|
)
|
|
|
|
def post(self, parsed_path: Path[UuidPkPath], parsed_body: Body[TelemetryIn]) -> TelemetryOut:
|
|
# Bug fix (approved): the original returned serializer.errors on success;
|
|
# we return the created packet instead. timestamp is still server-set.
|
|
packet = TelemetryPacket.objects.create(
|
|
timestamp=time.time(),
|
|
satellite=Satellite.objects.get(id=parsed_path.pk),
|
|
lat=parsed_body.lat,
|
|
lon=parsed_body.lon,
|
|
alt=parsed_body.alt,
|
|
payload=parsed_body.payload,
|
|
)
|
|
return TelemetryOut.model_validate(packet)
|
|
|
|
|
|
class SessionController(Controller[PydanticSerializer]):
|
|
"""Report whether the current request is authenticated."""
|
|
|
|
def get(self) -> SessionResponse:
|
|
return SessionResponse(isAuthenticated=True)
|
|
|
|
|
|
class WhoAmIController(Controller[PydanticSerializer]):
|
|
"""Return the current user's username."""
|
|
|
|
def get(self) -> WhoAmIResponse:
|
|
return WhoAmIResponse(username=self.request.user.username)
|
|
|
|
|
|
class CsrfController(Controller[PydanticSerializer]):
|
|
"""Get CSRF token."""
|
|
|
|
auth = None # public endpoint (was AllowAny)
|
|
csrf_exempt = True
|
|
|
|
def get(self) -> DetailResponse:
|
|
token = get_token(self.request)
|
|
return self.to_response(
|
|
DetailResponse(detail='CSRF cookie set'),
|
|
headers={'X-CSRFToken': token},
|
|
)
|
|
|
|
|
|
class LoginController(Controller[PydanticSerializer]):
|
|
"""Login user."""
|
|
|
|
auth = None # public endpoint (was AllowAny)
|
|
csrf_exempt = True
|
|
|
|
# post() would default to 201; the original returned 200.
|
|
@modify(status_code=HTTPStatus.OK)
|
|
def post(self) -> DetailResponse:
|
|
data = json.loads(self.request.body)
|
|
username = data.get('username')
|
|
password = data.get('password')
|
|
|
|
if username is None or password is None:
|
|
raise APIError(
|
|
status_code=HTTPStatus.BAD_REQUEST,
|
|
body={'detail': 'Please provide username and password.'},
|
|
)
|
|
|
|
user = authenticate(username=username, password=password)
|
|
if user is None:
|
|
raise APIError(
|
|
status_code=HTTPStatus.BAD_REQUEST,
|
|
body={'detail': 'Invalid credentials.'},
|
|
)
|
|
|
|
login(self.request, user)
|
|
return DetailResponse(detail='Successfully logged in.')
|
|
|
|
|
|
class LogoutController(Controller[PydanticSerializer]):
|
|
"""Logout user."""
|
|
|
|
auth = None # public endpoint (was AllowAny); checks auth state manually
|
|
|
|
@modify(status_code=HTTPStatus.OK)
|
|
def post(self) -> DetailResponse:
|
|
if not self.request.user.is_authenticated:
|
|
raise APIError(
|
|
status_code=HTTPStatus.BAD_REQUEST,
|
|
body={'detail': "You're not logged in."},
|
|
)
|
|
|
|
logout(self.request)
|
|
return DetailResponse(detail='Successfully logged out.')
|
|
|
|
|
|
_SAVED_POINT_DUPLICATE = 'A saved point with this name already exists for the user.'
|
|
|
|
|
|
def _check_saved_point_unique(user, name, exclude_pk=None):
|
|
"""Reproduce the former SavedPointSerializer UniqueTogetherValidator."""
|
|
qs = SavedPoint.objects.filter(user=user, name=name)
|
|
if exclude_pk is not None:
|
|
qs = qs.exclude(pk=exclude_pk)
|
|
if qs.exists():
|
|
raise APIError(
|
|
status_code=HTTPStatus.BAD_REQUEST,
|
|
body={'non_field_errors': [_SAVED_POINT_DUPLICATE]},
|
|
)
|
|
|
|
|
|
class SavedPointListController(Controller[PydanticSerializer]):
|
|
"""Collection endpoint for the current user's saved points (was SavedPointViewset)."""
|
|
|
|
def get(self) -> list[SavedPointOut]:
|
|
qs = SavedPoint.objects.filter(user=self.request.user)
|
|
return [SavedPointOut.model_validate(obj) for obj in qs]
|
|
|
|
def post(self, parsed_body: Body[SavedPointIn]) -> SavedPointOut:
|
|
user = self.request.user
|
|
_check_saved_point_unique(user, parsed_body.name)
|
|
obj = SavedPoint.objects.create(user=user, **parsed_body.model_dump())
|
|
return SavedPointOut.model_validate(obj)
|
|
|
|
|
|
class SavedPointDetailController(Controller[PydanticSerializer]):
|
|
"""Detail endpoint for a single saved point owned by the current user."""
|
|
|
|
def _get_object(self, pk: int):
|
|
# Filtering by user means another user's object reads as 404 (not 403),
|
|
# matching DRF's get_object() over a user-scoped queryset.
|
|
obj = SavedPoint.objects.filter(user=self.request.user, pk=pk).first()
|
|
if obj is None:
|
|
raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'})
|
|
return obj
|
|
|
|
def get(self, parsed_path: Path[PkPath]) -> SavedPointOut:
|
|
return SavedPointOut.model_validate(self._get_object(parsed_path.pk))
|
|
|
|
def put(self, parsed_path: Path[PkPath], parsed_body: Body[SavedPointIn]) -> SavedPointOut:
|
|
obj = self._get_object(parsed_path.pk)
|
|
_check_saved_point_unique(self.request.user, parsed_body.name, exclude_pk=obj.pk)
|
|
for field, value in parsed_body.model_dump().items():
|
|
setattr(obj, field, value)
|
|
obj.save()
|
|
return SavedPointOut.model_validate(obj)
|
|
|
|
def patch(self, parsed_path: Path[PkPath], parsed_body: Body[SavedPointPatchIn]) -> SavedPointOut:
|
|
obj = self._get_object(parsed_path.pk)
|
|
updates = parsed_body.model_dump(exclude_unset=True)
|
|
if 'name' in updates:
|
|
_check_saved_point_unique(self.request.user, updates['name'], exclude_pk=obj.pk)
|
|
for field, value in updates.items():
|
|
setattr(obj, field, value)
|
|
obj.save()
|
|
return SavedPointOut.model_validate(obj)
|
|
|
|
@modify(status_code=HTTPStatus.NO_CONTENT)
|
|
def delete(self, parsed_path: Path[PkPath]) -> None:
|
|
self._get_object(parsed_path.pk).delete()
|
|
|
|
|
|
class PredictionTemplateListController(Controller[PydanticSerializer]):
|
|
"""Collection endpoint for the current user's templates (was PreditctionTemplateViewset).
|
|
|
|
NOTE: as before, there is no app-level uniqueness check; a duplicate
|
|
(user, name) hits the model's unique_together and surfaces as a DB error.
|
|
"""
|
|
|
|
def get(self) -> list[PredictionTemplateOut]:
|
|
qs = PreditctionTemplate.objects.filter(user=self.request.user)
|
|
return [PredictionTemplateOut.model_validate(obj) for obj in qs]
|
|
|
|
def post(self, parsed_body: Body[PredictionTemplateIn]) -> PredictionTemplateOut:
|
|
obj = PreditctionTemplate.objects.create(
|
|
user=self.request.user, **parsed_body.model_dump()
|
|
)
|
|
return PredictionTemplateOut.model_validate(obj)
|
|
|
|
|
|
class PredictionTemplateDetailController(Controller[PydanticSerializer]):
|
|
"""Detail endpoint for a single template owned by the current user."""
|
|
|
|
def _get_object(self, pk: int):
|
|
obj = PreditctionTemplate.objects.filter(user=self.request.user, pk=pk).first()
|
|
if obj is None:
|
|
raise APIError(status_code=HTTPStatus.NOT_FOUND, body={'detail': 'Not found.'})
|
|
return obj
|
|
|
|
def get(self, parsed_path: Path[PkPath]) -> PredictionTemplateOut:
|
|
return PredictionTemplateOut.model_validate(self._get_object(parsed_path.pk))
|
|
|
|
def put(self, parsed_path: Path[PkPath], parsed_body: Body[PredictionTemplateIn]) -> PredictionTemplateOut:
|
|
obj = self._get_object(parsed_path.pk)
|
|
for field, value in parsed_body.model_dump().items():
|
|
setattr(obj, field, value)
|
|
obj.save()
|
|
return PredictionTemplateOut.model_validate(obj)
|
|
|
|
def patch(self, parsed_path: Path[PkPath], parsed_body: Body[PredictionTemplatePatchIn]) -> PredictionTemplateOut:
|
|
obj = self._get_object(parsed_path.pk)
|
|
for field, value in parsed_body.model_dump(exclude_unset=True).items():
|
|
setattr(obj, field, value)
|
|
obj.save()
|
|
return PredictionTemplateOut.model_validate(obj)
|
|
|
|
@modify(status_code=HTTPStatus.NO_CONTENT)
|
|
def delete(self, parsed_path: Path[PkPath]) -> None:
|
|
self._get_object(parsed_path.pk).delete()
|
|
|
|
|
|
class UserProfileController(Controller[PydanticSerializer]):
|
|
"""Read and partially update the current user's profile."""
|
|
|
|
def get(self) -> UserOut:
|
|
return UserOut.model_validate(self.request.user)
|
|
|
|
def patch(self, parsed_body: Body[UserUpdateIn]) -> UserOut:
|
|
user = self.request.user
|
|
# partial update: apply only the fields the client actually sent.
|
|
for field, value in parsed_body.model_dump(exclude_unset=True).items():
|
|
setattr(user, field, value)
|
|
user.save()
|
|
return UserOut.model_validate(user)
|
|
|
|
|
|
class ChangePasswordController(Controller[PydanticSerializer]):
|
|
"""Change the current user's password."""
|
|
|
|
# post() would default to 201; the original returned 200.
|
|
@modify(status_code=HTTPStatus.OK)
|
|
def post(self, parsed_body: Body[ChangePasswordIn]) -> DetailResponse:
|
|
user = self.request.user
|
|
|
|
if not user.check_password(parsed_body.old_password):
|
|
raise APIError(
|
|
status_code=HTTPStatus.BAD_REQUEST,
|
|
body={'detail': 'Old password is incorrect'},
|
|
)
|
|
|
|
user.set_password(parsed_body.new_password)
|
|
user.save()
|
|
return DetailResponse(detail='Password changed successfully')
|
|
|
|
|
|
class DeleteAccountController(Controller[PydanticSerializer]):
|
|
"""Delete the current user's account and all their data."""
|
|
|
|
# DMR forbids a Body component on DELETE, but the original endpoint read the
|
|
# password from a DELETE request body. Preserve that contract by parsing the
|
|
# body manually instead of via Body[DeleteAccountIn].
|
|
def delete(self) -> DetailResponse:
|
|
user = self.request.user
|
|
|
|
try:
|
|
parsed_body = DeleteAccountIn(**json.loads(self.request.body or b'{}'))
|
|
except ValidationError as exc:
|
|
raise APIError(status_code=HTTPStatus.BAD_REQUEST, body=json.loads(exc.json()))
|
|
except ValueError:
|
|
raise APIError(status_code=HTTPStatus.BAD_REQUEST, body={'detail': 'Invalid request body.'})
|
|
|
|
if not user.check_password(parsed_body.password):
|
|
raise APIError(
|
|
status_code=HTTPStatus.BAD_REQUEST,
|
|
body={'detail': 'Incorrect password'},
|
|
)
|
|
|
|
Prediction.objects.filter(user=user).delete()
|
|
SavedPoint.objects.filter(user=user).delete()
|
|
PreditctionTemplate.objects.filter(user=user).delete()
|
|
|
|
user.delete()
|
|
|
|
return DetailResponse(detail='Account deleted successfully')
|
|
|
|
|
|
class DeleteUserDataController(Controller[PydanticSerializer]):
|
|
"""Delete all of the current user's data without deleting the account."""
|
|
|
|
def delete(self) -> DetailResponse:
|
|
user = self.request.user
|
|
|
|
Prediction.objects.filter(user=user).delete()
|
|
SavedPoint.objects.filter(user=user).delete()
|
|
PreditctionTemplate.objects.filter(user=user).delete()
|
|
|
|
return DetailResponse(detail='All user data deleted successfully')
|
|
|
|
|
|
class ObtainTokenController(
|
|
ObtainTokensSyncController[
|
|
PydanticSerializer,
|
|
ObtainTokensPayload,
|
|
ObtainTokensResponse,
|
|
],
|
|
):
|
|
"""Exchange username/password for JWT access + refresh tokens.
|
|
|
|
Replaces DRF's obtain_auth_token. Approved drift: the token format and
|
|
semantics change from a single stored DRF token to stateless JWTs, so the
|
|
response is {access_token, refresh_token} instead of {token}.
|
|
"""
|
|
|
|
auth = None # public: credentials are supplied in the request body
|
|
csrf_exempt = True
|
|
jwt_expiration = timedelta(hours=1)
|
|
jwt_refresh_expiration = timedelta(days=7)
|
|
|
|
def convert_auth_payload(self, payload):
|
|
return payload
|
|
|
|
def make_api_response(self):
|
|
now = datetime.now(timezone.utc)
|
|
return {
|
|
'access_token': self.create_jwt_token(
|
|
expiration=now + self.jwt_expiration,
|
|
token_type='access',
|
|
),
|
|
'refresh_token': self.create_jwt_token(
|
|
expiration=now + self.jwt_refresh_expiration,
|
|
token_type='refresh',
|
|
),
|
|
}
|
|
|
|
|
|
class TokenManagementController(Controller[PydanticSerializer]):
|
|
"""Issue a fresh JWT access token for the current user.
|
|
|
|
Was TokenManagementView (DRF stored-token get/regenerate). With stateless
|
|
JWTs there is nothing stored to fetch or delete, so both GET and POST mint a
|
|
new access token (approved drift). The {"token": ...} response shape is kept.
|
|
"""
|
|
|
|
jwt_expiration = timedelta(hours=1)
|
|
|
|
def get(self) -> TokenResponse:
|
|
return TokenResponse(token=self._issue_token())
|
|
|
|
# post() would default to 201; the original returned 200.
|
|
@modify(status_code=HTTPStatus.OK)
|
|
def post(self) -> TokenResponse:
|
|
return TokenResponse(token=self._issue_token())
|
|
|
|
def _issue_token(self) -> str:
|
|
now = datetime.now(timezone.utc)
|
|
return self.create_jwt_token(
|
|
subject=str(self.request.user.pk),
|
|
expiration=now + self.jwt_expiration,
|
|
token_type='access',
|
|
)
|
|
|
|
|
|
|
|
# class PredictionCreateView(APIView):
|
|
# permission_classes = [IsAuthenticated]
|
|
|
|
# class TelemetryPacket(models.Model):
|
|
# id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
|
|
# satellite = models.ForeignKey(Satellite, on_delete=models.CASCADE, related_name="telemetry")
|
|
# timestamp = models.BigIntegerField() # unix time
|
|
# lat = models.FloatField()
|
|
# lon = models.FloatField()
|
|
# alt = models.FloatField()
|
|
# payload = models.JSONField(null=True, blank=True)
|
|
# created_at = models.DateTimeField(auto_now_add=True)
|
|
# fields = ['id', 'timestamp', 'lat', 'lon', 'alt', 'payload']
|