added validation, tawhiri request creation

This commit is contained in:
afanasyev.aa 2025-04-05 03:05:30 +09:00
parent 456551cd4e
commit 2aef4d4756
12 changed files with 184 additions and 18 deletions

Binary file not shown.

Binary file not shown.

View file

@ -1,7 +1,73 @@
from rest_framework import serializers from rest_framework import serializers
from .models import Prediction from .models import Prediction
from datetime import datetime
from .validators import (
validate_custom_curve, rate_clip,
_rfc3339_to_timestamp, base64_to_curve
)
class PredictionSerializer(serializers.ModelSerializer): class PredictionSerializer(serializers.ModelSerializer):
class Meta: class Meta:
model = Prediction model = Prediction
fields = ['id', 'created_at', 'updated_at', 'result'] fields = ['id', 'created_at', 'updated_at', 'result']
PROFILE_STANDARD = "standard_profile"
PROFILE_FLOAT = "float"
PROFILE_REVERSE = "reverse"
PROFILE_CUSTOM = "custom"
LATEST_DATASET_KEYWORD = "latest"
SUPPORTED_PROFILES = [PROFILE_STANDARD, PROFILE_FLOAT, PROFILE_REVERSE, PROFILE_CUSTOM]
class PredictionRequestSerializer(serializers.Serializer):
launch_latitude = serializers.FloatField(min_value=-90, max_value=90)
launch_longitude = serializers.FloatField(min_value=0, max_value=360)
launch_datetime = serializers.DateTimeField()
launch_altitude = serializers.FloatField(required=False)
format = serializers.CharField(default="json")
profile = serializers.ChoiceField(choices=SUPPORTED_PROFILES, default=PROFILE_STANDARD)
dataset = serializers.CharField(default=LATEST_DATASET_KEYWORD)
# --- профиль-dependent поля ---
ascent_rate = serializers.FloatField(required=False, min_value=0.01)
descent_rate = serializers.FloatField(required=False, min_value=0.01)
burst_altitude = serializers.FloatField(required=False)
float_altitude = serializers.FloatField(required=False)
stop_datetime = serializers.DateTimeField(required=False)
ascent_curve = serializers.CharField(required=False)
descent_curve = serializers.CharField(required=False)
interpolate = serializers.BooleanField(required=False, default=False)
def validate(self, data):
profile = data.get("profile", PROFILE_STANDARD)
launch_alt = data.get("launch_altitude", 0)
if profile == PROFILE_STANDARD:
if 'burst_altitude' not in data:
raise serializers.ValidationError("burst_altitude is required for standard profile.")
if data['burst_altitude'] <= launch_alt:
raise serializers.ValidationError("burst_altitude must be greater than launch_altitude.")
elif profile == PROFILE_FLOAT:
if 'float_altitude' not in data or data['float_altitude'] <= launch_alt:
raise serializers.ValidationError("float_altitude must be greater than launch_altitude.")
if 'stop_datetime' not in data or data['stop_datetime'] <= data['launch_datetime']:
raise serializers.ValidationError("stop_datetime must be later than launch_datetime.")
elif profile == PROFILE_CUSTOM:
if 'ascent_curve' not in data or not validate_custom_curve(data['ascent_curve']):
raise serializers.ValidationError("Invalid ascent_curve.")
if 'descent_curve' not in data or not validate_custom_curve(data['descent_curve']):
raise serializers.ValidationError("Invalid descent_curve.")
if 'burst_altitude' not in data or data['burst_altitude'] <= launch_alt:
raise serializers.ValidationError("burst_altitude must be greater than launch_altitude.")
# кастомная логика clipping'а
if 'ascent_rate' in data:
data['ascent_rate'] = rate_clip(data['ascent_rate'])
if 'descent_rate' in data:
data['descent_rate'] = rate_clip(data['descent_rate'])
return data

Binary file not shown.

52
api/services/tawhiri.py Normal file
View file

@ -0,0 +1,52 @@
import requests
from urllib.parse import urlencode
from datetime import datetime
from typing import Any
from zoneinfo import ZoneInfo
from collections import OrderedDict
class TawhiriClient:
BASE_URL = "https://fly.stratonautica.ru/api/v2/"
TIMEOUT = 15
@staticmethod
def _convert_value(value: Any) -> Any:
if isinstance(value, datetime):
return value.isoformat().replace("+00:00", "Z")
return value
@classmethod
def get_prediction(cls, params: dict) -> dict:
url = cls.build_url(params)
print("🔍 URL:", url)
response = requests.get(url, timeout=cls.TIMEOUT)
response.raise_for_status()
return response.json()
@classmethod
def build_url(cls, params: dict) -> str:
query = OrderedDict()
query["profile"] = params.get("profile")
query["launch_datetime"] = cls._convert_value(params.get("launch_datetime"))
query["launch_latitude"] = params.get("launch_latitude")
query["launch_longitude"] = params.get("launch_longitude")
query["launch_altitude"] = params.get("launch_altitude", 0)
query["ascent_rate"] = params.get("ascent_rate")
query["burst_altitude"] = params.get("burst_altitude")
query["descent_rate"] = params.get("descent_rate")
query["interpolate"] = str(params.get("interpolate", False)).lower()
query["dataset"] = cls._convert_value(params.get("dataset"))
query["format"] = params.get("format", "json")
query["pred_type"] = "single" # <-- в конце
filtered = {k: v for k, v in query.items() if v is not None}
return f"{cls.BASE_URL}?{urlencode(filtered)}"
@classmethod
def get_prediction(cls, params: dict) -> dict:
url = cls.build_url(params)
response = requests.get(url, timeout=cls.TIMEOUT)
response.raise_for_status()
return response.json()

31
api/validators.py Normal file
View file

@ -0,0 +1,31 @@
import base64
import json
from datetime import datetime
def rate_clip(rate):
"""Ограничивает допустимые значения скорости (например, 0.1 ≤ x ≤ 100)"""
return min(max(rate, 0.1), 100.0)
def _rfc3339_to_timestamp(value):
"""Парсинг RFC 3339 строки в datetime"""
return datetime.fromisoformat(value.replace("Z", "+00:00"))
def base64_to_curve(encoded):
"""Декодирует base64-encoded curve"""
try:
decoded = base64.b64decode(encoded).decode('utf-8')
return json.loads(decoded)
except Exception as e:
raise ValueError(f"Invalid curve format: {e}")
def validate_custom_curve(curve):
"""Проверяет, что curve имеет ожидаемую структуру (например, список точек)"""
try:
points = base64_to_curve(curve)
return isinstance(points, list) and all(isinstance(p, list) and len(p) == 2 for p in points)
except Exception:
return False

View file

@ -3,9 +3,14 @@ from rest_framework.response import Response
from rest_framework.views import APIView from rest_framework.views import APIView
from django.utils import timezone from django.utils import timezone
from .models import Prediction, User, UserPrediction from .models import Prediction, User, UserPrediction
from .serializers import PredictionSerializer from .serializers import PredictionSerializer, PredictionRequestSerializer
from rest_framework.permissions import IsAuthenticated from rest_framework.permissions import IsAuthenticated
import requests import requests
from django.views.decorators.csrf import csrf_exempt
from django.utils.decorators import method_decorator
from rest_framework.permissions import AllowAny
from .services.tawhiri import TawhiriClient
def get_prediction_from_tawhiri(params): def get_prediction_from_tawhiri(params):
base_url = "https://fly.stratonautica.ru/api/v2" base_url = "https://fly.stratonautica.ru/api/v2"
@ -15,25 +20,31 @@ def get_prediction_from_tawhiri(params):
return response.json() # получаем результат предсказания return response.json() # получаем результат предсказания
else: else:
raise Exception(f"Tawhiri error: {response.status_code} {response.text}") raise Exception(f"Tawhiri error: {response.status_code} {response.text}")
class PredictionCreateView(APIView): class PredictionCreateView(APIView):
def post(self, request): permission_classes = [AllowAny]
user_id = request.data.get('user_id')
user = User.objects.get(id=user_id) def post(self, request):
serializer = PredictionRequestSerializer(data=request.data)
if not serializer.is_valid():
return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
validated_data = serializer.validated_data
# Передаём остальные параметры (кроме user_id) в Tawhiri
tawhiri_params = {k: v for k, v in request.data.items() if k != 'user_id'}
try: try:
prediction_result = get_prediction_from_tawhiri(tawhiri_params) prediction_result = TawhiriClient.get_prediction(validated_data)
except Exception as e: except requests.RequestException as e:
return Response({"error": str(e)}, status=500) print("❌ Tawhiri error:", str(e), e.response.text if e.response else "no response")
return Response({"error": f"Tawhiri error: {str(e)}"}, status=status.HTTP_502_BAD_GATEWAY)
prediction = Prediction.objects.create(result=prediction_result) prediction = Prediction.objects.create(result=prediction_result)
UserPrediction.objects.create(user=user, prediction=prediction) UserPrediction.objects.create(user=request.user, prediction=prediction, created_at=timezone.now())
return Response(PredictionSerializer(prediction).data) return Response({
"id": prediction.id,
"created_at": prediction.created_at,
"result": prediction_result
}, status=status.HTTP_201_CREATED)
class PredictionListView(APIView): class PredictionListView(APIView):
def get(self, request): def get(self, request):
@ -59,5 +70,5 @@ class PredictionDeleteView(APIView):
except Prediction.DoesNotExist: except Prediction.DoesNotExist:
return Response({"error": "Not found"}, status=404) return Response({"error": "Not found"}, status=404)
class PredictionCreateView(APIView): #class PredictionCreateView(APIView):
permission_classes = [IsAuthenticated] #permission_classes = [IsAuthenticated]

Binary file not shown.

View file

@ -40,10 +40,12 @@ INSTALLED_APPS = [
'rest_framework', 'rest_framework',
'rest_framework.authtoken', 'rest_framework.authtoken',
'drf_spectacular', 'drf_spectacular',
'corsheaders',
'api' 'api'
] ]
MIDDLEWARE = [ MIDDLEWARE = [
'corsheaders.middleware.CorsMiddleware',
'django.middleware.security.SecurityMiddleware', 'django.middleware.security.SecurityMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware', 'django.contrib.sessions.middleware.SessionMiddleware',
'django.middleware.common.CommonMiddleware', 'django.middleware.common.CommonMiddleware',
@ -53,6 +55,8 @@ MIDDLEWARE = [
'django.middleware.clickjacking.XFrameOptionsMiddleware', 'django.middleware.clickjacking.XFrameOptionsMiddleware',
] ]
CORS_ALLOW_ALL_ORIGINS = True
ROOT_URLCONF = 'testapi.urls' ROOT_URLCONF = 'testapi.urls'
TEMPLATES = [ TEMPLATES = [
@ -138,9 +142,11 @@ REST_FRAMEWORK = {
# ВАШИ НАСТРОЙКИ # ВАШИ НАСТРОЙКИ
'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema', 'DEFAULT_SCHEMA_CLASS': 'drf_spectacular.openapi.AutoSchema',
'DEFAULT_AUTHENTICATION_CLASSES': [ 'DEFAULT_AUTHENTICATION_CLASSES': [
'rest_framework.authentication.TokenAuthentication', # 'rest_framework.authentication.TokenAuthentication',
# 'rest_framework.permissions.AllowAny',
], ],
'DEFAULT_PERMISSION_CLASSES': [ 'DEFAULT_PERMISSION_CLASSES': [
'rest_framework.permissions.IsAuthenticated', #'rest_framework.permissions.IsAuthenticated',
'rest_framework.permissions.AllowAny',
] ]
} }