added validation, tawhiri request creation

This commit is contained in:
afanasyev.aa 2025-04-05 03:05:30 +09:00
parent 456551cd4e
commit 2aef4d4756
12 changed files with 184 additions and 18 deletions

Binary file not shown.

Binary file not shown.

View file

@ -1,7 +1,73 @@
from rest_framework import serializers
from .models import Prediction
from datetime import datetime
from .validators import (
validate_custom_curve, rate_clip,
_rfc3339_to_timestamp, base64_to_curve
)
class PredictionSerializer(serializers.ModelSerializer):
class Meta:
model = Prediction
fields = ['id', 'created_at', 'updated_at', 'result']
PROFILE_STANDARD = "standard_profile"
PROFILE_FLOAT = "float"
PROFILE_REVERSE = "reverse"
PROFILE_CUSTOM = "custom"
LATEST_DATASET_KEYWORD = "latest"
SUPPORTED_PROFILES = [PROFILE_STANDARD, PROFILE_FLOAT, PROFILE_REVERSE, PROFILE_CUSTOM]
class PredictionRequestSerializer(serializers.Serializer):
launch_latitude = serializers.FloatField(min_value=-90, max_value=90)
launch_longitude = serializers.FloatField(min_value=0, max_value=360)
launch_datetime = serializers.DateTimeField()
launch_altitude = serializers.FloatField(required=False)
format = serializers.CharField(default="json")
profile = serializers.ChoiceField(choices=SUPPORTED_PROFILES, default=PROFILE_STANDARD)
dataset = serializers.CharField(default=LATEST_DATASET_KEYWORD)
# --- профиль-dependent поля ---
ascent_rate = serializers.FloatField(required=False, min_value=0.01)
descent_rate = serializers.FloatField(required=False, min_value=0.01)
burst_altitude = serializers.FloatField(required=False)
float_altitude = serializers.FloatField(required=False)
stop_datetime = serializers.DateTimeField(required=False)
ascent_curve = serializers.CharField(required=False)
descent_curve = serializers.CharField(required=False)
interpolate = serializers.BooleanField(required=False, default=False)
def validate(self, data):
profile = data.get("profile", PROFILE_STANDARD)
launch_alt = data.get("launch_altitude", 0)
if profile == PROFILE_STANDARD:
if 'burst_altitude' not in data:
raise serializers.ValidationError("burst_altitude is required for standard profile.")
if data['burst_altitude'] <= launch_alt:
raise serializers.ValidationError("burst_altitude must be greater than launch_altitude.")
elif profile == PROFILE_FLOAT:
if 'float_altitude' not in data or data['float_altitude'] <= launch_alt:
raise serializers.ValidationError("float_altitude must be greater than launch_altitude.")
if 'stop_datetime' not in data or data['stop_datetime'] <= data['launch_datetime']:
raise serializers.ValidationError("stop_datetime must be later than launch_datetime.")
elif profile == PROFILE_CUSTOM:
if 'ascent_curve' not in data or not validate_custom_curve(data['ascent_curve']):
raise serializers.ValidationError("Invalid ascent_curve.")
if 'descent_curve' not in data or not validate_custom_curve(data['descent_curve']):
raise serializers.ValidationError("Invalid descent_curve.")
if 'burst_altitude' not in data or data['burst_altitude'] <= launch_alt:
raise serializers.ValidationError("burst_altitude must be greater than launch_altitude.")
# кастомная логика clipping'а
if 'ascent_rate' in data:
data['ascent_rate'] = rate_clip(data['ascent_rate'])
if 'descent_rate' in data:
data['descent_rate'] = rate_clip(data['descent_rate'])
return data

Binary file not shown.

52
api/services/tawhiri.py Normal file
View file

@ -0,0 +1,52 @@
import requests
from urllib.parse import urlencode
from datetime import datetime
from typing import Any
from zoneinfo import ZoneInfo
from collections import OrderedDict
class TawhiriClient:
BASE_URL = "https://fly.stratonautica.ru/api/v2/"
TIMEOUT = 15
@staticmethod
def _convert_value(value: Any) -> Any:
if isinstance(value, datetime):
return value.isoformat().replace("+00:00", "Z")
return value
@classmethod
def get_prediction(cls, params: dict) -> dict:
url = cls.build_url(params)
print("🔍 URL:", url)
response = requests.get(url, timeout=cls.TIMEOUT)
response.raise_for_status()
return response.json()
@classmethod
def build_url(cls, params: dict) -> str:
query = OrderedDict()
query["profile"] = params.get("profile")
query["launch_datetime"] = cls._convert_value(params.get("launch_datetime"))
query["launch_latitude"] = params.get("launch_latitude")
query["launch_longitude"] = params.get("launch_longitude")
query["launch_altitude"] = params.get("launch_altitude", 0)
query["ascent_rate"] = params.get("ascent_rate")
query["burst_altitude"] = params.get("burst_altitude")
query["descent_rate"] = params.get("descent_rate")
query["interpolate"] = str(params.get("interpolate", False)).lower()
query["dataset"] = cls._convert_value(params.get("dataset"))
query["format"] = params.get("format", "json")
query["pred_type"] = "single" # <-- в конце
filtered = {k: v for k, v in query.items() if v is not None}
return f"{cls.BASE_URL}?{urlencode(filtered)}"
@classmethod
def get_prediction(cls, params: dict) -> dict:
url = cls.build_url(params)
response = requests.get(url, timeout=cls.TIMEOUT)
response.raise_for_status()
return response.json()

31
api/validators.py Normal file
View file

@ -0,0 +1,31 @@
import base64
import json
from datetime import datetime
def rate_clip(rate):
"""Ограничивает допустимые значения скорости (например, 0.1 ≤ x ≤ 100)"""
return min(max(rate, 0.1), 100.0)
def _rfc3339_to_timestamp(value):
"""Парсинг RFC 3339 строки в datetime"""
return datetime.fromisoformat(value.replace("Z", "+00:00"))
def base64_to_curve(encoded):
"""Декодирует base64-encoded curve"""
try:
decoded = base64.b64decode(encoded).decode('utf-8')
return json.loads(decoded)
except Exception as e:
raise ValueError(f"Invalid curve format: {e}")
def validate_custom_curve(curve):
"""Проверяет, что curve имеет ожидаемую структуру (например, список точек)"""
try:
points = base64_to_curve(curve)
return isinstance(points, list) and all(isinstance(p, list) and len(p) == 2 for p in points)
except Exception:
return False

View file

@ -3,9 +3,14 @@ from rest_framework.response import Response
from rest_framework.views import APIView
from django.utils import timezone
from .models import Prediction, User, UserPrediction
from .serializers import PredictionSerializer
from .serializers import PredictionSerializer, PredictionRequestSerializer
from rest_framework.permissions import IsAuthenticated
import requests
from django.views.decorators.csrf import csrf_exempt
from django.utils.decorators import method_decorator
from rest_framework.permissions import AllowAny
from .services.tawhiri import TawhiriClient
def get_prediction_from_tawhiri(params):
base_url = "https://fly.stratonautica.ru/api/v2"
@ -16,24 +21,30 @@ def get_prediction_from_tawhiri(params):
else:
raise Exception(f"Tawhiri error: {response.status_code} {response.text}")
class PredictionCreateView(APIView):
def post(self, request):
user_id = request.data.get('user_id')
user = User.objects.get(id=user_id)
permission_classes = [AllowAny]
# Передаём остальные параметры (кроме user_id) в Tawhiri
tawhiri_params = {k: v for k, v in request.data.items() if k != 'user_id'}
def post(self, request):
serializer = PredictionRequestSerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
validated_data = serializer.validated_data
try:
prediction_result = get_prediction_from_tawhiri(tawhiri_params)
except Exception as e:
return Response({"error": str(e)}, status=500)
prediction_result = TawhiriClient.get_prediction(validated_data)
except requests.RequestException as e:
print("❌ Tawhiri error:", str(e), e.response.text if e.response else "no response")
return Response({"error": f"Tawhiri error: {str(e)}"}, status=status.HTTP_502_BAD_GATEWAY)
prediction = Prediction.objects.create(result=prediction_result)
UserPrediction.objects.create(user=user, prediction=prediction)
UserPrediction.objects.create(user=request.user, prediction=prediction, created_at=timezone.now())
return Response(PredictionSerializer(prediction).data)
return Response({
"id": prediction.id,
"created_at": prediction.created_at,
"result": prediction_result
}, status=status.HTTP_201_CREATED)
class PredictionListView(APIView):
def get(self, request):
@ -59,5 +70,5 @@ class PredictionDeleteView(APIView):
except Prediction.DoesNotExist:
return Response({"error": "Not found"}, status=404)
class PredictionCreateView(APIView):
permission_classes = [IsAuthenticated]
#class PredictionCreateView(APIView):
#permission_classes = [IsAuthenticated]

Binary file not shown.

View file

@ -40,10 +40,12 @@ INSTALLED_APPS = [
'rest_framework',
'rest_framework.authtoken',
'drf_spectacular',
'corsheaders',
'api'
]
MIDDLEWARE = [
'corsheaders.middleware.CorsMiddleware',
'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware',
@ -53,6 +55,8 @@ MIDDLEWARE = [
'django.middleware.clickjacking.XFrameOptionsMiddleware',
]
CORS_ALLOW_ALL_ORIGINS = True
ROOT_URLCONF = 'testapi.urls'
TEMPLATES = [
@ -138,9 +142,11 @@ REST_FRAMEWORK = {
# ВАШИ НАСТРОЙКИ
'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema',
'DEFAULT_AUTHENTICATION_CLASSES': [
'rest_framework.authentication.TokenAuthentication',
# 'rest_framework.authentication.TokenAuthentication',
# 'rest_framework.permissions.AllowAny',
],
'DEFAULT_PERMISSION_CLASSES': [
'rest_framework.permissions.IsAuthenticated',
#'rest_framework.permissions.IsAuthenticated',
'rest_framework.permissions.AllowAny',
]
}