Refactor #7

Open
a.antonov wants to merge 8 commits from afanasyev.aa/predictor:refactor into main
72 changed files with 4511 additions and 4100 deletions

View file

@ -1,55 +0,0 @@
# Git
.git
.gitignore
# Docker
Dockerfile
docker-compose.yml
.dockerignore
# Documentation
README.md
*.md
# IDE
.vscode/
.idea/
*.swp
*.swo
*~
# OS
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
# Environment files
.env
.env.local
.env.*.local
# Build artifacts
predictor
*.exe
*.exe~
*.dll
*.so
*.dylib
*.test
# Logs
*.log
# Temporary files
/tmp/
/temp/
# Test coverage
*.out
# Go workspace
go.work

59
.gitignore vendored
View file

@ -1,62 +1,5 @@
# Binaries for programs and plugins
predictor
*.exe
*.exe~
*.dll
*.so
*.dylib
# Test binary, built with `go test -c`
*.test
# Output of the go coverage tool, specifically when used with LiteIDE
*.out
# Dependency directories (remove the comment below to include it)
# vendor/
# Go workspace file
go.work
# Environment variables
.env
.env.local
.env.*.local
# IDE files
.vscode/
.idea/
*.swp
*.swo
*~
# OS generated files
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
# Logs
*.log
# Temporary files
/tmp/
/temp/
# Build artifacts
/build/
/dist/
# GRIB files
/grib_data/
/grib_data/*
# Leaflet WebUI
/leaflet_predictor
/leaflet_predictor/*
# Tawhiri
/tawhiri
/tawhiri/*

View file

@ -1,3 +0,0 @@
{
"makefile.configureOnOpen": false
}

View file

@ -1,57 +0,0 @@
# Build stage
FROM golang:1.24.4-alpine AS builder
# Install build dependencies
RUN apk add --no-cache git ca-certificates tzdata
# Set working directory
WORKDIR /app
# Copy go mod files
COPY go.mod go.sum ./
# Download dependencies
RUN go mod download
# Copy source code
COPY . .
# Build the application
RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \
-ldflags="-w -s" \
-o predictor \
./cmd/api
# Runtime stage
FROM alpine:3.19
# Install runtime dependencies
RUN apk add --no-cache ca-certificates tzdata
# Create non-root user
RUN addgroup -g 1001 -S appgroup && \
adduser -u 1001 -S appuser -G appgroup
# Set working directory
WORKDIR /app
# Copy binary from builder stage
COPY --from=builder /app/predictor .
# Create necessary directories
RUN mkdir -p /tmp/grib && \
chown -R appuser:appgroup /app && \
chmod -R 777 /tmp/grib
# Switch to non-root user
USER appuser
# Expose port
EXPOSE 8080
# Health check
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
CMD wget --no-verbose --tries=1 --spider http://localhost:8080/ready || exit 1
# Run the application
CMD ["./predictor"]

111
Makefile
View file

@ -1,111 +1,40 @@
# Variables
IMAGE_NAME = predictor
TAG = latest
COMPOSE_FILE = docker-compose.yml
.PHONY: build run test fmt lint clean generate-ogen help
# Validate Docker configuration
.PHONY: validate-docker
validate-docker:
./scripts/validate-docker.sh
# Build the Docker image
.PHONY: build
# Build the application
build:
docker build -t $(IMAGE_NAME):$(TAG) .
# Run the application with docker-compose
.PHONY: up
up:
docker-compose -f $(COMPOSE_FILE) up -d
# Run the application with docker-compose and rebuild
.PHONY: up-build
up-build:
docker-compose -f $(COMPOSE_FILE) up -d --build
# Stop the application
.PHONY: down
down:
docker-compose -f $(COMPOSE_FILE) down
# Stop the application and remove volumes
.PHONY: down-volumes
down-volumes:
docker-compose -f $(COMPOSE_FILE) down -v
# View logs
.PHONY: logs
logs:
docker-compose -f $(COMPOSE_FILE) logs -f
# View logs for specific service
.PHONY: logs-predictor
logs-predictor:
docker-compose -f $(COMPOSE_FILE) logs -f predictor
# Check service status
.PHONY: ps
ps:
docker-compose -f $(COMPOSE_FILE) ps
# Execute command in predictor container
.PHONY: exec
exec:
docker-compose -f $(COMPOSE_FILE) exec predictor sh
# Clean up Docker resources
.PHONY: clean
clean:
docker-compose -f $(COMPOSE_FILE) down -v --rmi all
docker system prune -f
# Run tests
.PHONY: test
test:
go test ./...
# Build locally
.PHONY: build-local
build-local:
go build -o predictor ./cmd/api
# Run locally
.PHONY: run-local
run-local:
cd cmd/api && go run .
run:
go run ./cmd/api
# Run tests
test:
go test ./...
# Format code
.PHONY: fmt
fmt:
go fmt ./...
# Lint code
.PHONY: lint
lint:
golangci-lint run
# Generate ogen API code from swagger spec
generate-ogen:
go run github.com/ogen-go/ogen/cmd/ogen@latest --target pkg/rest --package rest --clean api/rest/predictor.swagger.yml
# Clean build artifacts
clean:
rm -f predictor
# Show help
.PHONY: help
help:
@echo "Available commands:"
@echo " validate-docker - Validate Docker configuration"
@echo " build - Build Docker image"
@echo " up - Start services with docker-compose"
@echo " up-build - Start services and rebuild images"
@echo " down - Stop services"
@echo " down-volumes - Stop services and remove volumes"
@echo " logs - View all logs"
@echo " logs-predictor - View predictor logs"
@echo " ps - Show service status"
@echo " exec - Execute shell in predictor container"
@echo " clean - Clean up Docker resources"
@echo " build - Build binary"
@echo " run - Run locally"
@echo " test - Run tests"
@echo " build-local - Build locally"
@echo " run-local - Run locally"
@echo " fmt - Format code"
@echo " lint - Lint code"
@echo " help - Show this help"
generate-ogen:
go run github.com/ogen-go/ogen/cmd/ogen@latest --target pkg/rest -package gsn --clean api/rest/predictor.swagger.yml
@echo " generate-ogen - Generate API code from swagger spec"
@echo " clean - Remove build artifacts"

261
README.md Normal file
View file

@ -0,0 +1,261 @@
# Balloon Trajectory Predictor
High-altitude balloon trajectory prediction service. Predicts ascent, burst, and descent trajectories using GFS wind forecast data from NOAA.
The prediction algorithms are an exact port of [Tawhiri](https://github.com/cuspaceflight/tawhiri) (Cambridge University Spaceflight) to Go, verified to produce identical results.
## Quick Start
```bash
# Build
make build
# Run (downloads ~9 GB of GFS data on first start, takes 30-60 min)
PREDICTOR_DATA_DIR=/tmp/predictor-data go run ./cmd/api
# Check readiness
curl http://localhost:8080/ready
# Run a prediction
curl 'http://localhost:8080/api/v1/prediction?launch_latitude=52.2&launch_longitude=0.1&launch_datetime=2026-03-28T12:00:00Z&launch_altitude=0&ascent_rate=5&burst_altitude=30000&descent_rate=5'
```
## Configuration
All configuration is via environment variables.
| Variable | Default | Description |
|---|---|---|
| `PREDICTOR_PORT` | `8080` | HTTP server port |
| `PREDICTOR_DATA_DIR` | `/tmp/predictor-data` | Directory for wind datasets and temp files |
| `PREDICTOR_DOWNLOAD_PARALLEL` | `8` | Max concurrent GRIB download goroutines |
| `PREDICTOR_UPDATE_INTERVAL` | `6h` | How often to check for new forecasts |
| `PREDICTOR_DATASET_TTL` | `48h` | Max age before a dataset is considered stale |
| `PREDICTOR_ELEVATION_DATASET` | `/srv/ruaumoko-dataset` | Path to elevation dataset (optional) |
## API
### `GET /api/v1/prediction`
Run a balloon trajectory prediction.
**Parameters** (query string):
| Parameter | Required | Description |
|---|---|---|
| `launch_latitude` | yes | Launch latitude in degrees (-90 to 90) |
| `launch_longitude` | yes | Launch longitude in degrees (-180 to 180 or 0 to 360) |
| `launch_datetime` | yes | Launch time in RFC 3339 format |
| `launch_altitude` | no | Launch altitude in metres ASL (default: 0) |
| `profile` | no | `standard_profile` (default) or `float_profile` |
| `ascent_rate` | no | Ascent rate in m/s (default: 5) |
| `burst_altitude` | no | Burst altitude in metres (default: 28000) |
| `descent_rate` | no | Sea-level descent rate in m/s (default: 5) |
| `float_altitude` | no | Float altitude in metres (float_profile only) |
| `stop_datetime` | no | Float end time (float_profile only, default: +24h) |
**Response** (Tawhiri-compatible):
```json
{
"prediction": [
{
"stage": "ascent",
"trajectory": [
{"datetime": "2026-03-28T12:00:00Z", "latitude": 52.2, "longitude": 0.1, "altitude": 0},
...
]
},
{
"stage": "descent",
"trajectory": [...]
}
],
"metadata": {
"start_datetime": "...",
"complete_datetime": "..."
},
"request": {
"dataset": "2026-03-28T06:00:00Z",
"launch_latitude": 52.2,
...
}
}
```
### `GET /ready`
Health check. Returns `{"status": "ok"}` when a dataset is loaded.
## Elevation Dataset
Without elevation data, descent terminates at sea level (altitude <= 0). With elevation data, descent terminates at ground level, matching Tawhiri's behaviour.
### Building the elevation dataset
The elevation dataset uses ETOPO 2022 at 30 arc-second resolution, converted to a ruaumoko-compatible binary format (21601 x 43200 grid of int16 little-endian elevation values in metres).
**Requirements**: Python 3, xarray, netcdf4, numpy.
```bash
pip install xarray netcdf4 numpy
# Downloads ~1.1 GB from NOAA, produces ~1.74 GB binary file
python3 scripts/build_elevation.py /tmp/predictor-data/ruaumoko-dataset
```
To skip the download if you already have the ETOPO NetCDF file:
```bash
ETOPO_NC_PATH=/path/to/ETOPO_2022_v1_30s_N90W180_surface.nc \
python3 scripts/build_elevation.py /tmp/predictor-data/ruaumoko-dataset
```
The ETOPO 2022 NetCDF can be manually downloaded from:
https://www.ncei.noaa.gov/products/etopo-global-relief-model
### Using the elevation dataset
```bash
PREDICTOR_ELEVATION_DATASET=/tmp/predictor-data/ruaumoko-dataset go run ./cmd/api
```
If the file doesn't exist or can't be read, the service starts normally with a warning and falls back to sea-level termination.
## Architecture
```
cmd/api/main.go Entry point, config, scheduler, HTTP server
internal/
dataset/
dataset.go Shape constants, pressure levels, S3 URLs
file.go mmap-backed dataset file (read/write/blit)
downloader/
downloader.go S3 partial GRIB download (idx + range requests)
idx.go NOAA .idx file parser
config.go Environment-based configuration
elevation/
elevation.go Ruaumoko-compatible elevation dataset (mmap int16)
prediction/
interpolate.go 4D wind interpolation (time, lat, lon, altitude)
solver.go RK4 integrator with binary search termination
models.go Ascent, descent, wind models; flight profiles
warnings.go Prediction warning counters
service/
service.go Dataset lifecycle, concurrent-safe access
transport/
middleware/log.go Request logging middleware
rest/
handler/handler.go ogen API handler implementation
handler/deps.go Service interface
transport.go ogen HTTP server, CORS
api/rest/predictor.swagger.yml OpenAPI 3.0 spec
pkg/rest/ Generated ogen code (17 files)
scripts/
build_elevation.py ETOPO 2022 to ruaumoko converter
```
## Wind Dataset
The service downloads GFS 0.5-degree forecast data from NOAA S3:
| Property | Value |
|---|---|
| Source | `noaa-gfs-bdp-pds.s3.amazonaws.com` |
| Resolution | 0.5 degrees |
| Grid | 361 lat x 720 lon |
| Time steps | 65 (every 3 hours, 0-192h) |
| Pressure levels | 47 (1000 to 1 hPa) |
| Variables | Geopotential height, U-wind, V-wind |
| Dataset size | 9,528,667,200 bytes (~8.87 GiB) |
| Update cadence | Every 6 hours (GFS runs at 00, 06, 12, 18 UTC) |
Data is downloaded using HTTP Range requests against `.idx` index files, fetching only the needed GRIB messages (HGT, UGRD, VGRD at 47 pressure levels). Full download takes 30-60 minutes depending on bandwidth.
The dataset is stored as a memory-mapped flat binary file of float32 values in C-order with shape `(65, 47, 3, 361, 720)`.
## Prediction Algorithms
All algorithms are exact ports of the reference implementations in Tawhiri. The following sections describe the key components.
### Interpolation (`internal/prediction/interpolate.go`)
4D wind interpolation from the dataset grid to arbitrary coordinates.
1. **Trilinear weights** (`pick3`): compute 8 interpolation weights for the (hour, lat, lon) cube corners.
2. **Altitude search** (`search`): binary search on interpolated geopotential height to find the two pressure levels bracketing the target altitude.
3. **Wind extraction** (`interp4`): 8-point weighted sum at each bracket level, then linear interpolation between levels.
Reference: `tawhiri/interpolate.pyx`
### Solver (`internal/prediction/solver.go`)
4th-order Runge-Kutta integrator with dt = 60 seconds.
- State vector: (latitude, longitude, altitude) in degrees and metres.
- Time: UNIX timestamp in seconds.
- Longitude is kept in [0, 360) via Python-style modulo after each `vecadd`.
- When a terminator fires, binary search refinement (tolerance 0.01) finds the precise termination point between the last good step and the first terminated step.
- Longitude interpolation (`lngLerp`) handles the 0/360 wrap-around.
Reference: `tawhiri/solver.pyx`
### Models (`internal/prediction/models.go`)
- **Constant ascent**: vertical velocity = ascent_rate m/s.
- **Drag descent**: NASA atmosphere density model with drag coefficient = sea_level_rate * 1.1045. Descent rate increases with altitude due to thinner air.
- **Wind velocity**: u, v components from interpolation converted to degrees/second: `dlat = (180/pi) * v / (R)`, `dlng = (180/pi) * u / (R * cos(lat))` where R = 6371009 + altitude.
- **Linear model**: sum of component models (e.g., wind + ascent).
- **Elevation termination**: `ground_elevation > altitude` using ruaumoko dataset.
Reference: `tawhiri/models.py`
### Profiles
- **standard_profile**: ascent (constant rate + wind) until burst altitude, then descent (drag + wind) until ground level.
- **float_profile**: ascent to float altitude, then drift at constant altitude until stop time.
## Verification
The predictor has been verified against the reference Tawhiri implementation:
| Test | Result |
|---|---|
| Dataset (step 0): 36.6M float32 values vs Python/cfgrib | 0 mismatches, max diff = 0.0 |
| Prediction burst point vs public Tawhiri API | Identical (lat, lon, alt all match) |
| Prediction landing point vs public Tawhiri API | Identical lat/lon, 5m altitude diff (different elevation datasets) |
| Descent point count | Identical (46 points) |
| Ascent point count | Identical (101 points) |
## Development
```bash
# Regenerate ogen API code after modifying the swagger spec
make generate-ogen
# Run tests
make test
# Format
make fmt
```
### Comparison tools
```bash
# Compare single dataset step against Python/cfgrib reference
go run ./cmd/compare_step0 <run_YYYYMMDDHH> <output_path>
# Run prediction and compare against public Tawhiri API
go run ./cmd/compare_prediction
```
## References
- [Tawhiri](https://github.com/cuspaceflight/tawhiri) — Reference Python/Cython predictor (Cambridge University Spaceflight)
- [tawhiri-downloader](https://github.com/cuspaceflight/tawhiri-downloader) — OCaml dataset downloader
- [ruaumoko](https://github.com/cuspaceflight/ruaumoko) — Global elevation dataset
- [NOAA GFS](https://www.ncei.noaa.gov/products/weather-climate-models/global-forecast) — Global Forecast System
- [NOAA GFS on S3](https://noaa-gfs-bdp-pds.s3.amazonaws.com/index.html) — Public S3 bucket
- [ETOPO 2022](https://www.ncei.noaa.gov/products/etopo-global-relief-model) — Global relief model for elevation data
- [SondeHub Tawhiri API](https://api.v2.sondehub.org/tawhiri) — Public Tawhiri instance for comparison

View file

@ -1,6 +1,6 @@
openapi: 3.0.4
info:
title: GSN Predictor - OpenAPI 3.0
title: Predictor API
version: 0.0.1
paths:
/api/v1/prediction:
@ -12,14 +12,17 @@ paths:
parameters:
- in: query
name: launch_latitude
required: true
schema:
type: number
- in: query
name: launch_longitude
required: true
schema:
type: number
- in: query
name: launch_datetime
required: true
schema:
type: string
format: date-time
@ -31,7 +34,8 @@ paths:
name: profile
schema:
type: string
enum: [standard_profile, float_profile, reverse_profile, custom_profile]
enum: [standard_profile, float_profile]
default: standard_profile
- in: query
name: ascent_rate
schema:
@ -53,23 +57,6 @@ paths:
schema:
type: string
format: date-time
- in: query
name: ascent_curve
schema:
type: string
- in: query
name: descent_curve
schema:
type: string
- in: query
name: interpolate
schema:
type: boolean
- in: query
name: format
schema:
type: string
enum: [json]
- in: query
name: dataset
schema:
@ -77,17 +64,17 @@ paths:
format: date-time
responses:
"200":
description: "Prediction response"
description: Prediction response
content:
application/json:
schema:
$ref: '#/components/schemas/PredictionResult'
$ref: '#/components/schemas/PredictionResponse'
default:
description: Error
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
$ref: '#/components/schemas/Error'
/ready:
get:
tags:
@ -106,37 +93,52 @@ paths:
content:
application/json:
schema:
$ref: "#/components/schemas/Error"
$ref: '#/components/schemas/Error'
components:
schemas:
Error:
type: object
required:
- message
- error
properties:
message:
type: string
details:
type: string
PredictionResult:
type: object
required:
- metadata
- prediction
properties:
metadata:
error:
type: object
required:
- complete_datetime
- start_datetime
- type
- description
properties:
complete_datetime:
type:
type: string
format: date-time
start_datetime:
description:
type: string
format: date-time
PredictionResponse:
type: object
required:
- prediction
- metadata
properties:
request:
type: object
properties:
dataset:
type: string
launch_latitude:
type: number
launch_longitude:
type: number
launch_datetime:
type: string
launch_altitude:
type: number
profile:
type: string
ascent_rate:
type: number
burst_altitude:
type: number
descent_rate:
type: number
prediction:
type: array
items:
@ -147,7 +149,7 @@ components:
properties:
stage:
type: string
enum: ["ascent", "descent"]
enum: ["ascent", "descent", "float"]
trajectory:
type: array
items:
@ -167,18 +169,31 @@ components:
type: number
altitude:
type: number
metadata:
type: object
required:
- start_datetime
- complete_datetime
properties:
start_datetime:
type: string
format: date-time
complete_datetime:
type: string
format: date-time
warnings:
type: object
additionalProperties: true
ReadinessResponse:
type: object
required:
- status
properties:
status:
type: string
enum: [ok, not_ready, error]
last_update:
dataset_time:
type: string
format: date-time
is_fresh:
type: boolean
error_message:
type: string
required:
- status

View file

@ -2,112 +2,97 @@ package main
import (
"context"
"fmt"
"os"
"os/signal"
"syscall"
"time"
"git.intra.yksa.space/gsn/predictor/internal/jobs/grib/updater"
"git.intra.yksa.space/gsn/predictor/internal/pkg/grib"
"git.intra.yksa.space/gsn/predictor/internal/pkg/log"
"git.intra.yksa.space/gsn/predictor/internal/service"
"git.intra.yksa.space/gsn/predictor/internal/transport/rest"
"git.intra.yksa.space/gsn/predictor/internal/transport/rest/handler"
"git.intra.yksa.space/gsn/predictor/pkg/scheduler"
"predictor-refactored/internal/downloader"
"predictor-refactored/internal/service"
"predictor-refactored/internal/transport/rest"
"predictor-refactored/internal/transport/rest/handler"
"github.com/go-co-op/gocron"
"go.uber.org/zap"
)
const servicePrefix = "GSN_PREDICTOR"
func main() {
lg, err := zap.NewProduction()
log, err := zap.NewProduction()
if err != nil {
panic(err)
}
defer lg.Sync()
ctx := log.ToCtx(context.Background(), lg)
defer log.Sync()
schedulerConfig, err := scheduler.NewConfig()
if err != nil {
log.Ctx(ctx).Fatal("failed to load scheduler configuration", zap.Error(err))
cfg := downloader.LoadConfig()
log.Info("configuration loaded",
zap.String("data_dir", cfg.DataDir),
zap.Int("parallel", cfg.Parallel),
zap.Duration("update_interval", cfg.UpdateInterval),
zap.Duration("dataset_ttl", cfg.DatasetTTL))
if err := os.MkdirAll(cfg.DataDir, 0o755); err != nil {
log.Fatal("failed to create data directory", zap.Error(err))
}
gribUpdaterConfig, err := updater.NewConfig()
if err != nil {
log.Ctx(ctx).Fatal("failed to load GRIB updater configuration", zap.Error(err))
}
gribCfg, err := grib.NewConfig()
if err != nil {
log.Ctx(ctx).Fatal("failed to load GRIB configuration", zap.Error(err))
}
gribService, err := grib.New(gribCfg)
if err != nil {
log.Ctx(ctx).Fatal("failed to initialize GRIB service", zap.Error(err))
}
defer gribService.Close()
// Force GRIB update on startup in a goroutine
go func() {
log.Ctx(ctx).Info("Performing initial GRIB update (async)...")
if err := gribService.Update(ctx); err != nil {
log.Ctx(ctx).Error("initial GRIB update failed", zap.Error(err))
} else {
log.Ctx(ctx).Info("initial GRIB update complete")
}
}()
svc, err := service.New(gribService)
if err != nil {
log.Ctx(ctx).Fatal("failed to initialize service", zap.Error(err))
}
svc := service.New(cfg, log)
defer svc.Close()
var sched *scheduler.Scheduler
if schedulerConfig.Enabled {
sched = scheduler.New()
gribJob := updater.New(gribService, gribUpdaterConfig)
if err := sched.AddJob(gribJob); err != nil {
log.Ctx(ctx).Error("failed to add GRIB update job to scheduler", zap.Error(err))
}
log.Ctx(ctx).Info("scheduler initialized with jobs")
// Load elevation dataset (optional — falls back to sea-level termination)
elevPath := "/srv/ruaumoko-dataset"
if v := os.Getenv("PREDICTOR_ELEVATION_DATASET"); v != "" {
elevPath = v
}
svc.LoadElevation(elevPath)
handler := handler.New(svc)
restConfig, err := rest.NewConfig()
if err != nil {
lg.Fatal("failed to init transport config", zap.Error(err))
}
transport, err := rest.New(handler, restConfig)
if err != nil {
lg.Fatal("failed to init transport", zap.Error(err))
}
svc.Start()
if sched != nil {
sched.Start()
lg.Info("scheduler started")
}
lg.Info("service started successfully")
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Initial dataset load (async so the server starts immediately)
go func() {
lg.Info("starting HTTP server on port", zap.Int("port", restConfig.Port))
transport.Run()
log.Info("performing initial dataset update...")
if err := svc.Update(context.Background()); err != nil {
log.Error("initial dataset update failed", zap.Error(err))
} else {
log.Info("initial dataset update complete")
}
}()
<-sigChan
lg.Info("received shutdown signal, stopping service")
// Scheduler for periodic dataset updates
scheduler := gocron.NewScheduler(time.UTC)
scheduler.Every(cfg.UpdateInterval).Do(func() {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
defer cancel()
log.Info("scheduled dataset update starting")
if err := svc.Update(ctx); err != nil {
log.Error("scheduled dataset update failed", zap.Error(err))
} else {
log.Info("scheduled dataset update complete")
}
})
scheduler.StartAsync()
defer scheduler.Stop()
if sched != nil {
sched.Stop()
lg.Info("scheduler stopped")
// HTTP transport (ogen)
port := 8080
if p := os.Getenv("PREDICTOR_PORT"); p != "" {
fmt.Sscanf(p, "%d", &port)
}
h := handler.New(svc, log)
transport, err := rest.New(h, port, log)
if err != nil {
log.Fatal("failed to create transport", zap.Error(err))
}
go func() {
if err := transport.Run(); err != nil {
log.Fatal("HTTP server error", zap.Error(err))
}
}()
log.Info("service started")
// Graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
sig := <-sigChan
log.Info("received shutdown signal", zap.String("signal", sig.String()))
}

View file

@ -0,0 +1,195 @@
package main
import (
"context"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"os"
"time"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/downloader"
"predictor-refactored/internal/prediction"
"go.uber.org/zap"
)
// Downloads a few forecast steps and runs a prediction, then compares
// against the public Tawhiri API.
func main() {
log, _ := zap.NewDevelopment()
cfg := &downloader.Config{
DataDir: os.TempDir(),
Parallel: 4,
}
dl := downloader.NewDownloader(cfg, log)
ctx := context.Background()
// Find latest run
run, err := dl.FindLatestRun(ctx)
if err != nil {
fmt.Fprintf(os.Stderr, "FindLatestRun: %v\n", err)
os.Exit(1)
}
fmt.Printf("Using run: %s\n", run.Format("2006010215"))
// Create dataset and download first 10 steps (0-27 hours, enough for a prediction)
dsPath := fmt.Sprintf("/tmp/pred_test_%s.bin", run.Format("2006010215"))
defer os.Remove(dsPath)
ds, err := dataset.Create(dsPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Create: %v\n", err)
os.Exit(1)
}
date := run.Format("20060102")
runHour := run.Hour()
stepsToDownload := []int{0, 3, 6, 9, 12, 15, 18, 21, 24, 27}
fmt.Printf("Downloading %d steps...\n", len(stepsToDownload))
for _, step := range stepsToDownload {
hourIdx := dataset.HourIndex(step)
fmt.Printf(" step %d (hour idx %d)...\n", step, hourIdx)
urlA := dataset.GribURL(date, runHour, step)
if err := dl.DownloadAndBlit(ctx, ds, urlA, hourIdx, dataset.LevelSetA); err != nil {
fmt.Fprintf(os.Stderr, " pgrb2 step %d: %v\n", step, err)
os.Exit(1)
}
urlB := dataset.GribURLB(date, runHour, step)
if err := dl.DownloadAndBlit(ctx, ds, urlB, hourIdx, dataset.LevelSetB); err != nil {
fmt.Fprintf(os.Stderr, " pgrb2b step %d: %v\n", step, err)
os.Exit(1)
}
}
ds.Flush()
fmt.Println("Download complete")
// Set dataset time
ds.DSTime = run
// Run our prediction
launchLat := 52.2135
launchLon := 0.0964 // already in [0, 360)
launchAlt := 0.0
ascentRate := 5.0
burstAlt := 30000.0
descentRate := 5.0
// Launch 3 hours into the forecast
launchTime := run.Add(3 * time.Hour)
launchTimestamp := float64(launchTime.Unix())
dsEpoch := float64(run.Unix())
warnings := &prediction.Warnings{}
stages := prediction.StandardProfile(ascentRate, burstAlt, descentRate, ds, dsEpoch, warnings, nil)
results := prediction.RunPrediction(launchTimestamp, launchLat, launchLon, launchAlt, stages)
fmt.Printf("\n=== Our prediction ===\n")
for i, sr := range results {
stage := "ascent"
if i == 1 {
stage = "descent"
}
first := sr.Points[0]
last := sr.Points[len(sr.Points)-1]
fmt.Printf(" %s: %d points, start=(%.4f, %.4f, %.0fm) end=(%.4f, %.4f, %.0fm)\n",
stage, len(sr.Points),
first.Lat, first.Lng, first.Alt,
last.Lat, last.Lng, last.Alt)
}
// Get landing point
var ourLandLat, ourLandLon float64
if len(results) >= 2 {
last := results[1].Points[len(results[1].Points)-1]
ourLandLat = last.Lat
ourLandLon = last.Lng
if ourLandLon > 180 {
ourLandLon -= 360
}
}
fmt.Printf(" Landing: lat=%.4f, lon=%.4f\n", ourLandLat, ourLandLon)
// Compare against public Tawhiri API
fmt.Printf("\n=== Tawhiri API comparison ===\n")
tawhiriLandLat, tawhiriLandLon, err := queryTawhiri(launchLat, launchLon, launchAlt, launchTime, ascentRate, burstAlt, descentRate)
if err != nil {
fmt.Printf(" Tawhiri API error: %v\n", err)
fmt.Println(" (Cannot compare — Tawhiri may use a different dataset)")
ds.Close()
return
}
fmt.Printf(" Tawhiri landing: lat=%.4f, lon=%.4f\n", tawhiriLandLat, tawhiriLandLon)
dist := haversine(ourLandLat, ourLandLon, tawhiriLandLat, tawhiriLandLon)
fmt.Printf(" Distance between landing points: %.2f km\n", dist/1000)
if dist < 1000 {
fmt.Println(" CLOSE MATCH (< 1 km)")
} else if dist < 50000 {
fmt.Printf(" MODERATE DIFFERENCE (%.1f km) — likely different datasets\n", dist/1000)
} else {
fmt.Printf(" LARGE DIFFERENCE (%.1f km) — possible bug\n", dist/1000)
}
ds.Close()
}
func queryTawhiri(lat, lon, alt float64, launchTime time.Time, ascentRate, burstAlt, descentRate float64) (landLat, landLon float64, err error) {
url := fmt.Sprintf(
"https://api.v2.sondehub.org/tawhiri?launch_latitude=%.4f&launch_longitude=%.4f&launch_altitude=%.0f&launch_datetime=%s&ascent_rate=%.1f&burst_altitude=%.0f&descent_rate=%.1f",
lat, lon, alt, launchTime.Format(time.RFC3339), ascentRate, burstAlt, descentRate)
resp, err := http.Get(url)
if err != nil {
return 0, 0, err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return 0, 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
var result struct {
Prediction []struct {
Stage string `json:"stage"`
Trajectory []struct {
Latitude float64 `json:"latitude"`
Longitude float64 `json:"longitude"`
Altitude float64 `json:"altitude"`
} `json:"trajectory"`
} `json:"prediction"`
}
if err := json.Unmarshal(body, &result); err != nil {
return 0, 0, err
}
for _, stage := range result.Prediction {
if stage.Stage == "descent" && len(stage.Trajectory) > 0 {
last := stage.Trajectory[len(stage.Trajectory)-1]
return last.Latitude, last.Longitude, nil
}
}
return 0, 0, fmt.Errorf("no descent stage found")
}
func haversine(lat1, lon1, lat2, lon2 float64) float64 {
const R = 6371000.0
phi1 := lat1 * math.Pi / 180
phi2 := lat2 * math.Pi / 180
dphi := (lat2 - lat1) * math.Pi / 180
dlam := (lon2 - lon1) * math.Pi / 180
a := math.Sin(dphi/2)*math.Sin(dphi/2) + math.Cos(phi1)*math.Cos(phi2)*math.Sin(dlam/2)*math.Sin(dlam/2)
return R * 2 * math.Atan2(math.Sqrt(a), math.Sqrt(1-a))
}

104
cmd/compare_step0/main.go Normal file
View file

@ -0,0 +1,104 @@
package main
import (
"context"
"fmt"
"os"
"time"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/downloader"
"go.uber.org/zap"
)
// Downloads step 0 of a given run and writes a minimal dataset for comparison.
// Usage: go run ./cmd/compare_step0 <run_YYYYMMDDHH> <output_path>
func main() {
if len(os.Args) < 3 {
fmt.Fprintf(os.Stderr, "Usage: %s <run_YYYYMMDDHH> <output_path>\n", os.Args[0])
os.Exit(1)
}
runStr := os.Args[1]
outPath := os.Args[2]
run, err := time.Parse("2006010215", runStr)
if err != nil {
fmt.Fprintf(os.Stderr, "Invalid run time %q: %v\n", runStr, err)
os.Exit(1)
}
log, _ := zap.NewDevelopment()
// Create a full-size dataset (we only fill step 0)
fmt.Printf("Creating dataset at %s (%d bytes)...\n", outPath, dataset.DatasetSize)
ds, err := dataset.Create(outPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Create dataset: %v\n", err)
os.Exit(1)
}
defer ds.Close()
cfg := &downloader.Config{
DataDir: os.TempDir(),
Parallel: 4,
}
dl := downloader.NewDownloader(cfg, log)
ctx := context.Background()
date := run.Format("20060102")
runHour := run.Hour()
// Download and blit step 0 from pgrb2
fmt.Println("Downloading pgrb2 step 0...")
urlA := dataset.GribURL(date, runHour, 0)
if err := dl.DownloadAndBlit(ctx, ds, urlA, 0, dataset.LevelSetA); err != nil {
fmt.Fprintf(os.Stderr, "pgrb2: %v\n", err)
os.Exit(1)
}
fmt.Println(" done")
// Download and blit step 0 from pgrb2b
fmt.Println("Downloading pgrb2b step 0...")
urlB := dataset.GribURLB(date, runHour, 0)
if err := dl.DownloadAndBlit(ctx, ds, urlB, 0, dataset.LevelSetB); err != nil {
fmt.Fprintf(os.Stderr, "pgrb2b: %v\n", err)
os.Exit(1)
}
fmt.Println(" done")
if err := ds.Flush(); err != nil {
fmt.Fprintf(os.Stderr, "Flush: %v\n", err)
os.Exit(1)
}
// Spot-check: print same values as the Python script
fmt.Println("\n=== Go dataset values (spot check) ===")
type testPoint struct {
varName string
varIdx int
levelIdx int
lat, lon int
}
points := []testPoint{
{"HGT", 0, 0, 0, 0}, // HGT @ 1000mb, lat=-90, lon=0
{"HGT", 0, 0, 180, 0}, // HGT @ 1000mb, lat=0, lon=0
{"HGT", 0, 0, 360, 0}, // HGT @ 1000mb, lat=+90, lon=0
{"HGT", 0, 20, 180, 360}, // HGT @ 500mb, lat=0, lon=180
{"UGRD", 1, 0, 180, 0}, // UGRD @ 1000mb, lat=0, lon=0
{"VGRD", 2, 0, 180, 0}, // VGRD @ 1000mb, lat=0, lon=0
{"UGRD", 1, 20, 284, 0}, // UGRD @ 500mb, lat=52N, lon=0
}
for _, p := range points {
val := ds.Val(0, p.levelIdx, p.varIdx, p.lat, p.lon)
actualLat := -90.0 + float64(p.lat)*0.5
actualLon := float64(p.lon) * 0.5
fmt.Printf(" %-4s %4dmb lat=%+7.1f lon=%6.1f: %12.4f\n",
p.varName, dataset.Pressures[p.levelIdx], actualLat, actualLon, val)
}
fmt.Printf("\nDataset written to %s\n", outPath)
}

51
go.mod
View file

@ -1,44 +1,25 @@
module git.intra.yksa.space/gsn/predictor
module predictor-refactored
go 1.24.4
go 1.25.0
require (
github.com/caarlos0/env/v11 v11.3.1
github.com/edsrzf/mmap-go v1.2.0
github.com/go-co-op/gocron v1.37.0
github.com/go-faster/errors v0.7.1
github.com/go-faster/jx v1.1.0
github.com/go-faster/jx v1.2.0
github.com/nilsmagnus/grib v1.2.8
github.com/ogen-go/ogen v1.16.0
github.com/rs/cors v1.11.1
go.opentelemetry.io/otel v1.38.0
go.opentelemetry.io/otel/metric v1.38.0
go.opentelemetry.io/otel/trace v1.38.0
go.uber.org/zap v1.27.0
golang.org/x/sync v0.17.0
github.com/ogen-go/ogen v1.20.2
go.opentelemetry.io/otel v1.42.0
go.opentelemetry.io/otel/metric v1.42.0
go.opentelemetry.io/otel/trace v1.42.0
go.uber.org/zap v1.27.1
golang.org/x/sync v0.20.0
)
require (
github.com/aws/aws-sdk-go-v2 v1.39.3 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.2 // indirect
github.com/aws/aws-sdk-go-v2/config v1.31.13 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.18.17 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10 // indirect
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10 // indirect
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10 // indirect
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.10 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.1 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.10 // indirect
github.com/aws/aws-sdk-go-v2/service/s3 v1.88.5 // indirect
github.com/aws/aws-sdk-go-v2/service/sso v1.29.7 // indirect
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2 // indirect
github.com/aws/aws-sdk-go-v2/service/sts v1.38.7 // indirect
github.com/aws/smithy-go v1.23.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/fatih/color v1.18.0 // indirect
github.com/fatih/color v1.19.0 // indirect
github.com/ghodss/yaml v1.0.0 // indirect
github.com/go-faster/yaml v0.4.6 // indirect
github.com/go-logr/logr v1.4.3 // indirect
@ -50,11 +31,11 @@ require (
github.com/segmentio/asm v1.2.1 // indirect
github.com/shopspring/decimal v1.4.0 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.uber.org/atomic v1.11.0 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/exp v0.0.0-20251017212417-90e834f514db // indirect
golang.org/x/net v0.46.0 // indirect
golang.org/x/sys v0.37.0 // indirect
golang.org/x/text v0.30.0 // indirect
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 // indirect
golang.org/x/net v0.52.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.35.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
)

122
go.sum
View file

@ -1,41 +1,5 @@
github.com/aws/aws-sdk-go-v2 v1.39.3 h1:h7xSsanJ4EQJXG5iuW4UqgP7qBopLpj84mpkNx3wPjM=
github.com/aws/aws-sdk-go-v2 v1.39.3/go.mod h1:yWSxrnioGUZ4WVv9TgMrNUeLV3PFESn/v+6T/Su8gnM=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.2 h1:t9yYsydLYNBk9cJ73rgPhPWqOh/52fcWDQB5b1JsKSY=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.2/go.mod h1:IusfVNTmiSN3t4rhxWFaBAqn+mcNdwKtPcV16eYdgko=
github.com/aws/aws-sdk-go-v2/config v1.31.13 h1:wcqQB3B0PgRPUF5ZE/QL1JVOyB0mbPevHFoAMpemR9k=
github.com/aws/aws-sdk-go-v2/config v1.31.13/go.mod h1:ySB5D5ybwqGbT6c3GszZ+u+3KvrlYCUQNo62+hkKOFk=
github.com/aws/aws-sdk-go-v2/credentials v1.18.17 h1:skpEwzN/+H8cdrrtT8y+rvWJGiWWv0DeNAe+4VTf+Vs=
github.com/aws/aws-sdk-go-v2/credentials v1.18.17/go.mod h1:Ed+nXsaYa5uBINovJhcAWkALvXw2ZLk36opcuiSZfJM=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10 h1:UuGVOX48oP4vgQ36oiKmW9RuSeT8jlgQgBFQD+HUiHY=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10/go.mod h1:vM/Ini41PzvudT4YkQyE/+WiQJiQ6jzeDyU8pQKwCac=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10 h1:mj/bdWleWEh81DtpdHKkw41IrS+r3uw1J/VQtbwYYp8=
github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10/go.mod h1:7+oEMxAZWP8gZCyjcm9VicI0M61Sx4DJtcGfKYv2yKQ=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10 h1:wh+/mn57yhUrFtLIxyFPh2RgxgQz/u+Yrf7hiHGHqKY=
github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10/go.mod h1:7zirD+ryp5gitJJ2m1BBux56ai8RIRDykXZrJSp540w=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk=
github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.10 h1:FHw90xCTsofzk6vjU808TSuDtDfOOKPNdz5Weyc3tUI=
github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.10/go.mod h1:n8jdIE/8F3UYkg8O4IGkQpn2qUmapg/1K1yl29/uf/c=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 h1:xtuxji5CS0JknaXoACOunXOYOQzgfTvGAc9s2QdCJA4=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2/go.mod h1:zxwi0DIR0rcRcgdbl7E2MSOvxDyyXGBlScvBkARFaLQ=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.1 h1:ne+eepnDB2Wh5lHKzELgEncIqeVlQ1rSF9fEa4r5I+A=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.1/go.mod h1:u0Jkg0L+dcG1ozUq21uFElmpbmjBnhHR5DELHIme4wg=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10 h1:DRND0dkCKtJzCj4Xl4OpVbXZgfttY5q712H9Zj7qc/0=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10/go.mod h1:tGGNmJKOTernmR2+VJ0fCzQRurcPZj9ut60Zu5Fi6us=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.10 h1:DA+Hl5adieRyFvE7pCvBWm3VOZTRexGVkXw33SUqNoY=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.10/go.mod h1:L+A89dH3/gr8L4ecrdzuXUYd1znoko6myzndVGZx/DA=
github.com/aws/aws-sdk-go-v2/service/s3 v1.88.5 h1:FlGScxzCGNzT+2AvHT1ZGMvxTwAMa6gsooFb1pO/AiM=
github.com/aws/aws-sdk-go-v2/service/s3 v1.88.5/go.mod h1:N/iojY+8bW3MYol9NUMuKimpSbPEur75cuI1SmtonFM=
github.com/aws/aws-sdk-go-v2/service/sso v1.29.7 h1:fspVFg6qMx0svs40YgRmE7LZXh9VRZvTT35PfdQR6FM=
github.com/aws/aws-sdk-go-v2/service/sso v1.29.7/go.mod h1:BQTKL3uMECaLaUV3Zc2L4Qybv8C6BIXjuu1dOPyxTQs=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2 h1:scVnW+NLXasGOhy7HhkdT9AGb6kjgW7fJ5xYkUaqHs0=
github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2/go.mod h1:FRNCY3zTEWZXBKm2h5UBUPvCVDOecTad9KhynDyGBc0=
github.com/aws/aws-sdk-go-v2/service/sts v1.38.7 h1:VEO5dqFkMsl8QZ2yHsFDJAIZLAkEbaYDB+xdKi0Feic=
github.com/aws/aws-sdk-go-v2/service/sts v1.38.7/go.mod h1:L1xxV3zAdB+qVrVW/pBIrIAnHFWHo6FBbFe4xOGsG/o=
github.com/aws/smithy-go v1.23.1 h1:sLvcH6dfAFwGkHLZ7dGiYF7aK6mg4CgKA/iDKjLDt9M=
github.com/aws/smithy-go v1.23.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0=
github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA=
github.com/caarlos0/env/v11 v11.3.1/go.mod h1:qupehSf/Y0TUTsxKywqRt/vJjN5nz6vauiYEUUr8P4U=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
@ -44,21 +8,19 @@ github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZ
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/edsrzf/mmap-go v1.2.0 h1:hXLYlkbaPzt1SaQk+anYwKSRNhufIDCchSPkUD6dD84=
github.com/edsrzf/mmap-go v1.2.0/go.mod h1:19H/e8pUPLicwkyNgOykDXkJ9F0MHE+Z52B8EIth78Q=
github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM=
github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU=
github.com/fatih/color v1.19.0 h1:Zp3PiM21/9Ld6FzSKyL5c/BULoe/ONr9KlbYVOfG8+w=
github.com/fatih/color v1.19.0/go.mod h1:zNk67I0ZUT1bEGsSGyCZYZNrHuTkJJB+r6Q9VuMi0LE=
github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/go-co-op/gocron v1.37.0 h1:ZYDJGtQ4OMhTLKOKMIch+/CY70Brbb1dGdooLEhh7b0=
github.com/go-co-op/gocron v1.37.0/go.mod h1:3L/n6BkO7ABj+TrfSVXLRzsP26zmikL4ISkLQ0O8iNY=
github.com/go-faster/errors v0.7.1 h1:MkJTnDoEdi9pDabt1dpWf7AA8/BaSYZqibYyhZ20AYg=
github.com/go-faster/errors v0.7.1/go.mod h1:5ySTjWFiphBs07IKuiL69nxdfd5+fzh1u7FPGZP2quo=
github.com/go-faster/jx v1.1.0 h1:ZsW3wD+snOdmTDy9eIVgQdjUpXRRV4rqW8NS3t+20bg=
github.com/go-faster/jx v1.1.0/go.mod h1:vKDNikrKoyUmpzaJ0OkIkRQClNHFX/nF3dnTJZb3skg=
github.com/go-faster/jx v1.2.0 h1:T2YHJPrFaYu21fJtUxC9GzmluKu8rVIFDwwGBKTDseI=
github.com/go-faster/jx v1.2.0/go.mod h1:UWLOVDmMG597a5tBFPLIWJdUxz5/2emOpfsj9Neg0PE=
github.com/go-faster/yaml v0.4.6 h1:lOK/EhI04gCpPgPhgt0bChS6bvw7G3WwI8xxVe0sw9I=
github.com/go-faster/yaml v0.4.6/go.mod h1:390dRIvV4zbnO7qC9FGo6YYutc+wyyUSHBgbXL52eXk=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
@ -77,19 +39,14 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/nilsmagnus/grib v1.2.8 h1:H7ch/1/agaCqM3MC8hW1Ft+EJ+q2XB757uml/IfPvp4=
github.com/nilsmagnus/grib v1.2.8/go.mod h1:XHm+5zuoOk0NSIWaGmA3JaAxI4i50YvD1L1vz+aqPOQ=
github.com/ogen-go/ogen v1.14.0 h1:TU1Nj4z9UBsAfTkf+IhuNNp7igdFQKqkk9+6/y4XuWg=
github.com/ogen-go/ogen v1.14.0/go.mod h1:Iw1vkqkx6SU7I9th5ceP+fVPJ6Wge4e3kAVzAxJEpPE=
github.com/ogen-go/ogen v1.16.0 h1:fKHEYokW/QrMzVNXId74/6RObRIUs9T2oroGKtR25Iw=
github.com/ogen-go/ogen v1.16.0/go.mod h1:s3nWiMzybSf8fhxckyO+wtto92+QHpEL8FmkPnhL3jI=
github.com/ogen-go/ogen v1.20.2 h1:mEZGPST7ZeX84AkqRlFawDLwcwuzcLO5PtYpAXLT1YE=
github.com/ogen-go/ogen v1.20.2/go.mod h1:sJ1pJVp4S1RcSZlYIiMLo0QSMSt2pls4zfrc+hNKnzk=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
@ -97,13 +54,8 @@ github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o=
github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII=
github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA=
github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU=
github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys=
github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/segmentio/asm v1.2.1 h1:DTNbBqs57ioxAD4PrArqftgypG4/qNpXoJx8TVXxPR0=
github.com/segmentio/asm v1.2.1/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
@ -115,57 +67,35 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA=
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA=
go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/otel v1.36.0 h1:UumtzIklRBY6cI/lllNZlALOF5nNIzJVb16APdvgTXg=
go.opentelemetry.io/otel v1.36.0/go.mod h1:/TcFMXYjyRNh8khOAO9ybYkqaDBb/70aVwkNML4pP8E=
go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8=
go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM=
go.opentelemetry.io/otel/metric v1.36.0 h1:MoWPKVhQvJ+eeXWHFBOPoBOi20jh6Iq2CcCREuTYufE=
go.opentelemetry.io/otel/metric v1.36.0/go.mod h1:zC7Ks+yeyJt4xig9DEw9kuUFe5C3zLbVjV2PzT6qzbs=
go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA=
go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI=
go.opentelemetry.io/otel/trace v1.36.0 h1:ahxWNuqZjpdiFAyrIoQ4GIiAIhxAunQR6MUoKrsNd4w=
go.opentelemetry.io/otel/trace v1.36.0/go.mod h1:gQ+OnDZzrybY4k4seLzPAWNwVBBVlF2szhehOBB/tGA=
go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE=
go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs=
go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho=
go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc=
go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4=
go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI=
go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY=
go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE=
go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/exp v0.0.0-20251017212417-90e834f514db h1:by6IehL4BH5k3e3SJmcoNbOobMey2SLpAF79iPOEBvw=
golang.org/x/exp v0.0.0-20251017212417-90e834f514db/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70=
golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY=
golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds=
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ=
golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw=
golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4=
golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA=
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=

158
internal/dataset/dataset.go Normal file
View file

@ -0,0 +1,158 @@
package dataset
import "fmt"
// Dataset shape constants.
// Shape: (65, 47, 3, 361, 720) = (hour, pressure_level, variable, latitude, longitude)
// This matches the reference predictor exactly.
const (
NumHours = 65 // 0, 3, 6, ..., 192
NumLevels = 47 // pressure levels
NumVariables = 3 // height, wind_u, wind_v
NumLatitudes = 361 // -90.0 to +90.0 in 0.5 degree steps
NumLongitudes = 720 // 0.0 to 359.5 in 0.5 degree steps
HourStep = 3 // hours between forecast time steps
MaxHour = 192 // maximum forecast hour
Resolution = 0.5 // grid resolution in degrees
LatStart = -90.0 // first latitude in the dataset
LonStart = 0.0 // first longitude in the dataset
// Variable indices within the dataset.
VarHeight = 0
VarWindU = 1
VarWindV = 2
ElementSize = 4 // float32 = 4 bytes
)
// DatasetSize is the total size of the dataset file in bytes.
// 65 * 47 * 3 * 361 * 720 * 4 = 9,528,667,200
const DatasetSize int64 = int64(NumHours) * int64(NumLevels) * int64(NumVariables) *
int64(NumLatitudes) * int64(NumLongitudes) * int64(ElementSize)
// LevelSet identifies which GRIB file set a pressure level belongs to.
type LevelSet int
const (
LevelSetA LevelSet = iota // pgrb2 (primary)
LevelSetB // pgrb2b (secondary)
)
// Pressures contains the 47 pressure levels in hPa, sorted descending.
// Index 0 = 1000 hPa (near surface), Index 46 = 1 hPa (high atmosphere).
var Pressures = [NumLevels]int{
1000, 975, 950, 925, 900, 875, 850, 825, 800, 775,
750, 725, 700, 675, 650, 625, 600, 575, 550, 525,
500, 475, 450, 425, 400, 375, 350, 325, 300, 275,
250, 225, 200, 175, 150, 125, 100, 70, 50, 30,
20, 10, 7, 5, 3, 2, 1,
}
// pressureIndex maps pressure in hPa to its index in the Pressures array.
var pressureIndex map[int]int
// pressureLevelSet maps pressure in hPa to its GRIB file set.
var pressureLevelSet map[int]LevelSet
func init() {
pressureIndex = make(map[int]int, NumLevels)
for i, p := range Pressures {
pressureIndex[p] = i
}
pressureLevelSet = make(map[int]LevelSet, NumLevels)
for _, p := range PressuresPgrb2 {
pressureLevelSet[p] = LevelSetA
}
for _, p := range PressuresPgrb2b {
pressureLevelSet[p] = LevelSetB
}
}
// PressuresPgrb2 contains levels found in the primary pgrb2 file (26 levels).
var PressuresPgrb2 = []int{
10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 350, 400,
450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 925,
950, 975, 1000,
}
// PressuresPgrb2b contains levels found in the secondary pgrb2b file (21 levels).
var PressuresPgrb2b = []int{
1, 2, 3, 5, 7, 125, 175, 225, 275, 325, 375, 425,
475, 525, 575, 625, 675, 725, 775, 825, 875,
}
// PressureIndex returns the dataset index for a given pressure level in hPa.
// Returns -1 if the level is not found.
func PressureIndex(hPa int) int {
idx, ok := pressureIndex[hPa]
if !ok {
return -1
}
return idx
}
// PressureLevelSet returns which GRIB file set a pressure level belongs to.
func PressureLevelSet(hPa int) (LevelSet, bool) {
ls, ok := pressureLevelSet[hPa]
return ls, ok
}
// HourIndex returns the dataset time index for a forecast hour.
// Returns -1 if the hour is invalid (not a multiple of HourStep or out of range).
func HourIndex(hour int) int {
if hour < 0 || hour > MaxHour || hour%HourStep != 0 {
return -1
}
return hour / HourStep
}
// Hours returns all forecast hours as a slice: [0, 3, 6, ..., 192].
func Hours() []int {
out := make([]int, 0, NumHours)
for h := 0; h <= MaxHour; h += HourStep {
out = append(out, h)
}
return out
}
// S3 URL configuration for NOAA GFS data.
const S3BaseURL = "https://noaa-gfs-bdp-pds.s3.amazonaws.com"
// GribURL returns the S3 URL for a primary (pgrb2) GRIB file.
func GribURL(date string, runHour, forecastStep int) string {
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.0p50.f%03d",
S3BaseURL, date, runHour, runHour, forecastStep)
}
// GribURLB returns the S3 URL for a secondary (pgrb2b) GRIB file.
func GribURLB(date string, runHour, forecastStep int) string {
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2b.0p50.f%03d",
S3BaseURL, date, runHour, runHour, forecastStep)
}
// GribFileName returns the local filename for a primary GRIB file.
func GribFileName(runHour, forecastStep int) string {
return fmt.Sprintf("gfs.t%02dz.pgrb2.0p50.f%03d", runHour, forecastStep)
}
// GribFileNameB returns the local filename for a secondary GRIB file.
func GribFileNameB(runHour, forecastStep int) string {
return fmt.Sprintf("gfs.t%02dz.pgrb2b.0p50.f%03d", runHour, forecastStep)
}
// VariableIndex returns the dataset variable index for a GRIB parameter.
// Returns -1 if the parameter is not recognized.
func VariableIndex(parameterCategory, parameterNumber int) int {
switch {
case parameterCategory == 3 && parameterNumber == 5:
return VarHeight // Geopotential Height
case parameterCategory == 2 && parameterNumber == 2:
return VarWindU // U-component of wind
case parameterCategory == 2 && parameterNumber == 3:
return VarWindV // V-component of wind
default:
return -1
}
}

View file

@ -0,0 +1,152 @@
package dataset
import (
"testing"
)
func TestDatasetShape(t *testing.T) {
if NumHours != 65 {
t.Errorf("NumHours = %d, want 65", NumHours)
}
if NumLevels != 47 {
t.Errorf("NumLevels = %d, want 47", NumLevels)
}
if NumVariables != 3 {
t.Errorf("NumVariables = %d, want 3", NumVariables)
}
if NumLatitudes != 361 {
t.Errorf("NumLatitudes = %d, want 361", NumLatitudes)
}
if NumLongitudes != 720 {
t.Errorf("NumLongitudes = %d, want 720", NumLongitudes)
}
}
func TestDatasetSize(t *testing.T) {
// 65 * 47 * 3 * 361 * 720 * 4 = 9,528,667,200
want := int64(9_528_667_200)
if DatasetSize != want {
t.Errorf("DatasetSize = %d, want %d", DatasetSize, want)
}
}
func TestPressureLevels(t *testing.T) {
if len(Pressures) != 47 {
t.Fatalf("len(Pressures) = %d, want 47", len(Pressures))
}
// First should be 1000 (highest pressure, near surface)
if Pressures[0] != 1000 {
t.Errorf("Pressures[0] = %d, want 1000", Pressures[0])
}
// Last should be 1 (lowest pressure, high atmosphere)
if Pressures[46] != 1 {
t.Errorf("Pressures[46] = %d, want 1", Pressures[46])
}
// Should be sorted descending
for i := 1; i < len(Pressures); i++ {
if Pressures[i] >= Pressures[i-1] {
t.Errorf("Pressures not descending at [%d]: %d >= %d", i, Pressures[i], Pressures[i-1])
}
}
// Total levels: 26 from pgrb2 + 21 from pgrb2b = 47
if len(PressuresPgrb2) != 26 {
t.Errorf("len(PressuresPgrb2) = %d, want 26", len(PressuresPgrb2))
}
if len(PressuresPgrb2b) != 21 {
t.Errorf("len(PressuresPgrb2b) = %d, want 21", len(PressuresPgrb2b))
}
}
func TestPressureIndex(t *testing.T) {
if PressureIndex(1000) != 0 {
t.Errorf("PressureIndex(1000) = %d, want 0", PressureIndex(1000))
}
if PressureIndex(1) != 46 {
t.Errorf("PressureIndex(1) = %d, want 46", PressureIndex(1))
}
if PressureIndex(500) != 20 {
t.Errorf("PressureIndex(500) = %d, want 20", PressureIndex(500))
}
if PressureIndex(9999) != -1 {
t.Errorf("PressureIndex(9999) = %d, want -1", PressureIndex(9999))
}
}
func TestPressureLevelSet(t *testing.T) {
// 1000 mb should be in pgrb2 (A)
ls, ok := PressureLevelSet(1000)
if !ok || ls != LevelSetA {
t.Errorf("PressureLevelSet(1000) = %v, %v; want A, true", ls, ok)
}
// 125 mb should be in pgrb2b (B)
ls, ok = PressureLevelSet(125)
if !ok || ls != LevelSetB {
t.Errorf("PressureLevelSet(125) = %v, %v; want B, true", ls, ok)
}
// 1, 2, 3, 5, 7 should be in pgrb2b (B)
for _, p := range []int{1, 2, 3, 5, 7} {
ls, ok := PressureLevelSet(p)
if !ok || ls != LevelSetB {
t.Errorf("PressureLevelSet(%d) = %v, %v; want B, true", p, ls, ok)
}
}
// Every pressure level should have a level set assignment
for _, p := range Pressures {
_, ok := PressureLevelSet(p)
if !ok {
t.Errorf("PressureLevelSet(%d) not found", p)
}
}
}
func TestHourIndex(t *testing.T) {
if HourIndex(0) != 0 {
t.Errorf("HourIndex(0) = %d, want 0", HourIndex(0))
}
if HourIndex(3) != 1 {
t.Errorf("HourIndex(3) = %d, want 1", HourIndex(3))
}
if HourIndex(192) != 64 {
t.Errorf("HourIndex(192) = %d, want 64", HourIndex(192))
}
if HourIndex(1) != -1 {
t.Errorf("HourIndex(1) = %d, want -1 (not multiple of 3)", HourIndex(1))
}
if HourIndex(195) != -1 {
t.Errorf("HourIndex(195) = %d, want -1 (out of range)", HourIndex(195))
}
}
func TestHours(t *testing.T) {
hours := Hours()
if len(hours) != NumHours {
t.Fatalf("len(Hours()) = %d, want %d", len(hours), NumHours)
}
if hours[0] != 0 {
t.Errorf("Hours()[0] = %d, want 0", hours[0])
}
if hours[len(hours)-1] != MaxHour {
t.Errorf("Hours()[last] = %d, want %d", hours[len(hours)-1], MaxHour)
}
}
func TestVariableIndex(t *testing.T) {
if VariableIndex(3, 5) != VarHeight {
t.Errorf("HGT: got %d, want %d", VariableIndex(3, 5), VarHeight)
}
if VariableIndex(2, 2) != VarWindU {
t.Errorf("UGRD: got %d, want %d", VariableIndex(2, 2), VarWindU)
}
if VariableIndex(2, 3) != VarWindV {
t.Errorf("VGRD: got %d, want %d", VariableIndex(2, 3), VarWindV)
}
if VariableIndex(0, 0) != -1 {
t.Errorf("unknown: got %d, want -1", VariableIndex(0, 0))
}
}

140
internal/dataset/file.go Normal file
View file

@ -0,0 +1,140 @@
package dataset
import (
"encoding/binary"
"fmt"
"math"
"os"
"time"
mmap "github.com/edsrzf/mmap-go"
)
// File represents an mmap-backed wind dataset file.
type File struct {
mm mmap.MMap
file *os.File
writable bool
DSTime time.Time // forecast run time (UTC)
}
// Open opens an existing dataset file for reading.
func Open(path string, dsTime time.Time) (*File, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open dataset: %w", err)
}
info, err := f.Stat()
if err != nil {
f.Close()
return nil, fmt.Errorf("stat dataset: %w", err)
}
if info.Size() != DatasetSize {
f.Close()
return nil, fmt.Errorf("dataset should be %d bytes (was %d)", DatasetSize, info.Size())
}
mm, err := mmap.Map(f, mmap.RDONLY, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap dataset: %w", err)
}
return &File{mm: mm, file: f, writable: false, DSTime: dsTime}, nil
}
// Create creates a new dataset file for writing.
// The file is truncated to the correct size and mmap'd read-write.
func Create(path string) (*File, error) {
f, err := os.Create(path)
if err != nil {
return nil, fmt.Errorf("create dataset: %w", err)
}
if err := f.Truncate(DatasetSize); err != nil {
f.Close()
return nil, fmt.Errorf("truncate dataset: %w", err)
}
mm, err := mmap.MapRegion(f, int(DatasetSize), mmap.RDWR, 0, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap dataset: %w", err)
}
return &File{mm: mm, file: f, writable: true}, nil
}
// offset computes the byte offset for element [hour][level][variable][lat][lon].
// Row-major C-order indexing matching the reference implementation:
// shape = (65, 47, 3, 361, 720)
func offset(hour, level, variable, lat, lon int) int64 {
idx := int64(hour)
idx = idx*int64(NumLevels) + int64(level)
idx = idx*int64(NumVariables) + int64(variable)
idx = idx*int64(NumLatitudes) + int64(lat)
idx = idx*int64(NumLongitudes) + int64(lon)
return idx * int64(ElementSize)
}
// Val reads a float32 value from the dataset at [hour][level][variable][lat][lon].
func (d *File) Val(hour, level, variable, lat, lon int) float32 {
off := offset(hour, level, variable, lat, lon)
bits := binary.LittleEndian.Uint32(d.mm[off : off+4])
return math.Float32frombits(bits)
}
// SetVal writes a float32 value to the dataset at [hour][level][variable][lat][lon].
// Only valid on writable (created) datasets.
func (d *File) SetVal(hour, level, variable, lat, lon int, val float32) {
off := offset(hour, level, variable, lat, lon)
binary.LittleEndian.PutUint32(d.mm[off:off+4], math.Float32bits(val))
}
// BlitGribData copies decoded GRIB grid data into the dataset at the given position.
// gribData is 361*720 float64 values in GRIB scan order (north-to-south, west-to-east).
// This function flips the latitude so that dataset index 0 = -90 (south) and 360 = +90 (north).
func (d *File) BlitGribData(hourIdx, levelIdx, varIdx int, gribData []float64) error {
expected := NumLatitudes * NumLongitudes
if len(gribData) != expected {
return fmt.Errorf("grib data has %d values, expected %d", len(gribData), expected)
}
for lat := 0; lat < NumLatitudes; lat++ {
for lon := 0; lon < NumLongitudes; lon++ {
// GRIB scans north-to-south: row 0 = 90N, row 360 = 90S
// Dataset stores south-to-north: index 0 = -90 (90S), index 360 = +90 (90N)
gribIdx := (360-lat)*NumLongitudes + lon
val := float32(gribData[gribIdx])
d.SetVal(hourIdx, levelIdx, varIdx, lat, lon, val)
}
}
return nil
}
// Flush flushes the mmap to disk.
func (d *File) Flush() error {
if d.mm != nil {
return d.mm.Flush()
}
return nil
}
// Close unmaps and closes the dataset file.
func (d *File) Close() error {
if d.mm != nil {
if err := d.mm.Unmap(); err != nil {
d.file.Close()
return fmt.Errorf("unmap: %w", err)
}
d.mm = nil
}
if d.file != nil {
err := d.file.Close()
d.file = nil
return err
}
return nil
}

View file

@ -0,0 +1,58 @@
package downloader
import (
"os"
"strconv"
"time"
)
// Config holds downloader configuration, loaded from environment variables.
type Config struct {
// DataDir is the directory for storing dataset files and temporary GRIB data.
DataDir string
// Parallel is the maximum number of concurrent GRIB downloads.
Parallel int
// UpdateInterval is how often the scheduler checks for new forecast data.
UpdateInterval time.Duration
// DatasetTTL is how long a dataset is considered fresh before a new one is needed.
DatasetTTL time.Duration
}
// DefaultConfig returns the default configuration.
func DefaultConfig() *Config {
return &Config{
DataDir: "/tmp/predictor-data",
Parallel: 8,
UpdateInterval: 6 * time.Hour,
DatasetTTL: 48 * time.Hour,
}
}
// LoadConfig loads configuration from environment variables, falling back to defaults.
func LoadConfig() *Config {
cfg := DefaultConfig()
if v := os.Getenv("PREDICTOR_DATA_DIR"); v != "" {
cfg.DataDir = v
}
if v := os.Getenv("PREDICTOR_DOWNLOAD_PARALLEL"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
cfg.Parallel = n
}
}
if v := os.Getenv("PREDICTOR_UPDATE_INTERVAL"); v != "" {
if d, err := time.ParseDuration(v); err == nil {
cfg.UpdateInterval = d
}
}
if v := os.Getenv("PREDICTOR_DATASET_TTL"); v != "" {
if d, err := time.ParseDuration(v); err == nil {
cfg.DatasetTTL = d
}
}
return cfg
}

View file

@ -0,0 +1,441 @@
package downloader
import (
"context"
"fmt"
"io"
"math"
"net/http"
"os"
"path/filepath"
"sync/atomic"
"time"
"predictor-refactored/internal/dataset"
"github.com/nilsmagnus/grib/griblib"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
)
// Downloader handles fetching GFS forecast data from S3 and assembling dataset files.
type Downloader struct {
cfg *Config
client *http.Client
log *zap.Logger
}
// NewDownloader creates a new Downloader.
func NewDownloader(cfg *Config, log *zap.Logger) *Downloader {
return &Downloader{
cfg: cfg,
client: &http.Client{
Timeout: 2 * time.Minute,
},
log: log,
}
}
// neededVariables is the set of GRIB variable names we need.
var neededVariables = map[string]bool{
"HGT": true,
"UGRD": true,
"VGRD": true,
}
// FindLatestRun finds the most recent available GFS model run on S3.
// It checks the last forecast step of each run to confirm availability.
func (d *Downloader) FindLatestRun(ctx context.Context) (time.Time, error) {
now := time.Now().UTC()
hour := now.Hour() - (now.Hour() % 6)
current := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC)
for i := 0; i < 8; i++ {
date := current.Format("20060102")
url := dataset.GribURL(date, current.Hour(), dataset.MaxHour) + ".idx"
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
if err != nil {
current = current.Add(-6 * time.Hour)
continue
}
resp, err := d.client.Do(req)
if err == nil {
resp.Body.Close()
if resp.StatusCode == http.StatusOK {
d.log.Info("found latest model run",
zap.Time("run", current),
zap.String("verified_url", url))
return current, nil
}
}
current = current.Add(-6 * time.Hour)
}
return time.Time{}, fmt.Errorf("no recent GFS forecast found (checked 8 runs)")
}
// progress tracks download progress across concurrent goroutines.
type progress struct {
bytesDownloaded atomic.Int64
stepsCompleted atomic.Int64
totalSteps int64
startTime time.Time
log *zap.Logger
}
func newProgress(totalSteps int, log *zap.Logger) *progress {
return &progress{
totalSteps: int64(totalSteps),
startTime: time.Now(),
log: log,
}
}
func (p *progress) addBytes(n int64) {
p.bytesDownloaded.Add(n)
}
func (p *progress) completeStep() {
done := p.stepsCompleted.Add(1)
total := p.totalSteps
bytes := p.bytesDownloaded.Load()
elapsed := time.Since(p.startTime).Seconds()
pct := float64(done) / float64(total) * 100
mbDownloaded := float64(bytes) / (1024 * 1024)
mbPerSec := 0.0
if elapsed > 0 {
mbPerSec = mbDownloaded / elapsed
}
// Estimate remaining
eta := ""
if done > 0 && done < total {
secsPerStep := elapsed / float64(done)
remaining := secsPerStep * float64(total-done)
if remaining > 60 {
eta = fmt.Sprintf("%.0fm%02.0fs", math.Floor(remaining/60), math.Mod(remaining, 60))
} else {
eta = fmt.Sprintf("%.0fs", remaining)
}
}
p.log.Info("download progress",
zap.String("progress", fmt.Sprintf("%d/%d", done, total)),
zap.String("percent", fmt.Sprintf("%.1f%%", pct)),
zap.String("downloaded", fmt.Sprintf("%.1f MB", mbDownloaded)),
zap.String("speed", fmt.Sprintf("%.1f MB/s", mbPerSec)),
zap.String("eta", eta))
}
// Download downloads a complete forecast and assembles a dataset file.
// Returns the path to the completed dataset file.
func (d *Downloader) Download(ctx context.Context, run time.Time) (string, error) {
date := run.Format("20060102")
runHour := run.Hour()
finalPath := filepath.Join(d.cfg.DataDir, run.Format("2006010215"))
tempPath := finalPath + ".downloading"
// Check if final dataset already exists
if info, err := os.Stat(finalPath); err == nil && info.Size() == dataset.DatasetSize {
d.log.Info("dataset already exists", zap.String("path", finalPath))
return finalPath, nil
}
steps := dataset.Hours()
totalSteps := len(steps) * 2 // pgrb2 + pgrb2b per step
prog := newProgress(totalSteps, d.log)
d.log.Info("starting dataset download",
zap.Time("run", run),
zap.Int("total_steps", totalSteps),
zap.String("temp_path", tempPath))
// Create the dataset file
ds, err := dataset.Create(tempPath)
if err != nil {
return "", fmt.Errorf("create dataset: %w", err)
}
defer ds.Close()
// Process each forecast step with bounded concurrency
g, ctx := errgroup.WithContext(ctx)
sem := make(chan struct{}, d.cfg.Parallel)
for _, step := range steps {
step := step
hourIdx := dataset.HourIndex(step)
if hourIdx < 0 {
continue
}
// Download pgrb2 (level set A)
sem <- struct{}{}
g.Go(func() error {
defer func() { <-sem }()
url := dataset.GribURL(date, runHour, step)
err := d.downloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetA, prog)
if err != nil {
return fmt.Errorf("step %d pgrb2: %w", step, err)
}
prog.completeStep()
return nil
})
// Download pgrb2b (level set B)
sem <- struct{}{}
g.Go(func() error {
defer func() { <-sem }()
url := dataset.GribURLB(date, runHour, step)
err := d.downloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetB, prog)
if err != nil {
return fmt.Errorf("step %d pgrb2b: %w", step, err)
}
prog.completeStep()
return nil
})
}
if err := g.Wait(); err != nil {
os.Remove(tempPath)
return "", err
}
elapsed := time.Since(prog.startTime)
totalMB := float64(prog.bytesDownloaded.Load()) / (1024 * 1024)
d.log.Info("download complete, flushing to disk",
zap.String("downloaded", fmt.Sprintf("%.1f MB", totalMB)),
zap.Duration("elapsed", elapsed),
zap.String("avg_speed", fmt.Sprintf("%.1f MB/s", totalMB/elapsed.Seconds())))
// Flush to disk
if err := ds.Flush(); err != nil {
os.Remove(tempPath)
return "", fmt.Errorf("flush dataset: %w", err)
}
// Close before rename
ds.Close()
// Atomic rename
if err := os.Rename(tempPath, finalPath); err != nil {
os.Remove(tempPath)
return "", fmt.Errorf("rename dataset: %w", err)
}
d.log.Info("dataset ready", zap.String("path", finalPath))
return finalPath, nil
}
// DownloadAndBlit downloads needed GRIB fields from a URL and writes them into the dataset.
func (d *Downloader) DownloadAndBlit(ctx context.Context, ds *dataset.File, baseURL string, hourIdx int, levelSet dataset.LevelSet) error {
return d.downloadAndBlit(ctx, ds, baseURL, hourIdx, levelSet, nil)
}
// downloadAndBlit is the internal implementation with optional progress tracking.
func (d *Downloader) downloadAndBlit(ctx context.Context, ds *dataset.File, baseURL string, hourIdx int, levelSet dataset.LevelSet, prog *progress) error {
// 1. Download .idx
idxURL := baseURL + ".idx"
idxBody, err := d.httpGet(ctx, idxURL)
if err != nil {
return fmt.Errorf("download idx: %w", err)
}
// 2. Parse and filter
entries := ParseIdx(idxBody)
filtered := FilterIdx(entries, neededVariables)
// Further filter to only levels in this level set
var relevant []IdxEntry
for _, e := range filtered {
ls, ok := dataset.PressureLevelSet(e.LevelMB)
if ok && ls == levelSet {
relevant = append(relevant, e)
}
}
if len(relevant) == 0 {
d.log.Warn("no relevant entries found in idx",
zap.String("url", idxURL),
zap.Int("total_entries", len(entries)),
zap.Int("filtered", len(filtered)))
return nil
}
// 3. Download byte ranges and write to temp file
ranges := EntriesToRanges(relevant)
tmpFile, err := d.downloadRangesToTempFile(ctx, baseURL, ranges, prog)
if err != nil {
return fmt.Errorf("download ranges: %w", err)
}
defer os.Remove(tmpFile)
// 4. Read GRIB messages from temp file
f, err := os.Open(tmpFile)
if err != nil {
return fmt.Errorf("open temp grib: %w", err)
}
messages, err := griblib.ReadMessages(f)
f.Close()
if err != nil {
return fmt.Errorf("read grib messages: %w", err)
}
// 5. Decode and blit each message into the dataset
for _, msg := range messages {
if msg.Section4.ProductDefinitionTemplateNumber != 0 {
continue
}
product := msg.Section4.ProductDefinitionTemplate
varIdx := dataset.VariableIndex(int(product.ParameterCategory), int(product.ParameterNumber))
if varIdx < 0 {
continue
}
if product.FirstSurface.Type != 100 { // isobaric surface
continue
}
pressurePa := float64(product.FirstSurface.Value)
pressureMB := int(math.Round(pressurePa / 100.0))
levelIdx := dataset.PressureIndex(pressureMB)
if levelIdx < 0 {
continue
}
data := msg.Data()
if err := ds.BlitGribData(hourIdx, levelIdx, varIdx, data); err != nil {
d.log.Warn("blit failed",
zap.Int("var", varIdx),
zap.Int("level_mb", pressureMB),
zap.Error(err))
continue
}
}
return nil
}
// downloadRangesToTempFile downloads multiple byte ranges from a URL,
// concatenating them into a single temp file (valid concatenated GRIB messages).
func (d *Downloader) downloadRangesToTempFile(ctx context.Context, baseURL string, ranges []ByteRange, prog *progress) (string, error) {
tmpFile, err := os.CreateTemp(d.cfg.DataDir, "grib-*.tmp")
if err != nil {
return "", fmt.Errorf("create temp file: %w", err)
}
tmpPath := tmpFile.Name()
for _, r := range ranges {
data, err := d.httpGetRange(ctx, baseURL, r.Start, r.End)
if err != nil {
tmpFile.Close()
os.Remove(tmpPath)
return "", fmt.Errorf("download range %d-%d: %w", r.Start, r.End, err)
}
if _, err := tmpFile.Write(data); err != nil {
tmpFile.Close()
os.Remove(tmpPath)
return "", fmt.Errorf("write temp: %w", err)
}
if prog != nil {
prog.addBytes(int64(len(data)))
}
}
if err := tmpFile.Close(); err != nil {
os.Remove(tmpPath)
return "", err
}
return tmpPath, nil
}
// httpGet downloads a URL and returns the body bytes.
func (d *Downloader) httpGet(ctx context.Context, url string) ([]byte, error) {
var lastErr error
for attempt := 0; attempt < 3; attempt++ {
if attempt > 0 {
select {
case <-time.After(time.Duration(attempt*2) * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := d.client.Do(req)
if err != nil {
lastErr = err
continue
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url)
continue
}
if err != nil {
lastErr = err
continue
}
return body, nil
}
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
}
// httpGetRange downloads a byte range from a URL.
func (d *Downloader) httpGetRange(ctx context.Context, url string, start, end int64) ([]byte, error) {
var lastErr error
for attempt := 0; attempt < 3; attempt++ {
if attempt > 0 {
select {
case <-time.After(time.Duration(attempt*2) * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
resp, err := d.client.Do(req)
if err != nil {
lastErr = err
continue
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("HTTP %d for range %d-%d of %s", resp.StatusCode, start, end, url)
continue
}
if err != nil {
lastErr = err
continue
}
return body, nil
}
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
}

157
internal/downloader/idx.go Normal file
View file

@ -0,0 +1,157 @@
package downloader
import (
"fmt"
"strconv"
"strings"
)
// IdxEntry represents a single parsed line from a GRIB .idx file.
// Example line: "15:1207405:d=2024010100:HGT:1000 mb:0 hour fcst:"
type IdxEntry struct {
Index int
Offset int64
Variable string // "HGT", "UGRD", "VGRD", etc.
LevelMB int // pressure level in mb (0 if not a pressure level)
Hour int // forecast hour
EndOffset int64 // byte after this message (from next entry's offset, or -1 if last)
}
// Length returns the byte length of this GRIB message, or -1 if unknown.
func (e *IdxEntry) Length() int64 {
if e.EndOffset <= 0 {
return -1
}
return e.EndOffset - e.Offset
}
// ParseIdx parses a .idx file body and returns all entries.
// Lines that can't be parsed are silently skipped.
func ParseIdx(body []byte) []IdxEntry {
lines := strings.Split(string(body), "\n")
var entries []IdxEntry
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.Split(line, ":")
if len(parts) < 7 {
continue
}
idx, err := strconv.Atoi(parts[0])
if err != nil {
continue
}
offset, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
continue
}
variable := parts[3]
levelStr := parts[4]
hourStr := parts[5]
levelMB := parseLevelMB(levelStr)
hour := parseHour(hourStr)
entries = append(entries, IdxEntry{
Index: idx,
Offset: offset,
Variable: variable,
LevelMB: levelMB,
Hour: hour,
EndOffset: -1, // filled in below
})
}
// Fill in EndOffset from the next entry's Offset.
for i := 0; i < len(entries)-1; i++ {
entries[i].EndOffset = entries[i+1].Offset
}
return entries
}
// FilterIdx returns entries matching the given variables at pressure levels.
// Only entries with a recognized pressure level (levelMB > 0) are returned.
func FilterIdx(entries []IdxEntry, variables map[string]bool) []IdxEntry {
var filtered []IdxEntry
for _, e := range entries {
if !variables[e.Variable] {
continue
}
if e.LevelMB <= 0 {
continue
}
// Must have a known length (not the last entry) or be handled specially
if e.Length() <= 0 {
continue
}
filtered = append(filtered, e)
}
return filtered
}
// parseLevelMB parses a level string like "1000 mb" and returns the pressure in mb.
// Returns 0 if not a pressure level.
func parseLevelMB(s string) int {
s = strings.TrimSpace(s)
if !strings.HasSuffix(s, " mb") {
return 0
}
numStr := strings.TrimSuffix(s, " mb")
n, err := strconv.Atoi(numStr)
if err != nil {
return 0
}
return n
}
// parseHour parses a forecast hour string like "0 hour fcst" or "anl".
// Returns -1 if it can't be parsed.
func parseHour(s string) int {
s = strings.TrimSpace(s)
if s == "anl" {
return 0
}
s = strings.TrimSuffix(s, " hour fcst")
n, err := strconv.Atoi(s)
if err != nil {
return -1
}
return n
}
// GroupByRange groups idx entries into byte ranges suitable for HTTP Range downloads.
// Each range covers one contiguous GRIB message.
type ByteRange struct {
Start int64
End int64 // inclusive
Entry IdxEntry
}
// EntriesToRanges converts filtered idx entries to byte ranges.
func EntriesToRanges(entries []IdxEntry) []ByteRange {
ranges := make([]ByteRange, 0, len(entries))
for _, e := range entries {
if e.Length() <= 0 {
continue
}
ranges = append(ranges, ByteRange{
Start: e.Offset,
End: e.EndOffset - 1, // inclusive
Entry: e,
})
}
return ranges
}
// FormatRange returns an HTTP Range header value for a byte range.
func (r ByteRange) FormatRange() string {
return fmt.Sprintf("bytes=%d-%d", r.Start, r.End)
}

View file

@ -0,0 +1,110 @@
package downloader
import (
"testing"
)
const sampleIdx = `1:0:d=2024010100:HGT:1000 mb:0 hour fcst:
2:289012:d=2024010100:HGT:975 mb:0 hour fcst:
3:541876:d=2024010100:TMP:1000 mb:0 hour fcst:
4:789012:d=2024010100:UGRD:1000 mb:0 hour fcst:
5:1045678:d=2024010100:VGRD:1000 mb:0 hour fcst:
6:1298765:d=2024010100:UGRD:975 mb:0 hour fcst:
7:1567890:d=2024010100:UGRD:2 m above ground:0 hour fcst:
8:1812345:d=2024010100:VGRD:975 mb:0 hour fcst:
9:2098765:d=2024010100:HGT:500 mb:3 hour fcst:
`
func TestParseIdx(t *testing.T) {
entries := ParseIdx([]byte(sampleIdx))
if len(entries) != 9 {
t.Fatalf("expected 9 entries, got %d", len(entries))
}
// Check first entry
e := entries[0]
if e.Index != 1 || e.Offset != 0 || e.Variable != "HGT" || e.LevelMB != 1000 || e.Hour != 0 {
t.Errorf("entry 0: got %+v", e)
}
if e.EndOffset != 289012 {
t.Errorf("entry 0 EndOffset: got %d, want 289012", e.EndOffset)
}
// Check "2 m above ground" is not a pressure level
e = entries[6] // UGRD at "2 m above ground"
if e.LevelMB != 0 {
t.Errorf("non-pressure level should have LevelMB=0, got %d", e.LevelMB)
}
// Last entry should have EndOffset = -1
last := entries[len(entries)-1]
if last.EndOffset != -1 {
t.Errorf("last entry EndOffset: got %d, want -1", last.EndOffset)
}
}
func TestFilterIdx(t *testing.T) {
entries := ParseIdx([]byte(sampleIdx))
filtered := FilterIdx(entries, neededVariables)
// Should include HGT/UGRD/VGRD at pressure levels, exclude TMP and "above ground"
// And exclude last entry (no EndOffset)
for _, e := range filtered {
if !neededVariables[e.Variable] {
t.Errorf("unexpected variable %s", e.Variable)
}
if e.LevelMB <= 0 {
t.Errorf("non-pressure level included: %+v", e)
}
if e.Length() <= 0 {
t.Errorf("entry with unknown length included: %+v", e)
}
}
// Count expected: HGT@1000, HGT@975, UGRD@1000, VGRD@1000, UGRD@975, VGRD@975 = 6
// But HGT@500 at 3hr fcst is the last entry (no EndOffset), so excluded
if len(filtered) != 6 {
t.Errorf("expected 6 filtered entries, got %d", len(filtered))
for _, e := range filtered {
t.Logf(" %s %d mb (offset %d, len %d)", e.Variable, e.LevelMB, e.Offset, e.Length())
}
}
}
func TestParseLevelMB(t *testing.T) {
tests := []struct {
input string
want int
}{
{"1000 mb", 1000},
{"975 mb", 975},
{"1 mb", 1},
{"2 m above ground", 0},
{"surface", 0},
{"tropopause", 0},
}
for _, tt := range tests {
got := parseLevelMB(tt.input)
if got != tt.want {
t.Errorf("parseLevelMB(%q) = %d, want %d", tt.input, got, tt.want)
}
}
}
func TestParseHour(t *testing.T) {
tests := []struct {
input string
want int
}{
{"0 hour fcst", 0},
{"3 hour fcst", 3},
{"192 hour fcst", 192},
{"anl", 0},
}
for _, tt := range tests {
got := parseHour(tt.input)
if got != tt.want {
t.Errorf("parseHour(%q) = %d, want %d", tt.input, got, tt.want)
}
}
}

View file

@ -0,0 +1,113 @@
package elevation
import (
"encoding/binary"
"fmt"
"math"
"os"
mmap "github.com/edsrzf/mmap-go"
)
// Dataset provides global elevation lookup, compatible with ruaumoko.
// Binary format: int16 little-endian elevation values in metres, row-major (lat, lon).
// Latitude axis: -90 to +90 (south to north), Longitude axis: 0 to 360 (wraps).
// Resolution: 120 cells per degree (30 arc-seconds).
const (
CellsPerDegree = 120
NumLats = 180*CellsPerDegree + 1 // 21601
NumLons = 360 * CellsPerDegree // 43200
DataSize = NumLats * NumLons * 2 // 1,866,326,400 bytes (~1.74 GiB)
)
// Dataset is a memory-mapped global elevation grid.
type Dataset struct {
mm mmap.MMap
file *os.File
}
// Open opens an existing elevation dataset file.
func Open(path string) (*Dataset, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open elevation: %w", err)
}
info, err := f.Stat()
if err != nil {
f.Close()
return nil, fmt.Errorf("stat elevation: %w", err)
}
if info.Size() != DataSize {
f.Close()
return nil, fmt.Errorf("elevation dataset should be %d bytes (was %d)", DataSize, info.Size())
}
mm, err := mmap.Map(f, mmap.RDONLY, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap elevation: %w", err)
}
return &Dataset{mm: mm, file: f}, nil
}
// getCell reads the int16 elevation at grid indices (latIdx, lngIdx).
func (d *Dataset) getCell(latIdx, lngIdx int) int16 {
// Clamp latitude
if latIdx < 0 {
latIdx = 0
}
if latIdx >= NumLats {
latIdx = NumLats - 1
}
// Wrap longitude
lngIdx = lngIdx % NumLons
if lngIdx < 0 {
lngIdx += NumLons
}
off := (latIdx*NumLons + lngIdx) * 2
return int16(binary.LittleEndian.Uint16(d.mm[off : off+2]))
}
// Get returns the interpolated elevation in metres at the given coordinates.
// lat: -90 to +90, lng: 0 to 360 (or -180 to 180, will be normalised).
func (d *Dataset) Get(lat, lng float64) float64 {
// Normalise longitude to [0, 360)
if lng < 0 {
lng += 360
}
// Convert to cell coordinates
latCell := (lat + 90.0) * CellsPerDegree
lngCell := lng * CellsPerDegree
lat0 := int(math.Floor(latCell))
lng0 := int(math.Floor(lngCell))
latFrac := latCell - float64(lat0)
lngFrac := lngCell - float64(lng0)
// Bilinear interpolation
v00 := float64(d.getCell(lat0, lng0))
v10 := float64(d.getCell(lat0+1, lng0))
v01 := float64(d.getCell(lat0, lng0+1))
v11 := float64(d.getCell(lat0+1, lng0+1))
return (1-latFrac)*((1-lngFrac)*v00+lngFrac*v01) +
latFrac*((1-lngFrac)*v10+lngFrac*v11)
}
// Close unmaps and closes the dataset.
func (d *Dataset) Close() error {
if d.mm != nil {
d.mm.Unmap()
d.mm = nil
}
if d.file != nil {
err := d.file.Close()
d.file = nil
return err
}
return nil
}

View file

@ -1,23 +0,0 @@
package updater
import (
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
env "github.com/caarlos0/env/v11"
)
type Config struct {
Interval time.Duration `env:"INTERVAL" envDefault:"6h"`
Timeout time.Duration `env:"TIMEOUT" envDefault:"45m"`
}
func NewConfig() (*Config, error) {
cfg := &Config{}
if err := env.ParseWithOptions(cfg, env.Options{
PrefixTagName: "GSN_PREDICTOR_GRIB_UPDATER_",
}); err != nil {
return nil, errcodes.Wrap(err, "failed to parse GRIB updater config")
}
return cfg, nil
}

View file

@ -1,8 +0,0 @@
package updater
import "context"
// GribService defines the interface for GRIB operations needed by the updater job
type GribService interface {
Update(ctx context.Context) error
}

View file

@ -1,51 +0,0 @@
package updater
import (
"context"
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
"git.intra.yksa.space/gsn/predictor/internal/pkg/log"
"go.uber.org/zap"
)
type Job struct {
service GribService
config *Config
}
func New(service GribService, config *Config) *Job {
return &Job{
service: service,
config: config,
}
}
func (j *Job) GetInterval() time.Duration {
return j.config.Interval
}
func (j *Job) GetTimeout() time.Duration {
return j.config.Timeout
}
func (j *Job) GetCount() int {
return 1
}
func (j *Job) GetAsync() bool {
return false
}
func (j *Job) Execute(ctx context.Context) error {
log := log.Ctx(ctx)
log.Info("executing GRIB update job")
if err := j.service.Update(ctx); err != nil {
log.Error("GRIB update failed", zap.Error(err))
return errcodes.Wrap(err, "failed to update GRIB data")
}
log.Info("GRIB update completed successfully")
return nil
}

View file

@ -1,89 +0,0 @@
package ds
import (
"time"
api "git.intra.yksa.space/gsn/predictor/pkg/rest"
)
type PredictionParameters struct {
LaunchLatitude *float64
LaunchLongitude *float64
LaunchDatetime *time.Time
LaunchAltitude *float64
Profile *string
AscentRate *float64
BurstAltitude *float64
DescentRate *float64
FloatAltitude *float64
StopDatetime *time.Time
AscentCurve *string // base64
DescentCurve *string // base64
Interpolate *bool
Format *string
Dataset *time.Time
// Add other parameters as needed
}
type PredicitonResult struct {
Latitude *float64
Longitude *float64
Altitude *float64
Timestamp *time.Time
WindU *float64
WindV *float64
// Add other result fields as needed
}
// Converts flat ogen params to internal pointer-based model
func ConvertFlatPredictionParams(params api.PerformPredictionParams) *PredictionParameters {
out := &PredictionParameters{}
if v, ok := params.LaunchLatitude.Get(); ok {
out.LaunchLatitude = &v
}
if v, ok := params.LaunchLongitude.Get(); ok {
out.LaunchLongitude = &v
}
if v, ok := params.LaunchDatetime.Get(); ok {
out.LaunchDatetime = &v
}
if v, ok := params.LaunchAltitude.Get(); ok {
out.LaunchAltitude = &v
}
if v, ok := params.Profile.Get(); ok {
s := string(v)
out.Profile = &s
}
if v, ok := params.AscentRate.Get(); ok {
out.AscentRate = &v
}
if v, ok := params.BurstAltitude.Get(); ok {
out.BurstAltitude = &v
}
if v, ok := params.DescentRate.Get(); ok {
out.DescentRate = &v
}
if v, ok := params.FloatAltitude.Get(); ok {
out.FloatAltitude = &v
}
if v, ok := params.StopDatetime.Get(); ok {
out.StopDatetime = &v
}
if v, ok := params.AscentCurve.Get(); ok {
out.AscentCurve = &v
}
if v, ok := params.DescentCurve.Get(); ok {
out.DescentCurve = &v
}
if v, ok := params.Interpolate.Get(); ok {
out.Interpolate = &v
}
if v, ok := params.Format.Get(); ok {
s := string(v)
out.Format = &s
}
if v, ok := params.Dataset.Get(); ok {
out.Dataset = &v
}
return out
}

View file

@ -1,102 +0,0 @@
package errcodes
import (
"net/http"
"strings"
)
type ErrorCode struct {
StatusCode int
Message string
Details string
}
func New(statusCode int, message string, details ...string) *ErrorCode {
return &ErrorCode{
StatusCode: statusCode,
Message: message,
Details: strings.Join(details, " "),
}
}
func (e *ErrorCode) Error() string {
return e.Message
}
func IsErr(err error) bool {
_, ok := err.(*ErrorCode)
return ok
}
func AsErr(err error) (*ErrorCode, bool) {
if err == nil {
return nil, false
}
errcode, ok := err.(*ErrorCode)
return errcode, ok
}
func Join(errs ...error) error {
if len(errs) == 0 {
return nil
}
var messages []string
var details []string
for _, err := range errs {
if err == nil {
continue
}
if errcode, ok := AsErr(err); ok {
messages = append(messages, errcode.Message)
if errcode.Details != "" {
details = append(details, errcode.Details)
}
} else {
messages = append(messages, err.Error())
}
}
if len(messages) == 0 {
return nil
}
statusCode := http.StatusInternalServerError
if len(errs) > 0 {
if errcode, ok := AsErr(errs[0]); ok {
statusCode = errcode.StatusCode
}
}
return New(statusCode, strings.Join(messages, "; "), details...)
}
func Wrap(err error, message string) error {
if err == nil {
return nil
}
if errcode, ok := AsErr(err); ok {
return New(errcode.StatusCode, message, errcode.Message, errcode.Details)
}
return New(http.StatusInternalServerError, message, err.Error())
}
var (
ErrNoDataset = New(http.StatusNotFound, "no grib dataset found")
ErrOutOfBounds = New(http.StatusBadRequest, "requested time is out of bounds")
ErrConfig = New(http.StatusInternalServerError, "configuration error")
ErrConfigInvalidEnv = New(http.StatusInternalServerError, "invalid environment configuration")
ErrConfigMissingRequired = New(http.StatusInternalServerError, "missing required configuration")
ErrDownload = New(http.StatusInternalServerError, "download error")
ErrProcessing = New(http.StatusInternalServerError, "data processing error")
ErrNoCubeFilesFound = New(http.StatusNotFound, "no cube files found")
ErrNoValidCubeFilesFound = New(http.StatusNotFound, "no valid cube files found")
ErrLatestCubeFileIsTooOld = New(http.StatusNotFound, "latest cube file is too old")
ErrScheduler = New(http.StatusInternalServerError, "scheduler error")
ErrSchedulerInvalidJob = New(http.StatusBadRequest, "invalid job configuration")
ErrSchedulerTimeoutTooLong = New(http.StatusBadRequest, "job timeout too long", "timeout cannot exceed interval")
)

View file

@ -1,100 +0,0 @@
# GRIB Module
Этот модуль реализует функциональность для работы с GRIB-файлами, аналогичную tawhiri-downloader и tawhiri, но на Go.
## Основные возможности
- **Скачивание GRIB-файлов** с NOMADS (GFS прогнозы)
- **Сборка 5D-куба** (время, давление, широта, долгота, переменные u/v)
- **Эффективное хранение** с использованием mmap
- **Интерполяция** ветровых данных для произвольных координат и времени
- **Кэширование** результатов (in-memory)
- **Распределенные блокировки** для предотвращения дублирования загрузок
## Архитектура
### Основные компоненты
- **Downloader** - скачивает GRIB-файлы с NOMADS
- **Cube** - управляет 5D-массивом данных через mmap
- **Extractor** - выполняет интерполяцию данных
- **Cache** - кэширует результаты запросов
- **Service** - основной интерфейс для работы с модулем
### Структура данных
5D-куб содержит:
- **Время**: 17 временных срезов (0, 3, 6, ..., 48 часов)
- **Давление**: 34 уровня давления (1000, 975, 950, ..., 2 hPa)
- **Широта**: 361 точка (-90° до +90°)
- **Долгота**: 720 точек (0° до 359.5°)
- **Переменные**: u-ветер и v-ветер
## Использование
```go
// Создание сервиса
cfg := grib.ServiceConfig{
Dir: "/tmp/grib",
TTL: 24 * time.Hour,
CacheTTL: 1 * time.Hour,
Parallel: 4,
Client: &http.Client{Timeout: 30 * time.Second},
}
service, err := grib.New(cfg)
if err != nil {
log.Fatal(err)
}
defer service.Close()
// Обновление данных
err = service.Update(ctx)
// Извлечение ветровых данных
wind, err := service.Extract(ctx, lat, lon, alt, timestamp)
// wind[0] - u-компонента ветра
// wind[1] - v-компонента ветра
```
## Интерполяция
Модуль выполняет 16-точечную интерполяцию:
1. **Временная интерполяция** между двумя ближайшими срезами
2. **Интерполяция по давлению** между двумя ближайшими уровнями
3. **Билинейная интерполяция** по широте и долготе
## Кэширование
- **In-memory кэш**: быстрый доступ к недавно запрошенным данным
## Расписание обновлений
Рекомендуемая частота вызова `Update()`:
- **Каждые 6 часов** - для получения свежих GFS прогнозов
- **При запуске** - для загрузки начальных данных
- **По требованию** - при отсутствии данных для запрашиваемого времени
## Отличия от tawhiri
### Преимущества Go-реализации:
- **Высокая производительность** (mmap, конкурентные загрузки)
- **Эффективное использование памяти** (не загружает весь массив в RAM)
- **Горизонтальное масштабирование** (stateless, множество реплик)
- **Встроенное кэширование** (in-memory)
### Особенности:
- Использует `github.com/nilsmagnus/grib` вместо pygrib
- Реализует собственную логику интерполяции
## Конфигурация
### Переменные окружения:
- `PREDICTOR_GRIB_DATASET_URL` - URL источника данных (опционально)
### Параметры ServiceConfig:
- `Dir` - директория для хранения файлов
- `TTL` - время жизни данных (по умолчанию 24 часа)
- `CacheTTL` - время жизни кэша (по умолчанию 1 час)
- `Parallel` - количество параллельных загрузок
- `Client` - HTTP клиент для загрузок

View file

@ -1,36 +0,0 @@
package grib
import (
"sync"
"time"
)
type vec [2]float64
type item struct {
v vec
exp time.Time
}
type memCache struct {
ttl time.Duration
m sync.Map
}
func (c *memCache) get(k uint64) (vec, bool) {
if v, ok := c.m.Load(k); ok {
it := v.(item)
if time.Now().Before(it.exp) {
return it.v, true
}
c.m.Delete(k)
}
return vec{}, false
}
func (c *memCache) set(k uint64, v vec) {
c.m.Store(k, item{v, time.Now().Add(c.ttl)})
}

View file

@ -1,32 +0,0 @@
package grib
import (
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
env "github.com/caarlos0/env/v11"
)
type Config struct {
Dir string `env:"DIR" envDefault:"/tmp/grib"`
TTL time.Duration `env:"TTL" envDefault:"24h"`
CacheTTL time.Duration `env:"CACHE_TTL" envDefault:"1h"`
Parallel int `env:"PARALLEL" envDefault:"8"`
DatasetURL string `env:"DATASET_URL" envDefault:"https://nomads.ncep.noaa.gov/pub/data/nccf/com/gfs/prod"`
// S3 configuration
UseS3 bool `env:"USE_S3" envDefault:"true"`
S3Bucket string `env:"S3_BUCKET" envDefault:"noaa-gfs-bdp-pds"`
S3Region string `env:"S3_REGION" envDefault:"us-east-1"`
S3Timeout time.Duration `env:"S3_TIMEOUT" envDefault:"300s"`
}
func NewConfig() (*Config, error) {
cfg := &Config{}
if err := env.ParseWithOptions(cfg, env.Options{
PrefixTagName: "GSN_PREDICTOR_GRIB_",
}); err != nil {
return nil, errcodes.Wrap(err, "failed to parse GRIB config")
}
return cfg, nil
}

View file

@ -1,55 +0,0 @@
package grib
import (
"encoding/binary"
"math"
"os"
mmap "github.com/edsrzf/mmap-go"
)
type cube struct {
mm mmap.MMap
t, p, lat, lon int
bytesPerVar int64
file *os.File
}
func openCube(path string) (*cube, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
mm, err := mmap.Map(f, mmap.RDONLY, 0)
if err != nil {
f.Close()
return nil, err
}
const (
nT = 97 // 0-96 hours with step 1 hour
nP = 47 // 47 pressure levels matching tawhiri
nLat = 361
nLon = 720
)
return &cube{mm: mm, t: nT, p: nP, lat: nLat, lon: nLon, bytesPerVar: int64(nT * nP * nLat * nLon * 4), file: f}, nil
}
func (c *cube) val(varIdx, ti, pi, y, x int) float32 {
idx := (((ti*c.p+pi)*c.lat + y) * c.lon) + x
off := int64(varIdx)*c.bytesPerVar + int64(idx)*4
bits := binary.LittleEndian.Uint32(c.mm[off : off+4])
return math.Float32frombits(bits)
}
func (c *cube) Close() error {
if c.mm != nil {
c.mm.Unmap()
}
if c.file != nil {
return c.file.Close()
}
return nil
}

View file

@ -1,13 +0,0 @@
package grib
type dataset struct {
cube *cube
runUTC int64 // unix seconds
}
func (d *dataset) Close() error {
if d.cube != nil {
return d.cube.Close()
}
return nil
}

View file

@ -1,91 +0,0 @@
package grib
import (
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
"golang.org/x/sync/errgroup"
)
type Downloader struct {
Dir string
Parallel int
Client *http.Client
DatasetURL string
}
func (d *Downloader) fileURL(run string, hour int, step int) string {
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.0p50.f%03d", d.DatasetURL, run, hour, hour, step)
}
func (d *Downloader) fetch(ctx context.Context, url, dst string) (err error) {
// Check if final file already exists
if _, err := os.Stat(dst); err == nil {
return nil
}
tmp := dst + ".part"
// Remove old .part file if it exists (fixes race condition)
os.Remove(tmp)
f, err := os.Create(tmp)
if err != nil {
return err
}
// Cleanup .part file on any error (using named return value)
defer func() {
f.Close()
if err != nil {
os.Remove(tmp)
}
}()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
resp, err := d.Client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errcodes.Wrap(errcodes.ErrDownload, "bad status: "+resp.Status)
}
if _, err := io.Copy(f, resp.Body); err != nil {
return err
}
// Close file before rename
if err := f.Close(); err != nil {
return err
}
// If rename fails, err will be set and defer will cleanup .part file
return os.Rename(tmp, dst)
}
func (d *Downloader) Run(ctx context.Context, run time.Time) error {
runStr := run.Format("20060102")
hour := run.Hour()
g, ctx := errgroup.WithContext(ctx)
sem := make(chan struct{}, d.Parallel)
for _, step := range steps {
step := step
sem <- struct{}{}
g.Go(func() error {
defer func() { <-sem }()
url := d.fileURL(runStr, hour, step)
dst := filepath.Join(d.Dir, fileName(run, step))
return d.fetch(ctx, url, dst)
})
}
return g.Wait()
}

View file

@ -1,54 +0,0 @@
package grib
import "math"
func lerp(a, b, t float64) float64 { return a + t*(b-a) }
// Interpolate 16point (time, p, lat, lon)
func (d *dataset) uv(lat, lon, alt float64, tHours float64) (float64, float64) {
if lon < 0 {
lon += 360
}
iy := (lat + 90) * 2
y0 := int(math.Floor(iy))
y1 := y0 + 1
wy := iy - float64(y0)
ix := lon * 2
x0 := int(math.Floor(ix)) % d.cube.lon
x1 := (x0 + 1) % d.cube.lon
wx := ix - float64(x0)
// For hourly data (step = 1 hour)
it0 := int(math.Floor(tHours))
wt := tHours - float64(it0)
p := pressureFromAlt(alt)
ip0 := 0
for ip0+1 < len(pressureLevels) && pressureLevels[ip0+1] > p {
ip0++
}
ip1 := ip0 + 1
wp := (pressureLevels[ip0] - p) / (pressureLevels[ip0] - pressureLevels[ip1])
fetch := func(ti, pi int) (float64, float64) {
u00 := d.cube.val(1, ti, pi, y0, x0)
u10 := d.cube.val(1, ti, pi, y0, x1)
u01 := d.cube.val(1, ti, pi, y1, x0)
u11 := d.cube.val(1, ti, pi, y1, x1)
v00 := d.cube.val(2, ti, pi, y0, x0)
v10 := d.cube.val(2, ti, pi, y0, x1)
v01 := d.cube.val(2, ti, pi, y1, x0)
v11 := d.cube.val(2, ti, pi, y1, x1)
uxy := (1-wy)*((1-wx)*float64(u00)+wx*float64(u10)) + wy*((1-wx)*float64(u01)+wx*float64(u11))
vxy := (1-wy)*((1-wx)*float64(v00)+wx*float64(v10)) + wy*((1-wx)*float64(v01)+wx*float64(v11))
return uxy, vxy
}
u0p0, v0p0 := fetch(it0, ip0)
u0p1, v0p1 := fetch(it0, ip1)
u1p0, v1p0 := fetch(it0+1, ip0)
u1p1, v1p1 := fetch(it0+1, ip1)
uLow := lerp(u0p0, u0p1, wp)
vLow := lerp(v0p0, v0p1, wp)
uHig := lerp(u1p0, u1p1, wp)
vHig := lerp(v1p0, v1p1, wp)
u := lerp(uLow, uHig, wt)
v := lerp(vLow, vHig, wt)
return u, v
}

View file

@ -1,321 +0,0 @@
package grib
import (
"context"
"encoding/binary"
"math"
"net/http"
"os"
"path/filepath"
"strings"
"sync/atomic"
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
"github.com/edsrzf/mmap-go"
"github.com/nilsmagnus/grib/griblib"
)
type Service interface {
Update(ctx context.Context) error
Extract(ctx context.Context, lat, lon, alt float64, ts time.Time) ([2]float64, error)
Close() error
GetStatus() (ready bool, lastUpdate time.Time, isFresh bool, errMsg string)
}
type service struct {
cfg *Config
cache memCache
data atomic.Pointer[dataset]
}
func New(cfg *Config) (Service, error) {
if cfg.TTL == 0 {
cfg.TTL = 24 * time.Hour
}
if err := os.MkdirAll(cfg.Dir, 0o755); err != nil {
return nil, err
}
s := &service{cfg: cfg, cache: memCache{ttl: cfg.CacheTTL}}
// Try to load existing dataset on startup
if err := s.loadExistingDataset(); err != nil {
// Log error but don't fail startup - dataset will be loaded on first Update()
// This allows the service to start even if no data is available yet
}
return s, nil
}
// loadExistingDataset tries to load the most recent available dataset
func (s *service) loadExistingDataset() error {
// Find the most recent cube file
pattern := filepath.Join(s.cfg.Dir, "*.cube")
matches, err := filepath.Glob(pattern)
if err != nil {
return err
}
if len(matches) == 0 {
return errcodes.ErrNoCubeFilesFound
}
// Sort by modification time (newest first)
var latestFile string
var latestTime time.Time
for _, match := range matches {
info, err := os.Stat(match)
if err != nil {
continue
}
if info.ModTime().After(latestTime) {
latestTime = info.ModTime()
latestFile = match
}
}
if latestFile == "" {
return errcodes.ErrNoValidCubeFilesFound
}
// Check if the file is fresh enough
if time.Since(latestTime) > s.cfg.TTL {
return errcodes.Wrap(errcodes.ErrLatestCubeFileIsTooOld, "latest cube file is too old")
}
// Load the dataset
c, err := openCube(latestFile)
if err != nil {
return err
}
// Extract run time from filename
base := filepath.Base(latestFile)
runStr := strings.TrimSuffix(base, ".cube")
run, err := time.Parse("20060102_15", runStr)
if err != nil {
c.Close()
return err
}
ds := &dataset{cube: c, runUTC: run.Unix()}
s.data.Store(ds)
return nil
}
// Update() downloads missing GRIBs, assembles cube into a single mmapfile.
func (s *service) Update(ctx context.Context) error {
// Check if we already have fresh data
if d := s.data.Load(); d != nil {
runTime := time.Unix(d.runUTC, 0)
if time.Since(runTime) < s.cfg.TTL {
// Data is still fresh, no need to update
return nil
}
}
// Check again after acquiring lock (double-checked locking pattern)
if d := s.data.Load(); d != nil {
runTime := time.Unix(d.runUTC, 0)
if time.Since(runTime) < s.cfg.TTL {
// Another instance already updated the data
return nil
}
}
run := nearestRun(time.Now().UTC().Add(-24 * time.Hour))
// Check if we already have this run
cubePath := filepath.Join(s.cfg.Dir, run.Format("20060102_15")) + ".cube"
if _, err := os.Stat(cubePath); err == nil {
// File exists, check if it's fresh
info, err := os.Stat(cubePath)
if err == nil && time.Since(info.ModTime()) < s.cfg.TTL {
// File is fresh, just load it
c, err := openCube(cubePath)
if err != nil {
return err
}
ds := &dataset{cube: c, runUTC: run.Unix()}
s.data.Store(ds)
s.cache = memCache{ttl: s.cfg.CacheTTL}
return nil
}
}
// Download new data using S3 or HTTP
var downloadErr error
if s.cfg.UseS3 {
s3dl, err := NewS3Downloader(s.cfg.Dir, s.cfg.Parallel, s.cfg.S3Bucket, s.cfg.S3Region)
if err != nil {
return errcodes.Wrap(err, "failed to create S3 downloader")
}
downloadErr = s3dl.Run(ctx, run)
} else {
dl := Downloader{
Dir: s.cfg.Dir,
Parallel: s.cfg.Parallel,
Client: http.DefaultClient,
DatasetURL: s.cfg.DatasetURL,
}
downloadErr = dl.Run(ctx, run)
}
if downloadErr != nil {
return downloadErr
}
// Assemble cube if it doesn't exist
if _, err := os.Stat(cubePath); err != nil {
if err := assembleCube(s.cfg.Dir, run, cubePath); err != nil {
return err
}
}
c, err := openCube(cubePath)
if err != nil {
return err
}
ds := &dataset{cube: c, runUTC: run.Unix()}
s.data.Store(ds)
s.cache = memCache{ttl: s.cfg.CacheTTL}
return nil
}
func assembleCube(dir string, run time.Time, cubePath string) error {
const sizePerVar = 97 * 47 * 361 * 720 * 4 // 97 time steps (0-96 hours), 47 pressure levels
total := int64(sizePerVar * 3) // 3 variables: gh, u, v
f, err := os.Create(cubePath)
if err != nil {
return err
}
if err := f.Truncate(total); err != nil {
return err
}
mm, err := mmap.MapRegion(f, int(total), mmap.RDWR, 0, 0)
if err != nil {
return err
}
defer mm.Unmap()
defer f.Close()
pIndex := make(map[int]int)
for i, p := range pressureLevels {
pIndex[int(math.Round(p))] = i
}
for ti, step := range steps {
fn := filepath.Join(dir, fileName(run, step))
file, err := os.Open(fn)
if err != nil {
return err
}
messages, err := griblib.ReadMessages(file)
file.Close() // Close immediately after reading
if err != nil {
return err
}
for _, m := range messages {
// Check if this is a wind component (u or v) or geopotential height
// ParameterCategory 2 = momentum, ParameterNumber 2 = u-wind, 3 = v-wind
// ParameterCategory 3 = mass, ParameterNumber 5 = geopotential height
if m.Section4.ProductDefinitionTemplateNumber != 0 {
continue
}
product := m.Section4.ProductDefinitionTemplate
var varIdx int
// Match tawhiri variable order: ['gh', 'u', 'v'] (indices 0, 1, 2)
if product.ParameterCategory == 2 {
switch product.ParameterNumber {
case 2: // u-wind
varIdx = 1
case 3: // v-wind
varIdx = 2
default:
continue
}
} else if product.ParameterCategory == 3 && product.ParameterNumber == 5 {
// geopotential height
varIdx = 0
} else {
continue
}
// Check if this is a pressure level (type 100)
if product.FirstSurface.Type != 100 {
continue
}
// Get pressure level in hPa
pressure := float64(product.FirstSurface.Value) / 100.0
pIdx, ok := pIndex[int(math.Round(pressure))]
if !ok {
continue
}
vals := m.Data()
// GRIB library returns scan north->south, west->east already in row-major order
raw := make([]byte, len(vals)*4)
for i, v := range vals {
binary.LittleEndian.PutUint32(raw[i*4:], math.Float32bits(float32(v)))
}
base := int64(varIdx*sizePerVar + (ti*47+pIdx)*361*720*4)
copy(mm[base:base+int64(len(raw))], raw)
}
}
return mm.Flush()
}
func (s *service) Extract(ctx context.Context, lat, lon, alt float64, ts time.Time) ([2]float64, error) {
var zero [2]float64
d := s.data.Load()
if d == nil {
return zero, errcodes.ErrNoDataset
}
if ts.Before(time.Unix(d.runUTC, 0)) || ts.After(time.Unix(d.runUTC, 0).Add(96*time.Hour)) {
return zero, errcodes.ErrOutOfBounds
}
// Try memory cache first
key := encodeKey(lat, lon, alt, ts)
if v, ok := s.cache.get(key); ok {
return [2]float64(v), nil
}
// Calculate result
td := ts.Sub(time.Unix(d.runUTC, 0)).Hours()
u, v := d.uv(lat, lon, alt, td)
out := [2]float64{u, v}
// Cache in memory
s.cache.set(key, vec(out))
return out, nil
}
func (s *service) Close() error {
if d := s.data.Load(); d != nil {
return d.Close()
}
return nil
}
func (s *service) GetStatus() (ready bool, lastUpdate time.Time, isFresh bool, errMsg string) {
d := s.data.Load()
if d == nil {
return false, time.Time{}, false, "no dataset loaded"
}
runTime := time.Unix(d.runUTC, 0)
fresh := time.Since(runTime) < s.cfg.TTL
if !fresh {
return false, runTime, false, "dataset is too old"
}
return true, runTime, true, ""
}

View file

@ -1,16 +0,0 @@
package grib
import "math"
// 47 pressure levels matching tawhiri configuration
var pressureLevels = []float64{
1000, 975, 950, 925, 900, 875, 850, 825, 800, 775,
750, 725, 700, 675, 650, 625, 600, 575, 550, 525,
500, 475, 450, 425, 400, 375, 350, 325, 300, 275,
250, 225, 200, 175, 150, 125, 100, 70, 50, 30,
20, 10, 7, 5, 3, 2, 1,
}
func pressureFromAlt(alt float64) float64 { // ICAO ISA
return 1013.25 * math.Pow(1-alt/44307.69396, 5.255877)
}

View file

@ -1,265 +0,0 @@
package grib
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/s3"
"golang.org/x/sync/errgroup"
)
// S3Downloader downloads GRIB files from AWS S3
type S3Downloader struct {
Dir string
Parallel int
Bucket string
Region string
Client *s3.Client
}
// NewS3Downloader creates a new S3 downloader with anonymous access
func NewS3Downloader(dir string, parallel int, bucket, region string) (*S3Downloader, error) {
// Create AWS config with anonymous credentials for public bucket
cfg, err := config.LoadDefaultConfig(context.Background(),
config.WithRegion(region),
config.WithCredentialsProvider(aws.AnonymousCredentials{}),
)
if err != nil {
return nil, errcodes.Wrap(err, "failed to load AWS config")
}
client := s3.NewFromConfig(cfg)
return &S3Downloader{
Dir: dir,
Parallel: parallel,
Bucket: bucket,
Region: region,
Client: client,
}, nil
}
// s3Key generates the S3 key for a GRIB file
// Path format: gfs.YYYYMMDD/HH/atmos/gfs.tHHz.pgrb2.0p50.fFFF
func (d *S3Downloader) s3Key(run string, hour int, step int) string {
return fmt.Sprintf("gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.0p50.f%03d", run, hour, hour, step)
}
// CheckFileExists checks if a file exists in S3 using HeadObject
func (d *S3Downloader) CheckFileExists(ctx context.Context, key string) (bool, int64, error) {
input := &s3.HeadObjectInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(key),
}
result, err := d.Client.HeadObject(ctx, input)
if err != nil {
// Check if error is NotFound
// AWS SDK v2 doesn't export specific error types, check error string
if isNotFoundError(err) {
return false, 0, nil
}
return false, 0, errcodes.Wrap(err, "failed to check file existence")
}
size := int64(0)
if result.ContentLength != nil {
size = *result.ContentLength
}
return true, size, nil
}
// isNotFoundError checks if error is a NotFound error
func isNotFoundError(err error) bool {
if err == nil {
return false
}
// AWS SDK v2 error handling
errStr := err.Error()
return contains(errStr, "NotFound") || contains(errStr, "404") || contains(errStr, "NoSuchKey")
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && findSubstring(s, substr))
}
func findSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// ListAvailableFiles lists all available files for a given run
func (d *S3Downloader) ListAvailableFiles(ctx context.Context, run string, hour int) ([]string, error) {
prefix := fmt.Sprintf("gfs.%s/%02d/atmos/", run, hour)
input := &s3.ListObjectsV2Input{
Bucket: aws.String(d.Bucket),
Prefix: aws.String(prefix),
}
var files []string
paginator := s3.NewListObjectsV2Paginator(d.Client, input)
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
return nil, errcodes.Wrap(err, "failed to list S3 objects")
}
for _, obj := range page.Contents {
if obj.Key != nil {
files = append(files, *obj.Key)
}
}
}
return files, nil
}
// fetchFromS3 downloads a file from S3 to local disk with retry logic
func (d *S3Downloader) fetchFromS3(ctx context.Context, key, dst string) (err error) {
// Check if final file already exists
if _, err := os.Stat(dst); err == nil {
return nil
}
const maxRetries = 3
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
if attempt > 0 {
// Exponential backoff: 2s, 4s, 8s
waitTime := time.Duration(1<<uint(attempt)) * time.Second
time.Sleep(waitTime)
}
lastErr = d.fetchFromS3Once(ctx, key, dst)
if lastErr == nil {
return nil
}
}
return errcodes.Wrap(lastErr, fmt.Sprintf("failed after %d retries", maxRetries))
}
// fetchFromS3Once performs a single download attempt
func (d *S3Downloader) fetchFromS3Once(ctx context.Context, key, dst string) (err error) {
tmp := dst + ".part"
// Remove old .part file if it exists
os.Remove(tmp)
f, err := os.Create(tmp)
if err != nil {
return err
}
fileClosed := false
// Cleanup .part file on any error (using named return value)
defer func() {
if !fileClosed {
f.Close()
}
if err != nil {
os.Remove(tmp)
}
}()
// Check if file exists in S3
exists, size, checkErr := d.CheckFileExists(ctx, key)
if checkErr != nil {
return errcodes.Wrap(checkErr, "failed to check S3 file existence")
}
if !exists {
return errcodes.Wrap(errcodes.ErrDownload, fmt.Sprintf("file not found in S3: %s", key))
}
// Download from S3
input := &s3.GetObjectInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(key),
}
result, err := d.Client.GetObject(ctx, input)
if err != nil {
return errcodes.Wrap(err, "failed to get S3 object")
}
defer result.Body.Close()
// Copy to local file
written, err := io.Copy(f, result.Body)
if err != nil {
return errcodes.Wrap(err, fmt.Sprintf("failed to write S3 object to file %s", dst))
}
// Verify size if available
if size > 0 && written != size {
return errcodes.Wrap(errcodes.ErrDownload, fmt.Sprintf("size mismatch: got %d bytes, expected %d", written, size))
}
// Close file before rename
if err := f.Close(); err != nil {
return err
}
fileClosed = true
// If rename fails, err will be set and defer will cleanup .part file
return os.Rename(tmp, dst)
}
// Run downloads all required GRIB files for a forecast run
func (d *S3Downloader) Run(ctx context.Context, run time.Time) error {
runStr := run.Format("20060102")
hour := run.Hour()
// First, list available files to verify they exist
availableFiles, err := d.ListAvailableFiles(ctx, runStr, hour)
if err != nil {
return errcodes.Wrap(err, "failed to list available files")
}
if len(availableFiles) == 0 {
return errcodes.Wrap(errcodes.ErrDownload, fmt.Sprintf("no files found for run %s/%02d", runStr, hour))
}
// Build a map of available files for quick lookup
availableMap := make(map[string]bool)
for _, file := range availableFiles {
availableMap[file] = true
}
g, ctx := errgroup.WithContext(ctx)
sem := make(chan struct{}, d.Parallel)
for _, step := range steps {
step := step
key := d.s3Key(runStr, hour, step)
// Check if file is available in S3
if !availableMap[key] {
// Log warning but don't fail - some forecast hours might not be available yet
continue
}
sem <- struct{}{}
g.Go(func() error {
defer func() { <-sem }()
dst := filepath.Join(d.Dir, fileName(run, step))
return d.fetchFromS3(ctx, key, dst)
})
}
return g.Wait()
}

View file

@ -1,34 +0,0 @@
package grib
import (
"fmt"
"hash/fnv"
"time"
)
// Generate steps from 0 to 96 with step 1 hour (97 steps total)
// GFS provides hourly data for 0-120 hours, we use first 96 hours
var steps = func() []int {
result := make([]int, 0, 97)
for i := 0; i <= 96; i++ {
result = append(result, i)
}
return result
}()
func nearestRun(t time.Time) time.Time {
h := t.UTC().Hour() - t.UTC().Hour()%6
return time.Date(t.Year(), t.Month(), t.Day(), h, 0, 0, 0, time.UTC)
}
func fileName(run time.Time, step int) string {
return fmt.Sprintf("gfs.t%02dz.pgrb2.0p50.f%03d", run.Hour(), step)
}
func encodeKey(a ...any) uint64 {
h := fnv.New64a()
for _, v := range a {
fmt.Fprint(h, v)
}
return h.Sum64()
}

View file

@ -1,23 +0,0 @@
package log
import (
"context"
"go.uber.org/zap"
)
type ctxLogKey struct{}
func ToCtx(ctx context.Context, lg *zap.Logger) context.Context {
return context.WithValue(ctx, ctxLogKey{}, lg)
}
func Ctx(ctx context.Context) *zap.Logger {
lg, ok := ctx.Value(ctxLogKey{}).(*zap.Logger)
if !ok || lg == nil {
zap.L().Error("no logger in context, using global")
return zap.L()
}
return lg
}

View file

@ -0,0 +1,153 @@
package prediction
import (
"fmt"
"predictor-refactored/internal/dataset"
)
// Exact port of the reference interpolation logic (interpolate.pyx).
// 4D interpolation: time, latitude, longitude, altitude (via geopotential height).
// lerp1 holds an index and interpolation weight for one axis.
type lerp1 struct {
index int
lerp float64
}
// lerp3 holds indices and a combined weight for the (hour, lat, lon) axes.
type lerp3 struct {
hour, lat, lng int
lerp float64
}
// RangeError indicates a coordinate is outside the dataset bounds.
type RangeError struct {
Variable string
Value float64
}
func (e *RangeError) Error() string {
return fmt.Sprintf("%s=%f out of range", e.Variable, e.Value)
}
// pick computes interpolation indices and weights for a single axis.
// left: axis start, step: axis spacing, n: number of points, value: query value.
// Returns two lerp1 values (lower and upper bracket).
func pick(left, step float64, n int, value float64, variableName string) ([2]lerp1, error) {
a := (value - left) / step
b := int(a) // truncation toward zero, same as Cython <long> cast
if b < 0 || b >= n-1 {
return [2]lerp1{}, &RangeError{Variable: variableName, Value: value}
}
l := a - float64(b)
return [2]lerp1{
{index: b, lerp: 1 - l},
{index: b + 1, lerp: l},
}, nil
}
// pick3 computes 8 trilinear interpolation weights for (hour, lat, lng).
func pick3(hour, lat, lng float64) ([8]lerp3, error) {
lhour, err := pick(0, 3, 65, hour, "hour")
if err != nil {
return [8]lerp3{}, err
}
llat, err := pick(-90, 0.5, 361, lat, "lat")
if err != nil {
return [8]lerp3{}, err
}
// Longitude wraps: tell pick the axis is one larger, then wrap index 720 → 0
llng, err := pick(0, 0.5, 720+1, lng, "lng")
if err != nil {
return [8]lerp3{}, err
}
if llng[1].index == 720 {
llng[1].index = 0
}
var out [8]lerp3
i := 0
for _, a := range lhour {
for _, b := range llat {
for _, c := range llng {
out[i] = lerp3{
hour: a.index,
lat: b.index,
lng: c.index,
lerp: a.lerp * b.lerp * c.lerp,
}
i++
}
}
}
return out, nil
}
// interp3 performs 8-point weighted interpolation at a given variable and pressure level.
func interp3(ds *dataset.File, lerps [8]lerp3, variable, level int) float64 {
var r float64
for i := 0; i < 8; i++ {
v := ds.Val(lerps[i].hour, level, variable, lerps[i].lat, lerps[i].lng)
r += float64(v) * lerps[i].lerp
}
return r
}
// search finds the largest pressure level index where interpolated geopotential
// height is less than the target altitude. Searches levels 0..45 (excludes topmost).
func search(ds *dataset.File, lerps [8]lerp3, target float64) int {
lower, upper := 0, 45
for lower < upper {
mid := (lower + upper + 1) / 2
test := interp3(ds, lerps, dataset.VarHeight, mid)
if target <= test {
upper = mid - 1
} else {
lower = mid
}
}
return lower
}
// interp4 performs altitude-interpolated wind lookup using two bracketing levels.
func interp4(ds *dataset.File, lerps [8]lerp3, altLerp lerp1, variable int) float64 {
lower := interp3(ds, lerps, variable, altLerp.index)
upper := interp3(ds, lerps, variable, altLerp.index+1)
return lower*altLerp.lerp + upper*(1-altLerp.lerp)
}
// GetWind returns interpolated (u, v) wind components for the given position.
// hour: fractional hours since dataset start.
// lat: latitude in degrees (-90 to +90).
// lng: longitude in degrees (0 to 360).
// alt: altitude in metres above sea level.
func GetWind(ds *dataset.File, warnings *Warnings, hour, lat, lng, alt float64) (u, v float64, err error) {
lerps, err := pick3(hour, lat, lng)
if err != nil {
return 0, 0, err
}
altidx := search(ds, lerps, alt)
lower := interp3(ds, lerps, dataset.VarHeight, altidx)
upper := interp3(ds, lerps, dataset.VarHeight, altidx+1)
var altLerp float64
if lower != upper {
altLerp = (upper - alt) / (upper - lower)
} else {
altLerp = 0.5
}
if altLerp < 0 {
warnings.AltitudeTooHigh.Add(1)
}
alt1 := lerp1{index: altidx, lerp: altLerp}
u = interp4(ds, lerps, alt1, dataset.VarWindU)
v = interp4(ds, lerps, alt1, dataset.VarWindV)
return u, v, nil
}

View file

@ -0,0 +1,188 @@
package prediction
import (
"math"
"time"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/elevation"
)
// Exact port of the reference flight models (models.py).
const (
pi180 = math.Pi / 180.0
_180pi = 180.0 / math.Pi
)
// --- Up/Down Models ---
// ConstantAscent returns a model with constant vertical velocity (m/s).
func ConstantAscent(ascentRate float64) Model {
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
return 0, 0, ascentRate
}
}
// DragDescent returns a descent-under-parachute model.
// seaLevelDescentRate is the descent rate at sea level (m/s, positive value).
// Uses the NASA atmosphere model for density at altitude.
func DragDescent(seaLevelDescentRate float64) Model {
dragCoefficient := seaLevelDescentRate * 1.1045
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
return 0, 0, -dragCoefficient / math.Sqrt(nasaDensity(alt))
}
}
// nasaDensity computes air density using the NASA atmosphere model.
// Reference: http://www.grc.nasa.gov/WWW/K-12/airplane/atmosmet.html
func nasaDensity(alt float64) float64 {
var temp, pressure float64
switch {
case alt > 25000:
temp = -131.21 + 0.00299*alt
pressure = 2.488 * math.Pow((temp+273.1)/216.6, -11.388)
case alt > 11000:
temp = -56.46
pressure = 22.65 * math.Exp(1.73-0.000157*alt)
default:
temp = 15.04 - 0.00649*alt
pressure = 101.29 * math.Pow((temp+273.1)/288.08, 5.256)
}
return pressure / (0.2869 * (temp + 273.1))
}
// --- Sideways Models ---
// WindVelocity returns a model that gives lateral movement at the wind velocity.
// ds is the wind dataset, dsEpoch is the dataset start time as UNIX timestamp.
func WindVelocity(ds *dataset.File, dsEpoch float64, warnings *Warnings) Model {
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
tHours := (t - dsEpoch) / 3600.0
u, v, err := GetWind(ds, warnings, tHours, lat, lng, alt)
if err != nil {
return 0, 0, 0
}
R := 6371009.0 + alt
dlat = _180pi * v / R
dlng = _180pi * u / (R * math.Cos(lat*pi180))
return dlat, dlng, 0
}
}
// --- Model Combinations ---
// LinearModel returns a model that sums all component models.
func LinearModel(models ...Model) Model {
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
for _, m := range models {
d1, d2, d3 := m(t, lat, lng, alt)
dlat += d1
dlng += d2
dalt += d3
}
return
}
}
// --- Termination Criteria ---
// BurstTermination returns a terminator that fires when altitude >= burstAltitude.
func BurstTermination(burstAltitude float64) Terminator {
return func(t, lat, lng, alt float64) bool {
return alt >= burstAltitude
}
}
// SeaLevelTermination fires when altitude <= 0.
func SeaLevelTermination(t, lat, lng, alt float64) bool {
return alt <= 0
}
// TimeTermination returns a terminator that fires when t > maxTime.
func TimeTermination(maxTime float64) Terminator {
return func(t, lat, lng, alt float64) bool {
return t > maxTime
}
}
// ElevationTermination returns a terminator that fires when alt < ground level.
// Uses ruaumoko-compatible elevation data. Longitude is normalised internally.
func ElevationTermination(elev *elevation.Dataset) Terminator {
return func(t, lat, lng, alt float64) bool {
return elev.Get(lat, lng) > alt
}
}
// --- Pre-Defined Profiles ---
// Stage pairs a model with its termination criterion.
type Stage struct {
Model Model
Terminator Terminator
}
// StandardProfile creates the chain for a standard high-altitude balloon flight:
// ascent at constant rate → burst → descent under parachute.
// If elev is non-nil, descent terminates at ground level; otherwise at sea level.
func StandardProfile(ascentRate, burstAltitude, descentRate float64,
ds *dataset.File, dsEpoch float64, warnings *Warnings,
elev *elevation.Dataset) []Stage {
wind := WindVelocity(ds, dsEpoch, warnings)
modelUp := LinearModel(ConstantAscent(ascentRate), wind)
termUp := BurstTermination(burstAltitude)
modelDown := LinearModel(DragDescent(descentRate), wind)
var termDown Terminator
if elev != nil {
termDown = ElevationTermination(elev)
} else {
termDown = Terminator(SeaLevelTermination)
}
return []Stage{
{Model: modelUp, Terminator: termUp},
{Model: modelDown, Terminator: termDown},
}
}
// FloatProfile creates the chain for a floating balloon flight:
// ascent to float altitude → float until stop time.
func FloatProfile(ascentRate, floatAltitude float64, stopTime time.Time,
ds *dataset.File, dsEpoch float64, warnings *Warnings) []Stage {
wind := WindVelocity(ds, dsEpoch, warnings)
modelUp := LinearModel(ConstantAscent(ascentRate), wind)
termUp := BurstTermination(floatAltitude)
modelFloat := wind
termFloat := TimeTermination(float64(stopTime.Unix()))
return []Stage{
{Model: modelUp, Terminator: termUp},
{Model: modelFloat, Terminator: termFloat},
}
}
// RunPrediction runs a prediction with the given profile stages.
// launchTime is a UNIX timestamp.
func RunPrediction(launchTime float64, lat, lng, alt float64, stages []Stage) []StageResult {
chain := make([]struct {
Model Model
Terminator Terminator
}, len(stages))
for i, s := range stages {
chain[i].Model = s.Model
chain[i].Terminator = s.Terminator
}
return Solve(launchTime, lat, lng, alt, chain)
}

View file

@ -0,0 +1,180 @@
package prediction
import "math"
// Exact port of the reference RK4 solver (solver.pyx).
// Integrates balloon state using RK4 with dt=60 seconds.
// Termination uses binary search refinement (tolerance 0.01).
// Vec holds the balloon state: latitude, longitude, altitude.
type Vec struct {
Lat float64
Lng float64
Alt float64
}
// Model is a function that returns (dlat/dt, dlng/dt, dalt/dt) given state.
// t is UNIX timestamp, lat/lng in degrees, alt in metres.
type Model func(t float64, lat, lng, alt float64) (dlat, dlng, dalt float64)
// Terminator returns true when integration should stop.
type Terminator func(t float64, lat, lng, alt float64) bool
// StageResult holds the trajectory points for one flight stage.
type StageResult struct {
Points []TrajectoryPoint
}
// TrajectoryPoint is a single point in a trajectory (used by solver).
type TrajectoryPoint struct {
T float64 // UNIX timestamp
Lat float64
Lng float64
Alt float64
}
// pymod returns a % b with Python semantics (always non-negative when b > 0).
func pymod(a, b float64) float64 {
r := math.Mod(a, b)
if r < 0 {
r += b
}
return r
}
// vecadd returns a + k*b, with lng wrapped to [0, 360).
func vecadd(a Vec, k float64, b Vec) Vec {
return Vec{
Lat: a.Lat + k*b.Lat,
Lng: pymod(a.Lng+k*b.Lng, 360.0),
Alt: a.Alt + k*b.Alt,
}
}
// scalarLerp returns (1-l)*a + l*b.
func scalarLerp(a, b, l float64) float64 {
return (1-l)*a + l*b
}
// lngLerp interpolates longitude handling the 0/360 wrap-around.
func lngLerp(a, b, l float64) float64 {
l2 := 1 - l
if a > b {
a, b = b, a
l, l2 = l2, l
}
// distance round one way: b - a
// distance around other: (a + 360) - b
if b-a < 180.0 {
return l2*a + l*b
}
return pymod(l2*(a+360)+l*b, 360.0)
}
// vecLerp returns (1-l)*a + l*b with proper longitude wrapping.
func vecLerp(a, b Vec, l float64) Vec {
return Vec{
Lat: scalarLerp(a.Lat, b.Lat, l),
Lng: lngLerp(a.Lng, b.Lng, l),
Alt: scalarLerp(a.Alt, b.Alt, l),
}
}
// rk4 integrates from initial conditions using RK4.
// dt=60.0 seconds, terminationTolerance=0.01.
func rk4(t float64, lat, lng, alt float64, model Model, terminator Terminator) []TrajectoryPoint {
const dt = 60.0
const terminationTolerance = 0.01
y := Vec{Lat: lat, Lng: lng, Alt: alt}
result := []TrajectoryPoint{{T: t, Lat: y.Lat, Lng: y.Lng, Alt: y.Alt}}
for {
// Evaluate model at 4 points (standard RK4)
k1lat, k1lng, k1alt := model(t, y.Lat, y.Lng, y.Alt)
k1 := Vec{Lat: k1lat, Lng: k1lng, Alt: k1alt}
mid1 := vecadd(y, dt/2, k1)
k2lat, k2lng, k2alt := model(t+dt/2, mid1.Lat, mid1.Lng, mid1.Alt)
k2 := Vec{Lat: k2lat, Lng: k2lng, Alt: k2alt}
mid2 := vecadd(y, dt/2, k2)
k3lat, k3lng, k3alt := model(t+dt/2, mid2.Lat, mid2.Lng, mid2.Alt)
k3 := Vec{Lat: k3lat, Lng: k3lng, Alt: k3alt}
end := vecadd(y, dt, k3)
k4lat, k4lng, k4alt := model(t+dt, end.Lat, end.Lng, end.Alt)
k4 := Vec{Lat: k4lat, Lng: k4lng, Alt: k4alt}
// y2 = y + dt/6*k1 + dt/3*k2 + dt/3*k3 + dt/6*k4
y2 := y
y2 = vecadd(y2, dt/6, k1)
y2 = vecadd(y2, dt/3, k2)
y2 = vecadd(y2, dt/3, k3)
y2 = vecadd(y2, dt/6, k4)
t2 := t + dt
if terminator(t2, y2.Lat, y2.Lng, y2.Alt) {
// Binary search to refine the termination point.
// Find l in [0, 1] such that (t3, y3) = lerp((t, y), (t2, y2), l)
// is near where the terminator fires.
left := 0.0
right := 1.0
var t3 float64
var y3 Vec
t3 = t2
y3 = y2
for right-left > terminationTolerance {
mid := (left + right) / 2
t3 = scalarLerp(t, t2, mid)
y3 = vecLerp(y, y2, mid)
if terminator(t3, y3.Lat, y3.Lng, y3.Alt) {
right = mid
} else {
left = mid
}
}
result = append(result, TrajectoryPoint{T: t3, Lat: y3.Lat, Lng: y3.Lng, Alt: y3.Alt})
break
}
// Update current state
t = t2
y = y2
result = append(result, TrajectoryPoint{T: t, Lat: y.Lat, Lng: y.Lng, Alt: y.Alt})
}
return result
}
// Solve runs through a chain of (model, terminator) stages.
// Returns one StageResult per stage.
func Solve(t, lat, lng, alt float64, chain []struct {
Model Model
Terminator Terminator
}) []StageResult {
var results []StageResult
for _, stage := range chain {
points := rk4(t, lat, lng, alt, stage.Model, stage.Terminator)
results = append(results, StageResult{Points: points})
// Next stage starts where this one ended
if len(points) > 0 {
last := points[len(points)-1]
t = last.T
lat = last.Lat
lng = last.Lng
alt = last.Alt
}
}
return results
}

View file

@ -0,0 +1,21 @@
package prediction
import "sync/atomic"
// Warnings tracks warning conditions during a prediction run.
type Warnings struct {
AltitudeTooHigh atomic.Int64
}
// ToMap returns warnings as a map suitable for JSON serialization.
// Only includes warnings that have fired.
func (w *Warnings) ToMap() map[string]any {
result := make(map[string]any)
if n := w.AltitudeTooHigh.Load(); n > 0 {
result["altitude_too_high"] = map[string]any{
"count": n,
"description": "The altitude went too high, above the max forecast wind. Wind data will be unreliable",
}
}
return result
}

View file

@ -1,12 +0,0 @@
package service
import (
"context"
"time"
)
type Grib interface {
Update(ctx context.Context) error
Extract(ctx context.Context, lat, lon, alt float64, ts time.Time) ([2]float64, error)
Close() error
}

View file

@ -1,516 +0,0 @@
package service
import (
"context"
"encoding/base64"
"encoding/json"
"math"
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/ds"
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
"git.intra.yksa.space/gsn/predictor/internal/pkg/log"
"go.uber.org/zap"
)
var ErrInvalidParameters = errcodes.New(400, "missing required prediction parameters")
// Stage represents a prediction stage (ascent, descent, float)
type Stage struct {
Name string
Results []ds.PredicitonResult
StartTime time.Time
EndTime time.Time
}
// CustomCurve represents a custom ascent/descent curve
type CustomCurve struct {
Altitude []float64 `json:"altitude"`
Time []float64 `json:"time"` // seconds from start
}
func (s *Service) PerformPrediction(ctx context.Context, params ds.PredictionParameters) ([]ds.PredicitonResult, error) {
// Validate required parameters
if params.LaunchLatitude == nil || params.LaunchLongitude == nil || params.LaunchAltitude == nil || params.LaunchDatetime == nil {
return nil, ErrInvalidParameters
}
// Get default values
profile := "standard_profile"
if params.Profile != nil {
profile = *params.Profile
}
ascentRate := 5.0
if params.AscentRate != nil {
ascentRate = *params.AscentRate
}
burstAltitude := 30000.0
if params.BurstAltitude != nil {
burstAltitude = *params.BurstAltitude
}
descentRate := 5.0
if params.DescentRate != nil {
descentRate = *params.DescentRate
}
floatAltitude := 0.0
if params.FloatAltitude != nil {
floatAltitude = *params.FloatAltitude
}
// Parse custom curves if provided
var ascentCurve, descentCurve *CustomCurve
if params.AscentCurve != nil && *params.AscentCurve != "" {
if curve, err := parseCustomCurve(*params.AscentCurve); err == nil {
ascentCurve = curve
}
}
if params.DescentCurve != nil && *params.DescentCurve != "" {
if curve, err := parseCustomCurve(*params.DescentCurve); err == nil {
descentCurve = curve
}
}
log.Ctx(ctx).Info("Starting prediction",
zap.String("profile", profile),
zap.Float64("lat", *params.LaunchLatitude),
zap.Float64("lon", *params.LaunchLongitude),
zap.Float64("alt", *params.LaunchAltitude),
zap.Time("time", *params.LaunchDatetime),
)
var allResults []ds.PredicitonResult
switch profile {
case "standard_profile":
allResults = s.standardProfile(ctx, params, ascentRate, burstAltitude, descentRate, ascentCurve, descentCurve)
case "float_profile":
allResults = s.floatProfile(ctx, params, ascentRate, burstAltitude, floatAltitude, descentRate, ascentCurve, descentCurve)
case "reverse_profile":
allResults = s.reverseProfile(ctx, params, ascentRate, burstAltitude, descentRate, ascentCurve, descentCurve)
case "custom_profile":
allResults = s.customProfile(ctx, params, ascentCurve, descentCurve)
default:
return nil, errcodes.New(400, "unsupported profile: "+profile)
}
log.Ctx(ctx).Info("Prediction complete", zap.Int("total_steps", len(allResults)))
return allResults, nil
}
func (s *Service) standardProfile(ctx context.Context, params ds.PredictionParameters, ascentRate, burstAltitude, descentRate float64, ascentCurve, descentCurve *CustomCurve) []ds.PredicitonResult {
var results []ds.PredicitonResult
// Stage 1: Ascent
ascentResults := s.simulateAscent(ctx, params, ascentRate, burstAltitude, ascentCurve)
results = append(results, ascentResults...)
if len(ascentResults) > 0 {
// Get final position from ascent
lastResult := ascentResults[len(ascentResults)-1]
// Stage 2: Descent
descentParams := ds.PredictionParameters{
LaunchLatitude: lastResult.Latitude,
LaunchLongitude: lastResult.Longitude,
LaunchAltitude: lastResult.Altitude,
LaunchDatetime: lastResult.Timestamp,
}
descentResults := s.simulateDescent(ctx, descentParams, descentRate, 0, descentCurve)
results = append(results, descentResults...)
}
return results
}
func (s *Service) floatProfile(ctx context.Context, params ds.PredictionParameters, ascentRate, burstAltitude, floatAltitude, descentRate float64, ascentCurve, descentCurve *CustomCurve) []ds.PredicitonResult {
var results []ds.PredicitonResult
// Stage 1: Ascent to float altitude
ascentResults := s.simulateAscent(ctx, params, ascentRate, floatAltitude, ascentCurve)
results = append(results, ascentResults...)
if len(ascentResults) > 0 {
// Stage 2: Float (simulate for some time)
lastResult := ascentResults[len(ascentResults)-1]
floatResults := s.simulateFloat(ctx, lastResult, 30*time.Minute) // Float for 30 minutes
results = append(results, floatResults...)
if len(floatResults) > 0 {
// Stage 3: Descent
finalFloat := floatResults[len(floatResults)-1]
descentParams := ds.PredictionParameters{
LaunchLatitude: finalFloat.Latitude,
LaunchLongitude: finalFloat.Longitude,
LaunchAltitude: finalFloat.Altitude,
LaunchDatetime: finalFloat.Timestamp,
}
descentResults := s.simulateDescent(ctx, descentParams, descentRate, 0, descentCurve)
results = append(results, descentResults...)
}
}
return results
}
func (s *Service) reverseProfile(ctx context.Context, params ds.PredictionParameters, ascentRate, burstAltitude, descentRate float64, ascentCurve, descentCurve *CustomCurve) []ds.PredicitonResult {
var results []ds.PredicitonResult
// Stage 1: Ascent
ascentResults := s.simulateAscent(ctx, params, ascentRate, burstAltitude, ascentCurve)
results = append(results, ascentResults...)
if len(ascentResults) > 0 {
// Stage 2: Descent to float altitude
lastResult := ascentResults[len(ascentResults)-1]
descentParams := ds.PredictionParameters{
LaunchLatitude: lastResult.Latitude,
LaunchLongitude: lastResult.Longitude,
LaunchAltitude: lastResult.Altitude,
LaunchDatetime: lastResult.Timestamp,
}
// Descent to float altitude (if specified)
floatAlt := 0.0
if params.FloatAltitude != nil {
floatAlt = *params.FloatAltitude
}
descentResults := s.simulateDescent(ctx, descentParams, descentRate, floatAlt, descentCurve)
results = append(results, descentResults...)
if floatAlt > 0 && len(descentResults) > 0 {
// Stage 3: Float
finalDescent := descentResults[len(descentResults)-1]
floatResults := s.simulateFloat(ctx, finalDescent, 30*time.Minute)
results = append(results, floatResults...)
}
}
return results
}
func (s *Service) customProfile(ctx context.Context, params ds.PredictionParameters, ascentCurve, descentCurve *CustomCurve) []ds.PredicitonResult {
var results []ds.PredicitonResult
if ascentCurve != nil {
ascentResults := s.simulateCustomAscent(ctx, params, ascentCurve)
results = append(results, ascentResults...)
}
if descentCurve != nil && len(results) > 0 {
lastResult := results[len(results)-1]
descentParams := ds.PredictionParameters{
LaunchLatitude: lastResult.Latitude,
LaunchLongitude: lastResult.Longitude,
LaunchAltitude: lastResult.Altitude,
LaunchDatetime: lastResult.Timestamp,
}
descentResults := s.simulateCustomDescent(ctx, descentParams, descentCurve)
results = append(results, descentResults...)
}
return results
}
func rk4Step(lat, lon, alt float64, t time.Time, dt float64, windFunc func(lat, lon, alt float64, t time.Time) (float64, float64), altRate float64) (float64, float64, float64) {
// Helper for RK4 integration step
toRad := math.Pi / 180.0
toDeg := 180.0 / math.Pi
R := func(alt float64) float64 { return 6371009.0 + alt }
f := func(lat, lon, alt float64, t time.Time) (float64, float64, float64) {
windU, windV := windFunc(lat, lon, alt, t)
Rnow := R(alt)
dlat := toDeg * windV / Rnow
dlon := toDeg * windU / (Rnow * math.Cos(lat*toRad))
return dlat, dlon, altRate
}
k1_lat, k1_lon, k1_alt := f(lat, lon, alt, t)
k2_lat, k2_lon, k2_alt := f(lat+0.5*k1_lat*dt, lon+0.5*k1_lon*dt, alt+0.5*k1_alt*dt, t.Add(time.Duration(0.5*dt)*time.Second))
k3_lat, k3_lon, k3_alt := f(lat+0.5*k2_lat*dt, lon+0.5*k2_lon*dt, alt+0.5*k2_alt*dt, t.Add(time.Duration(0.5*dt)*time.Second))
k4_lat, k4_lon, k4_alt := f(lat+k3_lat*dt, lon+k3_lon*dt, alt+k3_alt*dt, t.Add(time.Duration(dt)*time.Second))
latNew := lat + (dt/6.0)*(k1_lat+2*k2_lat+2*k3_lat+k4_lat)
lonNew := lon + (dt/6.0)*(k1_lon+2*k2_lon+2*k3_lon+k4_lon)
altNew := alt + (dt/6.0)*(k1_alt+2*k2_alt+2*k3_alt+k4_alt)
return latNew, lonNew, altNew
}
func (s *Service) simulateAscent(ctx context.Context, params ds.PredictionParameters, ascentRate, targetAltitude float64, customCurve *CustomCurve) []ds.PredicitonResult {
const dt = 10.0 // simulation step in seconds
const outputInterval = 60.0 // output every 60 seconds
lat := *params.LaunchLatitude
lon := *params.LaunchLongitude
alt := *params.LaunchAltitude
timeCur := *params.LaunchDatetime
results := make([]ds.PredicitonResult, 0, 1000)
latCopy := lat
lonCopy := lon
altCopy := alt
timeCopy := timeCur
wind := [2]float64{0, 0}
windU := wind[0]
windV := wind[1]
results = append(results, ds.PredicitonResult{
Latitude: &latCopy,
Longitude: &lonCopy,
Altitude: &altCopy,
Timestamp: &timeCopy,
WindU: &windU,
WindV: &windV,
})
nextOutputTime := timeCur.Add(time.Duration(outputInterval) * time.Second)
windFunc := func(lat, lon, alt float64, t time.Time) (float64, float64) {
w, err := s.ExtractWind(ctx, lat, lon, alt, t)
if err != nil {
log.Ctx(ctx).Warn("Wind extraction failed during ascent", zap.Error(err))
return 0, 0
}
return w[0], w[1]
}
for alt < targetAltitude {
altRate := ascentRate
if customCurve != nil {
altRate = s.getCustomAltitudeRate(customCurve, alt, ascentRate)
}
latNew, lonNew, altNew := rk4Step(lat, lon, alt, timeCur, dt, windFunc, altRate)
timeCur = timeCur.Add(time.Duration(dt) * time.Second)
lat = latNew
lon = lonNew
alt = altNew
if alt >= targetAltitude {
break
}
if !timeCur.Before(nextOutputTime) {
wU, wV := windFunc(lat, lon, alt, timeCur)
latCopy := lat
lonCopy := lon
altCopy := alt
timeCopy := timeCur
windU := wU
windV := wV
results = append(results, ds.PredicitonResult{
Latitude: &latCopy,
Longitude: &lonCopy,
Altitude: &altCopy,
Timestamp: &timeCopy,
WindU: &windU,
WindV: &windV,
})
nextOutputTime = nextOutputTime.Add(time.Duration(outputInterval) * time.Second)
}
}
return results
}
func (s *Service) simulateDescent(ctx context.Context, params ds.PredictionParameters, descentRate, targetAltitude float64, customCurve *CustomCurve) []ds.PredicitonResult {
const dt = 10.0 // simulation step in seconds
const outputInterval = 60.0 // output every 60 seconds
lat := *params.LaunchLatitude
lon := *params.LaunchLongitude
alt := *params.LaunchAltitude
timeCur := *params.LaunchDatetime
results := make([]ds.PredicitonResult, 0, 1000)
latCopy := lat
lonCopy := lon
altCopy := alt
timeCopy := timeCur
wind := [2]float64{0, 0}
windU := wind[0]
windV := wind[1]
results = append(results, ds.PredicitonResult{
Latitude: &latCopy,
Longitude: &lonCopy,
Altitude: &altCopy,
Timestamp: &timeCopy,
WindU: &windU,
WindV: &windV,
})
nextOutputTime := timeCur.Add(time.Duration(outputInterval) * time.Second)
windFunc := func(lat, lon, alt float64, t time.Time) (float64, float64) {
w, err := s.ExtractWind(ctx, lat, lon, alt, t)
if err != nil {
log.Ctx(ctx).Warn("Wind extraction failed during descent", zap.Error(err))
return 0, 0
}
return w[0], w[1]
}
for alt > targetAltitude {
altRate := -descentRate
if customCurve != nil {
altRate = -s.getCustomAltitudeRate(customCurve, alt, descentRate)
}
latNew, lonNew, altNew := rk4Step(lat, lon, alt, timeCur, dt, windFunc, altRate)
timeCur = timeCur.Add(time.Duration(dt) * time.Second)
lat = latNew
lon = lonNew
alt = altNew
if alt <= targetAltitude {
break
}
if !timeCur.Before(nextOutputTime) {
wU, wV := windFunc(lat, lon, alt, timeCur)
latCopy := lat
lonCopy := lon
altCopy := alt
timeCopy := timeCur
windU := wU
windV := wV
results = append(results, ds.PredicitonResult{
Latitude: &latCopy,
Longitude: &lonCopy,
Altitude: &altCopy,
Timestamp: &timeCopy,
WindU: &windU,
WindV: &windV,
})
nextOutputTime = nextOutputTime.Add(time.Duration(outputInterval) * time.Second)
}
}
return results
}
func (s *Service) simulateFloat(ctx context.Context, startResult ds.PredicitonResult, duration time.Duration) []ds.PredicitonResult {
const dt = 10.0 // simulation step in seconds
const outputInterval = 60.0 // output every 60 seconds
lat := *startResult.Latitude
lon := *startResult.Longitude
alt := *startResult.Altitude
timeCur := *startResult.Timestamp
endTime := timeCur.Add(duration)
results := make([]ds.PredicitonResult, 0, 1000)
// Always include the initial float point
latCopy := lat
lonCopy := lon
altCopy := alt
timeCopy := timeCur
wind := [2]float64{0, 0}
windU := wind[0]
windV := wind[1]
results = append(results, ds.PredicitonResult{
Latitude: &latCopy,
Longitude: &lonCopy,
Altitude: &altCopy,
Timestamp: &timeCopy,
WindU: &windU,
WindV: &windV,
})
var nextOutputTime = timeCur.Add(time.Duration(outputInterval) * time.Second)
for timeCur.Before(endTime) {
wind, err := s.ExtractWind(ctx, lat, lon, alt, timeCur)
if err != nil {
log.Ctx(ctx).Warn("Wind extraction failed during float", zap.Error(err))
break
}
latDot := (wind[1] / 111320.0)
lonDot := (wind[0] / (40075000.0 * math.Cos(lat*math.Pi/180) / 360.0))
lat += latDot * dt
lon += lonDot * dt
// alt remains constant during float
timeCur = timeCur.Add(time.Duration(dt) * time.Second)
if !timeCur.Before(nextOutputTime) {
latCopy := lat
lonCopy := lon
altCopy := alt
timeCopy := timeCur
windU := wind[0]
windV := wind[1]
results = append(results, ds.PredicitonResult{
Latitude: &latCopy,
Longitude: &lonCopy,
Altitude: &altCopy,
Timestamp: &timeCopy,
WindU: &windU,
WindV: &windV,
})
nextOutputTime = nextOutputTime.Add(time.Duration(outputInterval) * time.Second)
}
}
return results
}
func (s *Service) simulateCustomAscent(ctx context.Context, params ds.PredictionParameters, curve *CustomCurve) []ds.PredicitonResult {
// Implementation for custom ascent curve
// This would interpolate the altitude rate from the custom curve
return s.simulateAscent(ctx, params, 5.0, 30000.0, curve)
}
func (s *Service) simulateCustomDescent(ctx context.Context, params ds.PredictionParameters, curve *CustomCurve) []ds.PredicitonResult {
// Implementation for custom descent curve
// This would interpolate the altitude rate from the custom curve
return s.simulateDescent(ctx, params, 5.0, 0.0, curve)
}
func (s *Service) getCustomAltitudeRate(curve *CustomCurve, currentAltitude, defaultRate float64) float64 {
if curve == nil || len(curve.Altitude) < 2 {
return defaultRate
}
// Find the two points in the curve that bracket the current altitude
for i := 0; i < len(curve.Altitude)-1; i++ {
if curve.Altitude[i] <= currentAltitude && currentAltitude <= curve.Altitude[i+1] {
// Linear interpolation
alt1, alt2 := curve.Altitude[i], curve.Altitude[i+1]
time1, time2 := curve.Time[i], curve.Time[i+1]
if alt2 == alt1 {
return defaultRate
}
// Calculate rate (change in altitude per second)
if time2 > time1 {
return (alt2 - alt1) / (time2 - time1)
}
return defaultRate
}
}
return defaultRate
}
func parseCustomCurve(base64Data string) (*CustomCurve, error) {
data, err := base64.StdEncoding.DecodeString(base64Data)
if err != nil {
return nil, err
}
var curve CustomCurve
if err := json.Unmarshal(data, &curve); err != nil {
return nil, err
}
return &curve, nil
}

View file

@ -2,59 +2,244 @@ package service
import (
"context"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/log"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/downloader"
"predictor-refactored/internal/elevation"
"go.uber.org/zap"
)
// Service orchestrates the dataset lifecycle and provides prediction capabilities.
type Service struct {
grib Grib
mu sync.RWMutex
ds *dataset.File
elev *elevation.Dataset
cfg *downloader.Config
dl *downloader.Downloader
log *zap.Logger
updating sync.Mutex // prevents concurrent downloads
}
func New(gribService Grib) (*Service, error) {
svc := &Service{
grib: gribService,
// New creates a new Service.
func New(cfg *downloader.Config, log *zap.Logger) *Service {
return &Service{
cfg: cfg,
dl: downloader.NewDownloader(cfg, log),
log: log,
}
}
// LoadElevation loads the ruaumoko-compatible elevation dataset from path.
// If the file doesn't exist, elevation termination is disabled (falls back to sea level).
func (s *Service) LoadElevation(path string) {
ds, err := elevation.Open(path)
if err != nil {
s.log.Warn("elevation dataset not available, using sea-level termination",
zap.String("path", path), zap.Error(err))
return
}
s.elev = ds
s.log.Info("elevation dataset loaded", zap.String("path", path))
}
// Elevation returns the elevation dataset (may be nil).
func (s *Service) Elevation() *elevation.Dataset {
return s.elev
}
// Ready returns true if the service has a loaded dataset.
func (s *Service) Ready() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.ds != nil
}
// DatasetTime returns the forecast time of the currently loaded dataset.
func (s *Service) DatasetTime() (time.Time, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
if s.ds == nil {
return time.Time{}, false
}
return s.ds.DSTime, true
}
// Dataset returns the current dataset for reading.
func (s *Service) Dataset() *dataset.File {
s.mu.RLock()
defer s.mu.RUnlock()
return s.ds
}
// Update checks for and downloads new forecast data if needed.
func (s *Service) Update(ctx context.Context) error {
if !s.updating.TryLock() {
s.log.Info("update already in progress, skipping")
return nil
}
defer s.updating.Unlock()
// Check if current dataset is still fresh
if dsTime, ok := s.DatasetTime(); ok {
if time.Since(dsTime) < s.cfg.DatasetTTL {
s.log.Info("dataset still fresh, skipping update",
zap.Time("dataset_time", dsTime),
zap.Duration("age", time.Since(dsTime)))
return nil
}
}
return svc, nil
}
// Try loading an existing dataset from disk first
if err := s.loadExistingDataset(); err == nil {
return nil
}
// UpdateWeatherData updates weather forecast data using the configured grib service
func (s *Service) UpdateWeatherData(ctx context.Context) error {
return s.grib.Update(ctx)
}
// Find latest available model run
run, err := s.dl.FindLatestRun(ctx)
if err != nil {
return err
}
// ExtractWind extracts wind data for given coordinates and time
func (s *Service) ExtractWind(ctx context.Context, lat, lon, alt float64, ts time.Time) ([2]float64, error) {
return s.grib.Extract(ctx, lat, lon, alt, ts)
}
// Download and assemble
path, err := s.dl.Download(ctx, run)
if err != nil {
return err
}
// Update updates the GRIB data (implements updater.GribService)
func (s *Service) Update(ctx context.Context) error {
return s.UpdateWeatherData(ctx)
}
// Open the new dataset
ds, err := dataset.Open(path, run)
if err != nil {
return err
}
// Start starts the service
func (s *Service) Start() {
log.Ctx(context.Background()).Info("service started")
}
// Swap in the new dataset
s.setDataset(ds)
s.log.Info("dataset loaded", zap.Time("run", run), zap.String("path", path))
// Stop stops the service
func (s *Service) Stop() {
log.Ctx(context.Background()).Info("service stopped")
}
// Clean old datasets
s.cleanOldDatasets(path)
// Close closes the service and releases resources
func (s *Service) Close() error {
s.Stop()
return nil
}
func (s *Service) GetGribStatus(ctx context.Context) (ready bool, lastUpdate time.Time, isFresh bool, errMsg string) {
if gribStatus, ok := s.grib.(interface {
GetStatus() (ready bool, lastUpdate time.Time, isFresh bool, errMsg string)
}); ok {
return gribStatus.GetStatus()
// loadExistingDataset tries to find and load an existing dataset from the data directory.
func (s *Service) loadExistingDataset() error {
entries, err := os.ReadDir(s.cfg.DataDir)
if err != nil {
return err
}
return false, time.Time{}, false, "grib service does not implement GetStatus"
// Collect valid dataset files (name is YYYYMMDDHH, no extension, correct size)
type candidate struct {
name string
path string
run time.Time
}
var candidates []candidate
for _, e := range entries {
if e.IsDir() || strings.Contains(e.Name(), ".") {
continue
}
if len(e.Name()) != 10 {
continue
}
run, err := time.Parse("2006010215", e.Name())
if err != nil {
continue
}
path := filepath.Join(s.cfg.DataDir, e.Name())
info, err := os.Stat(path)
if err != nil || info.Size() != dataset.DatasetSize {
continue
}
if time.Since(run) > s.cfg.DatasetTTL {
continue
}
candidates = append(candidates, candidate{name: e.Name(), path: path, run: run})
}
if len(candidates) == 0 {
return os.ErrNotExist
}
// Pick the newest
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].run.After(candidates[j].run)
})
best := candidates[0]
ds, err := dataset.Open(best.path, best.run)
if err != nil {
return err
}
s.setDataset(ds)
s.log.Info("loaded existing dataset",
zap.Time("run", best.run),
zap.String("path", best.path))
return nil
}
// setDataset swaps the current dataset with a new one, closing the old one.
func (s *Service) setDataset(ds *dataset.File) {
s.mu.Lock()
old := s.ds
s.ds = ds
s.mu.Unlock()
if old != nil {
if err := old.Close(); err != nil {
s.log.Error("failed to close old dataset", zap.Error(err))
}
}
}
// cleanOldDatasets removes dataset files other than the one at keepPath.
func (s *Service) cleanOldDatasets(keepPath string) {
entries, err := os.ReadDir(s.cfg.DataDir)
if err != nil {
return
}
for _, e := range entries {
if e.IsDir() {
continue
}
path := filepath.Join(s.cfg.DataDir, e.Name())
if path == keepPath {
continue
}
// Remove old datasets and temp files
if len(e.Name()) == 10 || strings.HasSuffix(e.Name(), ".downloading") {
if err := os.Remove(path); err != nil {
s.log.Warn("failed to remove old file", zap.String("path", path), zap.Error(err))
} else {
s.log.Info("removed old dataset", zap.String("path", path))
}
}
}
}
// Close releases all resources.
func (s *Service) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.ds != nil {
err := s.ds.Close()
s.ds = nil
return err
}
return nil
}

View file

@ -3,42 +3,28 @@ package middleware
import (
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
"git.intra.yksa.space/gsn/predictor/internal/pkg/log"
"github.com/ogen-go/ogen/middleware"
"go.uber.org/zap"
)
func Logging() middleware.Middleware {
// Logging returns an ogen middleware that logs request duration.
func Logging(log *zap.Logger) middleware.Middleware {
return func(req middleware.Request, next func(req middleware.Request) (middleware.Response, error)) (middleware.Response, error) {
lg := log.Ctx(req.Context).With(
zap.String("operationId", req.OperationID),
)
lg.Info("started request")
req.Context = log.ToCtx(req.Context, lg)
lg := log.With(zap.String("operation", req.OperationID))
start := time.Now()
resp, err := next(req)
dur := time.Since(start).Microseconds()
dur := time.Since(start)
if err != nil {
if errcode, ok := err.(*errcodes.ErrorCode); ok {
lg.Error("request error",
zap.Int("status_code", errcode.StatusCode),
zap.String("message", errcode.Message),
zap.String("details", errcode.Details),
)
} else {
lg.Error("request internal error",
zap.Error(err),
)
}
lg.Error("request failed",
zap.Duration("duration", dur),
zap.Error(err))
} else {
lg.Info("request completed",
zap.Duration("duration", dur))
}
lg.Info("done request", zap.Float64("duration_ms", float64(dur)/float64(1000)))
return resp, err
}
}

View file

@ -1,24 +0,0 @@
package rest
import (
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
env "github.com/caarlos0/env/v11"
)
type Config struct {
Host string `env:"HOST" envDefault:"0.0.0.0"`
Port int `env:"PORT" envDefault:"8080"`
ReadTimeout string `env:"READ_TIMEOUT" envDefault:"30s"`
WriteTimeout string `env:"WRITE_TIMEOUT" envDefault:"30s"`
IdleTimeout string `env:"IDLE_TIMEOUT" envDefault:"60s"`
}
func NewConfig() (*Config, error) {
cfg := &Config{}
if err := env.ParseWithOptions(cfg, env.Options{
PrefixTagName: "GSN_PREDICTOR_REST_",
}); err != nil {
return nil, errcodes.Wrap(err, "failed to parse REST config")
}
return cfg, nil
}

View file

@ -1,14 +1,16 @@
package handler
import (
"context"
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/ds"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/elevation"
)
// Service defines the interface the handler needs from the service layer.
type Service interface {
UpdateWeatherData(ctx context.Context) error
ExtractWind(ctx context.Context, lat, lon, alt float64, ts time.Time) ([2]float64, error)
PerformPrediction(ctx context.Context, params ds.PredictionParameters) ([]ds.PredicitonResult, error)
Ready() bool
DatasetTime() (time.Time, bool)
Dataset() *dataset.File
Elevation() *elevation.Dataset
}

View file

@ -5,190 +5,212 @@ import (
"net/http"
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/ds"
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
api "git.intra.yksa.space/gsn/predictor/pkg/rest"
"predictor-refactored/internal/prediction"
api "predictor-refactored/pkg/rest"
"go.uber.org/zap"
)
var (
_ api.Handler = (*Handler)(nil)
)
var _ api.Handler = (*Handler)(nil)
// Handler implements the ogen-generated api.Handler interface.
type Handler struct {
svc Service
log *zap.Logger
}
func New(svc Service) *Handler {
return &Handler{
svc: svc,
}
// New creates a new Handler.
func New(svc Service, log *zap.Logger) *Handler {
return &Handler{svc: svc, log: log}
}
func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredictionParams) (*api.PredictionResult, error) {
internalParams := ds.ConvertFlatPredictionParams(params)
if internalParams == nil {
return nil, errcodes.New(http.StatusBadRequest, "invalid or missing parameters")
}
results, err := h.svc.PerformPrediction(ctx, *internalParams)
if err != nil {
return nil, err
}
if len(results) == 0 {
return nil, errcodes.New(http.StatusInternalServerError, "no prediction results")
// PerformPrediction implements the prediction endpoint.
func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredictionParams) (*api.PredictionResponse, error) {
if !h.svc.Ready() {
return nil, newError(http.StatusServiceUnavailable, "no dataset loaded, service is starting up")
}
// Group results into stages (ascent and descent)
stages := h.groupResultsIntoStages(results)
ds := h.svc.Dataset()
if ds == nil {
return nil, newError(http.StatusServiceUnavailable, "dataset unavailable")
}
// Map to OpenAPI schema
var predictionItems []api.PredictionResultPredictionItem
dsEpoch := float64(ds.DSTime.Unix())
for _, stage := range stages {
var trajectory []api.PredictionResultPredictionItemTrajectoryItem
// Parse parameters with defaults
profile := "standard_profile"
if p, ok := params.Profile.Get(); ok {
profile = string(p)
}
for _, result := range stage.Results {
traj := api.PredictionResultPredictionItemTrajectoryItem{
Datetime: *result.Timestamp,
Latitude: *result.Latitude,
Longitude: *result.Longitude,
Altitude: *result.Altitude,
ascentRate := 5.0
if v, ok := params.AscentRate.Get(); ok {
ascentRate = v
}
burstAltitude := 28000.0
if v, ok := params.BurstAltitude.Get(); ok {
burstAltitude = v
}
descentRate := 5.0
if v, ok := params.DescentRate.Get(); ok {
descentRate = v
}
launchAlt := 0.0
if v, ok := params.LaunchAltitude.Get(); ok {
launchAlt = v
}
// Normalize longitude to [0, 360)
lng := params.LaunchLongitude
if lng < 0 {
lng += 360.0
}
launchTime := float64(params.LaunchDatetime.Unix())
warnings := &prediction.Warnings{}
// Build profile chain
elev := h.svc.Elevation()
var stages []prediction.Stage
switch profile {
case "standard_profile":
stages = prediction.StandardProfile(
ascentRate, burstAltitude, descentRate,
ds, dsEpoch, warnings, elev)
case "float_profile":
floatAlt := 25000.0
if v, ok := params.FloatAltitude.Get(); ok {
floatAlt = v
}
stopTime := params.LaunchDatetime.Add(24 * time.Hour)
if v, ok := params.StopDatetime.Get(); ok {
stopTime = v
}
stages = prediction.FloatProfile(
ascentRate, floatAlt, stopTime,
ds, dsEpoch, warnings)
default:
return nil, newError(http.StatusBadRequest, "unknown profile: "+profile)
}
// Run prediction
startTime := time.Now().UTC()
results := prediction.RunPrediction(launchTime, params.LaunchLatitude, lng, launchAlt, stages)
completeTime := time.Now().UTC()
// Build response
stageNames := []string{"ascent", "descent"}
if profile == "float_profile" {
stageNames = []string{"ascent", "float"}
}
var predItems []api.PredictionResponsePredictionItem
for i, sr := range results {
stageName := "ascent"
if i < len(stageNames) {
stageName = stageNames[i]
}
var stageEnum api.PredictionResponsePredictionItemStage
switch stageName {
case "ascent":
stageEnum = api.PredictionResponsePredictionItemStageAscent
case "descent":
stageEnum = api.PredictionResponsePredictionItemStageDescent
case "float":
stageEnum = api.PredictionResponsePredictionItemStageFloat
}
var traj []api.PredictionResponsePredictionItemTrajectoryItem
for _, pt := range sr.Points {
ptLng := pt.Lng
if ptLng > 180 {
ptLng -= 360
}
trajectory = append(trajectory, traj)
}
item := api.PredictionResultPredictionItem{
Stage: stage.Stage,
Trajectory: trajectory,
}
predictionItems = append(predictionItems, item)
}
metadata := api.PredictionResultMetadata{
StartDatetime: *results[0].Timestamp,
CompleteDatetime: *results[len(results)-1].Timestamp,
}
resp := &api.PredictionResult{
Metadata: metadata,
Prediction: predictionItems,
}
return resp, nil
}
// StageResult represents a stage with its results
type StageResult struct {
Stage api.PredictionResultPredictionItemStage
Results []ds.PredicitonResult
}
// groupResultsIntoStages groups the prediction results into ascent and descent stages
func (h *Handler) groupResultsIntoStages(results []ds.PredicitonResult) []StageResult {
if len(results) == 0 {
return nil
}
var stages []StageResult
var currentStage []ds.PredicitonResult
var currentStageType api.PredictionResultPredictionItemStage
// Determine if we're in ascent or descent based on altitude changes
prevAlt := *results[0].Altitude
currentStage = append(currentStage, results[0])
currentStageType = api.PredictionResultPredictionItemStageAscent
for i := 1; i < len(results); i++ {
result := results[i]
currentAlt := *result.Altitude
// Determine if we're still in the same stage
var stageType api.PredictionResultPredictionItemStage
if currentAlt > prevAlt {
stageType = api.PredictionResultPredictionItemStageAscent
} else if currentAlt < prevAlt {
stageType = api.PredictionResultPredictionItemStageDescent
} else {
// Same altitude - continue with current stage
stageType = currentStageType
}
// If stage type changed, finalize current stage and start new one
if stageType != currentStageType && len(currentStage) > 0 {
stages = append(stages, StageResult{
Stage: currentStageType,
Results: currentStage,
traj = append(traj, api.PredictionResponsePredictionItemTrajectoryItem{
Datetime: time.Unix(int64(pt.T), 0).UTC(),
Latitude: pt.Lat,
Longitude: ptLng,
Altitude: pt.Alt,
})
currentStage = nil
currentStageType = stageType
}
currentStage = append(currentStage, result)
prevAlt = currentAlt
}
// Add the final stage
if len(currentStage) > 0 {
stages = append(stages, StageResult{
Stage: currentStageType,
Results: currentStage,
predItems = append(predItems, api.PredictionResponsePredictionItem{
Stage: stageEnum,
Trajectory: traj,
})
}
return stages
}
func (h *Handler) NewError(ctx context.Context, err error) *api.ErrorStatusCode {
if errcode, ok := err.(*errcodes.ErrorCode); ok {
resp := api.Error{
Message: errcode.Message,
}
if errcode.Details != "" {
resp.Details = api.NewOptString(errcode.Details)
}
return &api.ErrorStatusCode{
StatusCode: errcode.StatusCode,
Response: resp,
}
resp := &api.PredictionResponse{
Prediction: predItems,
Metadata: api.PredictionResponseMetadata{
StartDatetime: startTime,
CompleteDatetime: completeTime,
},
}
// Echo request
resp.Request = api.NewOptPredictionResponseRequest(api.PredictionResponseRequest{
Dataset: api.NewOptString(ds.DSTime.Format("2006-01-02T15:04:05Z")),
LaunchLatitude: api.NewOptFloat64(params.LaunchLatitude),
LaunchLongitude: api.NewOptFloat64(params.LaunchLongitude),
LaunchDatetime: api.NewOptString(params.LaunchDatetime.Format(time.RFC3339)),
LaunchAltitude: params.LaunchAltitude,
})
// Warnings
warnMap := warnings.ToMap()
if len(warnMap) > 0 {
resp.Warnings = api.NewOptPredictionResponseWarnings(api.PredictionResponseWarnings{})
}
h.log.Info("prediction complete",
zap.String("profile", profile),
zap.Int("stages", len(results)),
zap.Duration("elapsed", completeTime.Sub(startTime)))
return resp, nil
}
// ReadinessCheck implements the health check endpoint.
func (h *Handler) ReadinessCheck(ctx context.Context) (*api.ReadinessResponse, error) {
resp := &api.ReadinessResponse{}
if h.svc.Ready() {
resp.Status = api.ReadinessResponseStatusOk
if dsTime, ok := h.svc.DatasetTime(); ok {
resp.DatasetTime = api.NewOptDateTime(dsTime)
}
} else {
resp.Status = api.ReadinessResponseStatusNotReady
resp.ErrorMessage = api.NewOptString("no dataset loaded")
}
return resp, nil
}
// NewError creates an ErrorStatusCode from an error returned by a handler.
func (h *Handler) NewError(ctx context.Context, err error) *api.ErrorStatusCode {
if statusErr, ok := err.(*api.ErrorStatusCode); ok {
return statusErr
}
h.log.Error("unhandled error", zap.Error(err))
return newError(http.StatusInternalServerError, err.Error())
}
func newError(status int, description string) *api.ErrorStatusCode {
return &api.ErrorStatusCode{
StatusCode: http.StatusInternalServerError,
StatusCode: status,
Response: api.Error{
Message: "undefined internal error",
Details: api.NewOptString(err.Error()),
Error: api.ErrorError{
Type: http.StatusText(status),
Description: description,
},
},
}
}
func (h *Handler) ReadinessCheck(ctx context.Context) (*api.ReadinessResponse, error) {
status := api.ReadinessResponseStatusNotReady
var lastUpdate time.Time
var isFresh bool
var errMsg string
if s, ok := h.svc.(interface {
GetGribStatus(ctx context.Context) (ready bool, lastUpdate time.Time, isFresh bool, errMsg string)
}); ok {
ready, lu, fresh, em := s.GetGribStatus(ctx)
lastUpdate = lu
isFresh = fresh
errMsg = em
if ready {
status = api.ReadinessResponseStatusOk
} else if em != "" {
status = api.ReadinessResponseStatusError
}
} else {
errMsg = "service does not implement GetGribStatus"
status = api.ReadinessResponseStatusError
}
resp := &api.ReadinessResponse{
Status: status,
IsFresh: api.NewOptBool(isFresh),
LastUpdate: api.NewOptDateTime(lastUpdate),
ErrorMessage: api.NewOptString(errMsg),
}
return resp, nil
}

View file

@ -5,43 +5,71 @@ import (
"fmt"
"net/http"
"git.intra.yksa.space/gsn/predictor/internal/pkg/log"
"git.intra.yksa.space/gsn/predictor/internal/transport/middleware"
handler "git.intra.yksa.space/gsn/predictor/internal/transport/rest/handler"
api "git.intra.yksa.space/gsn/predictor/pkg/rest"
"github.com/rs/cors"
"predictor-refactored/internal/transport/middleware"
"predictor-refactored/internal/transport/rest/handler"
api "predictor-refactored/pkg/rest"
"go.uber.org/zap"
)
// Transport wraps the ogen HTTP server.
type Transport struct {
cfg *Config
srv *api.Server
handler *handler.Handler
port int
log *zap.Logger
}
func New(handler *handler.Handler, cfg *Config) (*Transport, error) {
// New creates a new REST transport.
func New(h *handler.Handler, port int, log *zap.Logger) (*Transport, error) {
srv, err := api.NewServer(
handler,
api.WithMiddleware(middleware.Logging()),
h,
api.WithMiddleware(middleware.Logging(log)),
)
if err != nil {
return nil, err
return nil, fmt.Errorf("create ogen server: %w", err)
}
return &Transport{
srv: srv,
cfg: cfg,
handler: handler,
handler: h,
port: port,
log: log,
}, nil
}
func (t *Transport) Run() {
log.Ctx(context.Background()).Info("started")
// Run starts the HTTP server. Blocks until the server stops.
func (t *Transport) Run() error {
mux := http.NewServeMux()
mux.Handle("/", t.srv)
cors.AllowAll().Handler(mux)
if err := http.ListenAndServe(fmt.Sprintf(":%d", t.cfg.Port), t.srv); err != nil {
panic(err)
httpSrv := &http.Server{
Addr: fmt.Sprintf(":%d", t.port),
Handler: corsMiddleware(mux),
}
t.log.Info("starting HTTP server", zap.Int("port", t.port))
return httpSrv.ListenAndServe()
}
// Shutdown gracefully stops the HTTP server.
func (t *Transport) Shutdown(ctx context.Context) error {
// The ogen server doesn't have a shutdown method;
// shutdown is handled by the http.Server in main.go
return nil
}
func corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}

View file

@ -1,18 +1,19 @@
// Code generated by ogen, DO NOT EDIT.
package gsn
package rest
import (
"net/http"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/trace"
"strings"
ht "github.com/ogen-go/ogen/http"
"github.com/ogen-go/ogen/middleware"
"github.com/ogen-go/ogen/ogenerrors"
"github.com/ogen-go/ogen/otelogen"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/trace"
)
var (
@ -32,6 +33,7 @@ type otelConfig struct {
Tracer trace.Tracer
MeterProvider metric.MeterProvider
Meter metric.Meter
Attributes []attribute.KeyValue
}
func (cfg *otelConfig) initOTEL() {
@ -81,18 +83,8 @@ func (o otelOptionFunc) applyServer(c *serverConfig) {
func newServerConfig(opts ...ServerOption) serverConfig {
cfg := serverConfig{
NotFound: http.NotFound,
MethodNotAllowed: func(w http.ResponseWriter, r *http.Request, allowed string) {
status := http.StatusMethodNotAllowed
if r.Method == "OPTIONS" {
w.Header().Set("Access-Control-Allow-Methods", allowed)
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
status = http.StatusNoContent
} else {
w.Header().Set("Allow", allowed)
}
w.WriteHeader(status)
},
NotFound: http.NotFound,
MethodNotAllowed: nil,
ErrorHandler: ogenerrors.DefaultErrorHandler,
Middleware: nil,
MaxMultipartMemory: 32 << 20, // 32 MB
@ -115,8 +107,44 @@ func (s baseServer) notFound(w http.ResponseWriter, r *http.Request) {
s.cfg.NotFound(w, r)
}
func (s baseServer) notAllowed(w http.ResponseWriter, r *http.Request, allowed string) {
s.cfg.MethodNotAllowed(w, r, allowed)
type notAllowedParams struct {
allowedMethods string
allowedHeaders map[string]string
acceptPost string
acceptPatch string
}
func (s baseServer) notAllowed(w http.ResponseWriter, r *http.Request, params notAllowedParams) {
h := w.Header()
isOptions := r.Method == "OPTIONS"
if isOptions {
h.Set("Access-Control-Allow-Methods", params.allowedMethods)
if params.allowedHeaders != nil {
m := r.Header.Get("Access-Control-Request-Method")
if m != "" {
allowedHeaders, ok := params.allowedHeaders[strings.ToUpper(m)]
if ok {
h.Set("Access-Control-Allow-Headers", allowedHeaders)
}
}
}
if params.acceptPost != "" {
h.Set("Accept-Post", params.acceptPost)
}
if params.acceptPatch != "" {
h.Set("Accept-Patch", params.acceptPatch)
}
}
if s.cfg.MethodNotAllowed != nil {
s.cfg.MethodNotAllowed(w, r, params.allowedMethods)
return
}
status := http.StatusNoContent
if !isOptions {
h.Set("Allow", params.allowedMethods)
status = http.StatusMethodNotAllowed
}
w.WriteHeader(status)
}
func (cfg serverConfig) baseServer() (s baseServer, err error) {
@ -215,6 +243,13 @@ func WithMeterProvider(provider metric.MeterProvider) Option {
})
}
// WithAttributes specifies default otel attributes.
func WithAttributes(attributes ...attribute.KeyValue) Option {
return otelOptionFunc(func(cfg *otelConfig) {
cfg.Attributes = attributes
})
}
// WithClient specifies http client to use.
func WithClient(client ht.Client) ClientOption {
return optionFunc[clientConfig](func(cfg *clientConfig) {

View file

@ -1,6 +1,6 @@
// Code generated by ogen, DO NOT EDIT.
package gsn
package rest
import (
"context"
@ -9,16 +9,15 @@ import (
"time"
"github.com/go-faster/errors"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/metric"
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
"go.opentelemetry.io/otel/trace"
"github.com/ogen-go/ogen/conv"
ht "github.com/ogen-go/ogen/http"
"github.com/ogen-go/ogen/otelogen"
"github.com/ogen-go/ogen/uri"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/metric"
semconv "go.opentelemetry.io/otel/semconv/v1.39.0"
"go.opentelemetry.io/otel/trace"
)
func trimTrailingSlashes(u *url.URL) {
@ -33,7 +32,7 @@ type Invoker interface {
// Perform prediction.
//
// GET /api/v1/prediction
PerformPrediction(ctx context.Context, params PerformPredictionParams) (*PredictionResult, error)
PerformPrediction(ctx context.Context, params PerformPredictionParams) (*PredictionResponse, error)
// ReadinessCheck invokes readinessCheck operation.
//
// Readiness check.
@ -47,14 +46,6 @@ type Client struct {
serverURL *url.URL
baseClient
}
type errorHandler interface {
NewError(ctx context.Context, err error) *ErrorStatusCode
}
var _ Handler = struct {
errorHandler
*Client
}{}
// NewClient initializes new Client defined by OAS.
func NewClient(serverURL string, opts ...ClientOption) (*Client, error) {
@ -94,17 +85,18 @@ func (c *Client) requestURL(ctx context.Context) *url.URL {
// Perform prediction.
//
// GET /api/v1/prediction
func (c *Client) PerformPrediction(ctx context.Context, params PerformPredictionParams) (*PredictionResult, error) {
func (c *Client) PerformPrediction(ctx context.Context, params PerformPredictionParams) (*PredictionResponse, error) {
res, err := c.sendPerformPrediction(ctx, params)
return res, err
}
func (c *Client) sendPerformPrediction(ctx context.Context, params PerformPredictionParams) (res *PredictionResult, err error) {
func (c *Client) sendPerformPrediction(ctx context.Context, params PerformPredictionParams) (res *PredictionResponse, err error) {
otelAttrs := []attribute.KeyValue{
otelogen.OperationID("performPrediction"),
semconv.HTTPRequestMethodKey.String("GET"),
semconv.HTTPRouteKey.String("/api/v1/prediction"),
semconv.URLTemplateKey.String("/api/v1/prediction"),
}
otelAttrs = append(otelAttrs, c.cfg.Attributes...)
// Run stopwatch.
startTime := time.Now()
@ -150,10 +142,7 @@ func (c *Client) sendPerformPrediction(ctx context.Context, params PerformPredic
}
if err := q.EncodeParam(cfg, func(e uri.Encoder) error {
if val, ok := params.LaunchLatitude.Get(); ok {
return e.EncodeValue(conv.Float64ToString(val))
}
return nil
return e.EncodeValue(conv.Float64ToString(params.LaunchLatitude))
}); err != nil {
return res, errors.Wrap(err, "encode query")
}
@ -167,10 +156,7 @@ func (c *Client) sendPerformPrediction(ctx context.Context, params PerformPredic
}
if err := q.EncodeParam(cfg, func(e uri.Encoder) error {
if val, ok := params.LaunchLongitude.Get(); ok {
return e.EncodeValue(conv.Float64ToString(val))
}
return nil
return e.EncodeValue(conv.Float64ToString(params.LaunchLongitude))
}); err != nil {
return res, errors.Wrap(err, "encode query")
}
@ -184,10 +170,7 @@ func (c *Client) sendPerformPrediction(ctx context.Context, params PerformPredic
}
if err := q.EncodeParam(cfg, func(e uri.Encoder) error {
if val, ok := params.LaunchDatetime.Get(); ok {
return e.EncodeValue(conv.DateTimeToString(val))
}
return nil
return e.EncodeValue(conv.DateTimeToString(params.LaunchDatetime))
}); err != nil {
return res, errors.Wrap(err, "encode query")
}
@ -311,74 +294,6 @@ func (c *Client) sendPerformPrediction(ctx context.Context, params PerformPredic
return res, errors.Wrap(err, "encode query")
}
}
{
// Encode "ascent_curve" parameter.
cfg := uri.QueryParameterEncodingConfig{
Name: "ascent_curve",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.EncodeParam(cfg, func(e uri.Encoder) error {
if val, ok := params.AscentCurve.Get(); ok {
return e.EncodeValue(conv.StringToString(val))
}
return nil
}); err != nil {
return res, errors.Wrap(err, "encode query")
}
}
{
// Encode "descent_curve" parameter.
cfg := uri.QueryParameterEncodingConfig{
Name: "descent_curve",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.EncodeParam(cfg, func(e uri.Encoder) error {
if val, ok := params.DescentCurve.Get(); ok {
return e.EncodeValue(conv.StringToString(val))
}
return nil
}); err != nil {
return res, errors.Wrap(err, "encode query")
}
}
{
// Encode "interpolate" parameter.
cfg := uri.QueryParameterEncodingConfig{
Name: "interpolate",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.EncodeParam(cfg, func(e uri.Encoder) error {
if val, ok := params.Interpolate.Get(); ok {
return e.EncodeValue(conv.BoolToString(val))
}
return nil
}); err != nil {
return res, errors.Wrap(err, "encode query")
}
}
{
// Encode "format" parameter.
cfg := uri.QueryParameterEncodingConfig{
Name: "format",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.EncodeParam(cfg, func(e uri.Encoder) error {
if val, ok := params.Format.Get(); ok {
return e.EncodeValue(conv.StringToString(string(val)))
}
return nil
}); err != nil {
return res, errors.Wrap(err, "encode query")
}
}
{
// Encode "dataset" parameter.
cfg := uri.QueryParameterEncodingConfig{
@ -409,7 +324,8 @@ func (c *Client) sendPerformPrediction(ctx context.Context, params PerformPredic
if err != nil {
return res, errors.Wrap(err, "do request")
}
defer resp.Body.Close()
body := resp.Body
defer body.Close()
stage = "DecodeResponse"
result, err := decodePerformPredictionResponse(resp)
@ -434,8 +350,9 @@ func (c *Client) sendReadinessCheck(ctx context.Context) (res *ReadinessResponse
otelAttrs := []attribute.KeyValue{
otelogen.OperationID("readinessCheck"),
semconv.HTTPRequestMethodKey.String("GET"),
semconv.HTTPRouteKey.String("/ready"),
semconv.URLTemplateKey.String("/ready"),
}
otelAttrs = append(otelAttrs, c.cfg.Attributes...)
// Run stopwatch.
startTime := time.Now()
@ -481,7 +398,8 @@ func (c *Client) sendReadinessCheck(ctx context.Context) (res *ReadinessResponse
if err != nil {
return res, errors.Wrap(err, "do request")
}
defer resp.Body.Close()
body := resp.Body
defer body.Close()
stage = "DecodeResponse"
result, err := decodeReadinessCheckResponse(resp)

View file

@ -1,6 +1,6 @@
// Code generated by ogen, DO NOT EDIT.
package gsn
package rest
import (
"context"
@ -8,16 +8,15 @@ import (
"time"
"github.com/go-faster/errors"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/metric"
semconv "go.opentelemetry.io/otel/semconv/v1.26.0"
"go.opentelemetry.io/otel/trace"
ht "github.com/ogen-go/ogen/http"
"github.com/ogen-go/ogen/middleware"
"github.com/ogen-go/ogen/ogenerrors"
"github.com/ogen-go/ogen/otelogen"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/metric"
semconv "go.opentelemetry.io/otel/semconv/v1.39.0"
"go.opentelemetry.io/otel/trace"
)
type codeRecorder struct {
@ -30,6 +29,10 @@ func (c *codeRecorder) WriteHeader(status int) {
c.ResponseWriter.WriteHeader(status)
}
func (c *codeRecorder) Unwrap() http.ResponseWriter {
return c.ResponseWriter
}
// handlePerformPredictionRequest handles performPrediction operation.
//
// Perform prediction.
@ -43,6 +46,8 @@ func (s *Server) handlePerformPredictionRequest(args [0]string, argsEscaped bool
semconv.HTTPRequestMethodKey.String("GET"),
semconv.HTTPRouteKey.String("/api/v1/prediction"),
}
// Add attributes from config.
otelAttrs = append(otelAttrs, s.cfg.Attributes...)
// Start a span for this request.
ctx, span := s.cfg.Tracer.Start(r.Context(), PerformPredictionOperation,
@ -86,7 +91,7 @@ func (s *Server) handlePerformPredictionRequest(args [0]string, argsEscaped bool
// unless there was another error (e.g., network error receiving the response body; or 3xx codes with
// max redirects exceeded), in which case status MUST be set to Error.
code := statusWriter.status
if code >= 100 && code < 500 {
if code < 100 || code >= 500 {
span.SetStatus(codes.Error, stage)
}
@ -115,7 +120,9 @@ func (s *Server) handlePerformPredictionRequest(args [0]string, argsEscaped bool
return
}
var response *PredictionResult
var rawBody []byte
var response *PredictionResponse
if m := s.cfg.Middleware; m != nil {
mreq := middleware.Request{
Context: ctx,
@ -123,6 +130,7 @@ func (s *Server) handlePerformPredictionRequest(args [0]string, argsEscaped bool
OperationSummary: "Perform prediction",
OperationID: "performPrediction",
Body: nil,
RawBody: rawBody,
Params: middleware.Parameters{
{
Name: "launch_latitude",
@ -164,22 +172,6 @@ func (s *Server) handlePerformPredictionRequest(args [0]string, argsEscaped bool
Name: "stop_datetime",
In: "query",
}: params.StopDatetime,
{
Name: "ascent_curve",
In: "query",
}: params.AscentCurve,
{
Name: "descent_curve",
In: "query",
}: params.DescentCurve,
{
Name: "interpolate",
In: "query",
}: params.Interpolate,
{
Name: "format",
In: "query",
}: params.Format,
{
Name: "dataset",
In: "query",
@ -191,7 +183,7 @@ func (s *Server) handlePerformPredictionRequest(args [0]string, argsEscaped bool
type (
Request = struct{}
Params = PerformPredictionParams
Response = *PredictionResult
Response = *PredictionResponse
)
response, err = middleware.HookMiddleware[
Request,
@ -248,6 +240,8 @@ func (s *Server) handleReadinessCheckRequest(args [0]string, argsEscaped bool, w
semconv.HTTPRequestMethodKey.String("GET"),
semconv.HTTPRouteKey.String("/ready"),
}
// Add attributes from config.
otelAttrs = append(otelAttrs, s.cfg.Attributes...)
// Start a span for this request.
ctx, span := s.cfg.Tracer.Start(r.Context(), ReadinessCheckOperation,
@ -291,7 +285,7 @@ func (s *Server) handleReadinessCheckRequest(args [0]string, argsEscaped bool, w
// unless there was another error (e.g., network error receiving the response body; or 3xx codes with
// max redirects exceeded), in which case status MUST be set to Error.
code := statusWriter.status
if code >= 100 && code < 500 {
if code < 100 || code >= 500 {
span.SetStatus(codes.Error, stage)
}
@ -306,6 +300,8 @@ func (s *Server) handleReadinessCheckRequest(args [0]string, argsEscaped bool, w
err error
)
var rawBody []byte
var response *ReadinessResponse
if m := s.cfg.Middleware; m != nil {
mreq := middleware.Request{
@ -314,6 +310,7 @@ func (s *Server) handleReadinessCheckRequest(args [0]string, argsEscaped bool, w
OperationSummary: "Readiness check",
OperationID: "readinessCheck",
Body: nil,
RawBody: rawBody,
Params: middleware.Parameters{},
Raw: r,
}

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
// Code generated by ogen, DO NOT EDIT.
package gsn
package rest
import (
"context"

View file

@ -1,6 +1,6 @@
// Code generated by ogen, DO NOT EDIT.
package gsn
package rest
import (
"github.com/ogen-go/ogen/middleware"

View file

@ -1,6 +1,6 @@
// Code generated by ogen, DO NOT EDIT.
package gsn
package rest
// OperationName is the ogen operation name
type OperationName = string

View file

@ -1,13 +1,12 @@
// Code generated by ogen, DO NOT EDIT.
package gsn
package rest
import (
"net/http"
"time"
"github.com/go-faster/errors"
"github.com/ogen-go/ogen/conv"
"github.com/ogen-go/ogen/middleware"
"github.com/ogen-go/ogen/ogenerrors"
@ -17,21 +16,17 @@ import (
// PerformPredictionParams is parameters of performPrediction operation.
type PerformPredictionParams struct {
LaunchLatitude OptFloat64
LaunchLongitude OptFloat64
LaunchDatetime OptDateTime
LaunchAltitude OptFloat64
Profile OptPerformPredictionProfile
AscentRate OptFloat64
BurstAltitude OptFloat64
DescentRate OptFloat64
FloatAltitude OptFloat64
StopDatetime OptDateTime
AscentCurve OptString
DescentCurve OptString
Interpolate OptBool
Format OptPerformPredictionFormat
Dataset OptDateTime
LaunchLatitude float64
LaunchLongitude float64
LaunchDatetime time.Time
LaunchAltitude OptFloat64 `json:",omitempty,omitzero"`
Profile OptPerformPredictionProfile `json:",omitempty,omitzero"`
AscentRate OptFloat64 `json:",omitempty,omitzero"`
BurstAltitude OptFloat64 `json:",omitempty,omitzero"`
DescentRate OptFloat64 `json:",omitempty,omitzero"`
FloatAltitude OptFloat64 `json:",omitempty,omitzero"`
StopDatetime OptDateTime `json:",omitempty,omitzero"`
Dataset OptDateTime `json:",omitempty,omitzero"`
}
func unpackPerformPredictionParams(packed middleware.Parameters) (params PerformPredictionParams) {
@ -40,27 +35,21 @@ func unpackPerformPredictionParams(packed middleware.Parameters) (params Perform
Name: "launch_latitude",
In: "query",
}
if v, ok := packed[key]; ok {
params.LaunchLatitude = v.(OptFloat64)
}
params.LaunchLatitude = packed[key].(float64)
}
{
key := middleware.ParameterKey{
Name: "launch_longitude",
In: "query",
}
if v, ok := packed[key]; ok {
params.LaunchLongitude = v.(OptFloat64)
}
params.LaunchLongitude = packed[key].(float64)
}
{
key := middleware.ParameterKey{
Name: "launch_datetime",
In: "query",
}
if v, ok := packed[key]; ok {
params.LaunchDatetime = v.(OptDateTime)
}
params.LaunchDatetime = packed[key].(time.Time)
}
{
key := middleware.ParameterKey{
@ -125,42 +114,6 @@ func unpackPerformPredictionParams(packed middleware.Parameters) (params Perform
params.StopDatetime = v.(OptDateTime)
}
}
{
key := middleware.ParameterKey{
Name: "ascent_curve",
In: "query",
}
if v, ok := packed[key]; ok {
params.AscentCurve = v.(OptString)
}
}
{
key := middleware.ParameterKey{
Name: "descent_curve",
In: "query",
}
if v, ok := packed[key]; ok {
params.DescentCurve = v.(OptString)
}
}
{
key := middleware.ParameterKey{
Name: "interpolate",
In: "query",
}
if v, ok := packed[key]; ok {
params.Interpolate = v.(OptBool)
}
}
{
key := middleware.ParameterKey{
Name: "format",
In: "query",
}
if v, ok := packed[key]; ok {
params.Format = v.(OptPerformPredictionFormat)
}
}
{
key := middleware.ParameterKey{
Name: "dataset",
@ -185,43 +138,31 @@ func decodePerformPredictionParams(args [0]string, argsEscaped bool, r *http.Req
if err := q.HasParam(cfg); err == nil {
if err := q.DecodeParam(cfg, func(d uri.Decoder) error {
var paramsDotLaunchLatitudeVal float64
if err := func() error {
val, err := d.DecodeValue()
if err != nil {
return err
}
c, err := conv.ToFloat64(val)
if err != nil {
return err
}
paramsDotLaunchLatitudeVal = c
return nil
}(); err != nil {
val, err := d.DecodeValue()
if err != nil {
return err
}
params.LaunchLatitude.SetTo(paramsDotLaunchLatitudeVal)
c, err := conv.ToFloat64(val)
if err != nil {
return err
}
params.LaunchLatitude = c
return nil
}); err != nil {
return err
}
if err := func() error {
if value, ok := params.LaunchLatitude.Get(); ok {
if err := func() error {
if err := (validate.Float{}).Validate(float64(value)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
if err := (validate.Float{}).Validate(float64(params.LaunchLatitude)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
} else {
return err
}
return nil
}(); err != nil {
@ -241,43 +182,31 @@ func decodePerformPredictionParams(args [0]string, argsEscaped bool, r *http.Req
if err := q.HasParam(cfg); err == nil {
if err := q.DecodeParam(cfg, func(d uri.Decoder) error {
var paramsDotLaunchLongitudeVal float64
if err := func() error {
val, err := d.DecodeValue()
if err != nil {
return err
}
c, err := conv.ToFloat64(val)
if err != nil {
return err
}
paramsDotLaunchLongitudeVal = c
return nil
}(); err != nil {
val, err := d.DecodeValue()
if err != nil {
return err
}
params.LaunchLongitude.SetTo(paramsDotLaunchLongitudeVal)
c, err := conv.ToFloat64(val)
if err != nil {
return err
}
params.LaunchLongitude = c
return nil
}); err != nil {
return err
}
if err := func() error {
if value, ok := params.LaunchLongitude.Get(); ok {
if err := func() error {
if err := (validate.Float{}).Validate(float64(value)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
if err := (validate.Float{}).Validate(float64(params.LaunchLongitude)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
} else {
return err
}
return nil
}(); err != nil {
@ -297,28 +226,23 @@ func decodePerformPredictionParams(args [0]string, argsEscaped bool, r *http.Req
if err := q.HasParam(cfg); err == nil {
if err := q.DecodeParam(cfg, func(d uri.Decoder) error {
var paramsDotLaunchDatetimeVal time.Time
if err := func() error {
val, err := d.DecodeValue()
if err != nil {
return err
}
c, err := conv.ToDateTime(val)
if err != nil {
return err
}
paramsDotLaunchDatetimeVal = c
return nil
}(); err != nil {
val, err := d.DecodeValue()
if err != nil {
return err
}
params.LaunchDatetime.SetTo(paramsDotLaunchDatetimeVal)
c, err := conv.ToDateTime(val)
if err != nil {
return err
}
params.LaunchDatetime = c
return nil
}); err != nil {
return err
}
} else {
return err
}
return nil
}(); err != nil {
@ -384,6 +308,11 @@ func decodePerformPredictionParams(args [0]string, argsEscaped bool, r *http.Req
Err: err,
}
}
// Set default value for query: profile.
{
val := PerformPredictionProfile("standard_profile")
params.Profile.SetTo(val)
}
// Decode query: profile.
if err := func() error {
cfg := uri.QueryParameterDecodingConfig{
@ -705,185 +634,6 @@ func decodePerformPredictionParams(args [0]string, argsEscaped bool, r *http.Req
Err: err,
}
}
// Decode query: ascent_curve.
if err := func() error {
cfg := uri.QueryParameterDecodingConfig{
Name: "ascent_curve",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.HasParam(cfg); err == nil {
if err := q.DecodeParam(cfg, func(d uri.Decoder) error {
var paramsDotAscentCurveVal string
if err := func() error {
val, err := d.DecodeValue()
if err != nil {
return err
}
c, err := conv.ToString(val)
if err != nil {
return err
}
paramsDotAscentCurveVal = c
return nil
}(); err != nil {
return err
}
params.AscentCurve.SetTo(paramsDotAscentCurveVal)
return nil
}); err != nil {
return err
}
}
return nil
}(); err != nil {
return params, &ogenerrors.DecodeParamError{
Name: "ascent_curve",
In: "query",
Err: err,
}
}
// Decode query: descent_curve.
if err := func() error {
cfg := uri.QueryParameterDecodingConfig{
Name: "descent_curve",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.HasParam(cfg); err == nil {
if err := q.DecodeParam(cfg, func(d uri.Decoder) error {
var paramsDotDescentCurveVal string
if err := func() error {
val, err := d.DecodeValue()
if err != nil {
return err
}
c, err := conv.ToString(val)
if err != nil {
return err
}
paramsDotDescentCurveVal = c
return nil
}(); err != nil {
return err
}
params.DescentCurve.SetTo(paramsDotDescentCurveVal)
return nil
}); err != nil {
return err
}
}
return nil
}(); err != nil {
return params, &ogenerrors.DecodeParamError{
Name: "descent_curve",
In: "query",
Err: err,
}
}
// Decode query: interpolate.
if err := func() error {
cfg := uri.QueryParameterDecodingConfig{
Name: "interpolate",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.HasParam(cfg); err == nil {
if err := q.DecodeParam(cfg, func(d uri.Decoder) error {
var paramsDotInterpolateVal bool
if err := func() error {
val, err := d.DecodeValue()
if err != nil {
return err
}
c, err := conv.ToBool(val)
if err != nil {
return err
}
paramsDotInterpolateVal = c
return nil
}(); err != nil {
return err
}
params.Interpolate.SetTo(paramsDotInterpolateVal)
return nil
}); err != nil {
return err
}
}
return nil
}(); err != nil {
return params, &ogenerrors.DecodeParamError{
Name: "interpolate",
In: "query",
Err: err,
}
}
// Decode query: format.
if err := func() error {
cfg := uri.QueryParameterDecodingConfig{
Name: "format",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.HasParam(cfg); err == nil {
if err := q.DecodeParam(cfg, func(d uri.Decoder) error {
var paramsDotFormatVal PerformPredictionFormat
if err := func() error {
val, err := d.DecodeValue()
if err != nil {
return err
}
c, err := conv.ToString(val)
if err != nil {
return err
}
paramsDotFormatVal = PerformPredictionFormat(c)
return nil
}(); err != nil {
return err
}
params.Format.SetTo(paramsDotFormatVal)
return nil
}); err != nil {
return err
}
if err := func() error {
if value, ok := params.Format.Get(); ok {
if err := func() error {
if err := value.Validate(); err != nil {
return err
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
return params, &ogenerrors.DecodeParamError{
Name: "format",
In: "query",
Err: err,
}
}
// Decode query: dataset.
if err := func() error {
cfg := uri.QueryParameterDecodingConfig{

View file

@ -1,3 +1,3 @@
// Code generated by ogen, DO NOT EDIT.
package gsn
package rest

View file

@ -1,3 +1,3 @@
// Code generated by ogen, DO NOT EDIT.
package gsn
package rest

View file

@ -1,6 +1,6 @@
// Code generated by ogen, DO NOT EDIT.
package gsn
package rest
import (
"io"
@ -9,12 +9,11 @@ import (
"github.com/go-faster/errors"
"github.com/go-faster/jx"
"github.com/ogen-go/ogen/ogenerrors"
"github.com/ogen-go/ogen/validate"
)
func decodePerformPredictionResponse(resp *http.Response) (res *PredictionResult, _ error) {
func decodePerformPredictionResponse(resp *http.Response) (res *PredictionResponse, _ error) {
switch resp.StatusCode {
case 200:
// Code 200.
@ -30,7 +29,7 @@ func decodePerformPredictionResponse(resp *http.Response) (res *PredictionResult
}
d := jx.DecodeBytes(buf)
var response PredictionResult
var response PredictionResponse
if err := func() error {
if err := response.Decode(d); err != nil {
return err

View file

@ -1,19 +1,18 @@
// Code generated by ogen, DO NOT EDIT.
package gsn
package rest
import (
"net/http"
"github.com/go-faster/errors"
"github.com/go-faster/jx"
ht "github.com/ogen-go/ogen/http"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
ht "github.com/ogen-go/ogen/http"
)
func encodePerformPredictionResponse(response *PredictionResult, w http.ResponseWriter, span trace.Span) error {
func encodePerformPredictionResponse(response *PredictionResponse, w http.ResponseWriter, span trace.Span) error {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(200)
span.SetStatus(codes.Ok, http.StatusText(200))

View file

@ -1,6 +1,6 @@
// Code generated by ogen, DO NOT EDIT.
package gsn
package rest
import (
"net/http"
@ -74,7 +74,12 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case "GET":
s.handlePerformPredictionRequest([0]string{}, elemIsEscaped, w, r)
default:
s.notAllowed(w, r, "GET")
s.notAllowed(w, r, notAllowedParams{
allowedMethods: "GET",
allowedHeaders: nil,
acceptPost: "",
acceptPatch: "",
})
}
return
@ -94,7 +99,12 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case "GET":
s.handleReadinessCheckRequest([0]string{}, elemIsEscaped, w, r)
default:
s.notAllowed(w, r, "GET")
s.notAllowed(w, r, notAllowedParams{
allowedMethods: "GET",
allowedHeaders: nil,
acceptPost: "",
acceptPatch: "",
})
}
return
@ -109,12 +119,13 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Route is route object.
type Route struct {
name string
summary string
operationID string
pathPattern string
count int
args [0]string
name string
summary string
operationID string
operationGroup string
pathPattern string
count int
args [0]string
}
// Name returns ogen operation name.
@ -134,6 +145,11 @@ func (r Route) OperationID() string {
return r.operationID
}
// OperationGroup returns the x-ogen-operation-group value.
func (r Route) OperationGroup() string {
return r.operationGroup
}
// PathPattern returns OpenAPI path.
func (r Route) PathPattern() string {
return r.pathPattern
@ -209,6 +225,7 @@ func (s *Server) FindPath(method string, u *url.URL) (r Route, _ bool) {
r.name = PerformPredictionOperation
r.summary = "Perform prediction"
r.operationID = "performPrediction"
r.operationGroup = ""
r.pathPattern = "/api/v1/prediction"
r.args = args
r.count = 0
@ -233,6 +250,7 @@ func (s *Server) FindPath(method string, u *url.URL) (r Route, _ bool) {
r.name = ReadinessCheckOperation
r.summary = "Readiness check"
r.operationID = "readinessCheck"
r.operationGroup = ""
r.pathPattern = "/ready"
r.args = args
r.count = 0

View file

@ -1,12 +1,13 @@
// Code generated by ogen, DO NOT EDIT.
package gsn
package rest
import (
"fmt"
"time"
"github.com/go-faster/errors"
"github.com/go-faster/jx"
)
func (s *ErrorStatusCode) Error() string {
@ -15,28 +16,42 @@ func (s *ErrorStatusCode) Error() string {
// Ref: #/components/schemas/Error
type Error struct {
Message string `json:"message"`
Details OptString `json:"details"`
Error ErrorError `json:"error"`
}
// GetMessage returns the value of Message.
func (s *Error) GetMessage() string {
return s.Message
// GetError returns the value of Error.
func (s *Error) GetError() ErrorError {
return s.Error
}
// GetDetails returns the value of Details.
func (s *Error) GetDetails() OptString {
return s.Details
// SetError sets the value of Error.
func (s *Error) SetError(val ErrorError) {
s.Error = val
}
// SetMessage sets the value of Message.
func (s *Error) SetMessage(val string) {
s.Message = val
type ErrorError struct {
Type string `json:"type"`
Description string `json:"description"`
}
// SetDetails sets the value of Details.
func (s *Error) SetDetails(val OptString) {
s.Details = val
// GetType returns the value of Type.
func (s *ErrorError) GetType() string {
return s.Type
}
// GetDescription returns the value of Description.
func (s *ErrorError) GetDescription() string {
return s.Description
}
// SetType sets the value of Type.
func (s *ErrorError) SetType(val string) {
s.Type = val
}
// SetDescription sets the value of Description.
func (s *ErrorError) SetDescription(val string) {
s.Description = val
}
// ErrorStatusCode wraps Error with StatusCode.
@ -65,52 +80,6 @@ func (s *ErrorStatusCode) SetResponse(val Error) {
s.Response = val
}
// NewOptBool returns new OptBool with value set to v.
func NewOptBool(v bool) OptBool {
return OptBool{
Value: v,
Set: true,
}
}
// OptBool is optional bool.
type OptBool struct {
Value bool
Set bool
}
// IsSet returns true if OptBool was set.
func (o OptBool) IsSet() bool { return o.Set }
// Reset unsets value.
func (o *OptBool) Reset() {
var v bool
o.Value = v
o.Set = false
}
// SetTo sets value to v.
func (o *OptBool) SetTo(v bool) {
o.Set = true
o.Value = v
}
// Get returns value and boolean that denotes whether value was set.
func (o OptBool) Get() (v bool, ok bool) {
if !o.Set {
return v, false
}
return o.Value, true
}
// Or returns value if set, or given parameter if does not.
func (o OptBool) Or(d bool) bool {
if v, ok := o.Get(); ok {
return v
}
return d
}
// NewOptDateTime returns new OptDateTime with value set to v.
func NewOptDateTime(v time.Time) OptDateTime {
return OptDateTime{
@ -203,52 +172,6 @@ func (o OptFloat64) Or(d float64) float64 {
return d
}
// NewOptPerformPredictionFormat returns new OptPerformPredictionFormat with value set to v.
func NewOptPerformPredictionFormat(v PerformPredictionFormat) OptPerformPredictionFormat {
return OptPerformPredictionFormat{
Value: v,
Set: true,
}
}
// OptPerformPredictionFormat is optional PerformPredictionFormat.
type OptPerformPredictionFormat struct {
Value PerformPredictionFormat
Set bool
}
// IsSet returns true if OptPerformPredictionFormat was set.
func (o OptPerformPredictionFormat) IsSet() bool { return o.Set }
// Reset unsets value.
func (o *OptPerformPredictionFormat) Reset() {
var v PerformPredictionFormat
o.Value = v
o.Set = false
}
// SetTo sets value to v.
func (o *OptPerformPredictionFormat) SetTo(v PerformPredictionFormat) {
o.Set = true
o.Value = v
}
// Get returns value and boolean that denotes whether value was set.
func (o OptPerformPredictionFormat) Get() (v PerformPredictionFormat, ok bool) {
if !o.Set {
return v, false
}
return o.Value, true
}
// Or returns value if set, or given parameter if does not.
func (o OptPerformPredictionFormat) Or(d PerformPredictionFormat) PerformPredictionFormat {
if v, ok := o.Get(); ok {
return v
}
return d
}
// NewOptPerformPredictionProfile returns new OptPerformPredictionProfile with value set to v.
func NewOptPerformPredictionProfile(v PerformPredictionProfile) OptPerformPredictionProfile {
return OptPerformPredictionProfile{
@ -295,6 +218,98 @@ func (o OptPerformPredictionProfile) Or(d PerformPredictionProfile) PerformPredi
return d
}
// NewOptPredictionResponseRequest returns new OptPredictionResponseRequest with value set to v.
func NewOptPredictionResponseRequest(v PredictionResponseRequest) OptPredictionResponseRequest {
return OptPredictionResponseRequest{
Value: v,
Set: true,
}
}
// OptPredictionResponseRequest is optional PredictionResponseRequest.
type OptPredictionResponseRequest struct {
Value PredictionResponseRequest
Set bool
}
// IsSet returns true if OptPredictionResponseRequest was set.
func (o OptPredictionResponseRequest) IsSet() bool { return o.Set }
// Reset unsets value.
func (o *OptPredictionResponseRequest) Reset() {
var v PredictionResponseRequest
o.Value = v
o.Set = false
}
// SetTo sets value to v.
func (o *OptPredictionResponseRequest) SetTo(v PredictionResponseRequest) {
o.Set = true
o.Value = v
}
// Get returns value and boolean that denotes whether value was set.
func (o OptPredictionResponseRequest) Get() (v PredictionResponseRequest, ok bool) {
if !o.Set {
return v, false
}
return o.Value, true
}
// Or returns value if set, or given parameter if does not.
func (o OptPredictionResponseRequest) Or(d PredictionResponseRequest) PredictionResponseRequest {
if v, ok := o.Get(); ok {
return v
}
return d
}
// NewOptPredictionResponseWarnings returns new OptPredictionResponseWarnings with value set to v.
func NewOptPredictionResponseWarnings(v PredictionResponseWarnings) OptPredictionResponseWarnings {
return OptPredictionResponseWarnings{
Value: v,
Set: true,
}
}
// OptPredictionResponseWarnings is optional PredictionResponseWarnings.
type OptPredictionResponseWarnings struct {
Value PredictionResponseWarnings
Set bool
}
// IsSet returns true if OptPredictionResponseWarnings was set.
func (o OptPredictionResponseWarnings) IsSet() bool { return o.Set }
// Reset unsets value.
func (o *OptPredictionResponseWarnings) Reset() {
var v PredictionResponseWarnings
o.Value = v
o.Set = false
}
// SetTo sets value to v.
func (o *OptPredictionResponseWarnings) SetTo(v PredictionResponseWarnings) {
o.Set = true
o.Value = v
}
// Get returns value and boolean that denotes whether value was set.
func (o OptPredictionResponseWarnings) Get() (v PredictionResponseWarnings, ok bool) {
if !o.Set {
return v, false
}
return o.Value, true
}
// Or returns value if set, or given parameter if does not.
func (o OptPredictionResponseWarnings) Or(d PredictionResponseWarnings) PredictionResponseWarnings {
if v, ok := o.Get(); ok {
return v
}
return d
}
// NewOptString returns new OptString with value set to v.
func NewOptString(v string) OptString {
return OptString{
@ -341,47 +356,11 @@ func (o OptString) Or(d string) string {
return d
}
type PerformPredictionFormat string
const (
PerformPredictionFormatJSON PerformPredictionFormat = "json"
)
// AllValues returns all PerformPredictionFormat values.
func (PerformPredictionFormat) AllValues() []PerformPredictionFormat {
return []PerformPredictionFormat{
PerformPredictionFormatJSON,
}
}
// MarshalText implements encoding.TextMarshaler.
func (s PerformPredictionFormat) MarshalText() ([]byte, error) {
switch s {
case PerformPredictionFormatJSON:
return []byte(s), nil
default:
return nil, errors.Errorf("invalid value: %q", s)
}
}
// UnmarshalText implements encoding.TextUnmarshaler.
func (s *PerformPredictionFormat) UnmarshalText(data []byte) error {
switch PerformPredictionFormat(data) {
case PerformPredictionFormatJSON:
*s = PerformPredictionFormatJSON
return nil
default:
return errors.Errorf("invalid value: %q", data)
}
}
type PerformPredictionProfile string
const (
PerformPredictionProfileStandardProfile PerformPredictionProfile = "standard_profile"
PerformPredictionProfileFloatProfile PerformPredictionProfile = "float_profile"
PerformPredictionProfileReverseProfile PerformPredictionProfile = "reverse_profile"
PerformPredictionProfileCustomProfile PerformPredictionProfile = "custom_profile"
)
// AllValues returns all PerformPredictionProfile values.
@ -389,8 +368,6 @@ func (PerformPredictionProfile) AllValues() []PerformPredictionProfile {
return []PerformPredictionProfile{
PerformPredictionProfileStandardProfile,
PerformPredictionProfileFloatProfile,
PerformPredictionProfileReverseProfile,
PerformPredictionProfileCustomProfile,
}
}
@ -401,10 +378,6 @@ func (s PerformPredictionProfile) MarshalText() ([]byte, error) {
return []byte(s), nil
case PerformPredictionProfileFloatProfile:
return []byte(s), nil
case PerformPredictionProfileReverseProfile:
return []byte(s), nil
case PerformPredictionProfileCustomProfile:
return []byte(s), nil
default:
return nil, errors.Errorf("invalid value: %q", s)
}
@ -419,114 +392,134 @@ func (s *PerformPredictionProfile) UnmarshalText(data []byte) error {
case PerformPredictionProfileFloatProfile:
*s = PerformPredictionProfileFloatProfile
return nil
case PerformPredictionProfileReverseProfile:
*s = PerformPredictionProfileReverseProfile
return nil
case PerformPredictionProfileCustomProfile:
*s = PerformPredictionProfileCustomProfile
return nil
default:
return errors.Errorf("invalid value: %q", data)
}
}
// Ref: #/components/schemas/PredictionResult
type PredictionResult struct {
Metadata PredictionResultMetadata `json:"metadata"`
Prediction []PredictionResultPredictionItem `json:"prediction"`
// Ref: #/components/schemas/PredictionResponse
type PredictionResponse struct {
Request OptPredictionResponseRequest `json:"request"`
Prediction []PredictionResponsePredictionItem `json:"prediction"`
Metadata PredictionResponseMetadata `json:"metadata"`
Warnings OptPredictionResponseWarnings `json:"warnings"`
}
// GetMetadata returns the value of Metadata.
func (s *PredictionResult) GetMetadata() PredictionResultMetadata {
return s.Metadata
// GetRequest returns the value of Request.
func (s *PredictionResponse) GetRequest() OptPredictionResponseRequest {
return s.Request
}
// GetPrediction returns the value of Prediction.
func (s *PredictionResult) GetPrediction() []PredictionResultPredictionItem {
func (s *PredictionResponse) GetPrediction() []PredictionResponsePredictionItem {
return s.Prediction
}
// SetMetadata sets the value of Metadata.
func (s *PredictionResult) SetMetadata(val PredictionResultMetadata) {
s.Metadata = val
// GetMetadata returns the value of Metadata.
func (s *PredictionResponse) GetMetadata() PredictionResponseMetadata {
return s.Metadata
}
// GetWarnings returns the value of Warnings.
func (s *PredictionResponse) GetWarnings() OptPredictionResponseWarnings {
return s.Warnings
}
// SetRequest sets the value of Request.
func (s *PredictionResponse) SetRequest(val OptPredictionResponseRequest) {
s.Request = val
}
// SetPrediction sets the value of Prediction.
func (s *PredictionResult) SetPrediction(val []PredictionResultPredictionItem) {
func (s *PredictionResponse) SetPrediction(val []PredictionResponsePredictionItem) {
s.Prediction = val
}
type PredictionResultMetadata struct {
CompleteDatetime time.Time `json:"complete_datetime"`
StartDatetime time.Time `json:"start_datetime"`
// SetMetadata sets the value of Metadata.
func (s *PredictionResponse) SetMetadata(val PredictionResponseMetadata) {
s.Metadata = val
}
// GetCompleteDatetime returns the value of CompleteDatetime.
func (s *PredictionResultMetadata) GetCompleteDatetime() time.Time {
return s.CompleteDatetime
// SetWarnings sets the value of Warnings.
func (s *PredictionResponse) SetWarnings(val OptPredictionResponseWarnings) {
s.Warnings = val
}
type PredictionResponseMetadata struct {
StartDatetime time.Time `json:"start_datetime"`
CompleteDatetime time.Time `json:"complete_datetime"`
}
// GetStartDatetime returns the value of StartDatetime.
func (s *PredictionResultMetadata) GetStartDatetime() time.Time {
func (s *PredictionResponseMetadata) GetStartDatetime() time.Time {
return s.StartDatetime
}
// SetCompleteDatetime sets the value of CompleteDatetime.
func (s *PredictionResultMetadata) SetCompleteDatetime(val time.Time) {
s.CompleteDatetime = val
// GetCompleteDatetime returns the value of CompleteDatetime.
func (s *PredictionResponseMetadata) GetCompleteDatetime() time.Time {
return s.CompleteDatetime
}
// SetStartDatetime sets the value of StartDatetime.
func (s *PredictionResultMetadata) SetStartDatetime(val time.Time) {
func (s *PredictionResponseMetadata) SetStartDatetime(val time.Time) {
s.StartDatetime = val
}
type PredictionResultPredictionItem struct {
Stage PredictionResultPredictionItemStage `json:"stage"`
Trajectory []PredictionResultPredictionItemTrajectoryItem `json:"trajectory"`
// SetCompleteDatetime sets the value of CompleteDatetime.
func (s *PredictionResponseMetadata) SetCompleteDatetime(val time.Time) {
s.CompleteDatetime = val
}
type PredictionResponsePredictionItem struct {
Stage PredictionResponsePredictionItemStage `json:"stage"`
Trajectory []PredictionResponsePredictionItemTrajectoryItem `json:"trajectory"`
}
// GetStage returns the value of Stage.
func (s *PredictionResultPredictionItem) GetStage() PredictionResultPredictionItemStage {
func (s *PredictionResponsePredictionItem) GetStage() PredictionResponsePredictionItemStage {
return s.Stage
}
// GetTrajectory returns the value of Trajectory.
func (s *PredictionResultPredictionItem) GetTrajectory() []PredictionResultPredictionItemTrajectoryItem {
func (s *PredictionResponsePredictionItem) GetTrajectory() []PredictionResponsePredictionItemTrajectoryItem {
return s.Trajectory
}
// SetStage sets the value of Stage.
func (s *PredictionResultPredictionItem) SetStage(val PredictionResultPredictionItemStage) {
func (s *PredictionResponsePredictionItem) SetStage(val PredictionResponsePredictionItemStage) {
s.Stage = val
}
// SetTrajectory sets the value of Trajectory.
func (s *PredictionResultPredictionItem) SetTrajectory(val []PredictionResultPredictionItemTrajectoryItem) {
func (s *PredictionResponsePredictionItem) SetTrajectory(val []PredictionResponsePredictionItemTrajectoryItem) {
s.Trajectory = val
}
type PredictionResultPredictionItemStage string
type PredictionResponsePredictionItemStage string
const (
PredictionResultPredictionItemStageAscent PredictionResultPredictionItemStage = "ascent"
PredictionResultPredictionItemStageDescent PredictionResultPredictionItemStage = "descent"
PredictionResponsePredictionItemStageAscent PredictionResponsePredictionItemStage = "ascent"
PredictionResponsePredictionItemStageDescent PredictionResponsePredictionItemStage = "descent"
PredictionResponsePredictionItemStageFloat PredictionResponsePredictionItemStage = "float"
)
// AllValues returns all PredictionResultPredictionItemStage values.
func (PredictionResultPredictionItemStage) AllValues() []PredictionResultPredictionItemStage {
return []PredictionResultPredictionItemStage{
PredictionResultPredictionItemStageAscent,
PredictionResultPredictionItemStageDescent,
// AllValues returns all PredictionResponsePredictionItemStage values.
func (PredictionResponsePredictionItemStage) AllValues() []PredictionResponsePredictionItemStage {
return []PredictionResponsePredictionItemStage{
PredictionResponsePredictionItemStageAscent,
PredictionResponsePredictionItemStageDescent,
PredictionResponsePredictionItemStageFloat,
}
}
// MarshalText implements encoding.TextMarshaler.
func (s PredictionResultPredictionItemStage) MarshalText() ([]byte, error) {
func (s PredictionResponsePredictionItemStage) MarshalText() ([]byte, error) {
switch s {
case PredictionResultPredictionItemStageAscent:
case PredictionResponsePredictionItemStageAscent:
return []byte(s), nil
case PredictionResultPredictionItemStageDescent:
case PredictionResponsePredictionItemStageDescent:
return []byte(s), nil
case PredictionResponsePredictionItemStageFloat:
return []byte(s), nil
default:
return nil, errors.Errorf("invalid value: %q", s)
@ -534,20 +527,23 @@ func (s PredictionResultPredictionItemStage) MarshalText() ([]byte, error) {
}
// UnmarshalText implements encoding.TextUnmarshaler.
func (s *PredictionResultPredictionItemStage) UnmarshalText(data []byte) error {
switch PredictionResultPredictionItemStage(data) {
case PredictionResultPredictionItemStageAscent:
*s = PredictionResultPredictionItemStageAscent
func (s *PredictionResponsePredictionItemStage) UnmarshalText(data []byte) error {
switch PredictionResponsePredictionItemStage(data) {
case PredictionResponsePredictionItemStageAscent:
*s = PredictionResponsePredictionItemStageAscent
return nil
case PredictionResultPredictionItemStageDescent:
*s = PredictionResultPredictionItemStageDescent
case PredictionResponsePredictionItemStageDescent:
*s = PredictionResponsePredictionItemStageDescent
return nil
case PredictionResponsePredictionItemStageFloat:
*s = PredictionResponsePredictionItemStageFloat
return nil
default:
return errors.Errorf("invalid value: %q", data)
}
}
type PredictionResultPredictionItemTrajectoryItem struct {
type PredictionResponsePredictionItemTrajectoryItem struct {
Datetime time.Time `json:"datetime"`
Latitude float64 `json:"latitude"`
Longitude float64 `json:"longitude"`
@ -555,50 +551,162 @@ type PredictionResultPredictionItemTrajectoryItem struct {
}
// GetDatetime returns the value of Datetime.
func (s *PredictionResultPredictionItemTrajectoryItem) GetDatetime() time.Time {
func (s *PredictionResponsePredictionItemTrajectoryItem) GetDatetime() time.Time {
return s.Datetime
}
// GetLatitude returns the value of Latitude.
func (s *PredictionResultPredictionItemTrajectoryItem) GetLatitude() float64 {
func (s *PredictionResponsePredictionItemTrajectoryItem) GetLatitude() float64 {
return s.Latitude
}
// GetLongitude returns the value of Longitude.
func (s *PredictionResultPredictionItemTrajectoryItem) GetLongitude() float64 {
func (s *PredictionResponsePredictionItemTrajectoryItem) GetLongitude() float64 {
return s.Longitude
}
// GetAltitude returns the value of Altitude.
func (s *PredictionResultPredictionItemTrajectoryItem) GetAltitude() float64 {
func (s *PredictionResponsePredictionItemTrajectoryItem) GetAltitude() float64 {
return s.Altitude
}
// SetDatetime sets the value of Datetime.
func (s *PredictionResultPredictionItemTrajectoryItem) SetDatetime(val time.Time) {
func (s *PredictionResponsePredictionItemTrajectoryItem) SetDatetime(val time.Time) {
s.Datetime = val
}
// SetLatitude sets the value of Latitude.
func (s *PredictionResultPredictionItemTrajectoryItem) SetLatitude(val float64) {
func (s *PredictionResponsePredictionItemTrajectoryItem) SetLatitude(val float64) {
s.Latitude = val
}
// SetLongitude sets the value of Longitude.
func (s *PredictionResultPredictionItemTrajectoryItem) SetLongitude(val float64) {
func (s *PredictionResponsePredictionItemTrajectoryItem) SetLongitude(val float64) {
s.Longitude = val
}
// SetAltitude sets the value of Altitude.
func (s *PredictionResultPredictionItemTrajectoryItem) SetAltitude(val float64) {
func (s *PredictionResponsePredictionItemTrajectoryItem) SetAltitude(val float64) {
s.Altitude = val
}
type PredictionResponseRequest struct {
Dataset OptString `json:"dataset"`
LaunchLatitude OptFloat64 `json:"launch_latitude"`
LaunchLongitude OptFloat64 `json:"launch_longitude"`
LaunchDatetime OptString `json:"launch_datetime"`
LaunchAltitude OptFloat64 `json:"launch_altitude"`
Profile OptString `json:"profile"`
AscentRate OptFloat64 `json:"ascent_rate"`
BurstAltitude OptFloat64 `json:"burst_altitude"`
DescentRate OptFloat64 `json:"descent_rate"`
}
// GetDataset returns the value of Dataset.
func (s *PredictionResponseRequest) GetDataset() OptString {
return s.Dataset
}
// GetLaunchLatitude returns the value of LaunchLatitude.
func (s *PredictionResponseRequest) GetLaunchLatitude() OptFloat64 {
return s.LaunchLatitude
}
// GetLaunchLongitude returns the value of LaunchLongitude.
func (s *PredictionResponseRequest) GetLaunchLongitude() OptFloat64 {
return s.LaunchLongitude
}
// GetLaunchDatetime returns the value of LaunchDatetime.
func (s *PredictionResponseRequest) GetLaunchDatetime() OptString {
return s.LaunchDatetime
}
// GetLaunchAltitude returns the value of LaunchAltitude.
func (s *PredictionResponseRequest) GetLaunchAltitude() OptFloat64 {
return s.LaunchAltitude
}
// GetProfile returns the value of Profile.
func (s *PredictionResponseRequest) GetProfile() OptString {
return s.Profile
}
// GetAscentRate returns the value of AscentRate.
func (s *PredictionResponseRequest) GetAscentRate() OptFloat64 {
return s.AscentRate
}
// GetBurstAltitude returns the value of BurstAltitude.
func (s *PredictionResponseRequest) GetBurstAltitude() OptFloat64 {
return s.BurstAltitude
}
// GetDescentRate returns the value of DescentRate.
func (s *PredictionResponseRequest) GetDescentRate() OptFloat64 {
return s.DescentRate
}
// SetDataset sets the value of Dataset.
func (s *PredictionResponseRequest) SetDataset(val OptString) {
s.Dataset = val
}
// SetLaunchLatitude sets the value of LaunchLatitude.
func (s *PredictionResponseRequest) SetLaunchLatitude(val OptFloat64) {
s.LaunchLatitude = val
}
// SetLaunchLongitude sets the value of LaunchLongitude.
func (s *PredictionResponseRequest) SetLaunchLongitude(val OptFloat64) {
s.LaunchLongitude = val
}
// SetLaunchDatetime sets the value of LaunchDatetime.
func (s *PredictionResponseRequest) SetLaunchDatetime(val OptString) {
s.LaunchDatetime = val
}
// SetLaunchAltitude sets the value of LaunchAltitude.
func (s *PredictionResponseRequest) SetLaunchAltitude(val OptFloat64) {
s.LaunchAltitude = val
}
// SetProfile sets the value of Profile.
func (s *PredictionResponseRequest) SetProfile(val OptString) {
s.Profile = val
}
// SetAscentRate sets the value of AscentRate.
func (s *PredictionResponseRequest) SetAscentRate(val OptFloat64) {
s.AscentRate = val
}
// SetBurstAltitude sets the value of BurstAltitude.
func (s *PredictionResponseRequest) SetBurstAltitude(val OptFloat64) {
s.BurstAltitude = val
}
// SetDescentRate sets the value of DescentRate.
func (s *PredictionResponseRequest) SetDescentRate(val OptFloat64) {
s.DescentRate = val
}
type PredictionResponseWarnings map[string]jx.Raw
func (s *PredictionResponseWarnings) init() PredictionResponseWarnings {
m := *s
if m == nil {
m = map[string]jx.Raw{}
*s = m
}
return m
}
// Ref: #/components/schemas/ReadinessResponse
type ReadinessResponse struct {
Status ReadinessResponseStatus `json:"status"`
LastUpdate OptDateTime `json:"last_update"`
IsFresh OptBool `json:"is_fresh"`
DatasetTime OptDateTime `json:"dataset_time"`
ErrorMessage OptString `json:"error_message"`
}
@ -607,14 +715,9 @@ func (s *ReadinessResponse) GetStatus() ReadinessResponseStatus {
return s.Status
}
// GetLastUpdate returns the value of LastUpdate.
func (s *ReadinessResponse) GetLastUpdate() OptDateTime {
return s.LastUpdate
}
// GetIsFresh returns the value of IsFresh.
func (s *ReadinessResponse) GetIsFresh() OptBool {
return s.IsFresh
// GetDatasetTime returns the value of DatasetTime.
func (s *ReadinessResponse) GetDatasetTime() OptDateTime {
return s.DatasetTime
}
// GetErrorMessage returns the value of ErrorMessage.
@ -627,14 +730,9 @@ func (s *ReadinessResponse) SetStatus(val ReadinessResponseStatus) {
s.Status = val
}
// SetLastUpdate sets the value of LastUpdate.
func (s *ReadinessResponse) SetLastUpdate(val OptDateTime) {
s.LastUpdate = val
}
// SetIsFresh sets the value of IsFresh.
func (s *ReadinessResponse) SetIsFresh(val OptBool) {
s.IsFresh = val
// SetDatasetTime sets the value of DatasetTime.
func (s *ReadinessResponse) SetDatasetTime(val OptDateTime) {
s.DatasetTime = val
}
// SetErrorMessage sets the value of ErrorMessage.

View file

@ -1,6 +1,6 @@
// Code generated by ogen, DO NOT EDIT.
package gsn
package rest
import (
"context"
@ -13,7 +13,7 @@ type Handler interface {
// Perform prediction.
//
// GET /api/v1/prediction
PerformPrediction(ctx context.Context, params PerformPredictionParams) (*PredictionResult, error)
PerformPrediction(ctx context.Context, params PerformPredictionParams) (*PredictionResponse, error)
// ReadinessCheck implements readinessCheck operation.
//
// Readiness check.

View file

@ -1,6 +1,6 @@
// Code generated by ogen, DO NOT EDIT.
package gsn
package rest
import (
"context"
@ -18,7 +18,7 @@ var _ Handler = UnimplementedHandler{}
// Perform prediction.
//
// GET /api/v1/prediction
func (UnimplementedHandler) PerformPrediction(ctx context.Context, params PerformPredictionParams) (r *PredictionResult, _ error) {
func (UnimplementedHandler) PerformPrediction(ctx context.Context, params PerformPredictionParams) (r *PredictionResponse, _ error) {
return r, ht.ErrNotImplemented
}

View file

@ -1,45 +1,49 @@
// Code generated by ogen, DO NOT EDIT.
package gsn
package rest
import (
"fmt"
"github.com/go-faster/errors"
"github.com/ogen-go/ogen/validate"
)
func (s PerformPredictionFormat) Validate() error {
switch s {
case "json":
return nil
default:
return errors.Errorf("invalid value: %v", s)
}
}
func (s PerformPredictionProfile) Validate() error {
switch s {
case "standard_profile":
return nil
case "float_profile":
return nil
case "reverse_profile":
return nil
case "custom_profile":
return nil
default:
return errors.Errorf("invalid value: %v", s)
}
}
func (s *PredictionResult) Validate() error {
func (s *PredictionResponse) Validate() error {
if s == nil {
return validate.ErrNilPointer
}
var failures []validate.FieldError
if err := func() error {
if value, ok := s.Request.Get(); ok {
if err := func() error {
if err := value.Validate(); err != nil {
return err
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "request",
Error: err,
})
}
if err := func() error {
if s.Prediction == nil {
return errors.New("nil is invalid value")
@ -74,7 +78,7 @@ func (s *PredictionResult) Validate() error {
return nil
}
func (s *PredictionResultPredictionItem) Validate() error {
func (s *PredictionResponsePredictionItem) Validate() error {
if s == nil {
return validate.ErrNilPointer
}
@ -125,18 +129,20 @@ func (s *PredictionResultPredictionItem) Validate() error {
return nil
}
func (s PredictionResultPredictionItemStage) Validate() error {
func (s PredictionResponsePredictionItemStage) Validate() error {
switch s {
case "ascent":
return nil
case "descent":
return nil
case "float":
return nil
default:
return errors.Errorf("invalid value: %v", s)
}
}
func (s *PredictionResultPredictionItemTrajectoryItem) Validate() error {
func (s *PredictionResponsePredictionItemTrajectoryItem) Validate() error {
if s == nil {
return validate.ErrNilPointer
}
@ -181,6 +187,126 @@ func (s *PredictionResultPredictionItemTrajectoryItem) Validate() error {
return nil
}
func (s *PredictionResponseRequest) Validate() error {
if s == nil {
return validate.ErrNilPointer
}
var failures []validate.FieldError
if err := func() error {
if value, ok := s.LaunchLatitude.Get(); ok {
if err := func() error {
if err := (validate.Float{}).Validate(float64(value)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "launch_latitude",
Error: err,
})
}
if err := func() error {
if value, ok := s.LaunchLongitude.Get(); ok {
if err := func() error {
if err := (validate.Float{}).Validate(float64(value)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "launch_longitude",
Error: err,
})
}
if err := func() error {
if value, ok := s.LaunchAltitude.Get(); ok {
if err := func() error {
if err := (validate.Float{}).Validate(float64(value)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "launch_altitude",
Error: err,
})
}
if err := func() error {
if value, ok := s.AscentRate.Get(); ok {
if err := func() error {
if err := (validate.Float{}).Validate(float64(value)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "ascent_rate",
Error: err,
})
}
if err := func() error {
if value, ok := s.BurstAltitude.Get(); ok {
if err := func() error {
if err := (validate.Float{}).Validate(float64(value)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "burst_altitude",
Error: err,
})
}
if err := func() error {
if value, ok := s.DescentRate.Get(); ok {
if err := func() error {
if err := (validate.Float{}).Validate(float64(value)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "descent_rate",
Error: err,
})
}
if len(failures) > 0 {
return &validate.Error{Fields: failures}
}
return nil
}
func (s *ReadinessResponse) Validate() error {
if s == nil {
return validate.ErrNilPointer

View file

@ -1,20 +0,0 @@
package scheduler
import (
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
env "github.com/caarlos0/env/v11"
)
type Config struct {
Enabled bool `env:"ENABLED" envDefault:"true"`
}
func NewConfig() (*Config, error) {
cfg := &Config{}
if err := env.ParseWithOptions(cfg, env.Options{
PrefixTagName: "GSN_PREDICTOR_SCHEDULER_",
}); err != nil {
return nil, errcodes.Wrap(err, "failed to parse scheduler config")
}
return cfg, nil
}

View file

@ -1,97 +0,0 @@
package scheduler
import (
"context"
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
"git.intra.yksa.space/gsn/predictor/internal/pkg/log"
"github.com/go-co-op/gocron"
"go.uber.org/zap"
)
type Job interface {
GetInterval() time.Duration
GetTimeout() time.Duration
GetCount() int
GetAsync() bool
Execute(context.Context) error
}
type Scheduler struct {
scheduler *gocron.Scheduler
}
func New() *Scheduler {
scheduler := gocron.NewScheduler(time.UTC)
return &Scheduler{
scheduler: scheduler,
}
}
func (s *Scheduler) AddJob(job Job) error {
interval := job.GetInterval()
timeout := job.GetTimeout()
count := job.GetCount()
async := job.GetAsync()
// Validate job parameters
if !async && count != 1 {
return errcodes.ErrSchedulerInvalidJob
}
if timeout > interval {
return errcodes.ErrSchedulerTimeoutTooLong
}
// Create job function with timeout
jobFunc := func() {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
logger := log.Ctx(ctx)
if err := job.Execute(ctx); err != nil {
logger.Error("job execution failed",
zap.Error(err),
zap.Duration("interval", interval),
zap.Duration("timeout", timeout))
} else {
logger.Debug("job executed successfully",
zap.Duration("interval", interval),
zap.Duration("timeout", timeout))
}
}
// Add job to scheduler
schedulerJob := s.scheduler.Every(interval)
if !async {
schedulerJob = schedulerJob.SingletonMode()
}
if count > 0 {
schedulerJob = schedulerJob.LimitRunsTo(count)
}
schedulerJob.Do(jobFunc)
log.Ctx(context.Background()).Info("job added to scheduler",
zap.Duration("interval", interval),
zap.Duration("timeout", timeout),
zap.Int("count", count),
zap.Bool("async", async))
return nil
}
func (s *Scheduler) Start() {
s.scheduler.StartAsync()
log.Ctx(context.Background()).Info("scheduler started")
}
func (s *Scheduler) Stop() {
s.scheduler.Stop()
log.Ctx(context.Background()).Info("scheduler stopped")
}
func (s *Scheduler) IsRunning() bool {
return s.scheduler.IsRunning()
}

155
scripts/build_elevation.py Normal file
View file

@ -0,0 +1,155 @@
#!/usr/bin/env python3
"""
Download ETOPO 2022 30-arc-second elevation data and convert to ruaumoko-compatible
binary format (int16 little-endian, 21601 lat x 43200 lon, south-to-north).
Output: ~1.74 GiB binary file.
Usage: python3 build_elevation.py [output_path]
Default output: /srv/ruaumoko-dataset
"""
import sys
import os
import struct
import tempfile
import numpy as np
CELLS_PER_DEGREE = 120
NUM_LATS = 180 * CELLS_PER_DEGREE + 1 # 21601
NUM_LONS = 360 * CELLS_PER_DEGREE # 43200
EXPECTED_SIZE = NUM_LATS * NUM_LONS * 2 # 1,866,326,400
ETOPO_URL = "https://www.ngdc.noaa.gov/thredds/fileServer/global/ETOPO2022/30s/30s_surface_elev_netcdf/ETOPO_2022_v1_30s_N90W180_surface.nc"
def download_etopo(output_path):
"""Download ETOPO 2022 NetCDF and convert to ruaumoko binary format."""
try:
import xarray as xr
except ImportError:
print("ERROR: xarray is required. Install with: pip install xarray netcdf4")
sys.exit(1)
# Check if we can download directly or need a local file
nc_path = os.environ.get("ETOPO_NC_PATH")
if nc_path and os.path.exists(nc_path):
print(f"Using local ETOPO file: {nc_path}")
else:
print(f"Downloading ETOPO 2022 30-second data (~1.1 GB)...")
print(f" URL: {ETOPO_URL}")
print(f" (Set ETOPO_NC_PATH env var to use a pre-downloaded file)")
import urllib.request
with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as f:
nc_path = f.name
try:
urllib.request.urlretrieve(ETOPO_URL, nc_path, _progress)
print()
except Exception as e:
os.unlink(nc_path)
print(f"\nDownload failed: {e}")
print("\nAlternative: manually download ETOPO 2022 30s NetCDF from:")
print(" https://www.ncei.noaa.gov/products/etopo-global-relief-model")
print(f" Then set ETOPO_NC_PATH=/path/to/file.nc and re-run")
sys.exit(1)
print(f"Opening NetCDF dataset...")
ds = xr.open_dataset(nc_path)
# ETOPO 2022 30s has:
# - lat: -90 to +90, 21601 points (south to north)
# - lon: -180 to +180, 43201 points
# We need:
# - lat: -90 to +90, 21601 points (south to north) ← same
# - lon: 0 to 360 (exclusive), 43200 points ← need to shift and drop last
z = ds["z"] # elevation variable
print(f" Shape: {z.shape}")
print(f" Lat range: {float(z.lat.min())} to {float(z.lat.max())}")
print(f" Lon range: {float(z.lon.min())} to {float(z.lon.max())}")
# Sort latitude south-to-north (should already be, but ensure)
z = z.sortby("lat")
# Shift longitude from [-180, 180] to [0, 360)
print("Shifting longitude to [0, 360)...")
z = z.assign_coords(lon=(z.lon % 360))
z = z.sortby("lon")
data = z.values
print(f" Raw shape after sort: {data.shape}")
# Handle longitude dimension: drop last col if it wraps (43201 → 43200)
if data.shape[1] == NUM_LONS + 1:
data = data[:, :NUM_LONS]
elif data.shape[1] != NUM_LONS:
print(f"ERROR: unexpected lon dimension: {data.shape[1]}, expected {NUM_LONS} or {NUM_LONS+1}")
sys.exit(1)
# Handle latitude dimension: ETOPO 2022 is cell-centered (21600 rows),
# ruaumoko expects grid-registered (21601 rows including both poles).
# Pad by duplicating edge rows for the poles.
if data.shape[0] == NUM_LATS - 1:
print(f" Padding latitude from {data.shape[0]} to {NUM_LATS} (adding north pole row)")
north_pole = data[-1:, :] # duplicate +89.99... as +90
data = np.concatenate([data, north_pole], axis=0)
elif data.shape[0] != NUM_LATS:
print(f"ERROR: unexpected lat dimension: {data.shape[0]}, expected {NUM_LATS} or {NUM_LATS-1}")
sys.exit(1)
print(f"Final grid shape: {data.shape}")
print(f"Elevation range: {data.min():.1f} to {data.max():.1f} metres")
# Write as int16 little-endian
print(f"Writing to {output_path}...")
elev_int16 = np.clip(data, -32768, 32767).astype(np.dtype("<i2"))
elev_int16.tofile(output_path)
actual_size = os.path.getsize(output_path)
print(f"Written {actual_size:,} bytes (expected {EXPECTED_SIZE:,})")
if actual_size == EXPECTED_SIZE:
print("SUCCESS")
else:
print("WARNING: size mismatch!")
ds.close()
# Spot check
verify(output_path)
def verify(path):
"""Quick spot-check of the elevation dataset."""
data = np.memmap(path, dtype="<i2", mode="r", shape=(NUM_LATS, NUM_LONS))
tests = [
("Mt Everest (~28.0N, 86.9E)", 28.0, 86.9, 8000, 9000),
("Dead Sea (~31.5N, 35.5E)", 31.5, 35.5, -500, 0),
("Pacific Ocean (~0N, 180E)", 0.0, 180.0, -6000, 0),
("Auburn AU (~-34.0S, 138.7E)", -34.03, 138.69, 200, 400),
]
print("\n Spot-check:")
for name, lat, lon, lo, hi in tests:
lat_idx = int((lat + 90) * CELLS_PER_DEGREE)
lon_idx = int(lon * CELLS_PER_DEGREE)
val = int(data[lat_idx, lon_idx])
ok = "OK" if lo <= val <= hi else "FAIL"
print(f" {name}: {val}m [{ok}] (expected {lo}-{hi})")
_last_pct = -1
def _progress(block_num, block_size, total_size):
global _last_pct
if total_size > 0:
pct = int(block_num * block_size * 100 / total_size)
if pct != _last_pct and pct % 5 == 0:
_last_pct = pct
print(f" {pct}%...", end="", flush=True)
if __name__ == "__main__":
output = sys.argv[1] if len(sys.argv) > 1 else "/srv/ruaumoko-dataset"
download_etopo(output)

View file

@ -1,303 +0,0 @@
#!/usr/bin/env python3
import subprocess
import sys
import time
import requests
import json
from typing import Any
import base64
import math
# --- Config ---
REFERENCE_API_URL = "https://fly.stratonautica.ru/api/v2/?profile=standard_profile&pred_type=single&launch_datetime=2025-06-25T20%3A45%3A00Z&launch_latitude=56.6992&launch_longitude=38.8247&launch_altitude=0&ascent_rate=5&burst_altitude=30000&descent_rate=5"
LOCAL_API_URL = "http://localhost:8080/api/v1/prediction?profile=standard_profile&pred_type=single&launch_datetime=2025-06-25T20%3A45%3A00Z&launch_latitude=56.6992&launch_longitude=38.8247&launch_altitude=0&ascent_rate=5&burst_altitude=30000&descent_rate=5"
LOCAL_API_PAYLOAD = {
"launch_latitude": 56.6992,
"launch_longitude": 38.8247,
"launch_datetime": "2025-06-25T20-45-000Z",
"launch_altitude": 0,
"profile": "standard_profile",
"ascent_rate": 5,
"burst_altitude": 30000,
"descent_rate": 5,
"format": "json"
}
READY_URL = "http://localhost:8080/ready"
# --- Utility functions ---
def run_compose_up():
print("[INFO] Running docker-compose down --remove-orphans ...")
result = subprocess.run(["docker-compose", "down", "--remove-orphans"], capture_output=True)
if result.returncode != 0:
print("[ERROR] docker-compose down failed:", result.stderr.decode())
sys.exit(1)
print("[INFO] docker-compose down completed.")
print("[INFO] Running docker-compose up -d ...")
result = subprocess.run(["docker-compose", "up", "-d"], capture_output=True)
if result.returncode != 0:
print("[ERROR] docker-compose up failed:", result.stderr.decode())
sys.exit(1)
print("[INFO] docker-compose up -d completed.")
return True
def wait_for_ready(timeout=900):
print(f"[INFO] Waiting for {READY_URL} to be ready ...")
start = time.time()
while time.time() - start < timeout:
try:
resp = requests.get(READY_URL, timeout=10)
if resp.status_code == 200:
data = resp.json()
if data.get("status") == "ok":
print("[INFO] Service is ready.")
return
else:
print(f"[INFO] Not ready yet: {data}")
else:
print(f"[INFO] /ready returned status {resp.status_code}")
except Exception as e:
print(f"[INFO] Exception while polling /ready: {e}")
time.sleep(10)
print(f"[ERROR] Service did not become ready in {timeout} seconds.")
sys.exit(1)
def fetch_reference():
print(f"[INFO] Fetching reference prediction from {REFERENCE_API_URL}")
resp = requests.get(REFERENCE_API_URL, timeout=60)
if resp.status_code != 200:
print(f"[ERROR] Reference API returned {resp.status_code}: {resp.text}")
sys.exit(1)
return resp.json()
def fetch_local():
print(f"[INFO] Fetching local prediction from {LOCAL_API_URL}")
resp = requests.get(LOCAL_API_URL, timeout=60)
if resp.status_code != 200:
print(f"[ERROR] Local API returned {resp.status_code}: {resp.text}")
sys.exit(1)
return resp.json()
def haversine(lat1, lon1, lat2, lon2):
"""Calculate the great-circle distance between two points on the Earth (specified in decimal degrees). Returns distance in kilometers."""
R = 6371.0 # Earth radius in kilometers
lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2])
dlat = lat2 - lat1
dlon = lon2 - lon1
a = math.sin(dlat/2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon/2)**2
c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
return R * c
def compare_results(reference_data, local_data):
"""Compare prediction results between reference and local APIs."""
print("[INFO] Comparing results ...")
# Extract trajectory data
ref_trajectory = reference_data.get('prediction', [{}])[0].get('trajectory', [])
local_trajectory = local_data.get('prediction', [{}])[0].get('trajectory', [])
print(f"[DEBUG] Reference trajectory length: {len(ref_trajectory)}")
print(f"[DEBUG] Local trajectory length: {len(local_trajectory)}")
# Show first 3 points from both APIs
print("\n[DEBUG] First 3 points - Reference API:")
for i, point in enumerate(ref_trajectory[:3]):
print(f" [{i}] alt={point.get('altitude', 'N/A')}, lat={point.get('latitude', 'N/A')}, lon={point.get('longitude', 'N/A')}, time={point.get('datetime', 'N/A')}")
print("\n[DEBUG] First 3 points - Local API:")
for i, point in enumerate(local_trajectory[:3]):
print(f" [{i}] alt={point.get('altitude', 'N/A')}, lat={point.get('latitude', 'N/A')}, lon={point.get('longitude', 'N/A')}, time={point.get('datetime', 'N/A')}")
# Show last 3 points from both APIs
print("\n[DEBUG] Last 3 points - Reference API:")
for i, point in enumerate(ref_trajectory[-3:]):
idx = len(ref_trajectory) - 3 + i
print(f" [{idx}] alt={point.get('altitude', 'N/A')}, lat={point.get('latitude', 'N/A')}, lon={point.get('longitude', 'N/A')}, time={point.get('datetime', 'N/A')}")
print("\n[DEBUG] Last 3 points - Local API:")
for i, point in enumerate(local_trajectory[-3:]):
idx = len(local_trajectory) - 3 + i
print(f" [{idx}] alt={point.get('altitude', 'N/A')}, lat={point.get('latitude', 'N/A')}, lon={point.get('longitude', 'N/A')}, time={point.get('datetime', 'N/A')}")
# Compare trajectory lengths
if len(ref_trajectory) != len(local_trajectory):
print(f"[DIFF] Trajectory length mismatch: {len(local_trajectory)} vs {len(ref_trajectory)}")
return False
# Compare trajectory points and calculate drift
min_len = min(len(ref_trajectory), len(local_trajectory))
max_drift = 0.0
max_drift_idx = -1
drift_list = []
print("\n[DRIFT] Trajectory point-by-point distance (km):")
for i in range(min_len):
ref_point = ref_trajectory[i]
local_point = local_trajectory[i]
ref_lat = ref_point.get('latitude')
ref_lon = ref_point.get('longitude')
local_lat = local_point.get('latitude')
local_lon = local_point.get('longitude')
drift_km = None
if None not in (ref_lat, ref_lon, local_lat, local_lon):
drift_km = haversine(ref_lat, ref_lon, local_lat, local_lon)
drift_list.append(drift_km)
if drift_km > max_drift:
max_drift = drift_km
max_drift_idx = i
print(f" [{i}] Drift: {drift_km:.3f} km")
else:
print(f" [{i}] Drift: N/A (missing lat/lon)")
if drift_list:
mean_drift = sum(drift_list) / len(drift_list)
print(f"\n[DRIFT] Max drift: {max_drift:.3f} km at idx {max_drift_idx}")
print(f"[DRIFT] Mean drift: {mean_drift:.3f} km over {len(drift_list)} points")
else:
print("[DRIFT] No valid drift data to report.")
# Continue with original comparison for altitude, etc.
for i in range(min_len):
ref_point = ref_trajectory[i]
local_point = local_trajectory[i]
for key in ['altitude', 'latitude', 'longitude']:
ref_val = ref_point.get(key)
local_val = local_point.get(key)
if ref_val is not None and local_val is not None:
if abs(ref_val - local_val) > 0.1:
print(f"[DIFF] At idx {i}, key {key}: {local_val} != {ref_val}")
return False
print("[SUCCESS] Results match!")
return True
def test_custom_profile():
"""Test custom profile with base64 encoded curve."""
print("\n[TEST] Testing custom_profile...")
# Create a simple custom ascent curve (altitude vs time in seconds)
curve_data = {
"altitude": [0, 30000],
"time": [0, 6000]
}
curve_b64 = base64.b64encode(json.dumps(curve_data).encode()).decode()
# Test parameters for custom profile
params = {
"launch_latitude": 56.6992,
"launch_longitude": 38.8247,
"launch_datetime": "2025-06-25T13:28:00Z",
"launch_altitude": 0,
"profile": "custom_profile",
"ascent_curve": curve_b64
}
try:
# Test local API (use GET)
local_resp = requests.get(
"http://localhost:8080/api/v1/prediction",
params=params,
timeout=30
)
local_resp.raise_for_status()
local_data = local_resp.json()
print(f"[INFO] Custom profile test - Local API returned {len(local_data.get('prediction', [{}])[0].get('trajectory', []))} trajectory points")
return True
except Exception as e:
print(f"[ERROR] Custom profile test failed: {e}")
return False
def test_all_profiles():
"""Test all available profiles."""
profiles = [
("standard_profile", "Standard profile test"),
("float_profile", "Float profile test"),
("reverse_profile", "Reverse profile test"),
("custom_profile", "Custom profile test")
]
results = {}
for profile, description in profiles:
print(f"\n[TEST] {description}...")
if profile == "custom_profile":
success = test_custom_profile()
else:
success = test_single_profile(profile)
results[profile] = success
print(f"[RESULT] {profile}: {'PASS' if success else 'FAIL'}")
# Print summary
print("\n" + "="*50)
print("TEST SUMMARY")
print("="*50)
for profile, success in results.items():
status = "PASS" if success else "FAIL"
print(f"{profile:20} : {status}")
total_tests = len(results)
passed_tests = sum(results.values())
print(f"\nTotal tests: {total_tests}, Passed: {passed_tests}, Failed: {total_tests - passed_tests}")
return all(results.values())
def test_single_profile(profile):
"""Test a single profile against reference API."""
# Test parameters
params = {
"launch_latitude": 56.6992,
"launch_longitude": 38.8247,
"launch_datetime": "2025-06-25T13:28:00Z",
"launch_altitude": 0,
"profile": profile,
"ascent_rate": 5,
"burst_altitude": 30000,
"descent_rate": 5
}
# Add float altitude for float profile
if profile == "float_profile":
params["float_altitude"] = 25000
try:
# Test local API (use GET)
local_resp = requests.get(
"http://localhost:8080/api/v1/prediction",
params=params,
timeout=30
)
local_resp.raise_for_status()
local_data = local_resp.json()
print(f"[INFO] {profile} - Local API returned {len(local_data.get('prediction', [{}])[0].get('trajectory', []))} trajectory points")
return True
except Exception as e:
print(f"[ERROR] {profile} test failed: {e}")
return False
def main():
"""Main test function."""
print("[INFO] Starting comprehensive predictor API tests...")
# Run the original standard profile test
print("\n[TEST] Running original standard_profile test...")
run_compose_up()
wait_for_ready()
ref = fetch_reference()
local = fetch_local()
print("[INFO] Comparing results ...")
original_success = compare_results(ref, local)
if original_success:
print("[SUCCESS] Original standard_profile test passed!")
else:
print("[FAIL] Original standard_profile test failed!")
# Test all profiles
print("\n[TEST] Running all profile tests...")
all_profiles_success = test_all_profiles()
# Final result
overall_success = original_success and all_profiles_success
print(f"\n[FINAL RESULT] Overall: {'PASS' if overall_success else 'FAIL'}")
if overall_success:
sys.exit(0)
else:
sys.exit(1)
if __name__ == "__main__":
main()

View file

@ -1,89 +0,0 @@
package main
import (
"context"
"fmt"
"log"
"os"
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/grib"
)
func main() {
ctx := context.Background()
// Create S3 downloader
downloader, err := grib.NewS3Downloader(
"/tmp/grib_test",
4, // parallel downloads
"noaa-gfs-bdp-pds",
"us-east-1",
)
if err != nil {
log.Fatalf("Failed to create S3 downloader: %v", err)
}
// Ensure directory exists
if err := os.MkdirAll("/tmp/grib_test", 0o755); err != nil {
log.Fatalf("Failed to create directory: %v", err)
}
// Find nearest run (6-hour intervals: 00, 06, 12, 18 UTC)
now := time.Now().UTC()
hour := now.Hour() - (now.Hour() % 6)
// Use data from 6 hours ago to ensure it's available
run := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC).Add(-6 * time.Hour)
fmt.Printf("Testing S3 download for run: %s\n", run.Format("2006-01-02 15:04 MST"))
// List available files first
runStr := run.Format("20060102")
fmt.Printf("Listing available files for %s/%02d...\n", runStr, run.Hour())
files, err := downloader.ListAvailableFiles(ctx, runStr, run.Hour())
if err != nil {
log.Fatalf("Failed to list files: %v", err)
}
fmt.Printf("Found %d files in S3:\n", len(files))
if len(files) > 0 {
// Show first 5 files
for i, file := range files {
if i >= 5 {
fmt.Printf("... and %d more files\n", len(files)-5)
break
}
fmt.Printf(" - %s\n", file)
}
}
// Try downloading just first 3 forecast hours (f000, f001, f002)
fmt.Println("\nTesting download of first 3 forecast hours...")
testRun := run
// Create a timeout context for the download
downloadCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
defer cancel()
if err := downloader.Run(downloadCtx, testRun); err != nil {
log.Fatalf("Failed to download: %v", err)
}
fmt.Println("\nDownload completed successfully!")
// Check downloaded files
entries, err := os.ReadDir("/tmp/grib_test")
if err != nil {
log.Fatalf("Failed to read directory: %v", err)
}
fmt.Printf("\nDownloaded %d files:\n", len(entries))
for i, entry := range entries {
if i >= 10 {
fmt.Printf("... and %d more files\n", len(entries)-10)
break
}
info, _ := entry.Info()
fmt.Printf(" - %s (%.2f MB)\n", entry.Name(), float64(info.Size())/1024/1024)
}
}

View file

@ -1,68 +0,0 @@
package main
import (
"context"
"fmt"
"io"
"log"
"os"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/s3"
)
func main() {
ctx := context.Background()
// Create AWS config with anonymous credentials
cfg, err := config.LoadDefaultConfig(ctx,
config.WithRegion("us-east-1"),
config.WithCredentialsProvider(aws.AnonymousCredentials{}),
)
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
client := s3.NewFromConfig(cfg)
// Try to download a single file
bucket := "noaa-gfs-bdp-pds"
key := "gfs.20251020/00/atmos/gfs.t00z.pgrb2.0p50.f000"
fmt.Printf("Downloading: s3://%s/%s\n", bucket, key)
input := &s3.GetObjectInput{
Bucket: aws.String(bucket),
Key: aws.String(key),
}
result, err := client.GetObject(ctx, input)
if err != nil {
log.Fatalf("Failed to get object: %v", err)
}
defer result.Body.Close()
// Create output file
outFile := "/tmp/test_grib.part"
f, err := os.Create(outFile)
if err != nil {
log.Fatalf("Failed to create file: %v", err)
}
defer f.Close()
// Copy data
written, err := io.Copy(f, result.Body)
if err != nil {
log.Fatalf("Failed to copy data: %v (wrote %d bytes)", err, written)
}
fmt.Printf("Successfully downloaded %d bytes\n", written)
// Rename
if err := os.Rename(outFile, "/tmp/test_grib"); err != nil {
log.Fatalf("Failed to rename: %v", err)
}
fmt.Println("Download complete!")
}