forked from gsn/predictor
feat: predictions
This commit is contained in:
parent
42e7924be9
commit
11be8f351f
42 changed files with 2221 additions and 516 deletions
280
scripts/test_predictor_vs_reference.py
Normal file
280
scripts/test_predictor_vs_reference.py
Normal file
|
|
@ -0,0 +1,280 @@
|
|||
#!/usr/bin/env python3
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import requests
|
||||
import json
|
||||
from typing import Any
|
||||
import base64
|
||||
|
||||
# --- Config ---
|
||||
LOCAL_API_URL = "http://localhost:8080/api/v1/prediction"
|
||||
REFERENCE_API_URL = (
|
||||
"https://fly.stratonautica.ru/api/v2/?profile=standard_profile&pred_type=single"
|
||||
"&launch_datetime=2025-06-25T13:28:00Z&launch_latitude=56.6992&launch_longitude=38.8247"
|
||||
"&launch_altitude=0&ascent_rate=5&burst_altitude=30000&descent_rate=5"
|
||||
)
|
||||
LOCAL_API_PAYLOAD = {
|
||||
"launch_latitude": 56.6992,
|
||||
"launch_longitude": 38.8247,
|
||||
"launch_datetime": "2025-06-25T13:28:00Z",
|
||||
"launch_altitude": 0,
|
||||
"profile": "standard_profile",
|
||||
"ascent_rate": 5,
|
||||
"burst_altitude": 30000,
|
||||
"descent_rate": 5,
|
||||
"format": "json"
|
||||
}
|
||||
READY_URL = "http://localhost:8080/ready"
|
||||
|
||||
# --- Utility functions ---
|
||||
def run_compose_up():
|
||||
print("[INFO] Running docker-compose down --remove-orphans ...")
|
||||
result = subprocess.run(["docker-compose", "down", "--remove-orphans"], capture_output=True)
|
||||
if result.returncode != 0:
|
||||
print("[ERROR] docker-compose down failed:", result.stderr.decode())
|
||||
sys.exit(1)
|
||||
print("[INFO] docker-compose down completed.")
|
||||
print("[INFO] Running docker-compose up -d ...")
|
||||
result = subprocess.run(["docker-compose", "up", "-d"], capture_output=True)
|
||||
if result.returncode != 0:
|
||||
print("[ERROR] docker-compose up failed:", result.stderr.decode())
|
||||
sys.exit(1)
|
||||
print("[INFO] docker-compose up -d completed.")
|
||||
return True
|
||||
|
||||
def wait_for_ready(timeout=900):
|
||||
print(f"[INFO] Waiting for {READY_URL} to be ready ...")
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
try:
|
||||
resp = requests.get(READY_URL, timeout=10)
|
||||
if resp.status_code == 200:
|
||||
data = resp.json()
|
||||
if data.get("status") == "ok":
|
||||
print("[INFO] Service is ready.")
|
||||
return
|
||||
else:
|
||||
print(f"[INFO] Not ready yet: {data}")
|
||||
else:
|
||||
print(f"[INFO] /ready returned status {resp.status_code}")
|
||||
except Exception as e:
|
||||
print(f"[INFO] Exception while polling /ready: {e}")
|
||||
time.sleep(10)
|
||||
print(f"[ERROR] Service did not become ready in {timeout} seconds.")
|
||||
sys.exit(1)
|
||||
|
||||
def fetch_reference():
|
||||
print(f"[INFO] Fetching reference prediction from {REFERENCE_API_URL}")
|
||||
resp = requests.get(REFERENCE_API_URL, timeout=60)
|
||||
if resp.status_code != 200:
|
||||
print(f"[ERROR] Reference API returned {resp.status_code}")
|
||||
sys.exit(1)
|
||||
return resp.json()
|
||||
|
||||
def fetch_local():
|
||||
print(f"[INFO] Fetching local prediction from {LOCAL_API_URL}")
|
||||
resp = requests.post(LOCAL_API_URL, json=LOCAL_API_PAYLOAD, timeout=120)
|
||||
if resp.status_code != 200:
|
||||
print(f"[ERROR] Local API returned {resp.status_code}: {resp.text}")
|
||||
sys.exit(1)
|
||||
return resp.json()
|
||||
|
||||
def compare_results(reference_data, local_data):
|
||||
"""Compare prediction results between reference and local APIs."""
|
||||
print("[INFO] Comparing results ...")
|
||||
|
||||
# Extract trajectory data
|
||||
ref_trajectory = reference_data.get('prediction', [{}])[0].get('trajectory', [])
|
||||
local_trajectory = local_data.get('prediction', [{}])[0].get('trajectory', [])
|
||||
|
||||
print(f"[DEBUG] Reference trajectory length: {len(ref_trajectory)}")
|
||||
print(f"[DEBUG] Local trajectory length: {len(local_trajectory)}")
|
||||
|
||||
# Show first 3 points from both APIs
|
||||
print("\n[DEBUG] First 3 points - Reference API:")
|
||||
for i, point in enumerate(ref_trajectory[:3]):
|
||||
print(f" [{i}] alt={point.get('altitude', 'N/A')}, lat={point.get('latitude', 'N/A')}, lon={point.get('longitude', 'N/A')}, time={point.get('datetime', 'N/A')}")
|
||||
|
||||
print("\n[DEBUG] First 3 points - Local API:")
|
||||
for i, point in enumerate(local_trajectory[:3]):
|
||||
print(f" [{i}] alt={point.get('altitude', 'N/A')}, lat={point.get('latitude', 'N/A')}, lon={point.get('longitude', 'N/A')}, time={point.get('datetime', 'N/A')}")
|
||||
|
||||
# Show last 3 points from both APIs
|
||||
print("\n[DEBUG] Last 3 points - Reference API:")
|
||||
for i, point in enumerate(ref_trajectory[-3:]):
|
||||
idx = len(ref_trajectory) - 3 + i
|
||||
print(f" [{idx}] alt={point.get('altitude', 'N/A')}, lat={point.get('latitude', 'N/A')}, lon={point.get('longitude', 'N/A')}, time={point.get('datetime', 'N/A')}")
|
||||
|
||||
print("\n[DEBUG] Last 3 points - Local API:")
|
||||
for i, point in enumerate(local_trajectory[-3:]):
|
||||
idx = len(local_trajectory) - 3 + i
|
||||
print(f" [{idx}] alt={point.get('altitude', 'N/A')}, lat={point.get('latitude', 'N/A')}, lon={point.get('longitude', 'N/A')}, time={point.get('datetime', 'N/A')}")
|
||||
|
||||
# Compare trajectory lengths
|
||||
if len(ref_trajectory) != len(local_trajectory):
|
||||
print(f"[DIFF] Trajectory length mismatch: {len(local_trajectory)} vs {len(ref_trajectory)}")
|
||||
return False
|
||||
|
||||
# Compare trajectory points
|
||||
min_len = min(len(ref_trajectory), len(local_trajectory))
|
||||
for i in range(min_len):
|
||||
ref_point = ref_trajectory[i]
|
||||
local_point = local_trajectory[i]
|
||||
|
||||
# Compare key fields
|
||||
for key in ['altitude', 'latitude', 'longitude']:
|
||||
ref_val = ref_point.get(key)
|
||||
local_val = local_point.get(key)
|
||||
|
||||
if ref_val is not None and local_val is not None:
|
||||
# Use relative tolerance for floating point comparison
|
||||
if abs(ref_val - local_val) > 0.1: # 0.1 degree/meter tolerance
|
||||
print(f"[DIFF] At idx {i}, key {key}: {local_val} != {ref_val}")
|
||||
return False
|
||||
|
||||
print("[SUCCESS] Results match!")
|
||||
return True
|
||||
|
||||
def test_custom_profile():
|
||||
"""Test custom profile with base64 encoded curve."""
|
||||
print("\n[TEST] Testing custom_profile...")
|
||||
|
||||
# Create a simple custom ascent curve (altitude vs time in seconds)
|
||||
curve_data = {
|
||||
"altitude": [0, 30000],
|
||||
"time": [0, 6000]
|
||||
}
|
||||
|
||||
curve_b64 = base64.b64encode(json.dumps(curve_data).encode()).decode()
|
||||
|
||||
# Test parameters for custom profile
|
||||
params = {
|
||||
"launch_latitude": 56.6992,
|
||||
"launch_longitude": 38.8247,
|
||||
"launch_datetime": "2025-06-25T13:28:00Z",
|
||||
"launch_altitude": 0,
|
||||
"profile": "custom_profile",
|
||||
"ascent_curve": curve_b64
|
||||
}
|
||||
|
||||
try:
|
||||
# Test local API
|
||||
local_resp = requests.post(
|
||||
"http://localhost:8080/api/v1/prediction",
|
||||
json=params,
|
||||
timeout=30
|
||||
)
|
||||
local_resp.raise_for_status()
|
||||
local_data = local_resp.json()
|
||||
|
||||
print(f"[INFO] Custom profile test - Local API returned {len(local_data.get('prediction', [{}])[0].get('trajectory', []))} trajectory points")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"[ERROR] Custom profile test failed: {e}")
|
||||
return False
|
||||
|
||||
def test_all_profiles():
|
||||
"""Test all available profiles."""
|
||||
profiles = [
|
||||
("standard_profile", "Standard profile test"),
|
||||
("float_profile", "Float profile test"),
|
||||
("reverse_profile", "Reverse profile test"),
|
||||
("custom_profile", "Custom profile test")
|
||||
]
|
||||
|
||||
results = {}
|
||||
|
||||
for profile, description in profiles:
|
||||
print(f"\n[TEST] {description}...")
|
||||
|
||||
if profile == "custom_profile":
|
||||
success = test_custom_profile()
|
||||
else:
|
||||
success = test_single_profile(profile)
|
||||
|
||||
results[profile] = success
|
||||
print(f"[RESULT] {profile}: {'PASS' if success else 'FAIL'}")
|
||||
|
||||
# Print summary
|
||||
print("\n" + "="*50)
|
||||
print("TEST SUMMARY")
|
||||
print("="*50)
|
||||
for profile, success in results.items():
|
||||
status = "PASS" if success else "FAIL"
|
||||
print(f"{profile:20} : {status}")
|
||||
|
||||
total_tests = len(results)
|
||||
passed_tests = sum(results.values())
|
||||
print(f"\nTotal tests: {total_tests}, Passed: {passed_tests}, Failed: {total_tests - passed_tests}")
|
||||
|
||||
return all(results.values())
|
||||
|
||||
def test_single_profile(profile):
|
||||
"""Test a single profile against reference API."""
|
||||
# Test parameters
|
||||
params = {
|
||||
"launch_latitude": 56.6992,
|
||||
"launch_longitude": 38.8247,
|
||||
"launch_datetime": "2025-06-25T13:28:00Z",
|
||||
"launch_altitude": 0,
|
||||
"profile": profile,
|
||||
"ascent_rate": 5,
|
||||
"burst_altitude": 30000,
|
||||
"descent_rate": 5
|
||||
}
|
||||
|
||||
# Add float altitude for float profile
|
||||
if profile == "float_profile":
|
||||
params["float_altitude"] = 25000
|
||||
|
||||
try:
|
||||
# Test local API
|
||||
local_resp = requests.post(
|
||||
"http://localhost:8080/api/v1/prediction",
|
||||
json=params,
|
||||
timeout=30
|
||||
)
|
||||
local_resp.raise_for_status()
|
||||
local_data = local_resp.json()
|
||||
|
||||
print(f"[INFO] {profile} - Local API returned {len(local_data.get('prediction', [{}])[0].get('trajectory', []))} trajectory points")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"[ERROR] {profile} test failed: {e}")
|
||||
return False
|
||||
|
||||
def main():
|
||||
"""Main test function."""
|
||||
print("[INFO] Starting comprehensive predictor API tests...")
|
||||
|
||||
# Run the original standard profile test
|
||||
print("\n[TEST] Running original standard_profile test...")
|
||||
run_compose_up()
|
||||
wait_for_ready()
|
||||
ref = fetch_reference()
|
||||
local = fetch_local()
|
||||
|
||||
print("[INFO] Comparing results ...")
|
||||
original_success = compare_results(ref, local)
|
||||
|
||||
if original_success:
|
||||
print("[SUCCESS] Original standard_profile test passed!")
|
||||
else:
|
||||
print("[FAIL] Original standard_profile test failed!")
|
||||
|
||||
# Test all profiles
|
||||
print("\n[TEST] Running all profile tests...")
|
||||
all_profiles_success = test_all_profiles()
|
||||
|
||||
# Final result
|
||||
overall_success = original_success and all_profiles_success
|
||||
print(f"\n[FINAL RESULT] Overall: {'PASS' if overall_success else 'FAIL'}")
|
||||
|
||||
if overall_success:
|
||||
sys.exit(0)
|
||||
else:
|
||||
sys.exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Docker validation script
|
||||
set -e
|
||||
|
||||
echo "🔍 Validating Docker configuration..."
|
||||
|
||||
# Check if Docker is available
|
||||
if ! command -v docker &> /dev/null; then
|
||||
echo "❌ Docker is not installed or not in PATH"
|
||||
echo "Please install Docker Desktop and enable WSL 2 integration"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Check if docker-compose is available
|
||||
if ! command -v docker-compose &> /dev/null; then
|
||||
echo "❌ Docker Compose is not installed or not in PATH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✅ Docker and Docker Compose are available"
|
||||
|
||||
# Validate Dockerfile syntax
|
||||
echo "🔍 Validating Dockerfile..."
|
||||
if docker build --dry-run . > /dev/null 2>&1; then
|
||||
echo "✅ Dockerfile syntax is valid"
|
||||
else
|
||||
echo "❌ Dockerfile syntax is invalid"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Validate docker-compose.yml
|
||||
echo "🔍 Validating docker-compose.yml..."
|
||||
if docker-compose config > /dev/null 2>&1; then
|
||||
echo "✅ docker-compose.yml is valid"
|
||||
else
|
||||
echo "❌ docker-compose.yml is invalid"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Validate docker-compose.dev.yml
|
||||
echo "🔍 Validating docker-compose.dev.yml..."
|
||||
if docker-compose -f docker-compose.dev.yml config > /dev/null 2>&1; then
|
||||
echo "✅ docker-compose.dev.yml is valid"
|
||||
else
|
||||
echo "❌ docker-compose.dev.yml is invalid"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "✅ All Docker configurations are valid!"
|
||||
echo ""
|
||||
echo "🚀 Ready to build and run:"
|
||||
echo " make build # Build Docker image"
|
||||
echo " make up # Start services"
|
||||
echo " make logs # View logs"
|
||||
Loading…
Add table
Add a link
Reference in a new issue