This commit is contained in:
Anatoly Antonov 2026-05-18 03:17:17 +09:00
parent 7a8d5d13fa
commit 9e663db9dc
68 changed files with 5647 additions and 2958 deletions

View file

@ -1,12 +1,20 @@
.PHONY: build run test fmt lint clean generate-ogen help .PHONY: build server cli compare test fmt lint clean generate-ogen docs help
# Build the application # Build all binaries
build: build: server cli compare
go build -o predictor ./cmd/api
server:
go build -o bin/predictor ./cmd/predictor
cli:
go build -o bin/predictor-cli ./cmd/predictor-cli
compare:
go build -o bin/compare-tawhiri ./cmd/compare-tawhiri
# Run locally # Run locally
run: run:
go run ./cmd/api go run ./cmd/predictor
# Run tests # Run tests
test: test:
@ -20,21 +28,28 @@ fmt:
lint: lint:
golangci-lint run golangci-lint run
# Generate ogen API code from swagger spec # Build the numerics LaTeX doc (requires pdflatex)
docs:
cd docs && pdflatex numerics.tex
# Regenerate ogen API code from the OpenAPI spec
generate-ogen: generate-ogen:
go run github.com/ogen-go/ogen/cmd/ogen@latest --target pkg/rest --package rest --clean api/rest/predictor.swagger.yml go run github.com/ogen-go/ogen/cmd/ogen@latest --target pkg/rest --package rest --clean api/rest/predictor.swagger.yml
# Clean build artifacts # Clean build artifacts
clean: clean:
rm -f predictor rm -rf bin/ docs/numerics.pdf docs/numerics.aux docs/numerics.log
# Show help
help: help:
@echo "Available commands:" @echo "Available commands:"
@echo " build - Build binary" @echo " build - Build all binaries to bin/"
@echo " run - Run locally" @echo " server - Build the HTTP server (cmd/predictor)"
@echo " test - Run tests" @echo " cli - Build the CLI client (cmd/predictor-cli)"
@echo " compare - Build the validation tool (cmd/compare-tawhiri)"
@echo " run - Run the server with default config"
@echo " test - Run unit tests"
@echo " fmt - Format code" @echo " fmt - Format code"
@echo " lint - Lint code" @echo " lint - Lint code (golangci-lint)"
@echo " generate-ogen - Generate API code from swagger spec" @echo " docs - Build the numerics LaTeX doc (requires pdflatex)"
@echo " generate-ogen - Regenerate ogen code from the OpenAPI spec"
@echo " clean - Remove build artifacts" @echo " clean - Remove build artifacts"

419
README.md
View file

@ -1,261 +1,274 @@
# Balloon Trajectory Predictor # stratoflights-predictor
High-altitude balloon trajectory prediction service. Predicts ascent, burst, and descent trajectories using GFS wind forecast data from NOAA. High-altitude balloon trajectory prediction service. Forecasts ascent, descent,
and float trajectories from NOAA GFS wind data, exposed as a REST API.
The prediction algorithms are an exact port of [Tawhiri](https://github.com/cuspaceflight/tawhiri) (Cambridge University Spaceflight) to Go, verified to produce identical results. The trajectory engine is a propagator-and-constraint system: any flight
profile can be expressed as a chain of propagators (constant-rate ascent,
parachute descent, piecewise rates, wind drift) with attached constraints
(altitude, time, terrain contact). The legacy Tawhiri request shape is kept
as a compatibility endpoint so existing clients work unchanged.
## Quick Start ## Quick start
```bash ```bash
# Build # Build all three binaries (server, CLI, validation tool)
make build make build
# Run (downloads ~9 GB of GFS data on first start, takes 30-60 min) # Run the server (first start downloads ~9 GB of GFS data over 30-60 min)
PREDICTOR_DATA_DIR=/tmp/predictor-data go run ./cmd/api ./bin/predictor
# Check readiness # Check readiness
curl http://localhost:8080/ready ./bin/predictor-cli ready
# Run a prediction # Run a Tawhiri-style prediction
curl 'http://localhost:8080/api/v1/prediction?launch_latitude=52.2&launch_longitude=0.1&launch_datetime=2026-03-28T12:00:00Z&launch_altitude=0&ascent_rate=5&burst_altitude=30000&descent_rate=5' ./bin/predictor-cli predict \
launch_latitude=52.2 launch_longitude=0.1 \
launch_datetime=2026-03-28T12:00:00Z \
ascent_rate=5 burst_altitude=30000 descent_rate=5
``` ```
## Configuration ## Configuration
All configuration is via environment variables. Configuration is layered: built-in defaults, then a YAML file
(`--config path.yml` or `PREDICTOR_CONFIG_FILE=path.yml`), then env vars,
then CLI flags. Flags override env vars override file values override defaults.
| Variable | Default | Description | | Setting | Env var | CLI flag | Default |
|---|---|---| |---|---|---|---|
| `PREDICTOR_PORT` | `8080` | HTTP server port | | HTTP port | `PREDICTOR_PORT` | `-port` | `8080` |
| `PREDICTOR_DATA_DIR` | `/tmp/predictor-data` | Directory for wind datasets and temp files | | Data directory | `PREDICTOR_DATA_DIR` | `-data-dir` | `/tmp/predictor-data` |
| `PREDICTOR_DOWNLOAD_PARALLEL` | `8` | Max concurrent GRIB download goroutines | | Elevation dataset | `PREDICTOR_ELEVATION_DATASET` | `-elevation` | `/srv/ruaumoko-dataset` |
| `PREDICTOR_UPDATE_INTERVAL` | `6h` | How often to check for new forecasts | | Source | `PREDICTOR_SOURCE` | — | `noaa-gfs-0p50` |
| `PREDICTOR_DATASET_TTL` | `48h` | Max age before a dataset is considered stale | | Download parallelism | `PREDICTOR_DOWNLOAD_PARALLEL` | `-download-parallel` | `8` |
| `PREDICTOR_ELEVATION_DATASET` | `/srv/ruaumoko-dataset` | Path to elevation dataset (optional) | | Download bandwidth (bytes/s; 0 = unlimited) | `PREDICTOR_DOWNLOAD_BANDWIDTH` | `-download-bandwidth` | `0` |
| Scheduler interval | `PREDICTOR_UPDATE_INTERVAL` | `-update-interval` | `6h` |
| Dataset freshness TTL | `PREDICTOR_DATASET_TTL` | `-freshness-ttl` | `48h` |
| Metrics enabled | `PREDICTOR_METRICS_ENABLED` | `-metrics` | `true` |
| Metrics HTTP path | `PREDICTOR_METRICS_PATH` | `-metrics-path` | `/metrics` |
| Log level | `PREDICTOR_LOG_LEVEL` | `-log-level` | `info` |
## API A YAML config file mirrors the same structure:
### `GET /api/v1/prediction` ```yaml
http:
port: 8080
data:
dir: /var/lib/predictor
elevation_path: /var/lib/predictor/elevation
source: noaa-gfs-0p50
download:
parallel: 8
bandwidth_bytes_per_second: 0
update_interval: 6h
freshness_ttl: 48h
metrics:
enabled: true
path: /metrics
log:
level: info
```
Run a balloon trajectory prediction. ## REST API
**Parameters** (query string): ### Tawhiri-compatible
`GET /api/v1/prediction` — preserves the exact request and response shape of
the upstream Cambridge University Spaceflight predictor. Query parameters:
| Parameter | Required | Description | | Parameter | Required | Description |
|---|---|---| |---|---|---|
| `launch_latitude` | yes | Launch latitude in degrees (-90 to 90) | | `launch_latitude` | yes | Degrees, -90 to 90 |
| `launch_longitude` | yes | Launch longitude in degrees (-180 to 180 or 0 to 360) | | `launch_longitude` | yes | Degrees, -180 to 180 or 0 to 360 |
| `launch_datetime` | yes | Launch time in RFC 3339 format | | `launch_datetime` | yes | RFC 3339 |
| `launch_altitude` | no | Launch altitude in metres ASL (default: 0) | | `launch_altitude` | no | Metres ASL (default 0) |
| `profile` | no | `standard_profile` (default) or `float_profile` | | `profile` | no | `standard_profile` (default) or `float_profile` |
| `ascent_rate` | no | Ascent rate in m/s (default: 5) | | `ascent_rate` | no | m/s (default 5) |
| `burst_altitude` | no | Burst altitude in metres (default: 28000) | | `burst_altitude` | no | Metres (default 28000) |
| `descent_rate` | no | Sea-level descent rate in m/s (default: 5) | | `descent_rate` | no | m/s (default 5) |
| `float_altitude` | no | Float altitude in metres (float_profile only) | | `float_altitude` | no | Metres (float profile only) |
| `stop_datetime` | no | Float end time (float_profile only, default: +24h) | | `stop_datetime` | no | Float-profile end time |
**Response** (Tawhiri-compatible): `GET /ready` — returns `{"status": "ok", "dataset_time": "..."}` once a
dataset is loaded; `{"status": "not_ready", ...}` before then.
### Profile-driven (new primary)
`POST /api/v2/prediction` — accepts an arbitrary chain of propagators with
optional constraints. Useful when the frontend wants flight profiles the
Tawhiri shape can't express (e.g. piecewise rates, fallback on constraint
violation).
```json ```json
{ {
"prediction": [ "launch": {
"time": "2026-03-28T12:00:00Z",
"latitude": 52.2,
"longitude": 0.1,
"altitude": 0
},
"profile": [
{ {
"stage": "ascent", "name": "ascent",
"trajectory": [ "model": {"type": "constant_rate", "rate": 5, "include_wind": true},
{"datetime": "2026-03-28T12:00:00Z", "latitude": 52.2, "longitude": 0.1, "altitude": 0}, "constraints": [{"type": "max_altitude", "limit": 30000}]
... },
{
"name": "descent",
"model": {"type": "parachute_descent", "sea_level_rate": 5, "include_wind": true},
"constraints": [{"type": "terrain_contact"}]
}
] ]
},
{
"stage": "descent",
"trajectory": [...]
}
],
"metadata": {
"start_datetime": "...",
"complete_datetime": "..."
},
"request": {
"dataset": "2026-03-28T06:00:00Z",
"launch_latitude": 52.2,
...
}
} }
``` ```
### `GET /ready` Model types: `constant_rate`, `parachute_descent`, `piecewise`, `wind`.
Constraint types: `max_altitude`, `min_altitude`, `max_time`,
`terrain_contact`. Constraint actions: `stop` (default), `fallback`, `clip`.
Set `"direction": "reverse"` to integrate backward from a known landing.
Health check. Returns `{"status": "ok"}` when a dataset is loaded. ### Dataset admin
## Elevation Dataset ```
GET /api/v1/admin/datasets list stored epochs
Without elevation data, descent terminates at sea level (altitude <= 0). With elevation data, descent terminates at ground level, matching Tawhiri's behaviour. POST /api/v1/admin/datasets {epoch | latest} trigger a download
DELETE /api/v1/admin/datasets/{epoch} delete a stored dataset
### Building the elevation dataset GET /api/v1/admin/jobs list every job
GET /api/v1/admin/jobs/{id} fetch one job
The elevation dataset uses ETOPO 2022 at 30 arc-second resolution, converted to a ruaumoko-compatible binary format (21601 x 43200 grid of int16 little-endian elevation values in metres). DELETE /api/v1/admin/jobs/{id} cancel a running job
**Requirements**: Python 3, xarray, netcdf4, numpy.
```bash
pip install xarray netcdf4 numpy
# Downloads ~1.1 GB from NOAA, produces ~1.74 GB binary file
python3 scripts/build_elevation.py /tmp/predictor-data/ruaumoko-dataset
``` ```
To skip the download if you already have the ETOPO NetCDF file: Returns `JobInfo`:
```bash ```json
ETOPO_NC_PATH=/path/to/ETOPO_2022_v1_30s_N90W180_surface.nc \ {"id":"…","source":"noaa-gfs-0p50","epoch":"…","status":"running",
python3 scripts/build_elevation.py /tmp/predictor-data/ruaumoko-dataset "started_at":"…","total_units":130,"done_units":47,"bytes":510000000}
``` ```
The ETOPO 2022 NetCDF can be manually downloaded from: ### Metrics
https://www.ncei.noaa.gov/products/etopo-global-relief-model
### Using the elevation dataset `GET /metrics` — Prometheus text exposition. Counters:
`predictor_predictions_total{profile,status}`,
```bash `predictor_downloads_total{source,status}`,
PREDICTOR_ELEVATION_DATASET=/tmp/predictor-data/ruaumoko-dataset go run ./cmd/api `predictor_download_bytes_total{source}`,
``` and a gauge `predictor_active_dataset_epoch_seconds`.
If the file doesn't exist or can't be read, the service starts normally with a warning and falls back to sea-level termination.
## Architecture ## Architecture
``` ```
cmd/api/main.go Entry point, config, scheduler, HTTP server cmd/
predictor/main.go main server entry point
predictor-cli/main.go HTTP client
compare-tawhiri/main.go end-to-end validation against the public Tawhiri instance
internal/ internal/
dataset/ numerics/ pure numerical primitives (interp, bisect, RK4, refinement)
dataset.go Shape constants, pressure levels, S3 URLs engine/ propagator + constraint system + concrete models
file.go mmap-backed dataset file (read/write/blit) weather/ WindField interface; gfs/ — NOAA GFS file format + impl
downloader/ datasets/ Source/Storage/Manager + transactional, resumable downloads
downloader.go S3 partial GRIB download (idx + range requests) gfs/ — NOAA GFS source impl
idx.go NOAA .idx file parser elevation/ ruaumoko-format ground elevation reader
config.go Environment-based configuration config/ layered file+env+CLI config
elevation/ metrics/ Sink interface + Prometheus text impl
elevation.go Ruaumoko-compatible elevation dataset (mmap int16) api/ HTTP transport
prediction/ tawhiri/ — legacy v1 endpoint via ogen
interpolate.go 4D wind interpolation (time, lat, lon, altitude) v2/ — profile-driven endpoint
solver.go RK4 integrator with binary search termination admin/ — dataset/job admin endpoints
models.go Ascent, descent, wind models; flight profiles middleware/
warnings.go Prediction warning counters api/rest/predictor.swagger.yml OpenAPI 3 spec for v1 + /ready
service/ pkg/rest/ ogen-generated code (regenerate via `make generate-ogen`)
service.go Dataset lifecycle, concurrent-safe access docs/numerics.tex LaTeX math reference for the numerics package
transport/ scripts/build_elevation.py ETOPO 2022 → ruaumoko converter
middleware/log.go Request logging middleware
rest/
handler/handler.go ogen API handler implementation
handler/deps.go Service interface
transport.go ogen HTTP server, CORS
api/rest/predictor.swagger.yml OpenAPI 3.0 spec
pkg/rest/ Generated ogen code (17 files)
scripts/
build_elevation.py ETOPO 2022 to ruaumoko converter
``` ```
## Wind Dataset ## Deployment
The service downloads GFS 0.5-degree forecast data from NOAA S3: ### Local single instance
```bash
./bin/predictor --data-dir /var/lib/predictor
```
No external dependencies beyond the NOAA S3 mirror.
### Docker single container
```dockerfile
FROM golang:1.25 AS build
WORKDIR /src
COPY . .
RUN go build -o /predictor ./cmd/predictor
FROM gcr.io/distroless/base
COPY --from=build /predictor /predictor
EXPOSE 8080
ENTRYPOINT ["/predictor"]
```
Mount a volume at `/data` and set `PREDICTOR_DATA_DIR=/data`.
### Load-balanced cluster
The server is stateless apart from the on-disk dataset cache and in-memory
job table. For multiple replicas, point all replicas at a shared filesystem
(NFS or similar) for `data_dir`; each replica reads-only its own mmap. Active
download coordination across replicas is not implemented — run downloads on
one node, or accept that two nodes may download the same epoch concurrently
(only one Commit wins via atomic rename).
## Elevation dataset
Without elevation data, descent terminates at sea level. With elevation,
descent terminates at ground level, matching upstream Tawhiri.
```bash
pip install xarray netcdf4 numpy
python3 scripts/build_elevation.py /var/lib/predictor/elevation
```
`PREDICTOR_ELEVATION_DATASET=/var/lib/predictor/elevation ./bin/predictor`
## Numerical methods
The numerics package (`internal/numerics`) provides:
- regular-grid multilinear interpolation,
- monotone bisection,
- classical RK4 (forward and reverse time),
- binary-search refinement of a termination point.
Detailed math reference: `docs/numerics.tex`. The package has no
domain dependencies and is small enough for manual verification (~300
lines of Go), enabling a future C or Rust port without changes to the
trajectory engine.
## Wind data
| Property | Value | | Property | Value |
|---|---| |---|---|
| Source | `noaa-gfs-bdp-pds.s3.amazonaws.com` | | Source | NOAA GFS, S3 mirror (`noaa-gfs-bdp-pds.s3.amazonaws.com`) |
| Resolution | 0.5 degrees | | Resolution | 0.5° |
| Grid | 361 lat x 720 lon | | Grid | 361 × 720 (lat × lng) |
| Time steps | 65 (every 3 hours, 0-192h) | | Forecast steps | 65 (every 3 hours, 0192h) |
| Pressure levels | 47 (1000 to 1 hPa) | | Pressure levels | 47 (1000 1 hPa) |
| Variables | Geopotential height, U-wind, V-wind | | Variables | Geopotential height, U-wind, V-wind |
| Dataset size | 9,528,667,200 bytes (~8.87 GiB) | | File size | ~8.87 GiB (float32 flat binary, mmap-backed) |
| Update cadence | Every 6 hours (GFS runs at 00, 06, 12, 18 UTC) | | Update cadence | every 6 hours |
Data is downloaded using HTTP Range requests against `.idx` index files, fetching only the needed GRIB messages (HGT, UGRD, VGRD at 47 pressure levels). Full download takes 30-60 minutes depending on bandwidth. Downloads use HTTP Range requests against `.idx` index files to fetch only
the needed GRIB messages. Downloads are transactional (temp file, manifest,
atomic rename on commit) and resumable: interrupted downloads pick up where
they left off via the manifest.
The dataset is stored as a memory-mapped flat binary file of float32 values in C-order with shape `(65, 47, 3, 361, 720)`. ## Validation
## Prediction Algorithms `./bin/compare-tawhiri --server http://localhost:8080` runs an identical
prediction against the local server and against the public SondeHub Tawhiri
All algorithms are exact ports of the reference implementations in Tawhiri. The following sections describe the key components. instance, reporting the great-circle distance between landing points.
### Interpolation (`internal/prediction/interpolate.go`)
4D wind interpolation from the dataset grid to arbitrary coordinates.
1. **Trilinear weights** (`pick3`): compute 8 interpolation weights for the (hour, lat, lon) cube corners.
2. **Altitude search** (`search`): binary search on interpolated geopotential height to find the two pressure levels bracketing the target altitude.
3. **Wind extraction** (`interp4`): 8-point weighted sum at each bracket level, then linear interpolation between levels.
Reference: `tawhiri/interpolate.pyx`
### Solver (`internal/prediction/solver.go`)
4th-order Runge-Kutta integrator with dt = 60 seconds.
- State vector: (latitude, longitude, altitude) in degrees and metres.
- Time: UNIX timestamp in seconds.
- Longitude is kept in [0, 360) via Python-style modulo after each `vecadd`.
- When a terminator fires, binary search refinement (tolerance 0.01) finds the precise termination point between the last good step and the first terminated step.
- Longitude interpolation (`lngLerp`) handles the 0/360 wrap-around.
Reference: `tawhiri/solver.pyx`
### Models (`internal/prediction/models.go`)
- **Constant ascent**: vertical velocity = ascent_rate m/s.
- **Drag descent**: NASA atmosphere density model with drag coefficient = sea_level_rate * 1.1045. Descent rate increases with altitude due to thinner air.
- **Wind velocity**: u, v components from interpolation converted to degrees/second: `dlat = (180/pi) * v / (R)`, `dlng = (180/pi) * u / (R * cos(lat))` where R = 6371009 + altitude.
- **Linear model**: sum of component models (e.g., wind + ascent).
- **Elevation termination**: `ground_elevation > altitude` using ruaumoko dataset.
Reference: `tawhiri/models.py`
### Profiles
- **standard_profile**: ascent (constant rate + wind) until burst altitude, then descent (drag + wind) until ground level.
- **float_profile**: ascent to float altitude, then drift at constant altitude until stop time.
## Verification
The predictor has been verified against the reference Tawhiri implementation:
| Test | Result |
|---|---|
| Dataset (step 0): 36.6M float32 values vs Python/cfgrib | 0 mismatches, max diff = 0.0 |
| Prediction burst point vs public Tawhiri API | Identical (lat, lon, alt all match) |
| Prediction landing point vs public Tawhiri API | Identical lat/lon, 5m altitude diff (different elevation datasets) |
| Descent point count | Identical (46 points) |
| Ascent point count | Identical (101 points) |
## Development
```bash
# Regenerate ogen API code after modifying the swagger spec
make generate-ogen
# Run tests
make test
# Format
make fmt
```
### Comparison tools
```bash
# Compare single dataset step against Python/cfgrib reference
go run ./cmd/compare_step0 <run_YYYYMMDDHH> <output_path>
# Run prediction and compare against public Tawhiri API
go run ./cmd/compare_prediction
```
## References ## References
- [Tawhiri](https://github.com/cuspaceflight/tawhiri) — Reference Python/Cython predictor (Cambridge University Spaceflight) - [Tawhiri](https://github.com/cuspaceflight/tawhiri) — reference Python/Cython predictor
- [tawhiri-downloader](https://github.com/cuspaceflight/tawhiri-downloader) — OCaml dataset downloader - [ruaumoko](https://github.com/cuspaceflight/ruaumoko) — global elevation dataset format
- [ruaumoko](https://github.com/cuspaceflight/ruaumoko) — Global elevation dataset - [NOAA GFS](https://www.ncei.noaa.gov/products/weather-climate-models/global-forecast)
- [NOAA GFS](https://www.ncei.noaa.gov/products/weather-climate-models/global-forecast) — Global Forecast System - [ETOPO 2022](https://www.ncei.noaa.gov/products/etopo-global-relief-model)
- [NOAA GFS on S3](https://noaa-gfs-bdp-pds.s3.amazonaws.com/index.html) — Public S3 bucket - [SondeHub Tawhiri API](https://api.v2.sondehub.org/tawhiri) — public Tawhiri instance
- [ETOPO 2022](https://www.ncei.noaa.gov/products/etopo-global-relief-model) — Global relief model for elevation data
- [SondeHub Tawhiri API](https://api.v2.sondehub.org/tawhiri) — Public Tawhiri instance for comparison

View file

@ -1,98 +0,0 @@
package main
import (
"context"
"fmt"
"os"
"os/signal"
"syscall"
"time"
"predictor-refactored/internal/downloader"
"predictor-refactored/internal/service"
"predictor-refactored/internal/transport/rest"
"predictor-refactored/internal/transport/rest/handler"
"github.com/go-co-op/gocron"
"go.uber.org/zap"
)
func main() {
log, err := zap.NewProduction()
if err != nil {
panic(err)
}
defer log.Sync()
cfg := downloader.LoadConfig()
log.Info("configuration loaded",
zap.String("data_dir", cfg.DataDir),
zap.Int("parallel", cfg.Parallel),
zap.Duration("update_interval", cfg.UpdateInterval),
zap.Duration("dataset_ttl", cfg.DatasetTTL))
if err := os.MkdirAll(cfg.DataDir, 0o755); err != nil {
log.Fatal("failed to create data directory", zap.Error(err))
}
svc := service.New(cfg, log)
defer svc.Close()
// Load elevation dataset (optional — falls back to sea-level termination)
elevPath := "/srv/ruaumoko-dataset"
if v := os.Getenv("PREDICTOR_ELEVATION_DATASET"); v != "" {
elevPath = v
}
svc.LoadElevation(elevPath)
// Initial dataset load (async so the server starts immediately)
go func() {
log.Info("performing initial dataset update...")
if err := svc.Update(context.Background()); err != nil {
log.Error("initial dataset update failed", zap.Error(err))
} else {
log.Info("initial dataset update complete")
}
}()
// Scheduler for periodic dataset updates
scheduler := gocron.NewScheduler(time.UTC)
scheduler.Every(cfg.UpdateInterval).Do(func() {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
defer cancel()
log.Info("scheduled dataset update starting")
if err := svc.Update(ctx); err != nil {
log.Error("scheduled dataset update failed", zap.Error(err))
} else {
log.Info("scheduled dataset update complete")
}
})
scheduler.StartAsync()
defer scheduler.Stop()
// HTTP transport (ogen)
port := 8080
if p := os.Getenv("PREDICTOR_PORT"); p != "" {
fmt.Sscanf(p, "%d", &port)
}
h := handler.New(svc, log)
transport, err := rest.New(h, port, log)
if err != nil {
log.Fatal("failed to create transport", zap.Error(err))
}
go func() {
if err := transport.Run(); err != nil {
log.Fatal("HTTP server error", zap.Error(err))
}
}()
log.Info("service started")
// Graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
sig := <-sigChan
log.Info("received shutdown signal", zap.String("signal", sig.String()))
}

153
cmd/compare-tawhiri/main.go Normal file
View file

@ -0,0 +1,153 @@
// Command compare-tawhiri runs the same prediction against a local predictor
// instance and against the public SondeHub Tawhiri instance, reporting the
// distance between the two predicted landing points.
//
// Intended use:
//
// ./compare-tawhiri --server http://localhost:8080
package main
import (
"encoding/json"
"flag"
"fmt"
"io"
"math"
"net/http"
"net/url"
"os"
"time"
)
const tawhiriPublicURL = "https://api.v2.sondehub.org/tawhiri"
func main() {
server := flag.String("server", "http://localhost:8080", "local predictor server URL")
lat := flag.Float64("lat", 52.2135, "launch latitude")
lng := flag.Float64("lng", 0.0964, "launch longitude")
alt := flag.Float64("alt", 0, "launch altitude")
rate := flag.Float64("ascent-rate", 5, "ascent rate m/s")
burst := flag.Float64("burst", 30000, "burst altitude m")
descent := flag.Float64("descent-rate", 5, "descent rate m/s")
launch := flag.String("launch", "", "launch time RFC3339; default: 3 hours after the active dataset epoch")
flag.Parse()
// Discover the active dataset epoch from /ready.
epoch, err := fetchActiveEpoch(*server)
if err != nil {
fmt.Fprintln(os.Stderr, "ready:", err)
os.Exit(1)
}
launchTime := epoch.Add(3 * time.Hour)
if *launch != "" {
t, err := time.Parse(time.RFC3339, *launch)
if err != nil {
fmt.Fprintln(os.Stderr, "invalid launch time:", err)
os.Exit(1)
}
launchTime = t
}
ourLat, ourLng, err := runPrediction(*server+"/api/v1/prediction", *lat, *lng, *alt, launchTime, *rate, *burst, *descent)
if err != nil {
fmt.Fprintln(os.Stderr, "local prediction:", err)
os.Exit(1)
}
fmt.Printf("local landing: lat=%.4f, lng=%.4f\n", ourLat, ourLng)
tawLat, tawLng, err := runPrediction(tawhiriPublicURL, *lat, *lng, *alt, launchTime, *rate, *burst, *descent)
if err != nil {
fmt.Fprintln(os.Stderr, "tawhiri prediction:", err)
os.Exit(1)
}
fmt.Printf("tawhiri landing: lat=%.4f, lng=%.4f\n", tawLat, tawLng)
d := haversine(ourLat, ourLng, tawLat, tawLng)
fmt.Printf("distance: %.2f km\n", d/1000)
switch {
case d < 1000:
fmt.Println("MATCH (< 1 km)")
case d < 50000:
fmt.Printf("MODERATE (%.1f km) — likely different forecast runs\n", d/1000)
default:
fmt.Printf("LARGE (%.1f km) — investigate\n", d/1000)
}
}
type readinessResp struct {
Status string `json:"status"`
DatasetTime string `json:"dataset_time"`
}
func fetchActiveEpoch(base string) (time.Time, error) {
resp, err := http.Get(base + "/ready")
if err != nil {
return time.Time{}, err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return time.Time{}, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
var r readinessResp
if err := json.Unmarshal(body, &r); err != nil {
return time.Time{}, err
}
if r.Status != "ok" {
return time.Time{}, fmt.Errorf("server status %q", r.Status)
}
return time.Parse(time.RFC3339, r.DatasetTime)
}
func runPrediction(endpoint string, lat, lng, alt float64, launch time.Time, rate, burst, descent float64) (float64, float64, error) {
q := url.Values{}
q.Set("launch_latitude", fmt.Sprintf("%.4f", lat))
q.Set("launch_longitude", fmt.Sprintf("%.4f", lng))
q.Set("launch_altitude", fmt.Sprintf("%.0f", alt))
q.Set("launch_datetime", launch.Format(time.RFC3339))
q.Set("ascent_rate", fmt.Sprintf("%.1f", rate))
q.Set("burst_altitude", fmt.Sprintf("%.0f", burst))
q.Set("descent_rate", fmt.Sprintf("%.1f", descent))
resp, err := http.Get(endpoint + "?" + q.Encode())
if err != nil {
return 0, 0, err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return 0, 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
var result struct {
Prediction []struct {
Stage string `json:"stage"`
Trajectory []struct {
Latitude float64 `json:"latitude"`
Longitude float64 `json:"longitude"`
} `json:"trajectory"`
} `json:"prediction"`
}
if err := json.Unmarshal(body, &result); err != nil {
return 0, 0, err
}
for _, stage := range result.Prediction {
if stage.Stage == "descent" && len(stage.Trajectory) > 0 {
last := stage.Trajectory[len(stage.Trajectory)-1]
return last.Latitude, last.Longitude, nil
}
}
return 0, 0, fmt.Errorf("no descent stage in response")
}
func haversine(lat1, lng1, lat2, lng2 float64) float64 {
const R = 6371000.0
phi1 := lat1 * math.Pi / 180
phi2 := lat2 * math.Pi / 180
dphi := (lat2 - lat1) * math.Pi / 180
dlam := (lng2 - lng1) * math.Pi / 180
a := math.Sin(dphi/2)*math.Sin(dphi/2) + math.Cos(phi1)*math.Cos(phi2)*math.Sin(dlam/2)*math.Sin(dlam/2)
return R * 2 * math.Atan2(math.Sqrt(a), math.Sqrt(1-a))
}

View file

@ -1,195 +0,0 @@
package main
import (
"context"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"os"
"time"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/downloader"
"predictor-refactored/internal/prediction"
"go.uber.org/zap"
)
// Downloads a few forecast steps and runs a prediction, then compares
// against the public Tawhiri API.
func main() {
log, _ := zap.NewDevelopment()
cfg := &downloader.Config{
DataDir: os.TempDir(),
Parallel: 4,
}
dl := downloader.NewDownloader(cfg, log)
ctx := context.Background()
// Find latest run
run, err := dl.FindLatestRun(ctx)
if err != nil {
fmt.Fprintf(os.Stderr, "FindLatestRun: %v\n", err)
os.Exit(1)
}
fmt.Printf("Using run: %s\n", run.Format("2006010215"))
// Create dataset and download first 10 steps (0-27 hours, enough for a prediction)
dsPath := fmt.Sprintf("/tmp/pred_test_%s.bin", run.Format("2006010215"))
defer os.Remove(dsPath)
ds, err := dataset.Create(dsPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Create: %v\n", err)
os.Exit(1)
}
date := run.Format("20060102")
runHour := run.Hour()
stepsToDownload := []int{0, 3, 6, 9, 12, 15, 18, 21, 24, 27}
fmt.Printf("Downloading %d steps...\n", len(stepsToDownload))
for _, step := range stepsToDownload {
hourIdx := dataset.HourIndex(step)
fmt.Printf(" step %d (hour idx %d)...\n", step, hourIdx)
urlA := dataset.GribURL(date, runHour, step)
if err := dl.DownloadAndBlit(ctx, ds, urlA, hourIdx, dataset.LevelSetA); err != nil {
fmt.Fprintf(os.Stderr, " pgrb2 step %d: %v\n", step, err)
os.Exit(1)
}
urlB := dataset.GribURLB(date, runHour, step)
if err := dl.DownloadAndBlit(ctx, ds, urlB, hourIdx, dataset.LevelSetB); err != nil {
fmt.Fprintf(os.Stderr, " pgrb2b step %d: %v\n", step, err)
os.Exit(1)
}
}
ds.Flush()
fmt.Println("Download complete")
// Set dataset time
ds.DSTime = run
// Run our prediction
launchLat := 52.2135
launchLon := 0.0964 // already in [0, 360)
launchAlt := 0.0
ascentRate := 5.0
burstAlt := 30000.0
descentRate := 5.0
// Launch 3 hours into the forecast
launchTime := run.Add(3 * time.Hour)
launchTimestamp := float64(launchTime.Unix())
dsEpoch := float64(run.Unix())
warnings := &prediction.Warnings{}
stages := prediction.StandardProfile(ascentRate, burstAlt, descentRate, ds, dsEpoch, warnings, nil)
results := prediction.RunPrediction(launchTimestamp, launchLat, launchLon, launchAlt, stages)
fmt.Printf("\n=== Our prediction ===\n")
for i, sr := range results {
stage := "ascent"
if i == 1 {
stage = "descent"
}
first := sr.Points[0]
last := sr.Points[len(sr.Points)-1]
fmt.Printf(" %s: %d points, start=(%.4f, %.4f, %.0fm) end=(%.4f, %.4f, %.0fm)\n",
stage, len(sr.Points),
first.Lat, first.Lng, first.Alt,
last.Lat, last.Lng, last.Alt)
}
// Get landing point
var ourLandLat, ourLandLon float64
if len(results) >= 2 {
last := results[1].Points[len(results[1].Points)-1]
ourLandLat = last.Lat
ourLandLon = last.Lng
if ourLandLon > 180 {
ourLandLon -= 360
}
}
fmt.Printf(" Landing: lat=%.4f, lon=%.4f\n", ourLandLat, ourLandLon)
// Compare against public Tawhiri API
fmt.Printf("\n=== Tawhiri API comparison ===\n")
tawhiriLandLat, tawhiriLandLon, err := queryTawhiri(launchLat, launchLon, launchAlt, launchTime, ascentRate, burstAlt, descentRate)
if err != nil {
fmt.Printf(" Tawhiri API error: %v\n", err)
fmt.Println(" (Cannot compare — Tawhiri may use a different dataset)")
ds.Close()
return
}
fmt.Printf(" Tawhiri landing: lat=%.4f, lon=%.4f\n", tawhiriLandLat, tawhiriLandLon)
dist := haversine(ourLandLat, ourLandLon, tawhiriLandLat, tawhiriLandLon)
fmt.Printf(" Distance between landing points: %.2f km\n", dist/1000)
if dist < 1000 {
fmt.Println(" CLOSE MATCH (< 1 km)")
} else if dist < 50000 {
fmt.Printf(" MODERATE DIFFERENCE (%.1f km) — likely different datasets\n", dist/1000)
} else {
fmt.Printf(" LARGE DIFFERENCE (%.1f km) — possible bug\n", dist/1000)
}
ds.Close()
}
func queryTawhiri(lat, lon, alt float64, launchTime time.Time, ascentRate, burstAlt, descentRate float64) (landLat, landLon float64, err error) {
url := fmt.Sprintf(
"https://api.v2.sondehub.org/tawhiri?launch_latitude=%.4f&launch_longitude=%.4f&launch_altitude=%.0f&launch_datetime=%s&ascent_rate=%.1f&burst_altitude=%.0f&descent_rate=%.1f",
lat, lon, alt, launchTime.Format(time.RFC3339), ascentRate, burstAlt, descentRate)
resp, err := http.Get(url)
if err != nil {
return 0, 0, err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return 0, 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
var result struct {
Prediction []struct {
Stage string `json:"stage"`
Trajectory []struct {
Latitude float64 `json:"latitude"`
Longitude float64 `json:"longitude"`
Altitude float64 `json:"altitude"`
} `json:"trajectory"`
} `json:"prediction"`
}
if err := json.Unmarshal(body, &result); err != nil {
return 0, 0, err
}
for _, stage := range result.Prediction {
if stage.Stage == "descent" && len(stage.Trajectory) > 0 {
last := stage.Trajectory[len(stage.Trajectory)-1]
return last.Latitude, last.Longitude, nil
}
}
return 0, 0, fmt.Errorf("no descent stage found")
}
func haversine(lat1, lon1, lat2, lon2 float64) float64 {
const R = 6371000.0
phi1 := lat1 * math.Pi / 180
phi2 := lat2 * math.Pi / 180
dphi := (lat2 - lat1) * math.Pi / 180
dlam := (lon2 - lon1) * math.Pi / 180
a := math.Sin(dphi/2)*math.Sin(dphi/2) + math.Cos(phi1)*math.Cos(phi2)*math.Sin(dlam/2)*math.Sin(dlam/2)
return R * 2 * math.Atan2(math.Sqrt(a), math.Sqrt(1-a))
}

View file

@ -1,104 +0,0 @@
package main
import (
"context"
"fmt"
"os"
"time"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/downloader"
"go.uber.org/zap"
)
// Downloads step 0 of a given run and writes a minimal dataset for comparison.
// Usage: go run ./cmd/compare_step0 <run_YYYYMMDDHH> <output_path>
func main() {
if len(os.Args) < 3 {
fmt.Fprintf(os.Stderr, "Usage: %s <run_YYYYMMDDHH> <output_path>\n", os.Args[0])
os.Exit(1)
}
runStr := os.Args[1]
outPath := os.Args[2]
run, err := time.Parse("2006010215", runStr)
if err != nil {
fmt.Fprintf(os.Stderr, "Invalid run time %q: %v\n", runStr, err)
os.Exit(1)
}
log, _ := zap.NewDevelopment()
// Create a full-size dataset (we only fill step 0)
fmt.Printf("Creating dataset at %s (%d bytes)...\n", outPath, dataset.DatasetSize)
ds, err := dataset.Create(outPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Create dataset: %v\n", err)
os.Exit(1)
}
defer ds.Close()
cfg := &downloader.Config{
DataDir: os.TempDir(),
Parallel: 4,
}
dl := downloader.NewDownloader(cfg, log)
ctx := context.Background()
date := run.Format("20060102")
runHour := run.Hour()
// Download and blit step 0 from pgrb2
fmt.Println("Downloading pgrb2 step 0...")
urlA := dataset.GribURL(date, runHour, 0)
if err := dl.DownloadAndBlit(ctx, ds, urlA, 0, dataset.LevelSetA); err != nil {
fmt.Fprintf(os.Stderr, "pgrb2: %v\n", err)
os.Exit(1)
}
fmt.Println(" done")
// Download and blit step 0 from pgrb2b
fmt.Println("Downloading pgrb2b step 0...")
urlB := dataset.GribURLB(date, runHour, 0)
if err := dl.DownloadAndBlit(ctx, ds, urlB, 0, dataset.LevelSetB); err != nil {
fmt.Fprintf(os.Stderr, "pgrb2b: %v\n", err)
os.Exit(1)
}
fmt.Println(" done")
if err := ds.Flush(); err != nil {
fmt.Fprintf(os.Stderr, "Flush: %v\n", err)
os.Exit(1)
}
// Spot-check: print same values as the Python script
fmt.Println("\n=== Go dataset values (spot check) ===")
type testPoint struct {
varName string
varIdx int
levelIdx int
lat, lon int
}
points := []testPoint{
{"HGT", 0, 0, 0, 0}, // HGT @ 1000mb, lat=-90, lon=0
{"HGT", 0, 0, 180, 0}, // HGT @ 1000mb, lat=0, lon=0
{"HGT", 0, 0, 360, 0}, // HGT @ 1000mb, lat=+90, lon=0
{"HGT", 0, 20, 180, 360}, // HGT @ 500mb, lat=0, lon=180
{"UGRD", 1, 0, 180, 0}, // UGRD @ 1000mb, lat=0, lon=0
{"VGRD", 2, 0, 180, 0}, // VGRD @ 1000mb, lat=0, lon=0
{"UGRD", 1, 20, 284, 0}, // UGRD @ 500mb, lat=52N, lon=0
}
for _, p := range points {
val := ds.Val(0, p.levelIdx, p.varIdx, p.lat, p.lon)
actualLat := -90.0 + float64(p.lat)*0.5
actualLon := float64(p.lon) * 0.5
fmt.Printf(" %-4s %4dmb lat=%+7.1f lon=%6.1f: %12.4f\n",
p.varName, dataset.Pressures[p.levelIdx], actualLat, actualLon, val)
}
fmt.Printf("\nDataset written to %s\n", outPath)
}

216
cmd/predictor-cli/main.go Normal file
View file

@ -0,0 +1,216 @@
// Command predictor-cli is a small HTTP client for stratoflights-predictor.
//
// It is intended for operations and development; production callers should
// use the REST API directly.
package main
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"time"
)
const usage = `predictor-cli HTTP client for stratoflights-predictor
USAGE
predictor-cli [--server URL] <command> [args...]
COMMANDS
ready Check service health
predict <KEY=VAL>... Run a Tawhiri-compat prediction (key=value pairs)
datasets list List stored dataset epochs
datasets download [--latest|--epoch RFC3339]
Trigger a dataset download
datasets delete <epoch> Delete a stored dataset
jobs list List download jobs
jobs get <id> Show one job
jobs cancel <id> Cancel a running job
ENVIRONMENT
PREDICTOR_SERVER Default --server (overridden by the flag)
`
func main() {
fs := flag.NewFlagSet("predictor-cli", flag.ContinueOnError)
fs.Usage = func() { fmt.Fprint(os.Stderr, usage) }
server := fs.String("server", envDefault("PREDICTOR_SERVER", "http://localhost:8080"), "predictor server URL")
if err := fs.Parse(os.Args[1:]); err != nil {
os.Exit(2)
}
args := fs.Args()
if len(args) == 0 {
fs.Usage()
os.Exit(2)
}
c := &client{base: strings.TrimRight(*server, "/"), http: &http.Client{Timeout: 30 * time.Second}}
if err := dispatch(c, args); err != nil {
fmt.Fprintln(os.Stderr, "error:", err)
os.Exit(1)
}
}
func envDefault(name, fallback string) string {
if v := os.Getenv(name); v != "" {
return v
}
return fallback
}
func dispatch(c *client, args []string) error {
switch args[0] {
case "ready":
return c.ready()
case "predict":
return c.predict(args[1:])
case "datasets":
if len(args) < 2 {
return fmt.Errorf("usage: datasets {list|download|delete}")
}
switch args[1] {
case "list":
return c.datasetsList()
case "download":
return c.datasetsDownload(args[2:])
case "delete":
if len(args) < 3 {
return fmt.Errorf("usage: datasets delete <epoch>")
}
return c.datasetsDelete(args[2])
}
case "jobs":
if len(args) < 2 {
return fmt.Errorf("usage: jobs {list|get|cancel}")
}
switch args[1] {
case "list":
return c.jobsList()
case "get":
if len(args) < 3 {
return fmt.Errorf("usage: jobs get <id>")
}
return c.jobsGet(args[2])
case "cancel":
if len(args) < 3 {
return fmt.Errorf("usage: jobs cancel <id>")
}
return c.jobsCancel(args[2])
}
}
return fmt.Errorf("unknown command %q", args[0])
}
type client struct {
base string
http *http.Client
}
func (c *client) ready() error {
return c.getPrint("/ready")
}
func (c *client) predict(kv []string) error {
q := url.Values{}
for _, p := range kv {
idx := strings.IndexByte(p, '=')
if idx <= 0 {
return fmt.Errorf("expected key=value, got %q", p)
}
q.Set(p[:idx], p[idx+1:])
}
return c.getPrint("/api/v1/prediction?" + q.Encode())
}
func (c *client) datasetsList() error {
return c.getPrint("/api/v1/admin/datasets")
}
func (c *client) datasetsDownload(args []string) error {
fs := flag.NewFlagSet("datasets download", flag.ContinueOnError)
latest := fs.Bool("latest", false, "download the latest available run")
epoch := fs.String("epoch", "", "RFC3339 epoch to download")
if err := fs.Parse(args); err != nil {
return err
}
body := map[string]any{}
if *latest {
body["latest"] = true
}
if *epoch != "" {
body["epoch"] = *epoch
}
return c.postPrint("/api/v1/admin/datasets", body)
}
func (c *client) datasetsDelete(epoch string) error {
return c.deletePrint("/api/v1/admin/datasets/" + url.PathEscape(epoch))
}
func (c *client) jobsList() error { return c.getPrint("/api/v1/admin/jobs") }
func (c *client) jobsGet(id string) error {
return c.getPrint("/api/v1/admin/jobs/" + url.PathEscape(id))
}
func (c *client) jobsCancel(id string) error {
return c.deletePrint("/api/v1/admin/jobs/" + url.PathEscape(id))
}
func (c *client) getPrint(path string) error {
resp, err := c.http.Get(c.base + path)
if err != nil {
return err
}
return printResp(resp)
}
func (c *client) postPrint(path string, body any) error {
buf, err := json.Marshal(body)
if err != nil {
return err
}
resp, err := c.http.Post(c.base+path, "application/json", bytes.NewReader(buf))
if err != nil {
return err
}
return printResp(resp)
}
func (c *client) deletePrint(path string) error {
req, err := http.NewRequest(http.MethodDelete, c.base+path, nil)
if err != nil {
return err
}
resp, err := c.http.Do(req)
if err != nil {
return err
}
return printResp(resp)
}
func printResp(resp *http.Response) error {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode >= 400 {
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
// Pretty-print JSON when possible; raw bytes otherwise.
if strings.Contains(resp.Header.Get("Content-Type"), "json") && len(body) > 0 {
var any any
if err := json.Unmarshal(body, &any); err == nil {
pretty, _ := json.MarshalIndent(any, "", " ")
fmt.Println(string(pretty))
return nil
}
}
if len(body) > 0 {
fmt.Println(strings.TrimSpace(string(body)))
}
return nil
}

181
cmd/predictor/main.go Normal file
View file

@ -0,0 +1,181 @@
// Command predictor is the stratoflights-predictor HTTP server.
//
// It wires the configuration, dataset manager, scheduler, and API layer
// into a single process and exits cleanly on SIGINT/SIGTERM.
package main
import (
"context"
"errors"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/go-co-op/gocron"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"predictor-refactored/internal/api"
"predictor-refactored/internal/config"
"predictor-refactored/internal/datasets"
"predictor-refactored/internal/datasets/gfs"
"predictor-refactored/internal/elevation"
"predictor-refactored/internal/metrics"
)
func main() {
if err := run(os.Args[1:]); err != nil {
fmt.Fprintln(os.Stderr, "fatal:", err)
os.Exit(1)
}
}
func run(args []string) error {
cfg, err := config.Load(args)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
log, err := newLogger(cfg.Log.Level)
if err != nil {
return fmt.Errorf("init logger: %w", err)
}
defer log.Sync()
log.Info("configuration loaded",
zap.Int("port", cfg.HTTP.Port),
zap.String("data_dir", cfg.Data.Dir),
zap.String("source", cfg.Data.Source),
zap.Int("download_parallel", cfg.Download.Parallel),
zap.Duration("update_interval", cfg.Download.UpdateInterval),
zap.Duration("freshness_ttl", cfg.Download.FreshnessTTL),
zap.Bool("metrics_enabled", cfg.Metrics.Enabled),
)
store, err := datasets.NewLocalStore(cfg.Data.Dir, cfg.Data.Source)
if err != nil {
return fmt.Errorf("init store: %w", err)
}
// Source is GFS today; the spec leaves room for ECMWF later via the
// same datasets.Source interface.
if cfg.Data.Source != "noaa-gfs-0p50" {
return fmt.Errorf("source %q not supported", cfg.Data.Source)
}
source := gfs.NewSource(log)
source.Parallel = cfg.Download.Parallel
var throttle datasets.Throttle
if cfg.Download.BandwidthBytesPerSecond > 0 {
throttle = datasets.NewTokenBucket(cfg.Download.BandwidthBytesPerSecond)
}
// Metrics (optional).
var sink metrics.Sink = metrics.Noop()
var metricsHandler http.Handler
if cfg.Metrics.Enabled {
prom := metrics.NewProm()
sink = prom
metricsHandler = prom
}
mgr := datasets.New(source, store, throttle, log)
defer mgr.Close()
// Optional elevation dataset. Missing or unreadable elevation is logged
// but non-fatal; descent terminates at sea level instead.
var elev *elevation.Dataset
if cfg.Data.ElevationPath != "" {
if d, err := elevation.Open(cfg.Data.ElevationPath); err == nil {
elev = d
log.Info("elevation dataset loaded", zap.String("path", cfg.Data.ElevationPath))
defer elev.Close()
} else {
log.Warn("elevation dataset not available, using sea-level termination",
zap.String("path", cfg.Data.ElevationPath),
zap.Error(err))
}
}
// Kick off the initial refresh in the background so the server can start
// answering /ready immediately.
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
defer cancel()
if _, err := mgr.Refresh(ctx, cfg.Download.FreshnessTTL); err != nil {
log.Error("initial dataset refresh failed", zap.Error(err))
}
if a := mgr.Active(); a != nil {
sink.ActiveEpoch(a.Epoch())
}
}()
scheduler := gocron.NewScheduler(time.UTC)
scheduler.Every(cfg.Download.UpdateInterval).Do(func() {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
defer cancel()
log.Info("scheduled dataset refresh starting")
if _, err := mgr.Refresh(ctx, cfg.Download.FreshnessTTL); err != nil {
log.Error("scheduled dataset refresh failed", zap.Error(err))
}
if a := mgr.Active(); a != nil {
sink.ActiveEpoch(a.Epoch())
}
})
scheduler.StartAsync()
defer scheduler.Stop()
server, err := api.New(cfg.HTTP.Port, api.Deps{
Manager: mgr,
Elevation: elev,
Metrics: sink,
MetricsHandler: metricsHandler,
MetricsPath: cfg.Metrics.Path,
Log: log,
})
if err != nil {
return fmt.Errorf("init server: %w", err)
}
// Graceful shutdown
ctx, cancel := signalContext()
defer cancel()
log.Info("service started")
if err := server.Run(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("http server: %w", err)
}
log.Info("service stopped")
return nil
}
func signalContext() (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(context.Background())
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigCh
cancel()
}()
return ctx, cancel
}
func newLogger(level string) (*zap.Logger, error) {
cfg := zap.NewProductionConfig()
switch level {
case "debug":
cfg.Level = zap.NewAtomicLevelAt(zapcore.DebugLevel)
case "info":
cfg.Level = zap.NewAtomicLevelAt(zapcore.InfoLevel)
case "warn":
cfg.Level = zap.NewAtomicLevelAt(zapcore.WarnLevel)
case "error":
cfg.Level = zap.NewAtomicLevelAt(zapcore.ErrorLevel)
default:
cfg.Level = zap.NewAtomicLevelAt(zapcore.InfoLevel)
}
return cfg.Build()
}

160
docs/numerics.tex Normal file
View file

@ -0,0 +1,160 @@
\documentclass[a4paper,11pt]{article}
\usepackage[margin=1in]{geometry}
\usepackage{amsmath, amssymb}
\usepackage{algorithm, algpseudocode}
\usepackage{hyperref}
\title{Numerics Library: Mathematical Reference}
\author{stratoflights-predictor}
\date{}
\begin{document}
\maketitle
This document describes every numerical primitive in the
\verb|internal/numerics| package. Each section pairs the mathematical
definition with a pointer to the Go implementation and at least one
worked example that can be reproduced manually.
\section{Regular-grid bracketing}
\paragraph{Definition.} An \emph{axis} is the regularly-spaced sequence
$x_i = \ell + i \cdot s$ for $i = 0, 1, \ldots, N - 1$, parameterised by
the left edge $\ell$, the step $s > 0$, and the point count $N$.
Given a query $v$, the \emph{bracket} is the pair $(i_0, i_1)$ with
$x_{i_0} \le v < x_{i_1}$ and the dimensionless position
\[
f = \frac{v - x_{i_0}}{s} \in [0, 1).
\]
Implemented as \verb|Axis.Locate| (\verb|internal/numerics/grid.go|).
\paragraph{Wrapping axes.} For periodic axes (e.g.\ longitude), the
sequence is extended by the convention $x_N = x_0$ so a value approaching
$x_N$ from below brackets $(N{-}1, 0)$ with fraction
$f = (v - x_{N-1})/s$.
\paragraph{Domain.} The bracket is undefined when $v$ falls outside the
half-open interval $[\ell, \ell + (N{-}1)\,s)$ (for non-wrapping axes) or
$[\ell, \ell + N\,s)$ (for wrapping axes); the implementation returns
an \verb|AxisError| in those cases.
\paragraph{Worked example.} Latitude axis $\ell = -90$, $s = 0{.}5$,
$N = 361$. Query $v = -89{.}75$ yields
$p = (-89{.}75 - (-90))/0{.}5 = 0{.}5$, so $i_0 = 0$, $i_1 = 1$, $f = 0{.}5$.
\section{Multilinear interpolation}
\paragraph{Definition.} For a scalar field $u$ defined at the grid nodes
of three axes, the trilinear interpolant at brackets $b_a, b_b, b_c$ is
\[
\tilde u = \sum_{i, j, k \in \{0, 1\}} w_{a,i} \, w_{b,j} \, w_{c,k}
\; u\bigl(b_a^i, b_b^j, b_c^k\bigr),
\]
where $w_{\bullet, 0} = 1 - f_\bullet$ and $w_{\bullet, 1} = f_\bullet$.
Implemented as \verb|EvalTrilinear|.
\paragraph{Linear exactness.} For any affine field
$u(i, j, k) = \alpha i + \beta j + \gamma k + \delta$, the formula returns
$\alpha \cdot p_a + \beta \cdot p_b + \gamma \cdot p_c + \delta$ exactly
(modulo floating-point rounding), where $p_\bullet = b_\bullet^0 + f_\bullet$.
\paragraph{Evaluation order.} The eight corner terms are accumulated in
the order $(0,0,0), (0,0,1), \ldots, (1,1,1)$, matching the reference
Tawhiri implementation \emph{exactly} so that double-precision results
agree bit-for-bit.
\section{Monotone bisection}
\paragraph{Definition.} For an integer-indexed monotone non-decreasing
sequence $f : \{i_{\min}, \ldots, i_{\max}\} \to \mathbb{R}$ and a target
$t$, $\mathrm{Bisect}$ returns the largest index $i^\star$ with
$f(i^\star) < t$. The implementation evaluates $f$ on a midpoint
$m = \lceil(i_{\min} + i_{\max})/2\rceil$ each iteration and halves the
interval, taking $\mathcal{O}(\log(i_{\max} - i_{\min}))$ evaluations.
\paragraph{Boundary behaviour.} If $t \le f(i_{\min})$, the function
returns $i_{\min}$; if $t > f(i_{\max})$, it returns $i_{\max}$.
\paragraph{Usage in this codebase.} The pressure-level search in the GFS
wind field locates the largest level whose interpolated geopotential
height is below the query altitude; vertical interpolation then runs
between that level and its successor.
\section{Classical Runge--Kutta--4 integrator}
\paragraph{Definition.} For a state $y$, derivative $\dot y = f(t, y)$,
and step $\Delta t$, \verb|RK4Step| applies the classical RK4 update
\[
\begin{aligned}
k_1 &= f(t, y), \\
k_2 &= f\bigl(t + \tfrac{\Delta t}{2}, \; y + \tfrac{\Delta t}{2} k_1\bigr), \\
k_3 &= f\bigl(t + \tfrac{\Delta t}{2}, \; y + \tfrac{\Delta t}{2} k_2\bigr), \\
k_4 &= f\bigl(t + \Delta t, \; y + \Delta t \cdot k_3\bigr), \\
y(t + \Delta t) &= y + \tfrac{\Delta t}{6}\bigl(k_1 + 2 k_2 + 2 k_3 + k_4\bigr).
\end{aligned}
\]
\paragraph{Reverse-time integration.} Passing $\Delta t < 0$ integrates
backwards in time. The derivative $f$ is treated as direction-independent:
all sign accounting lives in the integrator. The implementation contains
no explicit branch on the sign of $\Delta t$.
\paragraph{Vector state.} \verb|RK4Step| is generic on the state type.
Domain-specific vector arithmetic (in particular longitude wrap on the
$0\!:\!360$ degree circle) is injected via the \verb|VecAdd| operation
$\mathrm{add}(y, k, \delta y) = y + k \cdot \delta y$.
\section{Termination-point refinement}
After each integration step the propagator checks one or more
constraints. When a constraint reports a violation between $(t_1, y_1)$
(not violated) and $(t_2, y_2)$ (violated), \verb|RefineTrigger|
locates the crossing within tolerance $\tau \in (0, 1)$ by binary
search in the linear interpolation parameter $\lambda$:
\begin{algorithm}[H]
\caption{RefineTrigger}\label{alg:refine}
\begin{algorithmic}[1]
\State $L \gets 0,\; R \gets 1$
\State $t_3 \gets t_2,\; y_3 \gets y_2$
\While{$R - L > \tau$}
\State $m \gets (L + R)/2$
\State $t_3 \gets (1 - m)\,t_1 + m\,t_2$
\State $y_3 \gets \mathrm{lerp}(y_1, y_2, m)$
\If{constraint violated at $(t_3, y_3)$}
\State $R \gets m$
\Else
\State $L \gets m$
\EndIf
\EndWhile
\State \Return $(t_3, y_3)$
\end{algorithmic}
\end{algorithm}
\paragraph{Termination guarantee.} After $\lceil \log_2 \tau^{-1} \rceil$
iterations, $R - L \le \tau$. With $\tau = 0{.}01$ and $\Delta t = 60$~s,
the returned point is within $0{.}6$~s of the true crossing in parameter
space; the corresponding altitude error is bounded by $0{.}6\,|\dot y|$,
which for typical balloon ascent and parachute descent rates is at most
$\sim 3$~m.
\paragraph{Quirk.} The returned $(t_3, y_3)$ is the \emph{last midpoint
sampled} rather than guaranteed to lie on the triggered side; this
matches the reference Tawhiri implementation byte-for-byte.
\paragraph{Vector lerp.} As with \verb|RK4Step|, the per-coordinate
linear interpolation is delegated to the caller's \verb|VecLerp| to keep
the integrator agnostic of state semantics. The engine package's
\verb|lerpState| applies the shorter-arc convention for longitudes
crossing the $0\!:\!360$ boundary.
\section{Implementation notes}
The library is intentionally small (under 300 lines of Go) and uses no
runtime allocations on the hot path. The type-generic \verb|RK4Step| and
\verb|RefineTrigger| compile to per-type specialisations under Go's
generics, so a future C or Rust port can mirror the implementation
verbatim without changing the call sites in the trajectory engine.
\end{document}

4
go.mod
View file

@ -7,6 +7,7 @@ require (
github.com/go-co-op/gocron v1.37.0 github.com/go-co-op/gocron v1.37.0
github.com/go-faster/errors v0.7.1 github.com/go-faster/errors v0.7.1
github.com/go-faster/jx v1.2.0 github.com/go-faster/jx v1.2.0
github.com/google/uuid v1.6.0
github.com/nilsmagnus/grib v1.2.8 github.com/nilsmagnus/grib v1.2.8
github.com/ogen-go/ogen v1.20.2 github.com/ogen-go/ogen v1.20.2
go.opentelemetry.io/otel v1.42.0 go.opentelemetry.io/otel v1.42.0
@ -14,6 +15,7 @@ require (
go.opentelemetry.io/otel/trace v1.42.0 go.opentelemetry.io/otel/trace v1.42.0
go.uber.org/zap v1.27.1 go.uber.org/zap v1.27.1
golang.org/x/sync v0.20.0 golang.org/x/sync v0.20.0
gopkg.in/yaml.v2 v2.4.0
) )
require ( require (
@ -24,7 +26,6 @@ require (
github.com/go-faster/yaml v0.4.6 // indirect github.com/go-faster/yaml v0.4.6 // indirect
github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect github.com/go-logr/stdr v1.2.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-isatty v0.0.20 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect
@ -37,5 +38,4 @@ require (
golang.org/x/net v0.52.0 // indirect golang.org/x/net v0.52.0 // indirect
golang.org/x/sys v0.42.0 // indirect golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.35.0 // indirect golang.org/x/text v0.35.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
) )

View file

@ -0,0 +1,206 @@
// Package admin implements dataset-management HTTP endpoints used by the
// stratoflights operator console.
//
// Endpoints:
//
// GET /api/v1/admin/datasets list stored epochs
// POST /api/v1/admin/datasets trigger a download
// DELETE /api/v1/admin/datasets/{epoch} delete a stored epoch
// GET /api/v1/admin/jobs list all jobs
// GET /api/v1/admin/jobs/{id} fetch one job
// DELETE /api/v1/admin/jobs/{id} cancel a running job
package admin
import (
"context"
"encoding/json"
"net/http"
"time"
"go.uber.org/zap"
"predictor-refactored/internal/datasets"
)
// Handler serves all /api/v1/admin/* endpoints.
type Handler struct {
mgr *datasets.Manager
log *zap.Logger
}
// New wires an admin handler.
func New(mgr *datasets.Manager, log *zap.Logger) *Handler {
if log == nil {
log = zap.NewNop()
}
return &Handler{mgr: mgr, log: log}
}
// Register installs admin routes on mux. Routes are mounted under
// /api/v1/admin/...
func (h *Handler) Register(mux *http.ServeMux) {
mux.HandleFunc("GET /api/v1/admin/datasets", h.listDatasets)
mux.HandleFunc("POST /api/v1/admin/datasets", h.triggerDownload)
mux.HandleFunc("DELETE /api/v1/admin/datasets/{epoch}", h.deleteDataset)
mux.HandleFunc("GET /api/v1/admin/jobs", h.listJobs)
mux.HandleFunc("GET /api/v1/admin/jobs/{id}", h.getJob)
mux.HandleFunc("DELETE /api/v1/admin/jobs/{id}", h.cancelJob)
}
// listDatasets handles GET /api/v1/admin/datasets.
func (h *Handler) listDatasets(w http.ResponseWriter, _ *http.Request) {
epochs, err := h.mgr.ListEpochs()
if err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
active := ""
if a := h.mgr.Active(); a != nil {
active = a.Epoch().UTC().Format(time.RFC3339)
}
out := struct {
Source string `json:"source"`
Active string `json:"active,omitempty"`
Epochs []string `json:"epochs"`
}{
Source: h.mgr.Source(),
Active: active,
}
for _, e := range epochs {
out.Epochs = append(out.Epochs, e.UTC().Format(time.RFC3339))
}
writeJSON(w, http.StatusOK, out)
}
// triggerDownload handles POST /api/v1/admin/datasets.
//
// Body: {"epoch": "2026-03-28T06:00:00Z"} OR {"latest": true}.
func (h *Handler) triggerDownload(w http.ResponseWriter, r *http.Request) {
var body struct {
Epoch string `json:"epoch,omitempty"`
Latest bool `json:"latest,omitempty"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeError(w, http.StatusBadRequest, "invalid body: "+err.Error())
return
}
if !body.Latest && body.Epoch == "" {
writeError(w, http.StatusBadRequest, "specify either epoch or latest=true")
return
}
var epoch time.Time
if body.Latest {
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
defer cancel()
jobID, err := h.mgr.Refresh(ctx, 0)
if err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
writeJSON(w, http.StatusAccepted, map[string]string{"job_id": jobID})
return
}
var err error
epoch, err = time.Parse(time.RFC3339, body.Epoch)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid epoch: "+err.Error())
return
}
jobID := h.mgr.Download(epoch)
writeJSON(w, http.StatusAccepted, map[string]string{"job_id": jobID})
}
// deleteDataset handles DELETE /api/v1/admin/datasets/{epoch}.
func (h *Handler) deleteDataset(w http.ResponseWriter, r *http.Request) {
rawEpoch := r.PathValue("epoch")
epoch, err := time.Parse(time.RFC3339, rawEpoch)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid epoch: "+err.Error())
return
}
if err := h.mgr.RemoveEpoch(epoch); err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
w.WriteHeader(http.StatusNoContent)
}
// listJobs handles GET /api/v1/admin/jobs.
func (h *Handler) listJobs(w http.ResponseWriter, _ *http.Request) {
jobs := h.mgr.ListJobs()
out := make([]jobDTO, 0, len(jobs))
for _, j := range jobs {
out = append(out, toDTO(j))
}
writeJSON(w, http.StatusOK, out)
}
// getJob handles GET /api/v1/admin/jobs/{id}.
func (h *Handler) getJob(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
job, ok := h.mgr.GetJob(id)
if !ok {
writeError(w, http.StatusNotFound, "job not found")
return
}
writeJSON(w, http.StatusOK, toDTO(job))
}
// cancelJob handles DELETE /api/v1/admin/jobs/{id}.
func (h *Handler) cancelJob(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
if !h.mgr.CancelJob(id) {
writeError(w, http.StatusConflict, "job not found or already terminal")
return
}
w.WriteHeader(http.StatusNoContent)
}
type jobDTO struct {
ID string `json:"id"`
Source string `json:"source"`
Epoch string `json:"epoch"`
Status string `json:"status"`
StartedAt string `json:"started_at"`
EndedAt string `json:"ended_at,omitempty"`
Err string `json:"error,omitempty"`
Total int `json:"total_units"`
Done int `json:"done_units"`
Bytes int64 `json:"bytes"`
}
func toDTO(j datasets.JobInfo) jobDTO {
dto := jobDTO{
ID: j.ID,
Source: j.Source,
Epoch: j.Epoch.UTC().Format(time.RFC3339),
Status: string(j.Status),
StartedAt: j.StartedAt.UTC().Format(time.RFC3339),
Err: j.Err,
Total: j.Total,
Done: j.Done,
Bytes: j.Bytes,
}
if j.EndedAt != nil {
dto.EndedAt = j.EndedAt.UTC().Format(time.RFC3339)
}
return dto
}
func writeJSON(w http.ResponseWriter, status int, body any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(body)
}
func writeError(w http.ResponseWriter, status int, description string) {
writeJSON(w, status, map[string]any{
"error": map[string]string{
"type": http.StatusText(status),
"description": description,
},
})
}

View file

@ -0,0 +1,20 @@
package middleware
import "net/http"
// CORS wraps next with permissive CORS headers and short-circuits OPTIONS preflight.
//
// This service is meant to sit behind an authenticated gateway, so we set
// "Access-Control-Allow-Origin: *". Tighten this if you deploy elsewhere.
func CORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}

View file

@ -0,0 +1,51 @@
// Package middleware contains HTTP and ogen middleware used by the API layer.
package middleware
import (
"net/http"
"time"
"github.com/ogen-go/ogen/middleware"
"go.uber.org/zap"
)
// OgenLogging is an ogen middleware that logs request duration and outcome.
func OgenLogging(log *zap.Logger) middleware.Middleware {
return func(req middleware.Request, next func(req middleware.Request) (middleware.Response, error)) (middleware.Response, error) {
lg := log.With(zap.String("op", req.OperationID))
start := time.Now()
resp, err := next(req)
dur := time.Since(start)
if err != nil {
lg.Error("request failed", zap.Duration("duration", dur), zap.Error(err))
} else {
lg.Info("request completed", zap.Duration("duration", dur))
}
return resp, err
}
}
// statusRecorder captures the response status for HTTPLogging.
type statusRecorder struct {
http.ResponseWriter
status int
}
func (r *statusRecorder) WriteHeader(code int) {
r.status = code
r.ResponseWriter.WriteHeader(code)
}
// HTTPLogging wraps the given http.Handler with a per-request log line.
func HTTPLogging(log *zap.Logger, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
rec := &statusRecorder{ResponseWriter: w, status: 200}
next.ServeHTTP(rec, r)
log.Info("http",
zap.String("method", r.Method),
zap.String("path", r.URL.Path),
zap.Int("status", rec.status),
zap.Duration("duration", time.Since(start)))
})
}

View file

@ -0,0 +1,252 @@
// Package tawhiri implements the legacy Tawhiri-compatible HTTP endpoint
// (GET /api/v1/prediction). The request/response shapes match the original
// Cambridge University Spaceflight predictor for drop-in compatibility.
//
// Internally the handler builds an engine.Profile from query parameters and
// dispatches it through the same engine path as the new v2 endpoint.
package tawhiri
import (
"context"
"errors"
"net/http"
"time"
"go.uber.org/zap"
"predictor-refactored/internal/datasets"
"predictor-refactored/internal/elevation"
"predictor-refactored/internal/engine"
"predictor-refactored/internal/metrics"
api "predictor-refactored/pkg/rest"
)
// Handler implements api.Handler (the ogen-generated interface for
// performPrediction and readinessCheck).
type Handler struct {
mgr *datasets.Manager
elev *elevation.Dataset
metrics metrics.Sink
log *zap.Logger
}
// New wires a Handler.
func New(mgr *datasets.Manager, elev *elevation.Dataset, sink metrics.Sink, log *zap.Logger) *Handler {
if log == nil {
log = zap.NewNop()
}
if sink == nil {
sink = metrics.Noop()
}
return &Handler{mgr: mgr, elev: elev, metrics: sink, log: log}
}
// Compile-time check that Handler satisfies api.Handler.
var _ api.Handler = (*Handler)(nil)
// PerformPrediction runs the Tawhiri-style prediction.
func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredictionParams) (*api.PredictionResponse, error) {
field := h.mgr.Active()
if field == nil {
return nil, newError(http.StatusServiceUnavailable, "no dataset loaded, service is starting up")
}
// Parameters with Tawhiri defaults.
profileKind := "standard_profile"
if v, ok := params.Profile.Get(); ok {
profileKind = string(v)
}
ascentRate := 5.0
if v, ok := params.AscentRate.Get(); ok {
ascentRate = v
}
burstAltitude := 28000.0
if v, ok := params.BurstAltitude.Get(); ok {
burstAltitude = v
}
descentRate := 5.0
if v, ok := params.DescentRate.Get(); ok {
descentRate = v
}
launchAlt := 0.0
if v, ok := params.LaunchAltitude.Get(); ok {
launchAlt = v
}
lng := params.LaunchLongitude
if lng < 0 {
lng += 360
}
launchTime := float64(params.LaunchDatetime.Unix())
warnings := &engine.Warnings{}
// Build the profile.
var stageNames []string
var prof engine.Profile
switch profileKind {
case "standard_profile":
stageNames = []string{"ascent", "descent"}
prof = engine.Profile{
Direction: engine.Forward,
Stages: []*engine.Propagator{
{
Name: "ascent",
Step: 60,
Model: engine.Sum(
engine.ConstantRate(ascentRate),
engine.WindTransport(field, warnings),
),
Constraints: []engine.Constraint{engine.MaxAltitude{Limit: burstAltitude, On: engine.ActionStop}},
},
{
Name: "descent",
Step: 60,
Model: engine.Sum(
engine.ParachuteDescent(descentRate),
engine.WindTransport(field, warnings),
),
Constraints: descentConstraints(h.elev),
},
},
}
case "float_profile":
floatAlt := 25000.0
if v, ok := params.FloatAltitude.Get(); ok {
floatAlt = v
}
stopTime := params.LaunchDatetime.Add(24 * time.Hour)
if v, ok := params.StopDatetime.Get(); ok {
stopTime = v
}
stageNames = []string{"ascent", "float"}
prof = engine.Profile{
Direction: engine.Forward,
Stages: []*engine.Propagator{
{
Name: "ascent",
Step: 60,
Model: engine.Sum(
engine.ConstantRate(ascentRate),
engine.WindTransport(field, warnings),
),
Constraints: []engine.Constraint{engine.MaxAltitude{Limit: floatAlt, On: engine.ActionStop}},
},
{
Name: "float",
Step: 60,
Model: engine.WindTransport(field, warnings),
Constraints: []engine.Constraint{engine.MaxTime{Limit: float64(stopTime.Unix()), On: engine.ActionStop}},
},
},
}
default:
return nil, newError(http.StatusBadRequest, "unknown profile: "+profileKind)
}
started := time.Now().UTC()
results := prof.Run(launchTime, engine.State{Lat: params.LaunchLatitude, Lng: lng, Altitude: launchAlt})
completed := time.Now().UTC()
h.metrics.Prediction(profileKind, completed.Sub(started), nil)
resp := &api.PredictionResponse{
Metadata: api.PredictionResponseMetadata{
StartDatetime: started,
CompleteDatetime: completed,
},
}
for i, r := range results {
stageName := "ascent"
if i < len(stageNames) {
stageName = stageNames[i]
}
stageEnum := api.PredictionResponsePredictionItemStageAscent
switch stageName {
case "descent":
stageEnum = api.PredictionResponsePredictionItemStageDescent
case "float":
stageEnum = api.PredictionResponsePredictionItemStageFloat
}
traj := make([]api.PredictionResponsePredictionItemTrajectoryItem, 0, len(r.Points))
for _, pt := range r.Points {
ptLng := pt.Lng
if ptLng > 180 {
ptLng -= 360
}
traj = append(traj, api.PredictionResponsePredictionItemTrajectoryItem{
Datetime: time.Unix(int64(pt.Time), 0).UTC(),
Latitude: pt.Lat,
Longitude: ptLng,
Altitude: pt.Altitude,
})
}
resp.Prediction = append(resp.Prediction, api.PredictionResponsePredictionItem{
Stage: stageEnum,
Trajectory: traj,
})
}
resp.Request = api.NewOptPredictionResponseRequest(api.PredictionResponseRequest{
Dataset: api.NewOptString(field.Epoch().Format("2006-01-02T15:04:05Z")),
LaunchLatitude: api.NewOptFloat64(params.LaunchLatitude),
LaunchLongitude: api.NewOptFloat64(params.LaunchLongitude),
LaunchDatetime: api.NewOptString(params.LaunchDatetime.Format(time.RFC3339)),
LaunchAltitude: params.LaunchAltitude,
})
if warns := warnings.ToMap(); len(warns) > 0 {
resp.Warnings = api.NewOptPredictionResponseWarnings(api.PredictionResponseWarnings{})
}
h.log.Info("prediction complete",
zap.String("profile", profileKind),
zap.Int("stages", len(results)),
zap.Duration("elapsed", completed.Sub(started)))
return resp, nil
}
// descentConstraints returns the descent termination set: TerrainContact if an
// elevation dataset is loaded, MinAltitude(0) otherwise.
func descentConstraints(elev *elevation.Dataset) []engine.Constraint {
if elev != nil {
return []engine.Constraint{engine.TerrainContact{Provider: elev, On: engine.ActionStop}}
}
return []engine.Constraint{engine.MinAltitude{Limit: 0, On: engine.ActionStop}}
}
// ReadinessCheck reports whether a dataset is currently loaded.
func (h *Handler) ReadinessCheck(_ context.Context) (*api.ReadinessResponse, error) {
resp := &api.ReadinessResponse{}
if field := h.mgr.Active(); field != nil {
resp.Status = api.ReadinessResponseStatusOk
resp.DatasetTime = api.NewOptDateTime(field.Epoch())
} else {
resp.Status = api.ReadinessResponseStatusNotReady
resp.ErrorMessage = api.NewOptString("no dataset loaded")
}
return resp, nil
}
// NewError implements the ogen Handler interface for unhandled errors.
func (h *Handler) NewError(_ context.Context, err error) *api.ErrorStatusCode {
var statusErr *api.ErrorStatusCode
if errors.As(err, &statusErr) {
return statusErr
}
h.log.Error("unhandled error", zap.Error(err))
return newError(http.StatusInternalServerError, err.Error())
}
func newError(status int, description string) *api.ErrorStatusCode {
return &api.ErrorStatusCode{
StatusCode: status,
Response: api.Error{
Error: api.ErrorError{
Type: http.StatusText(status),
Description: description,
},
},
}
}

109
internal/api/transport.go Normal file
View file

@ -0,0 +1,109 @@
// Package api wires together every HTTP-facing component of the service:
//
// - Tawhiri-compatible v1 endpoints generated from the OpenAPI spec (ogen);
// - The new v2 prediction endpoint;
// - Dataset and job admin endpoints under /api/v1/admin/;
// - Optional Prometheus-format metrics endpoint.
package api
import (
"context"
"fmt"
"net/http"
"time"
"go.uber.org/zap"
"predictor-refactored/internal/api/admin"
"predictor-refactored/internal/api/middleware"
"predictor-refactored/internal/api/tawhiri"
v2 "predictor-refactored/internal/api/v2"
"predictor-refactored/internal/datasets"
"predictor-refactored/internal/elevation"
"predictor-refactored/internal/metrics"
apirest "predictor-refactored/pkg/rest"
)
// Server is the top-level HTTP server.
type Server struct {
port int
mux *http.ServeMux
log *zap.Logger
}
// Deps are the runtime dependencies the API layer needs.
type Deps struct {
Manager *datasets.Manager
Elevation *elevation.Dataset
Metrics metrics.Sink
MetricsHandler http.Handler // optional; mounted at MetricsPath when non-nil
MetricsPath string
Log *zap.Logger
}
// New wires the HTTP server. The returned Server is not yet started.
func New(port int, d Deps) (*Server, error) {
if d.Log == nil {
d.Log = zap.NewNop()
}
if d.Metrics == nil {
d.Metrics = metrics.Noop()
}
mux := http.NewServeMux()
// ogen-generated server handles the Tawhiri-compat surface
// (GET /api/v1/prediction and GET /ready).
tw := tawhiri.New(d.Manager, d.Elevation, d.Metrics, d.Log)
ogenSrv, err := apirest.NewServer(tw, apirest.WithMiddleware(middleware.OgenLogging(d.Log)))
if err != nil {
return nil, fmt.Errorf("create ogen server: %w", err)
}
// New primary prediction endpoint.
v2h := v2.New(d.Manager, d.Elevation, d.Metrics, d.Log)
mux.Handle("/api/v2/prediction", v2h)
// Admin endpoints.
adminH := admin.New(d.Manager, d.Log)
adminH.Register(mux)
// Metrics endpoint.
if d.MetricsHandler != nil && d.MetricsPath != "" {
mux.Handle(d.MetricsPath, d.MetricsHandler)
}
// Fallback to the ogen-generated routes (v1 + ready) for anything else.
mux.Handle("/", ogenSrv)
return &Server{
port: port,
mux: mux,
log: d.Log,
}, nil
}
// Run starts the HTTP server and blocks until it returns.
//
// The handler chain is: CORS → request logger → mux.
func (s *Server) Run(ctx context.Context) error {
srv := &http.Server{
Addr: fmt.Sprintf(":%d", s.port),
Handler: middleware.CORS(middleware.HTTPLogging(s.log, s.mux)),
}
errCh := make(chan error, 1)
go func() {
s.log.Info("HTTP server starting", zap.Int("port", s.port))
errCh <- srv.ListenAndServe()
}()
select {
case err := <-errCh:
return err
case <-ctx.Done():
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
return srv.Shutdown(shutdownCtx)
}
}

173
internal/api/v2/handler.go Normal file
View file

@ -0,0 +1,173 @@
package v2
import (
"encoding/json"
"fmt"
"net/http"
"time"
"go.uber.org/zap"
"predictor-refactored/internal/datasets"
"predictor-refactored/internal/elevation"
"predictor-refactored/internal/engine"
"predictor-refactored/internal/metrics"
)
// Handler serves POST /api/v2/prediction.
type Handler struct {
mgr *datasets.Manager
elev *elevation.Dataset
metrics metrics.Sink
log *zap.Logger
}
// New wires a v2 Handler.
func New(mgr *datasets.Manager, elev *elevation.Dataset, sink metrics.Sink, log *zap.Logger) *Handler {
if log == nil {
log = zap.NewNop()
}
if sink == nil {
sink = metrics.Noop()
}
return &Handler{mgr: mgr, elev: elev, metrics: sink, log: log}
}
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "use POST")
return
}
var req PredictionRequest
dec := json.NewDecoder(r.Body)
dec.DisallowUnknownFields()
if err := dec.Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid request body: "+err.Error())
return
}
if err := validateRequest(req); err != nil {
writeError(w, http.StatusBadRequest, err.Error())
return
}
field := h.mgr.Active()
if field == nil {
writeError(w, http.StatusServiceUnavailable, "no dataset loaded, service is starting up")
return
}
// Normalize longitude to [0, 360) for internal use.
lng := req.Launch.Longitude
if lng < 0 {
lng += 360
}
warnings := &engine.Warnings{}
var terrain engine.TerrainProvider
if h.elev != nil {
terrain = h.elev
}
prof, err := buildProfile(req, field, terrain, warnings)
if err != nil {
writeError(w, http.StatusBadRequest, err.Error())
return
}
started := time.Now().UTC()
results := prof.Run(float64(req.Launch.Time.Unix()), engine.State{
Lat: req.Launch.Latitude,
Lng: lng,
Altitude: req.Launch.Altitude,
})
completed := time.Now().UTC()
h.metrics.Prediction("v2", completed.Sub(started), nil)
resp := PredictionResponse{
Stages: make([]StageResult, 0, len(results)),
StartedAt: started,
CompletedAt: completed,
Dataset: DatasetInfo{
Source: field.Source(),
Epoch: field.Epoch(),
},
}
for _, r := range results {
stage := StageResult{
Name: r.Propagator,
Outcome: outcomeString(r.Outcome),
}
if r.Constraint != nil {
stage.Constraint = r.Constraint.Name()
}
stage.Trajectory = make([]TrajectoryPoint, len(r.Points))
for i, pt := range r.Points {
ptLng := pt.Lng
if ptLng > 180 {
ptLng -= 360
}
stage.Trajectory[i] = TrajectoryPoint{
Time: time.Unix(int64(pt.Time), 0).UTC(),
Latitude: pt.Lat,
Longitude: ptLng,
Altitude: pt.Altitude,
}
}
resp.Stages = append(resp.Stages, stage)
}
if warns := warnings.ToMap(); len(warns) > 0 {
resp.Warnings = warns
}
h.log.Info("v2 prediction complete",
zap.Int("stages", len(results)),
zap.Duration("elapsed", completed.Sub(started)))
writeJSON(w, http.StatusOK, resp)
}
func validateRequest(req PredictionRequest) error {
if req.Launch.Latitude < -90 || req.Launch.Latitude > 90 {
return fmt.Errorf("launch.latitude must be in [-90, 90]")
}
if req.Launch.Longitude < -180 || req.Launch.Longitude >= 360 {
return fmt.Errorf("launch.longitude must be in [-180, 360)")
}
if len(req.Profile) == 0 {
return fmt.Errorf("profile must contain at least one stage")
}
for i, s := range req.Profile {
if s.Name == "" {
return fmt.Errorf("profile[%d].name is required", i)
}
if s.Model.Type == "" {
return fmt.Errorf("profile[%d].model.type is required", i)
}
}
return nil
}
func outcomeString(o engine.Outcome) string {
switch o {
case engine.OutcomeStopped:
return "stopped"
case engine.OutcomeFallback:
return "fallback"
default:
return "continued"
}
}
func writeError(w http.ResponseWriter, status int, description string) {
writeJSON(w, status, ErrorResponse{Error: ErrorBody{
Type: http.StatusText(status),
Description: description,
}})
}
func writeJSON(w http.ResponseWriter, status int, body any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(body)
}

145
internal/api/v2/profile.go Normal file
View file

@ -0,0 +1,145 @@
package v2
import (
"fmt"
"predictor-refactored/internal/engine"
"predictor-refactored/internal/weather"
)
// buildProfile translates a PredictionRequest into an engine.Profile.
//
// elev may be nil when no terrain dataset is loaded; TerrainContact constraints
// will return an error in that case.
func buildProfile(req PredictionRequest, field weather.WindField, elev engine.TerrainProvider, warnings *engine.Warnings) (engine.Profile, error) {
if len(req.Profile) == 0 {
return engine.Profile{}, fmt.Errorf("profile must contain at least one stage")
}
step := req.Options.StepSeconds
if step == 0 {
step = 60
}
tol := req.Options.Tolerance
if tol == 0 {
tol = 0.01
}
dir := engine.Forward
switch req.Direction {
case "", "forward":
dir = engine.Forward
case "reverse":
dir = engine.Reverse
default:
return engine.Profile{}, fmt.Errorf("unknown direction %q", req.Direction)
}
props := make([]*engine.Propagator, len(req.Profile))
for i, stage := range req.Profile {
model, err := buildModel(stage.Model, field, warnings)
if err != nil {
return engine.Profile{}, fmt.Errorf("stage %q: %w", stage.Name, err)
}
constraints, err := buildConstraints(stage.Constraints, elev)
if err != nil {
return engine.Profile{}, fmt.Errorf("stage %q: %w", stage.Name, err)
}
props[i] = &engine.Propagator{
Name: stage.Name,
Step: step,
Model: model,
Constraints: constraints,
Tolerance: tol,
}
}
// Wire fallbacks once all stages exist.
for i, stage := range req.Profile {
if stage.FallbackIndex == nil {
continue
}
idx := *stage.FallbackIndex
if idx < 0 || idx >= len(props) {
return engine.Profile{}, fmt.Errorf("stage %q: fallback_index %d out of range", stage.Name, idx)
}
props[i].Fallback = props[idx]
}
return engine.Profile{Stages: props, Direction: dir}, nil
}
func buildModel(spec ModelSpec, field weather.WindField, warnings *engine.Warnings) (engine.Model, error) {
var base engine.Model
switch spec.Type {
case "constant_rate":
base = engine.ConstantRate(spec.Rate)
case "parachute_descent":
if spec.SeaLevelRate <= 0 {
return nil, fmt.Errorf("parachute_descent requires positive sea_level_rate")
}
base = engine.ParachuteDescent(spec.SeaLevelRate)
case "piecewise":
segs := make([]engine.RateSegment, len(spec.Segments))
for i, s := range spec.Segments {
segs[i] = engine.RateSegment{Until: s.Until, Rate: s.Rate}
}
base = engine.Piecewise(segs)
case "wind":
if field == nil {
return nil, fmt.Errorf("wind model requires a loaded dataset")
}
return engine.WindTransport(field, warnings), nil
default:
return nil, fmt.Errorf("unknown model type %q", spec.Type)
}
if spec.IncludeWind {
if field == nil {
return nil, fmt.Errorf("include_wind requires a loaded dataset")
}
return engine.Sum(base, engine.WindTransport(field, warnings)), nil
}
return base, nil
}
func buildConstraints(specs []ConstraintSpec, elev engine.TerrainProvider) ([]engine.Constraint, error) {
out := make([]engine.Constraint, 0, len(specs))
for _, spec := range specs {
action, err := parseAction(spec.Action)
if err != nil {
return nil, err
}
var c engine.Constraint
switch spec.Type {
case "max_altitude":
c = engine.MaxAltitude{Limit: spec.Limit, On: action}
case "min_altitude":
c = engine.MinAltitude{Limit: spec.Limit, On: action}
case "max_time":
c = engine.MaxTime{Limit: spec.Limit, On: action}
case "terrain_contact":
if elev == nil {
return nil, fmt.Errorf("terrain_contact requires an elevation dataset")
}
c = engine.TerrainContact{Provider: elev, On: action}
default:
return nil, fmt.Errorf("unknown constraint type %q", spec.Type)
}
out = append(out, c)
}
return out, nil
}
func parseAction(s string) (engine.Action, error) {
switch s {
case "", "stop":
return engine.ActionStop, nil
case "fallback":
return engine.ActionFallback, nil
case "clip":
return engine.ActionClip, nil
default:
return 0, fmt.Errorf("unknown constraint action %q", s)
}
}

114
internal/api/v2/types.go Normal file
View file

@ -0,0 +1,114 @@
// Package v2 implements the new primary prediction endpoint, which accepts a
// user-defined profile (chain of propagators with optional constraints) and
// returns the resulting trajectory.
//
// Endpoint: POST /api/v2/prediction
package v2
import "time"
// PredictionRequest is the request body for POST /api/v2/prediction.
type PredictionRequest struct {
Launch Launch `json:"launch"`
Profile []Stage `json:"profile"`
Options Options `json:"options,omitempty"`
Direction string `json:"direction,omitempty"` // "forward" (default) or "reverse"
}
// Launch is the initial state of the balloon (or, for reverse predictions,
// the known landing point).
type Launch struct {
Time time.Time `json:"time"`
Latitude float64 `json:"latitude"`
Longitude float64 `json:"longitude"`
Altitude float64 `json:"altitude"`
}
// Stage is one entry in the propagator chain.
type Stage struct {
Name string `json:"name"`
Model ModelSpec `json:"model"`
Constraints []ConstraintSpec `json:"constraints,omitempty"`
// FallbackIndex, when set, points to another stage in the same profile to
// transfer to on ActionFallback constraints. Optional.
FallbackIndex *int `json:"fallback_index,omitempty"`
}
// ModelSpec describes the per-stage propagation model.
type ModelSpec struct {
// Type selects the model: "constant_rate", "parachute_descent", "piecewise", "wind".
Type string `json:"type"`
// Rate (m/s) for constant_rate.
Rate float64 `json:"rate,omitempty"`
// SeaLevelRate (m/s, positive) for parachute_descent.
SeaLevelRate float64 `json:"sea_level_rate,omitempty"`
// Segments for piecewise.
Segments []PiecewiseSegment `json:"segments,omitempty"`
// IncludeWind sums a WindTransport model into the resulting derivative,
// allowing the same stage to model both vertical motion and wind drift.
IncludeWind bool `json:"include_wind"`
}
// PiecewiseSegment is one entry in a piecewise rate schedule.
type PiecewiseSegment struct {
Until float64 `json:"until"` // UNIX seconds; segment applies for t < Until
Rate float64 `json:"rate"` // m/s
}
// ConstraintSpec describes one constraint attached to a stage.
type ConstraintSpec struct {
// Type: "max_altitude", "min_altitude", "max_time", "terrain_contact".
Type string `json:"type"`
// Limit is interpreted per Type: metres for altitude, UNIX seconds for time.
Limit float64 `json:"limit,omitempty"`
// Action: "stop" (default), "fallback", "clip".
Action string `json:"action,omitempty"`
}
// Options tweaks the integrator behaviour.
type Options struct {
StepSeconds float64 `json:"step_seconds,omitempty"`
Tolerance float64 `json:"tolerance,omitempty"`
}
// PredictionResponse is the response body for POST /api/v2/prediction.
type PredictionResponse struct {
Stages []StageResult `json:"stages"`
Warnings map[string]any `json:"warnings,omitempty"`
Dataset DatasetInfo `json:"dataset"`
StartedAt time.Time `json:"started_at"`
CompletedAt time.Time `json:"completed_at"`
}
// StageResult is the outcome of one stage.
type StageResult struct {
Name string `json:"name"`
Outcome string `json:"outcome"` // "stopped" | "fallback" | "continued"
Constraint string `json:"constraint,omitempty"`
Trajectory []TrajectoryPoint `json:"trajectory"`
}
// TrajectoryPoint is one sampled point of the trajectory.
type TrajectoryPoint struct {
Time time.Time `json:"time"`
Latitude float64 `json:"latitude"`
Longitude float64 `json:"longitude"`
Altitude float64 `json:"altitude"`
}
// DatasetInfo identifies the dataset the prediction was computed against.
type DatasetInfo struct {
Source string `json:"source"`
Epoch time.Time `json:"epoch"`
}
// ErrorResponse is the JSON error shape used by both v2 and admin endpoints.
type ErrorResponse struct {
Error ErrorBody `json:"error"`
}
// ErrorBody is the error detail.
type ErrorBody struct {
Type string `json:"type"`
Description string `json:"description"`
}

252
internal/config/config.go Normal file
View file

@ -0,0 +1,252 @@
// Package config holds the service's runtime configuration, loaded by
// merging (in order of increasing precedence): built-in defaults, a YAML
// config file, environment variables, and command-line flags.
//
// Validation is performed once on load; downstream consumers receive an
// immutable struct.
package config
import (
"flag"
"fmt"
"os"
"strconv"
"time"
"gopkg.in/yaml.v2"
)
// Config is the top-level configuration tree.
type Config struct {
HTTP HTTPConfig `yaml:"http"`
Data DataConfig `yaml:"data"`
Download DownloadConfig `yaml:"download"`
Metrics MetricsConfig `yaml:"metrics"`
Log LogConfig `yaml:"log"`
}
// HTTPConfig configures the HTTP server.
type HTTPConfig struct {
Port int `yaml:"port"`
}
// DataConfig configures dataset and elevation storage.
type DataConfig struct {
Dir string `yaml:"dir"`
ElevationPath string `yaml:"elevation_path"`
// Source is the dataset source identifier; only "noaa-gfs-0p50" is supported today.
Source string `yaml:"source"`
}
// DownloadConfig configures the dataset downloader.
type DownloadConfig struct {
Parallel int `yaml:"parallel"`
BandwidthBytesPerSecond int64 `yaml:"bandwidth_bytes_per_second"`
UpdateInterval time.Duration `yaml:"update_interval"`
FreshnessTTL time.Duration `yaml:"freshness_ttl"`
}
// MetricsConfig configures the metrics endpoint.
type MetricsConfig struct {
Enabled bool `yaml:"enabled"`
Path string `yaml:"path"`
}
// LogConfig configures logging.
type LogConfig struct {
Level string `yaml:"level"` // "debug", "info", "warn", "error"
}
// Defaults returns a Config with reasonable default values.
func Defaults() Config {
return Config{
HTTP: HTTPConfig{Port: 8080},
Data: DataConfig{
Dir: "/tmp/predictor-data",
ElevationPath: "/srv/ruaumoko-dataset",
Source: "noaa-gfs-0p50",
},
Download: DownloadConfig{
Parallel: 8,
BandwidthBytesPerSecond: 0,
UpdateInterval: 6 * time.Hour,
FreshnessTTL: 48 * time.Hour,
},
Metrics: MetricsConfig{
Enabled: true,
Path: "/metrics",
},
Log: LogConfig{Level: "info"},
}
}
// Load resolves the configuration by merging built-in defaults with
// (in increasing precedence): a YAML file (path from PREDICTOR_CONFIG_FILE
// env var or --config flag), environment variables, and command-line flags.
//
// args is os.Args[1:] in production code; tests pass a custom slice.
func Load(args []string) (Config, error) {
cfg := Defaults()
fs := flag.NewFlagSet("predictor", flag.ContinueOnError)
// Surface a deterministic usage by suppressing the default output:
fs.SetOutput(os.Stderr)
var (
configPath = fs.String("config", os.Getenv("PREDICTOR_CONFIG_FILE"), "path to YAML config file")
// Flag-driven overrides. Empty / -1 means "not specified".
flagPort = fs.Int("port", -1, "HTTP listen port")
flagDataDir = fs.String("data-dir", "", "directory for dataset files")
flagElevation = fs.String("elevation", "", "path to ruaumoko elevation dataset")
flagParallel = fs.Int("download-parallel", -1, "max concurrent GRIB downloads")
flagBandwidth = fs.Int64("download-bandwidth", -1, "download bandwidth limit in bytes/sec (0 = unlimited)")
flagInterval = fs.Duration("update-interval", 0, "scheduler refresh interval")
flagTTL = fs.Duration("freshness-ttl", 0, "max age before a dataset is considered stale")
flagMetricsEnabled = fs.Bool("metrics", true, "enable Prometheus-compatible metrics endpoint")
flagMetricsPath = fs.String("metrics-path", "", "HTTP path for the metrics endpoint")
flagLogLevel = fs.String("log-level", "", "log level: debug|info|warn|error")
)
if err := fs.Parse(args); err != nil {
return Config{}, fmt.Errorf("parse flags: %w", err)
}
// 1. File.
if *configPath != "" {
data, err := os.ReadFile(*configPath)
if err != nil {
return Config{}, fmt.Errorf("read config %s: %w", *configPath, err)
}
if err := yaml.UnmarshalStrict(data, &cfg); err != nil {
return Config{}, fmt.Errorf("parse config %s: %w", *configPath, err)
}
}
// 2. Env vars.
applyEnv(&cfg)
// 3. Flags (only when explicitly set).
if *flagPort >= 0 {
cfg.HTTP.Port = *flagPort
}
if *flagDataDir != "" {
cfg.Data.Dir = *flagDataDir
}
if *flagElevation != "" {
cfg.Data.ElevationPath = *flagElevation
}
if *flagParallel >= 0 {
cfg.Download.Parallel = *flagParallel
}
if *flagBandwidth >= 0 {
cfg.Download.BandwidthBytesPerSecond = *flagBandwidth
}
if *flagInterval != 0 {
cfg.Download.UpdateInterval = *flagInterval
}
if *flagTTL != 0 {
cfg.Download.FreshnessTTL = *flagTTL
}
// flag.Bool defaults to true here so we only override if user explicitly disables it.
if isFlagSet(fs, "metrics") {
cfg.Metrics.Enabled = *flagMetricsEnabled
}
if *flagMetricsPath != "" {
cfg.Metrics.Path = *flagMetricsPath
}
if *flagLogLevel != "" {
cfg.Log.Level = *flagLogLevel
}
if err := cfg.Validate(); err != nil {
return Config{}, err
}
return cfg, nil
}
func isFlagSet(fs *flag.FlagSet, name string) bool {
set := false
fs.Visit(func(f *flag.Flag) {
if f.Name == name {
set = true
}
})
return set
}
// applyEnv overlays PREDICTOR_* environment variables onto cfg.
func applyEnv(cfg *Config) {
if v := os.Getenv("PREDICTOR_PORT"); v != "" {
if n, err := strconv.Atoi(v); err == nil {
cfg.HTTP.Port = n
}
}
if v := os.Getenv("PREDICTOR_DATA_DIR"); v != "" {
cfg.Data.Dir = v
}
if v := os.Getenv("PREDICTOR_ELEVATION_DATASET"); v != "" {
cfg.Data.ElevationPath = v
}
if v := os.Getenv("PREDICTOR_SOURCE"); v != "" {
cfg.Data.Source = v
}
if v := os.Getenv("PREDICTOR_DOWNLOAD_PARALLEL"); v != "" {
if n, err := strconv.Atoi(v); err == nil {
cfg.Download.Parallel = n
}
}
if v := os.Getenv("PREDICTOR_DOWNLOAD_BANDWIDTH"); v != "" {
if n, err := strconv.ParseInt(v, 10, 64); err == nil {
cfg.Download.BandwidthBytesPerSecond = n
}
}
if v := os.Getenv("PREDICTOR_UPDATE_INTERVAL"); v != "" {
if d, err := time.ParseDuration(v); err == nil {
cfg.Download.UpdateInterval = d
}
}
if v := os.Getenv("PREDICTOR_DATASET_TTL"); v != "" {
if d, err := time.ParseDuration(v); err == nil {
cfg.Download.FreshnessTTL = d
}
}
if v := os.Getenv("PREDICTOR_METRICS_ENABLED"); v != "" {
cfg.Metrics.Enabled = v == "1" || v == "true" || v == "yes"
}
if v := os.Getenv("PREDICTOR_METRICS_PATH"); v != "" {
cfg.Metrics.Path = v
}
if v := os.Getenv("PREDICTOR_LOG_LEVEL"); v != "" {
cfg.Log.Level = v
}
}
// Validate reports configuration errors.
func (c Config) Validate() error {
if c.HTTP.Port < 0 || c.HTTP.Port > 65535 {
return fmt.Errorf("http.port %d outside [0, 65535]", c.HTTP.Port)
}
if c.Data.Dir == "" {
return fmt.Errorf("data.dir is required")
}
if c.Data.Source == "" {
return fmt.Errorf("data.source is required")
}
if c.Download.Parallel <= 0 {
return fmt.Errorf("download.parallel must be > 0")
}
if c.Download.UpdateInterval <= 0 {
return fmt.Errorf("download.update_interval must be > 0")
}
if c.Download.FreshnessTTL <= 0 {
return fmt.Errorf("download.freshness_ttl must be > 0")
}
if c.Metrics.Enabled && c.Metrics.Path == "" {
return fmt.Errorf("metrics.path is required when metrics enabled")
}
switch c.Log.Level {
case "debug", "info", "warn", "error":
default:
return fmt.Errorf("log.level %q is not one of debug|info|warn|error", c.Log.Level)
}
return nil
}

View file

@ -0,0 +1,76 @@
package config
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestLoadDefaults(t *testing.T) {
t.Setenv("PREDICTOR_DATA_DIR", "")
t.Setenv("PREDICTOR_PORT", "")
t.Setenv("PREDICTOR_CONFIG_FILE", "")
cfg, err := Load(nil)
if err != nil {
t.Fatalf("Load: %v", err)
}
if cfg.HTTP.Port != 8080 {
t.Errorf("default port = %d, want 8080", cfg.HTTP.Port)
}
if cfg.Download.Parallel != 8 {
t.Errorf("default parallel = %d, want 8", cfg.Download.Parallel)
}
}
func TestLoadEnvOverridesDefaults(t *testing.T) {
t.Setenv("PREDICTOR_PORT", "9090")
t.Setenv("PREDICTOR_UPDATE_INTERVAL", "30m")
cfg, err := Load(nil)
if err != nil {
t.Fatalf("Load: %v", err)
}
if cfg.HTTP.Port != 9090 {
t.Errorf("env port = %d, want 9090", cfg.HTTP.Port)
}
if cfg.Download.UpdateInterval != 30*time.Minute {
t.Errorf("env update interval = %v, want 30m", cfg.Download.UpdateInterval)
}
}
func TestLoadFlagsOverrideEnv(t *testing.T) {
t.Setenv("PREDICTOR_PORT", "9090")
cfg, err := Load([]string{"-port", "7777"})
if err != nil {
t.Fatalf("Load: %v", err)
}
if cfg.HTTP.Port != 7777 {
t.Errorf("flag should override env: port = %d, want 7777", cfg.HTTP.Port)
}
}
func TestLoadFileOverridesDefaults(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "predictor.yml")
if err := os.WriteFile(path, []byte("http:\n port: 12345\n"), 0o644); err != nil {
t.Fatal(err)
}
cfg, err := Load([]string{"-config", path})
if err != nil {
t.Fatalf("Load: %v", err)
}
if cfg.HTTP.Port != 12345 {
t.Errorf("file port = %d, want 12345", cfg.HTTP.Port)
}
}
func TestValidate(t *testing.T) {
cfg := Defaults()
cfg.Data.Dir = ""
if err := cfg.Validate(); err == nil {
t.Error("expected validation error for empty data dir")
}
}

View file

@ -1,158 +0,0 @@
package dataset
import "fmt"
// Dataset shape constants.
// Shape: (65, 47, 3, 361, 720) = (hour, pressure_level, variable, latitude, longitude)
// This matches the reference predictor exactly.
const (
NumHours = 65 // 0, 3, 6, ..., 192
NumLevels = 47 // pressure levels
NumVariables = 3 // height, wind_u, wind_v
NumLatitudes = 361 // -90.0 to +90.0 in 0.5 degree steps
NumLongitudes = 720 // 0.0 to 359.5 in 0.5 degree steps
HourStep = 3 // hours between forecast time steps
MaxHour = 192 // maximum forecast hour
Resolution = 0.5 // grid resolution in degrees
LatStart = -90.0 // first latitude in the dataset
LonStart = 0.0 // first longitude in the dataset
// Variable indices within the dataset.
VarHeight = 0
VarWindU = 1
VarWindV = 2
ElementSize = 4 // float32 = 4 bytes
)
// DatasetSize is the total size of the dataset file in bytes.
// 65 * 47 * 3 * 361 * 720 * 4 = 9,528,667,200
const DatasetSize int64 = int64(NumHours) * int64(NumLevels) * int64(NumVariables) *
int64(NumLatitudes) * int64(NumLongitudes) * int64(ElementSize)
// LevelSet identifies which GRIB file set a pressure level belongs to.
type LevelSet int
const (
LevelSetA LevelSet = iota // pgrb2 (primary)
LevelSetB // pgrb2b (secondary)
)
// Pressures contains the 47 pressure levels in hPa, sorted descending.
// Index 0 = 1000 hPa (near surface), Index 46 = 1 hPa (high atmosphere).
var Pressures = [NumLevels]int{
1000, 975, 950, 925, 900, 875, 850, 825, 800, 775,
750, 725, 700, 675, 650, 625, 600, 575, 550, 525,
500, 475, 450, 425, 400, 375, 350, 325, 300, 275,
250, 225, 200, 175, 150, 125, 100, 70, 50, 30,
20, 10, 7, 5, 3, 2, 1,
}
// pressureIndex maps pressure in hPa to its index in the Pressures array.
var pressureIndex map[int]int
// pressureLevelSet maps pressure in hPa to its GRIB file set.
var pressureLevelSet map[int]LevelSet
func init() {
pressureIndex = make(map[int]int, NumLevels)
for i, p := range Pressures {
pressureIndex[p] = i
}
pressureLevelSet = make(map[int]LevelSet, NumLevels)
for _, p := range PressuresPgrb2 {
pressureLevelSet[p] = LevelSetA
}
for _, p := range PressuresPgrb2b {
pressureLevelSet[p] = LevelSetB
}
}
// PressuresPgrb2 contains levels found in the primary pgrb2 file (26 levels).
var PressuresPgrb2 = []int{
10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 350, 400,
450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 925,
950, 975, 1000,
}
// PressuresPgrb2b contains levels found in the secondary pgrb2b file (21 levels).
var PressuresPgrb2b = []int{
1, 2, 3, 5, 7, 125, 175, 225, 275, 325, 375, 425,
475, 525, 575, 625, 675, 725, 775, 825, 875,
}
// PressureIndex returns the dataset index for a given pressure level in hPa.
// Returns -1 if the level is not found.
func PressureIndex(hPa int) int {
idx, ok := pressureIndex[hPa]
if !ok {
return -1
}
return idx
}
// PressureLevelSet returns which GRIB file set a pressure level belongs to.
func PressureLevelSet(hPa int) (LevelSet, bool) {
ls, ok := pressureLevelSet[hPa]
return ls, ok
}
// HourIndex returns the dataset time index for a forecast hour.
// Returns -1 if the hour is invalid (not a multiple of HourStep or out of range).
func HourIndex(hour int) int {
if hour < 0 || hour > MaxHour || hour%HourStep != 0 {
return -1
}
return hour / HourStep
}
// Hours returns all forecast hours as a slice: [0, 3, 6, ..., 192].
func Hours() []int {
out := make([]int, 0, NumHours)
for h := 0; h <= MaxHour; h += HourStep {
out = append(out, h)
}
return out
}
// S3 URL configuration for NOAA GFS data.
const S3BaseURL = "https://noaa-gfs-bdp-pds.s3.amazonaws.com"
// GribURL returns the S3 URL for a primary (pgrb2) GRIB file.
func GribURL(date string, runHour, forecastStep int) string {
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.0p50.f%03d",
S3BaseURL, date, runHour, runHour, forecastStep)
}
// GribURLB returns the S3 URL for a secondary (pgrb2b) GRIB file.
func GribURLB(date string, runHour, forecastStep int) string {
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2b.0p50.f%03d",
S3BaseURL, date, runHour, runHour, forecastStep)
}
// GribFileName returns the local filename for a primary GRIB file.
func GribFileName(runHour, forecastStep int) string {
return fmt.Sprintf("gfs.t%02dz.pgrb2.0p50.f%03d", runHour, forecastStep)
}
// GribFileNameB returns the local filename for a secondary GRIB file.
func GribFileNameB(runHour, forecastStep int) string {
return fmt.Sprintf("gfs.t%02dz.pgrb2b.0p50.f%03d", runHour, forecastStep)
}
// VariableIndex returns the dataset variable index for a GRIB parameter.
// Returns -1 if the parameter is not recognized.
func VariableIndex(parameterCategory, parameterNumber int) int {
switch {
case parameterCategory == 3 && parameterNumber == 5:
return VarHeight // Geopotential Height
case parameterCategory == 2 && parameterNumber == 2:
return VarWindU // U-component of wind
case parameterCategory == 2 && parameterNumber == 3:
return VarWindV // V-component of wind
default:
return -1
}
}

View file

@ -1,152 +0,0 @@
package dataset
import (
"testing"
)
func TestDatasetShape(t *testing.T) {
if NumHours != 65 {
t.Errorf("NumHours = %d, want 65", NumHours)
}
if NumLevels != 47 {
t.Errorf("NumLevels = %d, want 47", NumLevels)
}
if NumVariables != 3 {
t.Errorf("NumVariables = %d, want 3", NumVariables)
}
if NumLatitudes != 361 {
t.Errorf("NumLatitudes = %d, want 361", NumLatitudes)
}
if NumLongitudes != 720 {
t.Errorf("NumLongitudes = %d, want 720", NumLongitudes)
}
}
func TestDatasetSize(t *testing.T) {
// 65 * 47 * 3 * 361 * 720 * 4 = 9,528,667,200
want := int64(9_528_667_200)
if DatasetSize != want {
t.Errorf("DatasetSize = %d, want %d", DatasetSize, want)
}
}
func TestPressureLevels(t *testing.T) {
if len(Pressures) != 47 {
t.Fatalf("len(Pressures) = %d, want 47", len(Pressures))
}
// First should be 1000 (highest pressure, near surface)
if Pressures[0] != 1000 {
t.Errorf("Pressures[0] = %d, want 1000", Pressures[0])
}
// Last should be 1 (lowest pressure, high atmosphere)
if Pressures[46] != 1 {
t.Errorf("Pressures[46] = %d, want 1", Pressures[46])
}
// Should be sorted descending
for i := 1; i < len(Pressures); i++ {
if Pressures[i] >= Pressures[i-1] {
t.Errorf("Pressures not descending at [%d]: %d >= %d", i, Pressures[i], Pressures[i-1])
}
}
// Total levels: 26 from pgrb2 + 21 from pgrb2b = 47
if len(PressuresPgrb2) != 26 {
t.Errorf("len(PressuresPgrb2) = %d, want 26", len(PressuresPgrb2))
}
if len(PressuresPgrb2b) != 21 {
t.Errorf("len(PressuresPgrb2b) = %d, want 21", len(PressuresPgrb2b))
}
}
func TestPressureIndex(t *testing.T) {
if PressureIndex(1000) != 0 {
t.Errorf("PressureIndex(1000) = %d, want 0", PressureIndex(1000))
}
if PressureIndex(1) != 46 {
t.Errorf("PressureIndex(1) = %d, want 46", PressureIndex(1))
}
if PressureIndex(500) != 20 {
t.Errorf("PressureIndex(500) = %d, want 20", PressureIndex(500))
}
if PressureIndex(9999) != -1 {
t.Errorf("PressureIndex(9999) = %d, want -1", PressureIndex(9999))
}
}
func TestPressureLevelSet(t *testing.T) {
// 1000 mb should be in pgrb2 (A)
ls, ok := PressureLevelSet(1000)
if !ok || ls != LevelSetA {
t.Errorf("PressureLevelSet(1000) = %v, %v; want A, true", ls, ok)
}
// 125 mb should be in pgrb2b (B)
ls, ok = PressureLevelSet(125)
if !ok || ls != LevelSetB {
t.Errorf("PressureLevelSet(125) = %v, %v; want B, true", ls, ok)
}
// 1, 2, 3, 5, 7 should be in pgrb2b (B)
for _, p := range []int{1, 2, 3, 5, 7} {
ls, ok := PressureLevelSet(p)
if !ok || ls != LevelSetB {
t.Errorf("PressureLevelSet(%d) = %v, %v; want B, true", p, ls, ok)
}
}
// Every pressure level should have a level set assignment
for _, p := range Pressures {
_, ok := PressureLevelSet(p)
if !ok {
t.Errorf("PressureLevelSet(%d) not found", p)
}
}
}
func TestHourIndex(t *testing.T) {
if HourIndex(0) != 0 {
t.Errorf("HourIndex(0) = %d, want 0", HourIndex(0))
}
if HourIndex(3) != 1 {
t.Errorf("HourIndex(3) = %d, want 1", HourIndex(3))
}
if HourIndex(192) != 64 {
t.Errorf("HourIndex(192) = %d, want 64", HourIndex(192))
}
if HourIndex(1) != -1 {
t.Errorf("HourIndex(1) = %d, want -1 (not multiple of 3)", HourIndex(1))
}
if HourIndex(195) != -1 {
t.Errorf("HourIndex(195) = %d, want -1 (out of range)", HourIndex(195))
}
}
func TestHours(t *testing.T) {
hours := Hours()
if len(hours) != NumHours {
t.Fatalf("len(Hours()) = %d, want %d", len(hours), NumHours)
}
if hours[0] != 0 {
t.Errorf("Hours()[0] = %d, want 0", hours[0])
}
if hours[len(hours)-1] != MaxHour {
t.Errorf("Hours()[last] = %d, want %d", hours[len(hours)-1], MaxHour)
}
}
func TestVariableIndex(t *testing.T) {
if VariableIndex(3, 5) != VarHeight {
t.Errorf("HGT: got %d, want %d", VariableIndex(3, 5), VarHeight)
}
if VariableIndex(2, 2) != VarWindU {
t.Errorf("UGRD: got %d, want %d", VariableIndex(2, 2), VarWindU)
}
if VariableIndex(2, 3) != VarWindV {
t.Errorf("VGRD: got %d, want %d", VariableIndex(2, 3), VarWindV)
}
if VariableIndex(0, 0) != -1 {
t.Errorf("unknown: got %d, want -1", VariableIndex(0, 0))
}
}

View file

@ -1,140 +0,0 @@
package dataset
import (
"encoding/binary"
"fmt"
"math"
"os"
"time"
mmap "github.com/edsrzf/mmap-go"
)
// File represents an mmap-backed wind dataset file.
type File struct {
mm mmap.MMap
file *os.File
writable bool
DSTime time.Time // forecast run time (UTC)
}
// Open opens an existing dataset file for reading.
func Open(path string, dsTime time.Time) (*File, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open dataset: %w", err)
}
info, err := f.Stat()
if err != nil {
f.Close()
return nil, fmt.Errorf("stat dataset: %w", err)
}
if info.Size() != DatasetSize {
f.Close()
return nil, fmt.Errorf("dataset should be %d bytes (was %d)", DatasetSize, info.Size())
}
mm, err := mmap.Map(f, mmap.RDONLY, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap dataset: %w", err)
}
return &File{mm: mm, file: f, writable: false, DSTime: dsTime}, nil
}
// Create creates a new dataset file for writing.
// The file is truncated to the correct size and mmap'd read-write.
func Create(path string) (*File, error) {
f, err := os.Create(path)
if err != nil {
return nil, fmt.Errorf("create dataset: %w", err)
}
if err := f.Truncate(DatasetSize); err != nil {
f.Close()
return nil, fmt.Errorf("truncate dataset: %w", err)
}
mm, err := mmap.MapRegion(f, int(DatasetSize), mmap.RDWR, 0, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap dataset: %w", err)
}
return &File{mm: mm, file: f, writable: true}, nil
}
// offset computes the byte offset for element [hour][level][variable][lat][lon].
// Row-major C-order indexing matching the reference implementation:
// shape = (65, 47, 3, 361, 720)
func offset(hour, level, variable, lat, lon int) int64 {
idx := int64(hour)
idx = idx*int64(NumLevels) + int64(level)
idx = idx*int64(NumVariables) + int64(variable)
idx = idx*int64(NumLatitudes) + int64(lat)
idx = idx*int64(NumLongitudes) + int64(lon)
return idx * int64(ElementSize)
}
// Val reads a float32 value from the dataset at [hour][level][variable][lat][lon].
func (d *File) Val(hour, level, variable, lat, lon int) float32 {
off := offset(hour, level, variable, lat, lon)
bits := binary.LittleEndian.Uint32(d.mm[off : off+4])
return math.Float32frombits(bits)
}
// SetVal writes a float32 value to the dataset at [hour][level][variable][lat][lon].
// Only valid on writable (created) datasets.
func (d *File) SetVal(hour, level, variable, lat, lon int, val float32) {
off := offset(hour, level, variable, lat, lon)
binary.LittleEndian.PutUint32(d.mm[off:off+4], math.Float32bits(val))
}
// BlitGribData copies decoded GRIB grid data into the dataset at the given position.
// gribData is 361*720 float64 values in GRIB scan order (north-to-south, west-to-east).
// This function flips the latitude so that dataset index 0 = -90 (south) and 360 = +90 (north).
func (d *File) BlitGribData(hourIdx, levelIdx, varIdx int, gribData []float64) error {
expected := NumLatitudes * NumLongitudes
if len(gribData) != expected {
return fmt.Errorf("grib data has %d values, expected %d", len(gribData), expected)
}
for lat := 0; lat < NumLatitudes; lat++ {
for lon := 0; lon < NumLongitudes; lon++ {
// GRIB scans north-to-south: row 0 = 90N, row 360 = 90S
// Dataset stores south-to-north: index 0 = -90 (90S), index 360 = +90 (90N)
gribIdx := (360-lat)*NumLongitudes + lon
val := float32(gribData[gribIdx])
d.SetVal(hourIdx, levelIdx, varIdx, lat, lon, val)
}
}
return nil
}
// Flush flushes the mmap to disk.
func (d *File) Flush() error {
if d.mm != nil {
return d.mm.Flush()
}
return nil
}
// Close unmaps and closes the dataset file.
func (d *File) Close() error {
if d.mm != nil {
if err := d.mm.Unmap(); err != nil {
d.file.Close()
return fmt.Errorf("unmap: %w", err)
}
d.mm = nil
}
if d.file != nil {
err := d.file.Close()
d.file = nil
return err
}
return nil
}

11
internal/datasets/doc.go Normal file
View file

@ -0,0 +1,11 @@
// Package datasets manages the lifecycle of atmospheric datasets. It exposes:
//
// - A Source interface for pluggable dataset origins (GFS now, ECMWF later).
// - A Storage interface for transactional, resumable on-disk persistence.
// - A Manager that coordinates downloads, tracks job state, and owns the
// currently-active weather.WindField.
//
// The package is the only one in the service that knows about download
// scheduling, manifests, or bandwidth throttling — engine and API layers
// only see WindField + Manager-as-admin.
package datasets

View file

@ -0,0 +1,125 @@
package gfs
import (
"fmt"
"strconv"
"strings"
)
// IdxEntry is one parsed line from a NOAA GRIB .idx file.
//
// Example line: "15:1207405:d=2024010100:HGT:1000 mb:0 hour fcst:"
type IdxEntry struct {
Index int
Offset int64
Variable string
LevelMB int // 0 when the level is not isobaric
Hour int // forecast hour; 0 for analysis ("anl"); -1 if unparseable
EndOffset int64 // computed from the next entry's Offset; -1 for the final entry
}
// Length returns the byte length of this GRIB message, or -1 if unknown
// (the final entry in an idx file).
func (e *IdxEntry) Length() int64 {
if e.EndOffset <= 0 {
return -1
}
return e.EndOffset - e.Offset
}
// ParseIdx parses a .idx file body. Unparseable lines are silently skipped.
func ParseIdx(body []byte) []IdxEntry {
lines := strings.Split(string(body), "\n")
var entries []IdxEntry
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.Split(line, ":")
if len(parts) < 7 {
continue
}
idx, err := strconv.Atoi(parts[0])
if err != nil {
continue
}
off, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
continue
}
entries = append(entries, IdxEntry{
Index: idx,
Offset: off,
Variable: parts[3],
LevelMB: parseLevelMB(parts[4]),
Hour: parseHour(parts[5]),
EndOffset: -1,
})
}
for i := 0; i < len(entries)-1; i++ {
entries[i].EndOffset = entries[i+1].Offset
}
return entries
}
// FilterIdx returns entries matching one of the wanted variables at a known
// pressure level with a computable byte length.
func FilterIdx(entries []IdxEntry, wanted map[string]bool) []IdxEntry {
var out []IdxEntry
for _, e := range entries {
if !wanted[e.Variable] || e.LevelMB <= 0 || e.Length() <= 0 {
continue
}
out = append(out, e)
}
return out
}
func parseLevelMB(s string) int {
s = strings.TrimSpace(s)
if !strings.HasSuffix(s, " mb") {
return 0
}
n, err := strconv.Atoi(strings.TrimSuffix(s, " mb"))
if err != nil {
return 0
}
return n
}
func parseHour(s string) int {
s = strings.TrimSpace(s)
if s == "anl" {
return 0
}
n, err := strconv.Atoi(strings.TrimSuffix(s, " hour fcst"))
if err != nil {
return -1
}
return n
}
// ByteRange is one HTTP range download corresponding to one GRIB message.
type ByteRange struct {
Start int64
End int64 // inclusive
Entry IdxEntry
}
// EntriesToRanges converts idx entries to inclusive HTTP byte ranges.
func EntriesToRanges(entries []IdxEntry) []ByteRange {
out := make([]ByteRange, 0, len(entries))
for _, e := range entries {
if e.Length() <= 0 {
continue
}
out = append(out, ByteRange{Start: e.Offset, End: e.EndOffset - 1, Entry: e})
}
return out
}
// FormatRange returns an HTTP Range header value for the byte range.
func (r ByteRange) FormatRange() string {
return fmt.Sprintf("bytes=%d-%d", r.Start, r.End)
}

View file

@ -0,0 +1,70 @@
package gfs
import "testing"
const sampleIdx = `1:0:d=2024010100:HGT:1000 mb:0 hour fcst:
2:289012:d=2024010100:HGT:975 mb:0 hour fcst:
3:541876:d=2024010100:TMP:1000 mb:0 hour fcst:
4:789012:d=2024010100:UGRD:1000 mb:0 hour fcst:
5:1045678:d=2024010100:VGRD:1000 mb:0 hour fcst:
6:1298765:d=2024010100:UGRD:975 mb:0 hour fcst:
7:1567890:d=2024010100:UGRD:2 m above ground:0 hour fcst:
8:1812345:d=2024010100:VGRD:975 mb:0 hour fcst:
9:2098765:d=2024010100:HGT:500 mb:3 hour fcst:
`
func TestParseIdx(t *testing.T) {
entries := ParseIdx([]byte(sampleIdx))
if len(entries) != 9 {
t.Fatalf("expected 9 entries, got %d", len(entries))
}
if e := entries[0]; e.Index != 1 || e.Offset != 0 || e.Variable != "HGT" || e.LevelMB != 1000 || e.Hour != 0 || e.EndOffset != 289012 {
t.Errorf("entry 0: %+v", e)
}
if e := entries[6]; e.LevelMB != 0 {
t.Errorf("non-pressure level should have LevelMB=0, got %d", e.LevelMB)
}
if e := entries[len(entries)-1]; e.EndOffset != -1 {
t.Errorf("last entry EndOffset: got %d, want -1", e.EndOffset)
}
}
func TestFilterIdx(t *testing.T) {
entries := ParseIdx([]byte(sampleIdx))
want := map[string]bool{"HGT": true, "UGRD": true, "VGRD": true}
filtered := FilterIdx(entries, want)
// HGT@1000, HGT@975, UGRD@1000, VGRD@1000, UGRD@975, VGRD@975 = 6
// HGT@500 at 3hr is last entry (no EndOffset), so dropped.
if len(filtered) != 6 {
t.Errorf("expected 6, got %d", len(filtered))
}
}
func TestParseLevelMB(t *testing.T) {
cases := []struct {
in string
want int
}{
{"1000 mb", 1000}, {"975 mb", 975}, {"1 mb", 1},
{"2 m above ground", 0}, {"surface", 0}, {"tropopause", 0},
}
for _, c := range cases {
if got := parseLevelMB(c.in); got != c.want {
t.Errorf("parseLevelMB(%q) = %d, want %d", c.in, got, c.want)
}
}
}
func TestParseHour(t *testing.T) {
cases := []struct {
in string
want int
}{
{"0 hour fcst", 0}, {"3 hour fcst", 3}, {"192 hour fcst", 192}, {"anl", 0},
}
for _, c := range cases {
if got := parseHour(c.in); got != c.want {
t.Errorf("parseHour(%q) = %d, want %d", c.in, got, c.want)
}
}
}

View file

@ -0,0 +1,430 @@
// Package gfs implements datasets.Source for NOAA GFS 0.5-degree forecasts.
package gfs
import (
"context"
"errors"
"fmt"
"io"
"math"
"net/http"
"os"
"sync"
"time"
"github.com/nilsmagnus/grib/griblib"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"predictor-refactored/internal/datasets"
"predictor-refactored/internal/weather"
wgfs "predictor-refactored/internal/weather/gfs"
)
// Source is the GFS implementation of datasets.Source.
type Source struct {
Parallel int // max concurrent step downloads
Client *http.Client // optional; defaults to a 2-minute-timeout client
Log *zap.Logger
}
// NewSource returns a default Source.
func NewSource(log *zap.Logger) *Source {
return &Source{
Parallel: 8,
Client: &http.Client{Timeout: 2 * time.Minute},
Log: log,
}
}
// ID returns the source identifier.
func (s *Source) ID() string { return "noaa-gfs-0p50" }
func (s *Source) log() *zap.Logger {
if s.Log == nil {
return zap.NewNop()
}
return s.Log
}
func (s *Source) client() *http.Client {
if s.Client == nil {
return &http.Client{Timeout: 2 * time.Minute}
}
return s.Client
}
func (s *Source) parallel() int {
if s.Parallel <= 0 {
return 8
}
return s.Parallel
}
// LatestEpoch returns the most recent run NOAA has finished publishing,
// determined by HEAD-ing the .idx for the final forecast hour. Walks back
// up to 8 runs (48 hours) before giving up.
func (s *Source) LatestEpoch(ctx context.Context) (time.Time, error) {
now := time.Now().UTC()
hour := now.Hour() - (now.Hour() % 6)
current := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC)
for range 8 {
date := current.Format("20060102")
url := wgfs.GribURL(date, current.Hour(), wgfs.MaxHour) + ".idx"
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
if err == nil {
resp, err := s.client().Do(req)
if err == nil {
resp.Body.Close()
if resp.StatusCode == http.StatusOK {
s.log().Info("latest GFS run discovered",
zap.Time("run", current),
zap.String("verified_url", url))
return current, nil
}
}
}
current = current.Add(-6 * time.Hour)
}
return time.Time{}, fmt.Errorf("no recent GFS run found (checked 8 runs)")
}
// Open loads a stored dataset as a WindField.
func (s *Source) Open(_ context.Context, epoch time.Time, store datasets.Storage) (weather.WindField, error) {
if !store.Exists(epoch) {
return nil, fmt.Errorf("epoch %s not found", epoch.Format(time.RFC3339))
}
file, err := wgfs.Open(store.Path(epoch), epoch.UTC())
if err != nil {
return nil, err
}
return wgfs.NewWind(file), nil
}
// neededVariables is the GRIB variable set we extract.
var neededVariables = map[string]bool{"HGT": true, "UGRD": true, "VGRD": true}
// Download fetches the full dataset for epoch in parallel, resuming any
// previously-completed work units. Honours ctx cancellation and prog
// (which may be nil).
func (s *Source) Download(ctx context.Context, epoch time.Time, store datasets.Storage, prog datasets.ProgressSink, throttle datasets.Throttle) error {
if prog == nil {
prog = noopSink{}
}
handle, err := store.BeginWrite(epoch)
if err != nil {
return fmt.Errorf("begin write: %w", err)
}
manifest := handle.Manifest()
// Open or create the temp file. If a previous attempt left a partial
// file of the right size, reuse it (resume); otherwise Create.
file, err := openOrCreateCube(handle.Path())
if err != nil {
_ = handle.Abort()
return err
}
date := epoch.UTC().Format("20060102")
runHour := epoch.UTC().Hour()
steps := wgfs.Hours()
totalUnits := len(steps) * 2
prog.SetTotal(totalUnits)
// Pre-count already-done units so progress is accurate on resume.
for _, u := range manifest.Units() {
_ = u
prog.StepComplete()
}
start := time.Now()
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(s.parallel())
// fileMu serialises concurrent BlitGribData calls because the underlying
// mmap is shared and SetVal isn't atomic.
var fileMu sync.Mutex
for _, step := range steps {
hourIdx := wgfs.HourIndex(step)
if hourIdx < 0 {
continue
}
for _, ls := range []wgfs.LevelSet{wgfs.LevelSetA, wgfs.LevelSetB} {
unit := unitKey(step, ls)
if manifest.Has(unit) {
continue
}
g.Go(func() error {
var url string
switch ls {
case wgfs.LevelSetA:
url = wgfs.GribURL(date, runHour, step)
case wgfs.LevelSetB:
url = wgfs.GribURLB(date, runHour, step)
}
if err := s.downloadAndBlit(ctx, file, &fileMu, url, hourIdx, ls, prog, throttle); err != nil {
return fmt.Errorf("step %d %s: %w", step, levelSetLabel(ls), err)
}
if err := manifest.Mark(unit); err != nil {
return fmt.Errorf("mark unit: %w", err)
}
prog.StepComplete()
return nil
})
}
}
if err := g.Wait(); err != nil {
_ = file.Close()
// Don't Abort on context cancellation — preserve progress for resume.
if errors.Is(err, context.Canceled) {
return err
}
// Other errors: abort if no progress was made; otherwise leave for resume.
if len(manifest.Units()) == 0 {
_ = handle.Abort()
}
return err
}
if err := file.Flush(); err != nil {
_ = file.Close()
return fmt.Errorf("flush: %w", err)
}
if err := file.Close(); err != nil {
return fmt.Errorf("close: %w", err)
}
if err := handle.Commit(); err != nil {
return fmt.Errorf("commit: %w", err)
}
s.log().Info("download complete",
zap.Time("epoch", epoch),
zap.Duration("elapsed", time.Since(start)))
return nil
}
// openOrCreateCube returns a writable cube file at path, creating it if the
// file does not exist or has the wrong size.
func openOrCreateCube(path string) (*wgfs.File, error) {
info, err := os.Stat(path)
if err == nil && info.Size() == wgfs.DatasetSize {
return wgfs.OpenWritable(path)
}
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("stat cube: %w", err)
}
// Wrong-size or missing — truncate-create.
return wgfs.Create(path)
}
// downloadAndBlit fetches and decodes one (URL, level-set) chunk and writes
// it into the dataset.
func (s *Source) downloadAndBlit(
ctx context.Context,
file *wgfs.File,
fileMu *sync.Mutex,
baseURL string,
hourIdx int,
ls wgfs.LevelSet,
prog datasets.ProgressSink,
throttle datasets.Throttle,
) error {
idxBody, err := s.httpGet(ctx, baseURL+".idx", throttle, prog)
if err != nil {
return fmt.Errorf("idx: %w", err)
}
entries := ParseIdx(idxBody)
filtered := FilterIdx(entries, neededVariables)
var relevant []IdxEntry
for _, e := range filtered {
set, ok := wgfs.PressureLevelSet(e.LevelMB)
if ok && set == ls {
relevant = append(relevant, e)
}
}
if len(relevant) == 0 {
return nil
}
ranges := EntriesToRanges(relevant)
tmp, err := os.CreateTemp("", "gfs-msg-*.tmp")
if err != nil {
return fmt.Errorf("temp: %w", err)
}
tmpPath := tmp.Name()
defer os.Remove(tmpPath)
for _, r := range ranges {
body, err := s.httpGetRange(ctx, baseURL, r.Start, r.End, throttle, prog)
if err != nil {
tmp.Close()
return fmt.Errorf("range %d-%d: %w", r.Start, r.End, err)
}
if _, err := tmp.Write(body); err != nil {
tmp.Close()
return fmt.Errorf("write tmp: %w", err)
}
}
if err := tmp.Close(); err != nil {
return err
}
f, err := os.Open(tmpPath)
if err != nil {
return err
}
messages, err := griblib.ReadMessages(f)
f.Close()
if err != nil {
return fmt.Errorf("read grib: %w", err)
}
for _, msg := range messages {
if msg.Section4.ProductDefinitionTemplateNumber != 0 {
continue
}
p := msg.Section4.ProductDefinitionTemplate
varIdx := wgfs.VariableIndex(int(p.ParameterCategory), int(p.ParameterNumber))
if varIdx < 0 {
continue
}
if p.FirstSurface.Type != 100 { // isobaric only
continue
}
pressureMB := int(math.Round(float64(p.FirstSurface.Value) / 100.0))
levelIdx := wgfs.PressureIndex(pressureMB)
if levelIdx < 0 {
continue
}
data := msg.Data()
fileMu.Lock()
err := file.BlitGribData(hourIdx, levelIdx, varIdx, data)
fileMu.Unlock()
if err != nil {
return fmt.Errorf("blit: %w", err)
}
}
return nil
}
// httpGet downloads a URL body with 3 retries and optional throttling.
func (s *Source) httpGet(ctx context.Context, url string, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) {
var lastErr error
for attempt := range 3 {
if attempt > 0 {
select {
case <-time.After(time.Duration(attempt*2) * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := s.client().Do(req)
if err != nil {
lastErr = err
continue
}
body, err := readThrottled(ctx, resp.Body, throttle, prog)
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url)
continue
}
if err != nil {
lastErr = err
continue
}
return body, nil
}
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
}
// httpGetRange downloads an inclusive byte range with 3 retries and throttling.
func (s *Source) httpGetRange(ctx context.Context, url string, start, end int64, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) {
var lastErr error
for attempt := range 3 {
if attempt > 0 {
select {
case <-time.After(time.Duration(attempt*2) * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
resp, err := s.client().Do(req)
if err != nil {
lastErr = err
continue
}
body, err := readThrottled(ctx, resp.Body, throttle, prog)
resp.Body.Close()
if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("HTTP %d for range %d-%d of %s", resp.StatusCode, start, end, url)
continue
}
if err != nil {
lastErr = err
continue
}
return body, nil
}
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
}
// readThrottled reads r into memory, consulting throttle (if non-nil) before
// each chunk and reporting bytes to prog.
func readThrottled(ctx context.Context, r io.Reader, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) {
buf := make([]byte, 0, 64*1024)
chunk := make([]byte, 32*1024)
for {
if throttle != nil {
if err := throttle.Wait(ctx, len(chunk)); err != nil {
return nil, err
}
}
n, err := r.Read(chunk)
if n > 0 {
buf = append(buf, chunk[:n]...)
prog.Bytes(int64(n))
}
if errors.Is(err, io.EOF) {
return buf, nil
}
if err != nil {
return nil, err
}
}
}
func unitKey(step int, ls wgfs.LevelSet) string {
return fmt.Sprintf("step%03d-%s", step, levelSetLabel(ls))
}
func levelSetLabel(ls wgfs.LevelSet) string {
if ls == wgfs.LevelSetB {
return "B"
}
return "A"
}
// noopSink discards progress events.
type noopSink struct{}
func (noopSink) SetTotal(int) {}
func (noopSink) StepComplete() {}
func (noopSink) Bytes(int64) {}

View file

@ -0,0 +1,383 @@
package datasets
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"predictor-refactored/internal/weather"
)
// JobStatus is the lifecycle state of a download job.
type JobStatus string
const (
JobPending JobStatus = "pending"
JobRunning JobStatus = "running"
JobComplete JobStatus = "complete"
JobFailed JobStatus = "failed"
JobCancelled JobStatus = "cancelled"
)
// JobInfo is the externally-visible snapshot of a download job.
type JobInfo struct {
ID string
Source string
Epoch time.Time
Status JobStatus
StartedAt time.Time
EndedAt *time.Time
Err string
Total int
Done int
Bytes int64
}
// jobEntry is the Manager's mutable record for one job.
type jobEntry struct {
id string
source string
epoch time.Time
startedAt time.Time
cancel context.CancelFunc
mu sync.Mutex
status JobStatus
endedAt time.Time
errStr string
total atomic.Int64
done atomic.Int64
bytes atomic.Int64
}
func (e *jobEntry) snapshot() JobInfo {
e.mu.Lock()
info := JobInfo{
ID: e.id, Source: e.source, Epoch: e.epoch,
StartedAt: e.startedAt, Status: e.status, Err: e.errStr,
}
if !e.endedAt.IsZero() {
ts := e.endedAt
info.EndedAt = &ts
}
e.mu.Unlock()
info.Total = int(e.total.Load())
info.Done = int(e.done.Load())
info.Bytes = e.bytes.Load()
return info
}
// jobProgress is the ProgressSink wired into a jobEntry.
type jobProgress struct{ e *jobEntry }
func (p jobProgress) SetTotal(n int) { p.e.total.Store(int64(n)) }
func (p jobProgress) StepComplete() { p.e.done.Add(1) }
func (p jobProgress) Bytes(n int64) { p.e.bytes.Add(n) }
// Manager coordinates dataset downloads and exposes the active WindField.
type Manager struct {
src Source
store Storage
throttle Throttle
log *zap.Logger
activeMu sync.RWMutex
active weather.WindField
jobsMu sync.RWMutex
jobs map[string]*jobEntry
// inFlight maps an epoch's RFC3339 representation to its jobID, enforcing
// single-flight per epoch.
inFlight sync.Map
}
// New returns a Manager wiring source, store, and an optional throttle.
// A nil log uses zap.NewNop().
func New(src Source, store Storage, throttle Throttle, log *zap.Logger) *Manager {
if log == nil {
log = zap.NewNop()
}
if src.ID() != store.SourceID() {
log.Warn("source/store ID mismatch",
zap.String("src", src.ID()),
zap.String("store", store.SourceID()))
}
return &Manager{
src: src, store: store, throttle: throttle, log: log,
jobs: make(map[string]*jobEntry),
}
}
// Source returns the underlying source ID.
func (m *Manager) Source() string { return m.src.ID() }
// Active returns the currently-loaded WindField, or nil.
func (m *Manager) Active() weather.WindField {
m.activeMu.RLock()
defer m.activeMu.RUnlock()
return m.active
}
// Ready reports whether a dataset is currently loaded.
func (m *Manager) Ready() bool { return m.Active() != nil }
// ListEpochs returns all stored dataset epochs, newest first.
func (m *Manager) ListEpochs() ([]time.Time, error) { return m.store.List() }
// ListJobs returns snapshots of every job recorded since startup.
func (m *Manager) ListJobs() []JobInfo {
m.jobsMu.RLock()
defer m.jobsMu.RUnlock()
out := make([]JobInfo, 0, len(m.jobs))
for _, e := range m.jobs {
out = append(out, e.snapshot())
}
return out
}
// GetJob returns the snapshot for a job, or false if id is unknown.
func (m *Manager) GetJob(id string) (JobInfo, bool) {
m.jobsMu.RLock()
e, ok := m.jobs[id]
m.jobsMu.RUnlock()
if !ok {
return JobInfo{}, false
}
return e.snapshot(), true
}
// CancelJob cancels a running job. Returns false if id is unknown or the
// job is already terminal.
func (m *Manager) CancelJob(id string) bool {
m.jobsMu.RLock()
e, ok := m.jobs[id]
m.jobsMu.RUnlock()
if !ok {
return false
}
e.mu.Lock()
terminal := e.status == JobComplete || e.status == JobFailed || e.status == JobCancelled
e.mu.Unlock()
if terminal {
return false
}
e.cancel()
return true
}
// RemoveEpoch deletes a stored dataset. If epoch is currently active, the
// active field is cleared.
func (m *Manager) RemoveEpoch(epoch time.Time) error {
epoch = epoch.UTC()
if active := m.Active(); active != nil && active.Epoch().Equal(epoch) {
m.activeMu.Lock()
m.active = nil
m.activeMu.Unlock()
}
return m.store.Remove(epoch)
}
// Download starts (or resumes) a download job for epoch in the background.
// Returns the JobID. If a job for the same epoch is already running, its
// existing JobID is returned.
//
// If the dataset is already present on disk, a synthetic completed JobInfo
// is recorded and its JobID returned.
func (m *Manager) Download(epoch time.Time) string {
epoch = epoch.UTC()
key := epoch.Format(time.RFC3339)
if existing, ok := m.inFlight.Load(key); ok {
return existing.(string)
}
jobID := uuid.New().String()
if other, loaded := m.inFlight.LoadOrStore(key, jobID); loaded {
return other.(string)
}
ctx, cancel := context.WithCancel(context.Background())
now := time.Now().UTC()
e := &jobEntry{
id: jobID,
source: m.src.ID(),
epoch: epoch,
startedAt: now,
status: JobPending,
cancel: cancel,
}
m.jobsMu.Lock()
m.jobs[jobID] = e
m.jobsMu.Unlock()
if m.store.Exists(epoch) {
// Skip the download but still record the job for traceability.
go m.completeShortCircuit(ctx, e)
return jobID
}
go m.runDownload(ctx, e)
return jobID
}
// LoadEpoch swaps the active WindField to epoch's stored dataset.
func (m *Manager) LoadEpoch(ctx context.Context, epoch time.Time) error {
epoch = epoch.UTC()
if !m.store.Exists(epoch) {
return fmt.Errorf("epoch %s not present on disk", epoch.Format(time.RFC3339))
}
field, err := m.src.Open(ctx, epoch, m.store)
if err != nil {
return fmt.Errorf("open epoch: %w", err)
}
m.swapActive(field)
m.log.Info("loaded dataset",
zap.Time("epoch", epoch),
zap.String("source", m.src.ID()))
return nil
}
// Refresh ensures the most recent upstream dataset is downloaded and active.
//
// If the freshest stored dataset is newer than retentionTTL old, no upstream
// check is performed. Otherwise the source's LatestEpoch is consulted; if it
// is newer than the active dataset, a download is started and on completion
// the new dataset becomes active.
//
// Returns the JobID started, or empty string when nothing was scheduled.
func (m *Manager) Refresh(ctx context.Context, freshnessTTL time.Duration) (string, error) {
if active := m.Active(); active != nil && time.Since(active.Epoch()) < freshnessTTL {
return "", nil
}
// Try loading the freshest existing dataset before going to the network.
if epochs, err := m.store.List(); err == nil {
for _, e := range epochs {
if time.Since(e) > freshnessTTL {
continue
}
if active := m.Active(); active != nil && active.Epoch().Equal(e) {
return "", nil
}
if err := m.LoadEpoch(ctx, e); err == nil {
return "", nil
}
}
}
latest, err := m.src.LatestEpoch(ctx)
if err != nil {
return "", fmt.Errorf("latest epoch: %w", err)
}
if active := m.Active(); active != nil && !latest.After(active.Epoch()) {
return "", nil
}
jobID := m.Download(latest)
// Spawn a watcher that loads the dataset on successful completion.
go func() {
for {
info, ok := m.GetJob(jobID)
if !ok {
return
}
switch info.Status {
case JobComplete:
if err := m.LoadEpoch(context.Background(), latest); err != nil {
m.log.Error("load after download", zap.Error(err))
}
return
case JobFailed, JobCancelled:
return
}
time.Sleep(2 * time.Second)
}
}()
return jobID, nil
}
// runDownload executes one Source.Download invocation and records its outcome.
func (m *Manager) runDownload(ctx context.Context, e *jobEntry) {
defer m.inFlight.Delete(e.epoch.Format(time.RFC3339))
e.mu.Lock()
e.status = JobRunning
e.mu.Unlock()
m.log.Info("download started",
zap.String("job", e.id),
zap.Time("epoch", e.epoch))
err := m.src.Download(ctx, e.epoch, m.store, jobProgress{e: e}, m.throttle)
now := time.Now().UTC()
e.mu.Lock()
e.endedAt = now
switch {
case errors.Is(err, context.Canceled):
e.status = JobCancelled
case err != nil:
e.status = JobFailed
e.errStr = err.Error()
default:
e.status = JobComplete
}
finalStatus := e.status
e.mu.Unlock()
m.log.Info("download finished",
zap.String("job", e.id),
zap.String("status", string(finalStatus)),
zap.NamedError("err", err))
}
// completeShortCircuit records a job as complete without performing any work.
func (m *Manager) completeShortCircuit(ctx context.Context, e *jobEntry) {
_ = ctx
defer m.inFlight.Delete(e.epoch.Format(time.RFC3339))
now := time.Now().UTC()
e.mu.Lock()
e.status = JobComplete
e.endedAt = now
e.mu.Unlock()
}
// swapActive replaces the active field and closes the previous one if it
// implements io.Closer.
func (m *Manager) swapActive(f weather.WindField) {
m.activeMu.Lock()
old := m.active
m.active = f
m.activeMu.Unlock()
if c, ok := old.(interface{ Close() error }); ok && c != nil {
if err := c.Close(); err != nil {
m.log.Warn("close old dataset", zap.Error(err))
}
}
}
// Close releases all resources, cancelling any in-flight jobs.
func (m *Manager) Close() error {
m.jobsMu.Lock()
for _, e := range m.jobs {
e.cancel()
}
m.jobsMu.Unlock()
m.activeMu.Lock()
active := m.active
m.active = nil
m.activeMu.Unlock()
if c, ok := active.(interface{ Close() error }); ok && c != nil {
return c.Close()
}
return nil
}

View file

@ -0,0 +1,118 @@
package datasets
import (
"encoding/json"
"errors"
"fmt"
"os"
"sort"
"sync"
)
// Manifest tracks completed work units for a partial dataset download.
// Units are arbitrary opaque strings; sources choose the format
// (e.g. "step12-A" for "forecast step 12, level set A").
//
// A Manifest is persisted as a JSON object: {"units": ["step0-A", "step0-B", ...]}.
type Manifest struct {
path string
mu sync.Mutex
units map[string]struct{}
}
// LoadManifest opens or creates the manifest at path. Missing or unreadable
// files are treated as empty; a corrupt file returns an error.
func LoadManifest(path string) (*Manifest, error) {
m := &Manifest{path: path, units: make(map[string]struct{})}
data, err := os.ReadFile(path)
if errors.Is(err, os.ErrNotExist) {
return m, nil
}
if err != nil {
return nil, fmt.Errorf("read manifest %s: %w", path, err)
}
if len(data) == 0 {
return m, nil
}
var doc struct {
Units []string `json:"units"`
}
if err := json.Unmarshal(data, &doc); err != nil {
return nil, fmt.Errorf("parse manifest %s: %w", path, err)
}
for _, u := range doc.Units {
m.units[u] = struct{}{}
}
return m, nil
}
// Has reports whether unit has been recorded as completed.
func (m *Manifest) Has(unit string) bool {
m.mu.Lock()
defer m.mu.Unlock()
_, ok := m.units[unit]
return ok
}
// Mark records unit as completed and persists the manifest to disk.
func (m *Manifest) Mark(unit string) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.units[unit]; ok {
return nil
}
m.units[unit] = struct{}{}
return m.persistLocked()
}
// Units returns the completed units in sorted order.
func (m *Manifest) Units() []string {
m.mu.Lock()
defer m.mu.Unlock()
out := make([]string, 0, len(m.units))
for u := range m.units {
out = append(out, u)
}
sort.Strings(out)
return out
}
// Reset clears all recorded units and removes the manifest file.
func (m *Manifest) Reset() error {
m.mu.Lock()
defer m.mu.Unlock()
m.units = make(map[string]struct{})
if err := os.Remove(m.path); err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("remove manifest %s: %w", m.path, err)
}
return nil
}
// persistLocked writes the manifest to disk via temp+rename.
// The caller must hold m.mu.
func (m *Manifest) persistLocked() error {
units := make([]string, 0, len(m.units))
for u := range m.units {
units = append(units, u)
}
sort.Strings(units)
data, err := json.Marshal(struct {
Units []string `json:"units"`
}{Units: units})
if err != nil {
return err
}
tmp := m.path + ".new"
if err := os.WriteFile(tmp, data, 0o644); err != nil {
return fmt.Errorf("write manifest temp: %w", err)
}
if err := os.Rename(tmp, m.path); err != nil {
os.Remove(tmp)
return fmt.Errorf("rename manifest: %w", err)
}
return nil
}

View file

@ -0,0 +1,167 @@
package datasets
import (
"errors"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"time"
)
// LocalStore stores dataset files on the local filesystem.
//
// Layout under Root:
//
// <epoch>.bin — committed dataset (binary cube)
// <epoch>.bin.downloading — in-progress dataset
// <epoch>.bin.manifest.json — manifest of completed work units
//
// The .bin suffix exists to differentiate from sidecars in directory listings;
// epoch is formatted as "20060102T150405Z" (UTC).
type LocalStore struct {
Root string
Source string // source ID, recorded for safety but currently advisory
Extension string // default ".bin"
}
// NewLocalStore returns a LocalStore at root. The directory is created if missing.
func NewLocalStore(root, sourceID string) (*LocalStore, error) {
if err := os.MkdirAll(root, 0o755); err != nil {
return nil, fmt.Errorf("create store root %s: %w", root, err)
}
return &LocalStore{Root: root, Source: sourceID, Extension: ".bin"}, nil
}
// SourceID returns the source ID this store is configured for.
func (s *LocalStore) SourceID() string { return s.Source }
const epochFormat = "20060102T150405Z"
func (s *LocalStore) ext() string {
if s.Extension == "" {
return ".bin"
}
return s.Extension
}
// Path returns the canonical path for an epoch's committed dataset file.
func (s *LocalStore) Path(epoch time.Time) string {
return filepath.Join(s.Root, epoch.UTC().Format(epochFormat)+s.ext())
}
func (s *LocalStore) tempPath(epoch time.Time) string {
return s.Path(epoch) + ".downloading"
}
func (s *LocalStore) manifestPath(epoch time.Time) string {
return s.Path(epoch) + ".manifest.json"
}
// Exists reports whether a committed dataset for epoch is present.
func (s *LocalStore) Exists(epoch time.Time) bool {
info, err := os.Stat(s.Path(epoch))
return err == nil && !info.IsDir()
}
// List returns all committed epochs, newest first.
func (s *LocalStore) List() ([]time.Time, error) {
entries, err := os.ReadDir(s.Root)
if err != nil {
return nil, fmt.Errorf("read store: %w", err)
}
var out []time.Time
ext := s.ext()
for _, e := range entries {
if e.IsDir() {
continue
}
name := e.Name()
if !strings.HasSuffix(name, ext) {
continue
}
stem := strings.TrimSuffix(name, ext)
// skip in-progress files (their stem already has .bin.downloading...)
if strings.Contains(stem, ".") {
continue
}
t, err := time.Parse(epochFormat, stem)
if err != nil {
continue
}
out = append(out, t.UTC())
}
sort.Slice(out, func(i, j int) bool { return out[i].After(out[j]) })
return out, nil
}
// Remove deletes the committed dataset and any sidecar files for epoch.
func (s *LocalStore) Remove(epoch time.Time) error {
var errs []error
for _, p := range []string{s.Path(epoch), s.tempPath(epoch), s.manifestPath(epoch)} {
if err := os.Remove(p); err != nil && !errors.Is(err, os.ErrNotExist) {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return fmt.Errorf("remove dataset: %v", errs)
}
return nil
}
// BeginWrite opens or resumes a TempHandle for epoch.
//
// If a partial download is already present, its file and manifest are reused
// so the new download picks up where the previous one stopped.
func (s *LocalStore) BeginWrite(epoch time.Time) (TempHandle, error) {
man, err := LoadManifest(s.manifestPath(epoch))
if err != nil {
return nil, err
}
return &localHandle{
store: s,
epoch: epoch,
manifest: man,
}, nil
}
type localHandle struct {
store *LocalStore
epoch time.Time
manifest *Manifest
closed bool
}
func (h *localHandle) Path() string { return h.store.tempPath(h.epoch) }
func (h *localHandle) Manifest() *Manifest { return h.manifest }
// Commit promotes the temp file to its final path and removes the manifest.
func (h *localHandle) Commit() error {
if h.closed {
return nil
}
h.closed = true
if err := os.Rename(h.store.tempPath(h.epoch), h.store.Path(h.epoch)); err != nil {
return fmt.Errorf("commit rename: %w", err)
}
if err := os.Remove(h.store.manifestPath(h.epoch)); err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("commit remove manifest: %w", err)
}
return nil
}
// Abort removes the in-progress file and manifest.
func (h *localHandle) Abort() error {
if h.closed {
return nil
}
h.closed = true
var firstErr error
for _, p := range []string{h.store.tempPath(h.epoch), h.store.manifestPath(h.epoch)} {
if err := os.Remove(p); err != nil && !errors.Is(err, os.ErrNotExist) && firstErr == nil {
firstErr = err
}
}
return firstErr
}

View file

@ -0,0 +1,82 @@
package datasets
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestLocalStoreBeginWriteResume(t *testing.T) {
dir := t.TempDir()
store, err := NewLocalStore(dir, "gfs-test")
if err != nil {
t.Fatalf("NewLocalStore: %v", err)
}
epoch := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
h, err := store.BeginWrite(epoch)
if err != nil {
t.Fatalf("BeginWrite: %v", err)
}
if err := os.WriteFile(h.Path(), []byte("partial"), 0o644); err != nil {
t.Fatalf("write partial: %v", err)
}
if err := h.Manifest().Mark("step000-A"); err != nil {
t.Fatalf("mark: %v", err)
}
// Re-open should see the previous manifest entry.
h2, err := store.BeginWrite(epoch)
if err != nil {
t.Fatalf("BeginWrite resume: %v", err)
}
if !h2.Manifest().Has("step000-A") {
t.Errorf("resumed manifest missing step000-A; units = %v", h2.Manifest().Units())
}
// Commit promotes the temp file and removes the manifest.
if err := h2.Commit(); err != nil {
t.Fatalf("Commit: %v", err)
}
if !store.Exists(epoch) {
t.Errorf("Exists after commit returned false")
}
if _, err := os.Stat(filepath.Join(dir, store.manifestPath(epoch))); !os.IsNotExist(err) {
t.Errorf("manifest should be removed, got err=%v", err)
}
// Listing finds the committed epoch.
epochs, err := store.List()
if err != nil {
t.Fatalf("List: %v", err)
}
if len(epochs) != 1 || !epochs[0].Equal(epoch) {
t.Errorf("List = %v, want [%v]", epochs, epoch)
}
// Remove cleans up.
if err := store.Remove(epoch); err != nil {
t.Fatalf("Remove: %v", err)
}
if store.Exists(epoch) {
t.Errorf("Exists after remove returned true")
}
}
func TestLocalStoreAbort(t *testing.T) {
dir := t.TempDir()
store, _ := NewLocalStore(dir, "gfs-test")
epoch := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
h, _ := store.BeginWrite(epoch)
os.WriteFile(h.Path(), []byte("x"), 0o644)
h.Manifest().Mark("step000-A")
if err := h.Abort(); err != nil {
t.Fatalf("Abort: %v", err)
}
if _, err := os.Stat(h.Path()); !os.IsNotExist(err) {
t.Errorf("temp file should be removed after abort, got %v", err)
}
}

View file

@ -0,0 +1,63 @@
package datasets
import (
"context"
"sync"
"time"
)
// TokenBucket is a simple bytes-per-second rate limiter.
//
// The bucket is initialised full (capacity = rate × 1 second). Calls to Wait
// block until enough tokens have accumulated.
type TokenBucket struct {
mu sync.Mutex
rate float64 // tokens per second
tokens float64
cap float64
last time.Time
}
// NewTokenBucket returns a TokenBucket emitting at most bytesPerSecond.
// A non-positive rate disables throttling (Wait becomes a no-op).
func NewTokenBucket(bytesPerSecond int64) *TokenBucket {
if bytesPerSecond <= 0 {
return &TokenBucket{rate: 0}
}
r := float64(bytesPerSecond)
return &TokenBucket{rate: r, tokens: r, cap: r, last: time.Now()}
}
// Wait blocks until n tokens are available or ctx is cancelled.
func (t *TokenBucket) Wait(ctx context.Context, n int) error {
if t.rate <= 0 {
return nil
}
want := float64(n)
for {
t.mu.Lock()
now := time.Now()
elapsed := now.Sub(t.last).Seconds()
t.last = now
t.tokens += elapsed * t.rate
if t.tokens > t.cap {
t.tokens = t.cap
}
if t.tokens >= want {
t.tokens -= want
t.mu.Unlock()
return nil
}
// Sleep until we expect enough tokens.
need := want - t.tokens
sleep := time.Duration(need / t.rate * float64(time.Second))
t.mu.Unlock()
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(sleep):
}
}
}

View file

@ -0,0 +1,97 @@
package datasets
import (
"context"
"time"
"predictor-refactored/internal/weather"
)
// Source is a pluggable origin for atmospheric datasets.
//
// Implementations download dataset files in a transactional, resumable
// manner and load them as weather.WindField. A Source must be safe for
// concurrent use across multiple Manager calls.
type Source interface {
// ID is a stable identifier, e.g. "noaa-gfs-0p50".
ID() string
// LatestEpoch returns the most recent dataset epoch this source can provide.
LatestEpoch(ctx context.Context) (time.Time, error)
// Download fetches the dataset for epoch into store. Sources must honour
// any partial progress recorded in store's manifest and skip
// already-completed work, so re-invocation after a crash resumes cleanly.
//
// prog receives progress events; nil is acceptable.
// throttle, if non-nil, is consulted before each network read for
// bandwidth limiting; nil means no throttling.
Download(ctx context.Context, epoch time.Time, store Storage, prog ProgressSink, throttle Throttle) error
// Open loads epoch's stored dataset and returns it as a WindField.
Open(ctx context.Context, epoch time.Time, store Storage) (weather.WindField, error)
}
// Storage abstracts the on-disk location of dataset files and their manifests.
//
// Atomicity: only datasets promoted via TempHandle.Commit appear in Exists or
// List. Aborted or in-progress downloads are invisible to readers.
type Storage interface {
// SourceID identifies the data source these files belong to. Mixing
// sources in one Storage is not supported.
SourceID() string
// Path returns the canonical local path for epoch's dataset. The path
// is valid even when the dataset has not been written.
Path(epoch time.Time) string
// Exists reports whether a committed dataset for epoch is present.
Exists(epoch time.Time) bool
// List returns all committed epochs available, newest first.
List() ([]time.Time, error)
// Remove deletes the dataset and any sidecar manifest for epoch.
Remove(epoch time.Time) error
// BeginWrite opens (or resumes) a transactional handle for downloading
// epoch's dataset. Callers must Commit or Abort the returned handle.
BeginWrite(epoch time.Time) (TempHandle, error)
}
// TempHandle is the storage state for one in-progress download.
type TempHandle interface {
// Path returns the path of the in-progress file. Sources write directly here.
Path() string
// Manifest is the tracker of completed work units for resume support.
Manifest() *Manifest
// Commit promotes the temp file to its canonical location and removes
// the manifest. Subsequent calls are no-ops.
Commit() error
// Abort discards the temp file and manifest. Subsequent calls are no-ops.
Abort() error
}
// ProgressSink receives progress events during a download.
//
// All methods are safe to call concurrently.
type ProgressSink interface {
// SetTotal sets the total number of work units this download expects.
// May be called multiple times if discovery happens incrementally.
SetTotal(n int)
// StepComplete records one work unit as completed.
StepComplete()
// Bytes records n bytes received from the network.
Bytes(n int64)
}
// Throttle is an optional bandwidth limiter consulted by sources before
// each network read.
type Throttle interface {
// Wait blocks until n bytes can be consumed from the budget,
// or returns ctx's error if the context is cancelled while waiting.
Wait(ctx context.Context, n int) error
}

View file

@ -1,58 +0,0 @@
package downloader
import (
"os"
"strconv"
"time"
)
// Config holds downloader configuration, loaded from environment variables.
type Config struct {
// DataDir is the directory for storing dataset files and temporary GRIB data.
DataDir string
// Parallel is the maximum number of concurrent GRIB downloads.
Parallel int
// UpdateInterval is how often the scheduler checks for new forecast data.
UpdateInterval time.Duration
// DatasetTTL is how long a dataset is considered fresh before a new one is needed.
DatasetTTL time.Duration
}
// DefaultConfig returns the default configuration.
func DefaultConfig() *Config {
return &Config{
DataDir: "/tmp/predictor-data",
Parallel: 8,
UpdateInterval: 6 * time.Hour,
DatasetTTL: 48 * time.Hour,
}
}
// LoadConfig loads configuration from environment variables, falling back to defaults.
func LoadConfig() *Config {
cfg := DefaultConfig()
if v := os.Getenv("PREDICTOR_DATA_DIR"); v != "" {
cfg.DataDir = v
}
if v := os.Getenv("PREDICTOR_DOWNLOAD_PARALLEL"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
cfg.Parallel = n
}
}
if v := os.Getenv("PREDICTOR_UPDATE_INTERVAL"); v != "" {
if d, err := time.ParseDuration(v); err == nil {
cfg.UpdateInterval = d
}
}
if v := os.Getenv("PREDICTOR_DATASET_TTL"); v != "" {
if d, err := time.ParseDuration(v); err == nil {
cfg.DatasetTTL = d
}
}
return cfg
}

View file

@ -1,441 +0,0 @@
package downloader
import (
"context"
"fmt"
"io"
"math"
"net/http"
"os"
"path/filepath"
"sync/atomic"
"time"
"predictor-refactored/internal/dataset"
"github.com/nilsmagnus/grib/griblib"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
)
// Downloader handles fetching GFS forecast data from S3 and assembling dataset files.
type Downloader struct {
cfg *Config
client *http.Client
log *zap.Logger
}
// NewDownloader creates a new Downloader.
func NewDownloader(cfg *Config, log *zap.Logger) *Downloader {
return &Downloader{
cfg: cfg,
client: &http.Client{
Timeout: 2 * time.Minute,
},
log: log,
}
}
// neededVariables is the set of GRIB variable names we need.
var neededVariables = map[string]bool{
"HGT": true,
"UGRD": true,
"VGRD": true,
}
// FindLatestRun finds the most recent available GFS model run on S3.
// It checks the last forecast step of each run to confirm availability.
func (d *Downloader) FindLatestRun(ctx context.Context) (time.Time, error) {
now := time.Now().UTC()
hour := now.Hour() - (now.Hour() % 6)
current := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC)
for i := 0; i < 8; i++ {
date := current.Format("20060102")
url := dataset.GribURL(date, current.Hour(), dataset.MaxHour) + ".idx"
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
if err != nil {
current = current.Add(-6 * time.Hour)
continue
}
resp, err := d.client.Do(req)
if err == nil {
resp.Body.Close()
if resp.StatusCode == http.StatusOK {
d.log.Info("found latest model run",
zap.Time("run", current),
zap.String("verified_url", url))
return current, nil
}
}
current = current.Add(-6 * time.Hour)
}
return time.Time{}, fmt.Errorf("no recent GFS forecast found (checked 8 runs)")
}
// progress tracks download progress across concurrent goroutines.
type progress struct {
bytesDownloaded atomic.Int64
stepsCompleted atomic.Int64
totalSteps int64
startTime time.Time
log *zap.Logger
}
func newProgress(totalSteps int, log *zap.Logger) *progress {
return &progress{
totalSteps: int64(totalSteps),
startTime: time.Now(),
log: log,
}
}
func (p *progress) addBytes(n int64) {
p.bytesDownloaded.Add(n)
}
func (p *progress) completeStep() {
done := p.stepsCompleted.Add(1)
total := p.totalSteps
bytes := p.bytesDownloaded.Load()
elapsed := time.Since(p.startTime).Seconds()
pct := float64(done) / float64(total) * 100
mbDownloaded := float64(bytes) / (1024 * 1024)
mbPerSec := 0.0
if elapsed > 0 {
mbPerSec = mbDownloaded / elapsed
}
// Estimate remaining
eta := ""
if done > 0 && done < total {
secsPerStep := elapsed / float64(done)
remaining := secsPerStep * float64(total-done)
if remaining > 60 {
eta = fmt.Sprintf("%.0fm%02.0fs", math.Floor(remaining/60), math.Mod(remaining, 60))
} else {
eta = fmt.Sprintf("%.0fs", remaining)
}
}
p.log.Info("download progress",
zap.String("progress", fmt.Sprintf("%d/%d", done, total)),
zap.String("percent", fmt.Sprintf("%.1f%%", pct)),
zap.String("downloaded", fmt.Sprintf("%.1f MB", mbDownloaded)),
zap.String("speed", fmt.Sprintf("%.1f MB/s", mbPerSec)),
zap.String("eta", eta))
}
// Download downloads a complete forecast and assembles a dataset file.
// Returns the path to the completed dataset file.
func (d *Downloader) Download(ctx context.Context, run time.Time) (string, error) {
date := run.Format("20060102")
runHour := run.Hour()
finalPath := filepath.Join(d.cfg.DataDir, run.Format("2006010215"))
tempPath := finalPath + ".downloading"
// Check if final dataset already exists
if info, err := os.Stat(finalPath); err == nil && info.Size() == dataset.DatasetSize {
d.log.Info("dataset already exists", zap.String("path", finalPath))
return finalPath, nil
}
steps := dataset.Hours()
totalSteps := len(steps) * 2 // pgrb2 + pgrb2b per step
prog := newProgress(totalSteps, d.log)
d.log.Info("starting dataset download",
zap.Time("run", run),
zap.Int("total_steps", totalSteps),
zap.String("temp_path", tempPath))
// Create the dataset file
ds, err := dataset.Create(tempPath)
if err != nil {
return "", fmt.Errorf("create dataset: %w", err)
}
defer ds.Close()
// Process each forecast step with bounded concurrency
g, ctx := errgroup.WithContext(ctx)
sem := make(chan struct{}, d.cfg.Parallel)
for _, step := range steps {
step := step
hourIdx := dataset.HourIndex(step)
if hourIdx < 0 {
continue
}
// Download pgrb2 (level set A)
sem <- struct{}{}
g.Go(func() error {
defer func() { <-sem }()
url := dataset.GribURL(date, runHour, step)
err := d.downloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetA, prog)
if err != nil {
return fmt.Errorf("step %d pgrb2: %w", step, err)
}
prog.completeStep()
return nil
})
// Download pgrb2b (level set B)
sem <- struct{}{}
g.Go(func() error {
defer func() { <-sem }()
url := dataset.GribURLB(date, runHour, step)
err := d.downloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetB, prog)
if err != nil {
return fmt.Errorf("step %d pgrb2b: %w", step, err)
}
prog.completeStep()
return nil
})
}
if err := g.Wait(); err != nil {
os.Remove(tempPath)
return "", err
}
elapsed := time.Since(prog.startTime)
totalMB := float64(prog.bytesDownloaded.Load()) / (1024 * 1024)
d.log.Info("download complete, flushing to disk",
zap.String("downloaded", fmt.Sprintf("%.1f MB", totalMB)),
zap.Duration("elapsed", elapsed),
zap.String("avg_speed", fmt.Sprintf("%.1f MB/s", totalMB/elapsed.Seconds())))
// Flush to disk
if err := ds.Flush(); err != nil {
os.Remove(tempPath)
return "", fmt.Errorf("flush dataset: %w", err)
}
// Close before rename
ds.Close()
// Atomic rename
if err := os.Rename(tempPath, finalPath); err != nil {
os.Remove(tempPath)
return "", fmt.Errorf("rename dataset: %w", err)
}
d.log.Info("dataset ready", zap.String("path", finalPath))
return finalPath, nil
}
// DownloadAndBlit downloads needed GRIB fields from a URL and writes them into the dataset.
func (d *Downloader) DownloadAndBlit(ctx context.Context, ds *dataset.File, baseURL string, hourIdx int, levelSet dataset.LevelSet) error {
return d.downloadAndBlit(ctx, ds, baseURL, hourIdx, levelSet, nil)
}
// downloadAndBlit is the internal implementation with optional progress tracking.
func (d *Downloader) downloadAndBlit(ctx context.Context, ds *dataset.File, baseURL string, hourIdx int, levelSet dataset.LevelSet, prog *progress) error {
// 1. Download .idx
idxURL := baseURL + ".idx"
idxBody, err := d.httpGet(ctx, idxURL)
if err != nil {
return fmt.Errorf("download idx: %w", err)
}
// 2. Parse and filter
entries := ParseIdx(idxBody)
filtered := FilterIdx(entries, neededVariables)
// Further filter to only levels in this level set
var relevant []IdxEntry
for _, e := range filtered {
ls, ok := dataset.PressureLevelSet(e.LevelMB)
if ok && ls == levelSet {
relevant = append(relevant, e)
}
}
if len(relevant) == 0 {
d.log.Warn("no relevant entries found in idx",
zap.String("url", idxURL),
zap.Int("total_entries", len(entries)),
zap.Int("filtered", len(filtered)))
return nil
}
// 3. Download byte ranges and write to temp file
ranges := EntriesToRanges(relevant)
tmpFile, err := d.downloadRangesToTempFile(ctx, baseURL, ranges, prog)
if err != nil {
return fmt.Errorf("download ranges: %w", err)
}
defer os.Remove(tmpFile)
// 4. Read GRIB messages from temp file
f, err := os.Open(tmpFile)
if err != nil {
return fmt.Errorf("open temp grib: %w", err)
}
messages, err := griblib.ReadMessages(f)
f.Close()
if err != nil {
return fmt.Errorf("read grib messages: %w", err)
}
// 5. Decode and blit each message into the dataset
for _, msg := range messages {
if msg.Section4.ProductDefinitionTemplateNumber != 0 {
continue
}
product := msg.Section4.ProductDefinitionTemplate
varIdx := dataset.VariableIndex(int(product.ParameterCategory), int(product.ParameterNumber))
if varIdx < 0 {
continue
}
if product.FirstSurface.Type != 100 { // isobaric surface
continue
}
pressurePa := float64(product.FirstSurface.Value)
pressureMB := int(math.Round(pressurePa / 100.0))
levelIdx := dataset.PressureIndex(pressureMB)
if levelIdx < 0 {
continue
}
data := msg.Data()
if err := ds.BlitGribData(hourIdx, levelIdx, varIdx, data); err != nil {
d.log.Warn("blit failed",
zap.Int("var", varIdx),
zap.Int("level_mb", pressureMB),
zap.Error(err))
continue
}
}
return nil
}
// downloadRangesToTempFile downloads multiple byte ranges from a URL,
// concatenating them into a single temp file (valid concatenated GRIB messages).
func (d *Downloader) downloadRangesToTempFile(ctx context.Context, baseURL string, ranges []ByteRange, prog *progress) (string, error) {
tmpFile, err := os.CreateTemp(d.cfg.DataDir, "grib-*.tmp")
if err != nil {
return "", fmt.Errorf("create temp file: %w", err)
}
tmpPath := tmpFile.Name()
for _, r := range ranges {
data, err := d.httpGetRange(ctx, baseURL, r.Start, r.End)
if err != nil {
tmpFile.Close()
os.Remove(tmpPath)
return "", fmt.Errorf("download range %d-%d: %w", r.Start, r.End, err)
}
if _, err := tmpFile.Write(data); err != nil {
tmpFile.Close()
os.Remove(tmpPath)
return "", fmt.Errorf("write temp: %w", err)
}
if prog != nil {
prog.addBytes(int64(len(data)))
}
}
if err := tmpFile.Close(); err != nil {
os.Remove(tmpPath)
return "", err
}
return tmpPath, nil
}
// httpGet downloads a URL and returns the body bytes.
func (d *Downloader) httpGet(ctx context.Context, url string) ([]byte, error) {
var lastErr error
for attempt := 0; attempt < 3; attempt++ {
if attempt > 0 {
select {
case <-time.After(time.Duration(attempt*2) * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := d.client.Do(req)
if err != nil {
lastErr = err
continue
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url)
continue
}
if err != nil {
lastErr = err
continue
}
return body, nil
}
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
}
// httpGetRange downloads a byte range from a URL.
func (d *Downloader) httpGetRange(ctx context.Context, url string, start, end int64) ([]byte, error) {
var lastErr error
for attempt := 0; attempt < 3; attempt++ {
if attempt > 0 {
select {
case <-time.After(time.Duration(attempt*2) * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
resp, err := d.client.Do(req)
if err != nil {
lastErr = err
continue
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("HTTP %d for range %d-%d of %s", resp.StatusCode, start, end, url)
continue
}
if err != nil {
lastErr = err
continue
}
return body, nil
}
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
}

View file

@ -1,157 +0,0 @@
package downloader
import (
"fmt"
"strconv"
"strings"
)
// IdxEntry represents a single parsed line from a GRIB .idx file.
// Example line: "15:1207405:d=2024010100:HGT:1000 mb:0 hour fcst:"
type IdxEntry struct {
Index int
Offset int64
Variable string // "HGT", "UGRD", "VGRD", etc.
LevelMB int // pressure level in mb (0 if not a pressure level)
Hour int // forecast hour
EndOffset int64 // byte after this message (from next entry's offset, or -1 if last)
}
// Length returns the byte length of this GRIB message, or -1 if unknown.
func (e *IdxEntry) Length() int64 {
if e.EndOffset <= 0 {
return -1
}
return e.EndOffset - e.Offset
}
// ParseIdx parses a .idx file body and returns all entries.
// Lines that can't be parsed are silently skipped.
func ParseIdx(body []byte) []IdxEntry {
lines := strings.Split(string(body), "\n")
var entries []IdxEntry
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.Split(line, ":")
if len(parts) < 7 {
continue
}
idx, err := strconv.Atoi(parts[0])
if err != nil {
continue
}
offset, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
continue
}
variable := parts[3]
levelStr := parts[4]
hourStr := parts[5]
levelMB := parseLevelMB(levelStr)
hour := parseHour(hourStr)
entries = append(entries, IdxEntry{
Index: idx,
Offset: offset,
Variable: variable,
LevelMB: levelMB,
Hour: hour,
EndOffset: -1, // filled in below
})
}
// Fill in EndOffset from the next entry's Offset.
for i := 0; i < len(entries)-1; i++ {
entries[i].EndOffset = entries[i+1].Offset
}
return entries
}
// FilterIdx returns entries matching the given variables at pressure levels.
// Only entries with a recognized pressure level (levelMB > 0) are returned.
func FilterIdx(entries []IdxEntry, variables map[string]bool) []IdxEntry {
var filtered []IdxEntry
for _, e := range entries {
if !variables[e.Variable] {
continue
}
if e.LevelMB <= 0 {
continue
}
// Must have a known length (not the last entry) or be handled specially
if e.Length() <= 0 {
continue
}
filtered = append(filtered, e)
}
return filtered
}
// parseLevelMB parses a level string like "1000 mb" and returns the pressure in mb.
// Returns 0 if not a pressure level.
func parseLevelMB(s string) int {
s = strings.TrimSpace(s)
if !strings.HasSuffix(s, " mb") {
return 0
}
numStr := strings.TrimSuffix(s, " mb")
n, err := strconv.Atoi(numStr)
if err != nil {
return 0
}
return n
}
// parseHour parses a forecast hour string like "0 hour fcst" or "anl".
// Returns -1 if it can't be parsed.
func parseHour(s string) int {
s = strings.TrimSpace(s)
if s == "anl" {
return 0
}
s = strings.TrimSuffix(s, " hour fcst")
n, err := strconv.Atoi(s)
if err != nil {
return -1
}
return n
}
// GroupByRange groups idx entries into byte ranges suitable for HTTP Range downloads.
// Each range covers one contiguous GRIB message.
type ByteRange struct {
Start int64
End int64 // inclusive
Entry IdxEntry
}
// EntriesToRanges converts filtered idx entries to byte ranges.
func EntriesToRanges(entries []IdxEntry) []ByteRange {
ranges := make([]ByteRange, 0, len(entries))
for _, e := range entries {
if e.Length() <= 0 {
continue
}
ranges = append(ranges, ByteRange{
Start: e.Offset,
End: e.EndOffset - 1, // inclusive
Entry: e,
})
}
return ranges
}
// FormatRange returns an HTTP Range header value for a byte range.
func (r ByteRange) FormatRange() string {
return fmt.Sprintf("bytes=%d-%d", r.Start, r.End)
}

View file

@ -1,110 +0,0 @@
package downloader
import (
"testing"
)
const sampleIdx = `1:0:d=2024010100:HGT:1000 mb:0 hour fcst:
2:289012:d=2024010100:HGT:975 mb:0 hour fcst:
3:541876:d=2024010100:TMP:1000 mb:0 hour fcst:
4:789012:d=2024010100:UGRD:1000 mb:0 hour fcst:
5:1045678:d=2024010100:VGRD:1000 mb:0 hour fcst:
6:1298765:d=2024010100:UGRD:975 mb:0 hour fcst:
7:1567890:d=2024010100:UGRD:2 m above ground:0 hour fcst:
8:1812345:d=2024010100:VGRD:975 mb:0 hour fcst:
9:2098765:d=2024010100:HGT:500 mb:3 hour fcst:
`
func TestParseIdx(t *testing.T) {
entries := ParseIdx([]byte(sampleIdx))
if len(entries) != 9 {
t.Fatalf("expected 9 entries, got %d", len(entries))
}
// Check first entry
e := entries[0]
if e.Index != 1 || e.Offset != 0 || e.Variable != "HGT" || e.LevelMB != 1000 || e.Hour != 0 {
t.Errorf("entry 0: got %+v", e)
}
if e.EndOffset != 289012 {
t.Errorf("entry 0 EndOffset: got %d, want 289012", e.EndOffset)
}
// Check "2 m above ground" is not a pressure level
e = entries[6] // UGRD at "2 m above ground"
if e.LevelMB != 0 {
t.Errorf("non-pressure level should have LevelMB=0, got %d", e.LevelMB)
}
// Last entry should have EndOffset = -1
last := entries[len(entries)-1]
if last.EndOffset != -1 {
t.Errorf("last entry EndOffset: got %d, want -1", last.EndOffset)
}
}
func TestFilterIdx(t *testing.T) {
entries := ParseIdx([]byte(sampleIdx))
filtered := FilterIdx(entries, neededVariables)
// Should include HGT/UGRD/VGRD at pressure levels, exclude TMP and "above ground"
// And exclude last entry (no EndOffset)
for _, e := range filtered {
if !neededVariables[e.Variable] {
t.Errorf("unexpected variable %s", e.Variable)
}
if e.LevelMB <= 0 {
t.Errorf("non-pressure level included: %+v", e)
}
if e.Length() <= 0 {
t.Errorf("entry with unknown length included: %+v", e)
}
}
// Count expected: HGT@1000, HGT@975, UGRD@1000, VGRD@1000, UGRD@975, VGRD@975 = 6
// But HGT@500 at 3hr fcst is the last entry (no EndOffset), so excluded
if len(filtered) != 6 {
t.Errorf("expected 6 filtered entries, got %d", len(filtered))
for _, e := range filtered {
t.Logf(" %s %d mb (offset %d, len %d)", e.Variable, e.LevelMB, e.Offset, e.Length())
}
}
}
func TestParseLevelMB(t *testing.T) {
tests := []struct {
input string
want int
}{
{"1000 mb", 1000},
{"975 mb", 975},
{"1 mb", 1},
{"2 m above ground", 0},
{"surface", 0},
{"tropopause", 0},
}
for _, tt := range tests {
got := parseLevelMB(tt.input)
if got != tt.want {
t.Errorf("parseLevelMB(%q) = %d, want %d", tt.input, got, tt.want)
}
}
}
func TestParseHour(t *testing.T) {
tests := []struct {
input string
want int
}{
{"0 hour fcst", 0},
{"3 hour fcst", 3},
{"192 hour fcst", 192},
{"anl", 0},
}
for _, tt := range tests {
got := parseHour(tt.input)
if got != tt.want {
t.Errorf("parseHour(%q) = %d, want %d", tt.input, got, tt.want)
}
}
}

View file

@ -71,9 +71,10 @@ func (d *Dataset) getCell(latIdx, lngIdx int) int16 {
return int16(binary.LittleEndian.Uint16(d.mm[off : off+2])) return int16(binary.LittleEndian.Uint16(d.mm[off : off+2]))
} }
// Get returns the interpolated elevation in metres at the given coordinates. // Elevation returns the bilinearly-interpolated ground elevation in metres at
// lat: -90 to +90, lng: 0 to 360 (or -180 to 180, will be normalised). // the given coordinates. lat is in [-90, +90]; lng accepts either [0, 360) or
func (d *Dataset) Get(lat, lng float64) float64 { // [-180, 180) and is normalised internally.
func (d *Dataset) Elevation(lat, lng float64) float64 {
// Normalise longitude to [0, 360) // Normalise longitude to [0, 360)
if lng < 0 { if lng < 0 {
lng += 360 lng += 360

View file

@ -0,0 +1,47 @@
package engine
// MaxAltitude triggers when altitude rises above Limit (in metres).
// Used as the burst condition for ascent stages.
type MaxAltitude struct {
Limit float64
On Action
}
func (c MaxAltitude) Name() string { return "max_altitude" }
func (c MaxAltitude) Violated(_ float64, s State) bool { return s.Altitude >= c.Limit }
func (c MaxAltitude) Action() Action { return c.On }
// MinAltitude triggers when altitude falls at or below Limit (in metres).
// With Limit=0 this is the "sea level" terminator.
type MinAltitude struct {
Limit float64
On Action
}
func (c MinAltitude) Name() string { return "min_altitude" }
func (c MinAltitude) Violated(_ float64, s State) bool { return s.Altitude <= c.Limit }
func (c MinAltitude) Action() Action { return c.On }
// MaxTime triggers when t exceeds Limit (UNIX seconds). Used as a stop
// condition for float profiles.
type MaxTime struct {
Limit float64
On Action
}
func (c MaxTime) Name() string { return "max_time" }
func (c MaxTime) Violated(t float64, _ State) bool { return t > c.Limit }
func (c MaxTime) Action() Action { return c.On }
// TerrainContact triggers when altitude has dropped at or below ground level.
// Equivalent to Tawhiri's elevation termination.
type TerrainContact struct {
Provider TerrainProvider
On Action
}
func (c TerrainContact) Name() string { return "terrain_contact" }
func (c TerrainContact) Violated(_ float64, s State) bool {
return c.Provider.Elevation(s.Lat, s.Lng) > s.Altitude
}
func (c TerrainContact) Action() Action { return c.On }

View file

@ -0,0 +1,176 @@
package engine
import (
"math"
"testing"
"time"
"predictor-refactored/internal/weather"
)
// noWind is a WindField that always returns zero wind. Lets us test
// integration of vertical-only profiles deterministically.
type noWind struct{ epoch time.Time }
func (n noWind) Wind(_ float64, _, _, _ float64) (weather.Sample, error) {
return weather.Sample{}, nil
}
func (n noWind) Epoch() time.Time { return n.epoch }
func (n noWind) Source() string { return "test" }
// flatGround returns 0 metres everywhere.
type flatGround struct{}
func (flatGround) Elevation(_, _ float64) float64 { return 0 }
func TestConstantAscentToBurst(t *testing.T) {
burst := 30000.0
rate := 5.0
ascend := &Propagator{
Name: "ascent",
Step: 60,
Model: Sum(ConstantRate(rate), WindTransport(noWind{}, nil)),
Constraints: []Constraint{MaxAltitude{Limit: burst, On: ActionStop}},
}
prof := Profile{Stages: []*Propagator{ascend}, Direction: Forward}
results := prof.Run(0, State{Lat: 0, Lng: 0, Altitude: 0})
if len(results) != 1 || results[0].Outcome != OutcomeStopped {
t.Fatalf("expected one stopped stage, got %+v", results)
}
last := results[0].Points[len(results[0].Points)-1]
// Refinement tolerance is 0.01 in parameter space over a 60s step, so the
// returned point sits within ±0.6s × rate ≈ ±3m of the boundary.
if math.Abs(last.Altitude-burst) > 5 {
t.Errorf("burst altitude = %v, want within 5m of %v", last.Altitude, burst)
}
wantTime := burst / rate
if math.Abs(last.Time-wantTime) > 1 {
t.Errorf("burst time = %v, want within 1s of %v", last.Time, wantTime)
}
}
func TestProfileWithFallback(t *testing.T) {
burst := 1000.0
rate := 5.0
descent := &Propagator{
Name: "descent",
Step: 60,
Model: ParachuteDescent(rate),
Constraints: []Constraint{TerrainContact{Provider: flatGround{}, On: ActionStop}},
}
ascend := &Propagator{
Name: "ascent",
Step: 60,
Model: ConstantRate(rate),
Constraints: []Constraint{MaxAltitude{Limit: burst, On: ActionFallback}},
Fallback: descent,
}
prof := Profile{Stages: []*Propagator{ascend}, Direction: Forward}
results := prof.Run(0, State{Altitude: 0})
if len(results) != 2 {
t.Fatalf("expected 2 results (ascent then descent fallback), got %d", len(results))
}
if results[0].Outcome != OutcomeFallback {
t.Errorf("first outcome = %v, want OutcomeFallback", results[0].Outcome)
}
if results[1].Outcome != OutcomeStopped {
t.Errorf("second outcome = %v, want OutcomeStopped", results[1].Outcome)
}
last := results[1].Points[len(results[1].Points)-1]
if math.Abs(last.Altitude) > 5 {
t.Errorf("final altitude = %v, want within 5m of 0", last.Altitude)
}
}
func TestReverseDirection(t *testing.T) {
// Start at altitude 100m with downward rate; integrating reverse should
// give increasing altitude.
desc := &Propagator{
Name: "rewind",
Step: 1,
Model: ConstantRate(-1), // forward: alt decreases at 1 m/s
Constraints: []Constraint{MaxAltitude{Limit: 200, On: ActionStop}},
}
prof := Profile{Stages: []*Propagator{desc}, Direction: Reverse}
results := prof.Run(0, State{Altitude: 100})
last := results[0].Points[len(results[0].Points)-1]
if math.Abs(last.Altitude-200) > 1 {
t.Errorf("reverse final altitude = %v, want ~200", last.Altitude)
}
if last.Time >= 0 {
t.Errorf("reverse final time = %v, want < 0", last.Time)
}
}
func TestPiecewiseRate(t *testing.T) {
m := Piecewise([]RateSegment{
{Until: 100, Rate: 5},
{Until: 200, Rate: 3},
{Until: math.Inf(1), Rate: 0},
})
if r := m(50, State{}); r.Altitude != 5 {
t.Errorf("rate at t=50 = %v, want 5", r.Altitude)
}
if r := m(150, State{}); r.Altitude != 3 {
t.Errorf("rate at t=150 = %v, want 3", r.Altitude)
}
if r := m(300, State{}); r.Altitude != 0 {
t.Errorf("rate at t=300 = %v, want 0", r.Altitude)
}
}
// fixedWind returns a constant wind sample.
type fixedWind struct{ u, v float64 }
func (w fixedWind) Wind(_ float64, _, _, _ float64) (weather.Sample, error) {
return weather.Sample{U: w.u, V: w.v}, nil
}
func (fixedWind) Epoch() time.Time { return time.Unix(0, 0) }
func (fixedWind) Source() string { return "test-fixed" }
func TestWindTransportUnitConversion(t *testing.T) {
// Pure eastward wind of 10 m/s at the equator at sea level.
// Expected dlng/dt = (180/pi) * 10 / (6371009 * cos(0)) ≈ 0.00008991 deg/s.
// Expected dlat/dt = 0.
wind := WindTransport(fixedWind{u: 10, v: 0}, nil)
d := wind(0, State{Lat: 0, Lng: 0, Altitude: 0})
wantLng := (180.0 / math.Pi) * 10.0 / 6371009.0
if math.Abs(d.Lng-wantLng) > 1e-12 {
t.Errorf("dlng = %v, want %v", d.Lng, wantLng)
}
if math.Abs(d.Lat) > 1e-12 {
t.Errorf("dlat = %v, want 0 for u=10 v=0", d.Lat)
}
// Pure northward at 60° latitude: dlat = (180/pi) * v / R, dlng = 0.
wind2 := WindTransport(fixedWind{u: 0, v: 5}, nil)
d = wind2(0, State{Lat: 60, Lng: 0, Altitude: 0})
wantLat := (180.0 / math.Pi) * 5.0 / 6371009.0
if math.Abs(d.Lat-wantLat) > 1e-12 {
t.Errorf("dlat at lat=60 = %v, want %v", d.Lat, wantLat)
}
}
func TestStateAddWrapsLongitude(t *testing.T) {
// Demonstrates state algebra used by the integrator and refinement.
s := stateAdd(State{Lat: 0, Lng: 350, Altitude: 0}, 1, State{Lng: 20})
if math.Abs(s.Lng-10) > 1e-9 {
t.Errorf("addState wrap: lng = %v, want 10", s.Lng)
}
mid := stateLerp(State{Lng: 350}, State{Lng: 10}, 0.5)
if math.Abs(mid.Lng-0) > 1e-9 && math.Abs(mid.Lng-360) > 1e-9 {
t.Errorf("lerpState lng wrap: %v, want 0 or 360", mid.Lng)
}
}

151
internal/engine/models.go Normal file
View file

@ -0,0 +1,151 @@
package engine
import (
"math"
"sort"
"sync/atomic"
"predictor-refactored/internal/weather"
)
// Sum composes models by summing their derivatives at each evaluation point.
//
// Useful for combining e.g. a vertical-rate model with a horizontal wind model
// into a single propagator. Equivalent to Tawhiri's LinearModel.
func Sum(models ...Model) Model {
if len(models) == 1 {
return models[0]
}
return func(t float64, s State) State {
var sum State
for _, m := range models {
d := m(t, s)
sum.Lat += d.Lat
sum.Lng += d.Lng
sum.Altitude += d.Altitude
}
return sum
}
}
// ConstantRate returns a model with a constant vertical velocity (m/s).
// A positive rate is upward (ascent); a negative rate is downward.
func ConstantRate(rate float64) Model {
return func(_ float64, _ State) State {
return State{Altitude: rate}
}
}
// ParachuteDescent returns a model where vertical velocity grows with altitude
// because thinner air provides less drag.
//
// seaLevelRate is the descent speed at sea level (m/s, positive number).
// The terminal velocity at altitude is computed as
//
// v = -k / sqrt(rho(alt)), k = seaLevelRate * 1.1045,
//
// using the NASA atmosphere model for rho. Equivalent to Tawhiri's drag_descent.
func ParachuteDescent(seaLevelRate float64) Model {
k := seaLevelRate * 1.1045
return func(_ float64, s State) State {
return State{Altitude: -k / math.Sqrt(nasaDensity(s.Altitude))}
}
}
// nasaDensity returns air density (kg/m^3) for the given altitude in metres,
// using the NASA simple atmosphere model. See
// https://www.grc.nasa.gov/WWW/K-12/airplane/atmosmet.html.
func nasaDensity(alt float64) float64 {
var temp, pressure float64
switch {
case alt > 25000:
temp = -131.21 + 0.00299*alt
pressure = 2.488 * math.Pow((temp+273.1)/216.6, -11.388)
case alt > 11000:
temp = -56.46
pressure = 22.65 * math.Exp(1.73-0.000157*alt)
default:
temp = 15.04 - 0.00649*alt
pressure = 101.29 * math.Pow((temp+273.1)/288.08, 5.256)
}
return pressure / (0.2869 * (temp + 273.1))
}
// RateSegment is one entry in a Piecewise rate schedule.
type RateSegment struct {
// Until is the UNIX timestamp at which this segment ends.
// The model applies the segment's Rate for all t < Until.
Until float64
// Rate is the vertical velocity (m/s) during the segment. Positive is up.
Rate float64
}
// Piecewise returns a model that produces a piecewise-constant vertical rate
// over a sequence of time intervals.
//
// Segments are searched by their Until field; the first segment whose Until
// exceeds t supplies the active rate. For t at or after the last Until, the
// final segment's Rate is held indefinitely. Input is sorted ascending by
// Until on construction.
func Piecewise(segments []RateSegment) Model {
if len(segments) == 0 {
return ConstantRate(0)
}
sorted := append([]RateSegment(nil), segments...)
sort.Slice(sorted, func(i, j int) bool { return sorted[i].Until < sorted[j].Until })
finalRate := sorted[len(sorted)-1].Rate
return func(t float64, _ State) State {
idx := sort.Search(len(sorted), func(i int) bool { return sorted[i].Until > t })
if idx == len(sorted) {
return State{Altitude: finalRate}
}
return State{Altitude: sorted[idx].Rate}
}
}
// Warnings aggregates non-fatal conditions encountered during integration.
type Warnings struct {
// AltitudeTooHigh counts evaluations where the wind sampler reported
// that altitude was above the highest pressure level of the dataset.
AltitudeTooHigh atomic.Int64
}
// ToMap returns warnings as a map suitable for JSON output. Only counters
// that have fired are included.
func (w *Warnings) ToMap() map[string]any {
out := make(map[string]any)
if n := w.AltitudeTooHigh.Load(); n > 0 {
out["altitude_too_high"] = map[string]any{
"count": n,
"description": "altitude exceeded the highest pressure level of the wind dataset; samples were extrapolated",
}
}
return out
}
// WindTransport returns a model that moves laterally at the wind velocity
// sampled from field. The vertical component of the returned derivative is
// zero. Wind units are converted from m/s to deg/s on Earth's surface.
//
// If warnings is non-nil, the AltitudeTooHigh counter is incremented for any
// sample where the wind field reported altitude above the model top.
func WindTransport(field weather.WindField, warnings *Warnings) Model {
const earthR = 6371009.0
const piOver180 = math.Pi / 180.0
const degPerRad = 180.0 / math.Pi
return func(t float64, s State) State {
sample, err := field.Wind(t, s.Lat, s.Lng, s.Altitude)
if err != nil {
return State{}
}
if sample.AboveModel && warnings != nil {
warnings.AltitudeTooHigh.Add(1)
}
r := earthR + s.Altitude
return State{
Lat: degPerRad * sample.V / r,
Lng: degPerRad * sample.U / (r * math.Cos(s.Lat*piOver180)),
}
}
}

View file

@ -0,0 +1,55 @@
package engine
// Profile is an ordered chain of propagators executed sequentially. Each
// propagator picks up where the previous one finished.
type Profile struct {
// Stages are run in order. For Direction=Reverse they are still iterated
// from index 0 onwards, but each propagator integrates with negative dt.
Stages []*Propagator
// Direction controls the sign of dt across the whole profile.
Direction Direction
// Globals are constraints evaluated alongside each stage's local Constraints.
// Useful for profile-wide bounds like "stop after N hours total".
Globals []Constraint
}
// Run executes the profile from the given launch point. Returns one Result
// per executed stage, including any Fallback chains that were activated.
func (p *Profile) Run(t0 float64, launch State) []Result {
if p.Direction == 0 {
p.Direction = Forward
}
results := make([]Result, 0, len(p.Stages))
t, s := t0, launch
for i := 0; i < len(p.Stages); i++ {
stage := p.Stages[i]
res := stage.run(t, s, p.Direction, p.Globals)
results = append(results, res)
last := res.Points[len(res.Points)-1]
t = last.Time
s = State{Lat: last.Lat, Lng: last.Lng, Altitude: last.Altitude}
// Follow Fallback chains until none remains. Each fallback consumes
// from the same point the previous stage stopped at.
for res.Outcome == OutcomeFallback && stage.Fallback != nil {
stage = stage.Fallback
res = stage.run(t, s, p.Direction, p.Globals)
results = append(results, res)
last = res.Points[len(res.Points)-1]
t = last.Time
s = State{Lat: last.Lat, Lng: last.Lng, Altitude: last.Altitude}
}
// If a propagator's stop fired (not a fallback), end the profile.
if res.Outcome == OutcomeStopped {
continue
}
}
return results
}

View file

@ -0,0 +1,156 @@
package engine
import (
"predictor-refactored/internal/numerics"
)
// Propagator advances state under one Model, checking a set of Constraints
// after every integration step.
//
// When a constraint fires, the propagator binary-search refines the violation
// point and emits it as its final trajectory point. The Action of the
// triggering constraint controls what the surrounding Profile does next:
// stop the profile, transfer to Fallback, or clip and continue.
type Propagator struct {
// Name identifies the propagator in trajectory metadata.
Name string
// Step is the magnitude of the integration step in seconds (always positive).
// The Profile flips its sign for Reverse direction.
Step float64
// Model produces the per-second time derivative of state.
Model Model
// Constraints are evaluated after each step. Any fired constraint stops
// the propagator at the refined point; the first one in this slice wins
// on ties.
Constraints []Constraint
// Fallback is the propagator to switch to when a constraint with
// ActionFallback fires. Optional.
Fallback *Propagator
// Tolerance is the binary-search refinement tolerance in parameter space
// (default 0.01, matching Tawhiri).
Tolerance float64
}
// Outcome describes how a propagator's run ended.
type Outcome int
const (
// OutcomeStopped means a Constraint with ActionStop fired and the profile
// should end here.
OutcomeStopped Outcome = iota
// OutcomeFallback means a Constraint with ActionFallback fired and the
// profile should transfer to the propagator's Fallback chain.
OutcomeFallback
// OutcomeContinued means no constraint fired before the time horizon was
// reached. In practice this is only seen when a propagator runs unbounded,
// which means the profile is misconfigured.
OutcomeContinued
)
// Result is the output of running one propagator.
type Result struct {
Propagator string
Points []TrajectoryPoint
Outcome Outcome
// Constraint is the constraint that fired, or nil if Outcome == OutcomeContinued.
Constraint Constraint
}
// run integrates the model from (t0, s0) in direction dir, returning a Result.
// globals are constraints injected by the Profile and checked alongside the
// propagator's local Constraints.
func (p *Propagator) run(t0 float64, s0 State, dir Direction, globals []Constraint) Result {
dt := p.Step * float64(dir)
tol := p.Tolerance
if tol == 0 {
tol = 0.01
}
deriv := numerics.Deriv[State](func(t float64, s State) State { return p.Model(t, s) })
add := numerics.VecAdd[State](stateAdd)
lerp := numerics.VecLerp[State](stateLerp)
out := Result{
Propagator: p.Name,
Outcome: OutcomeContinued,
Points: []TrajectoryPoint{{
Time: t0, Lat: s0.Lat, Lng: s0.Lng, Altitude: s0.Altitude,
}},
}
t := t0
s := s0
for {
s2 := numerics.RK4Step(t, s, dt, deriv, add)
t2 := t + dt
if c, fired := firstFiring(p.Constraints, globals, t2, s2); fired {
trig := numerics.Trigger[State](func(tt float64, ss State) bool { return c.Violated(tt, ss) })
t3, s3 := numerics.RefineTrigger(t, s, t2, s2, trig, lerp, tol)
switch c.Action() {
case ActionClip:
s3 = clipToConstraint(c, s3)
out.Points = append(out.Points, TrajectoryPoint{
Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude,
})
t, s = t3, s3
continue
case ActionFallback:
out.Points = append(out.Points, TrajectoryPoint{
Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude,
})
out.Outcome = OutcomeFallback
out.Constraint = c
return out
default: // ActionStop
out.Points = append(out.Points, TrajectoryPoint{
Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude,
})
out.Outcome = OutcomeStopped
out.Constraint = c
return out
}
}
t, s = t2, s2
out.Points = append(out.Points, TrajectoryPoint{
Time: t, Lat: s.Lat, Lng: s.Lng, Altitude: s.Altitude,
})
}
}
// firstFiring scans local then global constraints for the first one whose
// Violated returns true at (t, s).
func firstFiring(local, globals []Constraint, t float64, s State) (Constraint, bool) {
for _, c := range local {
if c.Violated(t, s) {
return c, true
}
}
for _, c := range globals {
if c.Violated(t, s) {
return c, true
}
}
return nil, false
}
// clipToConstraint adjusts s so that the given constraint is exactly satisfied
// (not violated). Implemented for constraints with a well-defined boundary;
// others fall through unchanged.
func clipToConstraint(c Constraint, s State) State {
switch v := c.(type) {
case MaxAltitude:
s.Altitude = v.Limit
case MinAltitude:
s.Altitude = v.Limit
}
return s
}

50
internal/engine/state.go Normal file
View file

@ -0,0 +1,50 @@
package engine
import "math"
// pymod returns a % b with Python semantics: the result has the sign of b,
// so for b > 0 the result is always in [0, b).
func pymod(a, b float64) float64 {
r := math.Mod(a, b)
if r < 0 {
r += b
}
return r
}
// stateAdd is the RK4 integrator's update operation y + k*dy, with longitude
// kept wrapped to [0, 360).
//
// Time is not stored in State — it is tracked separately by the integrator
// and passed to Model.
func stateAdd(y State, k float64, dy State) State {
return State{
Lat: y.Lat + k*dy.Lat,
Lng: pymod(y.Lng+k*dy.Lng, 360),
Altitude: y.Altitude + k*dy.Altitude,
}
}
// stateLerp computes the linear interpolation of two states by parameter l
// in [0, 1]. Longitude uses lngLerp so that wrap-around is handled.
func stateLerp(a, b State, l float64) State {
return State{
Lat: (1-l)*a.Lat + l*b.Lat,
Lng: lngLerp(a.Lng, b.Lng, l),
Altitude: (1-l)*a.Altitude + l*b.Altitude,
}
}
// lngLerp interpolates between two longitudes in [0, 360), choosing the
// shorter great-circle arc.
func lngLerp(a, b, l float64) float64 {
l2 := 1 - l
if a > b {
a, b = b, a
l, l2 = l2, l
}
if b-a < 180 {
return l2*a + l*b
}
return pymod(l2*(a+360)+l*b, 360)
}

80
internal/engine/types.go Normal file
View file

@ -0,0 +1,80 @@
// Package engine is the trajectory calculation engine. It composes
// propagators (model-driven integrators) into profiles (ordered chains) and
// runs them over a wind field.
//
// The engine has no direct dependency on any specific data source: wind data
// is consumed through weather.WindField and terrain data through any type
// satisfying TerrainProvider.
package engine
// State holds the spatial state of the balloon. When returned by a Model
// the same struct is interpreted as the per-second time derivative of state.
type State struct {
// Lat is degrees latitude in [-90, 90] (or deg/s when returned as a derivative).
Lat float64
// Lng is degrees longitude in [0, 360) (or deg/s as a derivative).
Lng float64
// Altitude is metres above mean sea level (or m/s as a derivative).
Altitude float64
}
// Model returns the time derivative of state at (t, s).
//
// The derivative is direction-independent; the integrator applies the sign
// of dt for reverse propagation.
type Model func(t float64, s State) State
// TrajectoryPoint is one sampled point of an integration result.
type TrajectoryPoint struct {
Time float64 // UNIX seconds
Lat float64
Lng float64
Altitude float64
}
// Direction is the time direction of integration. Forward (+1) integrates
// from launch to landing; Reverse (-1) integrates from a known landing back
// to a candidate launch point.
type Direction int8
const (
Forward Direction = +1
Reverse Direction = -1
)
// Action describes what the profile runner should do when a Constraint
// reports a violation.
type Action int
const (
// ActionStop ends the current propagator at the (refined) violation point.
// This matches the only behaviour available in the reference Tawhiri solver.
ActionStop Action = iota
// ActionFallback ends the current propagator and starts its Fallback
// propagator from the violation point. Useful for "if max altitude is
// reached during ascent, switch to descent" profiles.
ActionFallback
// ActionClip clips the violated coordinate to the boundary and continues
// integration. Useful for soft constraints such as "max altitude floor".
ActionClip
)
// Constraint reports when integration should stop, branch, or clip.
//
// A constraint is direction-agnostic: it reads state and decides. The profile
// runner is responsible for refining the trigger point via binary search and
// dispatching the configured Action.
type Constraint interface {
// Name identifies the constraint in logs and result metadata.
Name() string
// Violated reports whether the constraint is breached at (t, s).
Violated(t float64, s State) bool
// Action is the behaviour to take on violation.
Action() Action
}
// TerrainProvider returns ground elevation in metres at a coordinate.
// Implementations must be safe for concurrent use.
type TerrainProvider interface {
Elevation(lat, lng float64) float64
}

146
internal/metrics/prom.go Normal file
View file

@ -0,0 +1,146 @@
package metrics
import (
"fmt"
"io"
"net/http"
"sort"
"strings"
"sync"
"time"
)
// Prom is a minimal Sink that exposes counters and gauges in Prometheus's
// text exposition format. No external dependencies.
//
// The Prom sink supports labelled counters, sums (for durations and byte
// counts), and labelled gauges. Histograms are intentionally omitted; if
// they are needed later, swap Prom for an OTel-based sink.
type Prom struct {
mu sync.Mutex
counters map[string]map[string]float64 // name → label-key → value
gauges map[string]map[string]float64 // name → label-key → value
}
// NewProm returns an empty Prom sink.
func NewProm() *Prom {
return &Prom{
counters: make(map[string]map[string]float64),
gauges: make(map[string]map[string]float64),
}
}
// Prediction implements Sink.
func (p *Prom) Prediction(profile string, d time.Duration, err error) {
status := "ok"
if err != nil {
status = "error"
}
labels := map[string]string{"profile": profile, "status": status}
p.incCounter("predictor_predictions_total", labels, 1)
p.incCounter("predictor_prediction_duration_seconds_sum", labels, d.Seconds())
p.incCounter("predictor_prediction_duration_seconds_count", labels, 1)
}
// Download implements Sink.
func (p *Prom) Download(source string, d time.Duration, status string, bytes int64) {
labels := map[string]string{"source": source, "status": status}
p.incCounter("predictor_downloads_total", labels, 1)
p.incCounter("predictor_download_duration_seconds_sum", labels, d.Seconds())
p.incCounter("predictor_download_bytes_total", map[string]string{"source": source}, float64(bytes))
}
// ActiveEpoch implements Sink.
func (p *Prom) ActiveEpoch(t time.Time) {
var v float64
if !t.IsZero() {
v = float64(t.Unix())
}
p.setGauge("predictor_active_dataset_epoch_seconds", map[string]string{}, v)
}
// ServeHTTP writes the metrics in Prometheus text exposition format.
func (p *Prom) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain; version=0.0.4")
p.Write(w)
}
// Write writes the metrics in Prometheus exposition format to w.
func (p *Prom) Write(w io.Writer) {
p.mu.Lock()
defer p.mu.Unlock()
names := make([]string, 0, len(p.counters)+len(p.gauges))
for n := range p.counters {
names = append(names, n)
}
for n := range p.gauges {
names = append(names, n)
}
sort.Strings(names)
for _, name := range names {
if labels, ok := p.counters[name]; ok {
fmt.Fprintf(w, "# TYPE %s counter\n", name)
writeMetricFamily(w, name, labels)
}
if labels, ok := p.gauges[name]; ok {
fmt.Fprintf(w, "# TYPE %s gauge\n", name)
writeMetricFamily(w, name, labels)
}
}
}
func writeMetricFamily(w io.Writer, name string, labels map[string]float64) {
keys := make([]string, 0, len(labels))
for k := range labels {
keys = append(keys, k)
}
sort.Strings(keys)
for _, key := range keys {
fmt.Fprintf(w, "%s%s %g\n", name, key, labels[key])
}
}
func (p *Prom) incCounter(name string, labels map[string]string, n float64) {
key := labelKey(labels)
p.mu.Lock()
defer p.mu.Unlock()
if p.counters[name] == nil {
p.counters[name] = make(map[string]float64)
}
p.counters[name][key] += n
}
func (p *Prom) setGauge(name string, labels map[string]string, v float64) {
key := labelKey(labels)
p.mu.Lock()
defer p.mu.Unlock()
if p.gauges[name] == nil {
p.gauges[name] = make(map[string]float64)
}
p.gauges[name][key] = v
}
// labelKey renders the labels into a Prometheus-format "{k1="v1",k2="v2"}"
// suffix, empty if no labels.
func labelKey(labels map[string]string) string {
if len(labels) == 0 {
return ""
}
keys := make([]string, 0, len(labels))
for k := range labels {
keys = append(keys, k)
}
sort.Strings(keys)
var sb strings.Builder
sb.WriteByte('{')
for i, k := range keys {
if i > 0 {
sb.WriteByte(',')
}
fmt.Fprintf(&sb, "%s=%q", k, labels[k])
}
sb.WriteByte('}')
return sb.String()
}

View file

@ -0,0 +1,49 @@
package metrics
import (
"bytes"
"strings"
"testing"
"time"
)
func TestPromCounters(t *testing.T) {
p := NewProm()
p.Prediction("standard_profile", 100*time.Millisecond, nil)
p.Prediction("standard_profile", 200*time.Millisecond, nil)
p.Prediction("float_profile", 50*time.Millisecond, nil)
var buf bytes.Buffer
p.Write(&buf)
out := buf.String()
if !strings.Contains(out, `predictor_predictions_total{profile="standard_profile",status="ok"} 2`) {
t.Errorf("expected count=2 for standard_profile, got: %s", out)
}
if !strings.Contains(out, `predictor_predictions_total{profile="float_profile",status="ok"} 1`) {
t.Errorf("expected count=1 for float_profile, got: %s", out)
}
// Sum of durations: 0.1 + 0.2 = 0.3 seconds.
if !strings.Contains(out, "predictor_prediction_duration_seconds_sum") {
t.Errorf("expected sum present, got: %s", out)
}
}
func TestPromGauge(t *testing.T) {
p := NewProm()
p.ActiveEpoch(time.Unix(1700000000, 0))
var buf bytes.Buffer
p.Write(&buf)
out := buf.String()
if !strings.Contains(out, "predictor_active_dataset_epoch_seconds 1.7e+09") {
t.Errorf("expected gauge with epoch 1700000000, got: %s", out)
}
}
func TestNoop(t *testing.T) {
sink := Noop()
sink.Prediction("any", time.Second, nil)
sink.Download("any", time.Second, "complete", 0)
sink.ActiveEpoch(time.Now())
}

36
internal/metrics/types.go Normal file
View file

@ -0,0 +1,36 @@
// Package metrics defines the Sink interface used to record service metrics
// and ships two implementations: a Noop sink (default, zero-cost) and a Prom
// sink that exposes counters in the Prometheus text exposition format.
//
// The metrics layer is optional: if no Sink is wired (or Noop is wired), the
// service runs unchanged.
package metrics
import "time"
// Sink collects observations from the rest of the service.
//
// Implementations must be safe for concurrent use across many goroutines.
// All methods are advisory; implementations may ignore any observation.
type Sink interface {
// Prediction records the duration and outcome of one prediction.
// err is nil on success; otherwise the error's class is used as a label.
Prediction(profile string, duration time.Duration, err error)
// Download records the outcome of one dataset download job.
// status is "complete", "failed", or "cancelled".
Download(source string, duration time.Duration, status string, bytes int64)
// ActiveEpoch reports the forecast time of the currently-loaded dataset.
// Pass time.Time{} when no dataset is loaded.
ActiveEpoch(t time.Time)
}
// Noop returns a Sink that discards every observation.
func Noop() Sink { return noop{} }
type noop struct{}
func (noop) Prediction(string, time.Duration, error) {}
func (noop) Download(string, time.Duration, string, int64) {}
func (noop) ActiveEpoch(time.Time) {}

11
internal/numerics/doc.go Normal file
View file

@ -0,0 +1,11 @@
// Package numerics provides the numerical primitives used by the trajectory
// engine: regular-grid multilinear interpolation, monotone bisection, and
// a generic explicit Runge-Kutta-4 integrator with binary-search refinement
// of a termination point.
//
// The package has no dependencies on any domain type. State and derivative
// types are generic, and all coordinate-wrap or unit-conversion semantics
// live in the caller.
//
// All algorithms are documented in docs/numerics.tex.
package numerics

86
internal/numerics/grid.go Normal file
View file

@ -0,0 +1,86 @@
package numerics
import "fmt"
// Axis describes a regularly-spaced grid axis with N grid points,
// values left, left+step, left+2*step, ..., left+(N-1)*step.
//
// If Wrap is true, the axis is periodic with period N*step (e.g. longitude).
// A query value at left+N*step wraps to the value at left+0*step. Locate
// returns Hi = 0 in that case.
type Axis struct {
Left float64
Step float64
N int
Wrap bool
Name string
}
// AxisError is returned by Axis.Locate when value lies outside a non-wrapping axis.
type AxisError struct {
Axis string
Value float64
}
func (e *AxisError) Error() string {
return fmt.Sprintf("%s=%v out of range", e.Axis, e.Value)
}
// Bracket holds the two surrounding grid indices and the fractional position
// of a value within an axis. The weight at Lo is (1 - Frac); the weight at Hi
// is Frac. Frac lies in [0, 1).
type Bracket struct {
Lo, Hi int
Frac float64
}
// Locate returns the bracket containing value within the axis.
// For a non-wrapping axis, value must lie in [Left, Left + (N-1)*Step);
// for a wrapping axis, value must lie in [Left, Left + N*Step).
func (a Axis) Locate(value float64) (Bracket, error) {
pos := (value - a.Left) / a.Step
lo := int(pos) // truncates toward zero; pos is non-negative for valid inputs
maxLo := a.N - 2
if a.Wrap {
maxLo = a.N - 1
}
if lo < 0 || lo > maxLo {
return Bracket{}, &AxisError{Axis: a.Name, Value: value}
}
hi := lo + 1
if a.Wrap && hi == a.N {
hi = 0
}
return Bracket{Lo: lo, Hi: hi, Frac: pos - float64(lo)}, nil
}
// EvalTrilinear samples a 3D field via f at the eight corners defined by b3
// and returns the trilinearly interpolated value.
//
// The corners are visited in the order (axis0 outer, axis2 inner), matching
// the Cython reference. With f(i,j,k) = a*i + b*j + c*k + d this returns
// a*pos0 + b*pos1 + c*pos2 + d exactly, modulo floating-point rounding.
func EvalTrilinear(b3 [3]Bracket, f func(i, j, k int) float64) float64 {
wa0, wa1 := 1-b3[0].Frac, b3[0].Frac
wb0, wb1 := 1-b3[1].Frac, b3[1].Frac
wc0, wc1 := 1-b3[2].Frac, b3[2].Frac
a0, a1 := b3[0].Lo, b3[0].Hi
bb0, bb1 := b3[1].Lo, b3[1].Hi
c0, c1 := b3[2].Lo, b3[2].Hi
return wa0*wb0*wc0*f(a0, bb0, c0) +
wa0*wb0*wc1*f(a0, bb0, c1) +
wa0*wb1*wc0*f(a0, bb1, c0) +
wa0*wb1*wc1*f(a0, bb1, c1) +
wa1*wb0*wc0*f(a1, bb0, c0) +
wa1*wb0*wc1*f(a1, bb0, c1) +
wa1*wb1*wc0*f(a1, bb1, c0) +
wa1*wb1*wc1*f(a1, bb1, c1)
}
// Lerp returns (1-l)*a + l*b.
func Lerp(a, b, l float64) float64 {
return (1-l)*a + l*b
}

View file

@ -0,0 +1,94 @@
package numerics
import (
"math"
"testing"
)
func TestAxisLocate(t *testing.T) {
a := Axis{Left: -90, Step: 0.5, N: 361, Name: "lat"}
b, err := a.Locate(-90)
if err != nil || b.Lo != 0 || b.Hi != 1 || b.Frac != 0 {
t.Errorf("Locate(-90) = %+v, %v; want {0 1 0}, nil", b, err)
}
b, err = a.Locate(0)
if err != nil || b.Lo != 180 || b.Hi != 181 || b.Frac != 0 {
t.Errorf("Locate(0) = %+v, %v; want {180 181 0}, nil", b, err)
}
b, err = a.Locate(-89.75)
if err != nil || b.Lo != 0 || b.Hi != 1 || math.Abs(b.Frac-0.5) > 1e-12 {
t.Errorf("Locate(-89.75) = %+v, %v; want frac=0.5", b, err)
}
// 90 is exactly on the upper boundary — there's no Hi above it
if _, err := a.Locate(90); err == nil {
t.Errorf("Locate(90) should error, got nil")
}
if _, err := a.Locate(-91); err == nil {
t.Errorf("Locate(-91) should error, got nil")
}
}
func TestAxisLocateWrap(t *testing.T) {
a := Axis{Left: 0, Step: 0.5, N: 720, Wrap: true, Name: "lng"}
b, err := a.Locate(0)
if err != nil || b.Lo != 0 || b.Hi != 1 || b.Frac != 0 {
t.Errorf("Locate(0) = %+v, %v", b, err)
}
// Right up against the wrap boundary
b, err = a.Locate(359.75)
if err != nil || b.Lo != 719 || b.Hi != 0 || math.Abs(b.Frac-0.5) > 1e-12 {
t.Errorf("Locate(359.75) = %+v, %v; want {719 0 0.5}", b, err)
}
// 360 is outside the half-open interval
if _, err := a.Locate(360); err == nil {
t.Errorf("Locate(360) should error, got nil")
}
}
func TestEvalTrilinear(t *testing.T) {
// Field f(i,j,k) = 100*i + 10*j + k.
f := func(i, j, k int) float64 { return 100*float64(i) + 10*float64(j) + float64(k) }
// At all fractions = 0.5, expected value is the mean of the 8 corners.
bs := [3]Bracket{{Lo: 0, Hi: 1, Frac: 0.5}, {Lo: 0, Hi: 1, Frac: 0.5}, {Lo: 0, Hi: 1, Frac: 0.5}}
got := EvalTrilinear(bs, f)
want := (0 + 1 + 10 + 11 + 100 + 101 + 110 + 111) / 8.0
if math.Abs(got-want) > 1e-12 {
t.Errorf("EvalTrilinear at center = %v, want %v", got, want)
}
// At all fractions = 0, expected value is f(lo, lo, lo) = 0.
bs = [3]Bracket{{Lo: 0, Hi: 1, Frac: 0}, {Lo: 0, Hi: 1, Frac: 0}, {Lo: 0, Hi: 1, Frac: 0}}
got = EvalTrilinear(bs, f)
if got != 0 {
t.Errorf("EvalTrilinear at (lo,lo,lo) = %v, want 0", got)
}
// Asymmetric: linear field f(i,j,k) = i should give frac of axis 0 exactly.
f2 := func(i, _, _ int) float64 { return float64(i) }
bs = [3]Bracket{{Lo: 0, Hi: 1, Frac: 0.3}, {Lo: 0, Hi: 1, Frac: 0.7}, {Lo: 0, Hi: 1, Frac: 0.9}}
got = EvalTrilinear(bs, f2)
if math.Abs(got-0.3) > 1e-12 {
t.Errorf("EvalTrilinear of i-field = %v, want 0.3", got)
}
}
func TestLerp(t *testing.T) {
if Lerp(10, 20, 0) != 10 {
t.Errorf("Lerp(10, 20, 0) != 10")
}
if Lerp(10, 20, 1) != 20 {
t.Errorf("Lerp(10, 20, 1) != 20")
}
if math.Abs(Lerp(10, 20, 0.25)-12.5) > 1e-12 {
t.Errorf("Lerp(10, 20, 0.25) != 12.5")
}
}

61
internal/numerics/ode.go Normal file
View file

@ -0,0 +1,61 @@
package numerics
// VecAdd computes y + k*dy on the domain state type S.
// Any coordinate-wrap or other domain-specific operation lives here.
type VecAdd[S any] func(y S, k float64, dy S) S
// VecLerp computes (1-l)*a + l*b on the domain state type S.
type VecLerp[S any] func(a, b S, l float64) S
// Deriv computes the time derivative of state.
type Deriv[S any] func(t float64, y S) S
// Trigger reports whether a termination condition holds at (t, y).
type Trigger[S any] func(t float64, y S) bool
// RK4Step performs one classical Runge-Kutta-4 step from (t, y) with step dt.
// dt may be negative to integrate backwards in time.
func RK4Step[S any](t float64, y S, dt float64, deriv Deriv[S], add VecAdd[S]) S {
k1 := deriv(t, y)
k2 := deriv(t+dt/2, add(y, dt/2, k1))
k3 := deriv(t+dt/2, add(y, dt/2, k2))
k4 := deriv(t+dt, add(y, dt, k3))
y2 := y
y2 = add(y2, dt/6, k1)
y2 = add(y2, dt/3, k2)
y2 = add(y2, dt/3, k3)
y2 = add(y2, dt/6, k4)
return y2
}
// RefineTrigger locates the trigger point between (t1, y1) (trigger not fired)
// and (t2, y2) (trigger fired) via binary search in the linear-interpolation
// parameter space, stopping when the parameter interval is narrower than tol.
//
// Returns the final midpoint sampled, matching the behavior of Tawhiri's
// solver.pyx (the returned point is *not* guaranteed to satisfy the trigger;
// for tol << 1 the difference is at most one tolerance-width either side).
func RefineTrigger[S any](
t1 float64, y1 S,
t2 float64, y2 S,
trigger Trigger[S],
lerp VecLerp[S],
tol float64,
) (float64, S) {
left, right := 0.0, 1.0
t3 := t2
y3 := y2
for right-left > tol {
mid := (left + right) / 2
t3 = Lerp(t1, t2, mid)
y3 = lerp(y1, y2, mid)
if trigger(t3, y3) {
right = mid
} else {
left = mid
}
}
return t3, y3
}

View file

@ -0,0 +1,61 @@
package numerics
import (
"math"
"testing"
)
// scalarAdd / scalarLerp let us drive RK4 on a plain float64.
func scalarAdd(y float64, k float64, dy float64) float64 { return y + k*dy }
func scalarLerpF(a, b float64, l float64) float64 { return Lerp(a, b, l) }
func TestRK4ExponentialDecay(t *testing.T) {
// dy/dt = -y → exact: y(t) = y0 * exp(-t).
deriv := func(_ float64, y float64) float64 { return -y }
y := 1.0
tnow := 0.0
dt := 0.01
for range 100 {
y = RK4Step(tnow, y, dt, deriv, scalarAdd)
tnow += dt
}
want := math.Exp(-1.0)
if math.Abs(y-want) > 1e-8 {
t.Errorf("RK4 exp decay at t=1: got %v, want %v (diff %v)", y, want, y-want)
}
}
func TestRK4ReverseTime(t *testing.T) {
// dy/dt = y → exact: y(t) = y0 * exp(t).
// Integrating from t=1 backwards with dt=-0.01 over 100 steps should give y0.
deriv := func(_ float64, y float64) float64 { return y }
y := math.E
tnow := 1.0
dt := -0.01
for range 100 {
y = RK4Step(tnow, y, dt, deriv, scalarAdd)
tnow += dt
}
if math.Abs(y-1.0) > 1e-8 {
t.Errorf("RK4 reverse: got %v, want 1.0 (diff %v)", y, y-1.0)
}
}
func TestRefineTrigger(t *testing.T) {
// y crosses 0 at l=0.4 between y1=1 and y2=-1.5.
y1, y2 := 1.0, -1.5
t1, t2 := 0.0, 1.0
trig := func(_ float64, y float64) bool { return y <= 0 }
tr, yr := RefineTrigger(t1, y1, t2, y2, trig, scalarLerpF, 0.001)
// The exact crossing is at l = 1/(1+1.5) = 0.4 → t = 0.4, y = 0.
if math.Abs(tr-0.4) > 0.01 {
t.Errorf("Refined t = %v, want ~0.4", tr)
}
if math.Abs(yr) > 0.01 {
t.Errorf("Refined y = %v, want ~0", yr)
}
}

View file

@ -0,0 +1,19 @@
package numerics
// Bisect returns the largest index i in [imin, imax] such that f(i) < target,
// assuming f is monotonically nondecreasing on that range.
//
// If target <= f(imin), returns imin. If target > f(imax), returns imax.
// Performs O(log(imax-imin)) evaluations of f.
func Bisect(imin, imax int, target float64, f func(i int) float64) int {
lo, hi := imin, imax
for lo < hi {
mid := (lo + hi + 1) / 2
if target <= f(mid) {
hi = mid - 1
} else {
lo = mid
}
}
return lo
}

View file

@ -0,0 +1,28 @@
package numerics
import "testing"
func TestBisect(t *testing.T) {
// f(i) = 10*i, monotone increasing.
f := func(i int) float64 { return 10 * float64(i) }
// target = 25 → largest i with 10i < 25 is i=2
if got := Bisect(0, 10, 25, f); got != 2 {
t.Errorf("Bisect target=25 = %d, want 2", got)
}
// target on boundary: target = 30, condition is target <= f(mid) so f(3)=30 → not less; want 2
if got := Bisect(0, 10, 30, f); got != 2 {
t.Errorf("Bisect target=30 = %d, want 2", got)
}
// target below all values
if got := Bisect(0, 10, -5, f); got != 0 {
t.Errorf("Bisect target=-5 = %d, want 0", got)
}
// target above all values
if got := Bisect(0, 10, 1000, f); got != 10 {
t.Errorf("Bisect target=1000 = %d, want 10", got)
}
}

View file

@ -1,153 +0,0 @@
package prediction
import (
"fmt"
"predictor-refactored/internal/dataset"
)
// Exact port of the reference interpolation logic (interpolate.pyx).
// 4D interpolation: time, latitude, longitude, altitude (via geopotential height).
// lerp1 holds an index and interpolation weight for one axis.
type lerp1 struct {
index int
lerp float64
}
// lerp3 holds indices and a combined weight for the (hour, lat, lon) axes.
type lerp3 struct {
hour, lat, lng int
lerp float64
}
// RangeError indicates a coordinate is outside the dataset bounds.
type RangeError struct {
Variable string
Value float64
}
func (e *RangeError) Error() string {
return fmt.Sprintf("%s=%f out of range", e.Variable, e.Value)
}
// pick computes interpolation indices and weights for a single axis.
// left: axis start, step: axis spacing, n: number of points, value: query value.
// Returns two lerp1 values (lower and upper bracket).
func pick(left, step float64, n int, value float64, variableName string) ([2]lerp1, error) {
a := (value - left) / step
b := int(a) // truncation toward zero, same as Cython <long> cast
if b < 0 || b >= n-1 {
return [2]lerp1{}, &RangeError{Variable: variableName, Value: value}
}
l := a - float64(b)
return [2]lerp1{
{index: b, lerp: 1 - l},
{index: b + 1, lerp: l},
}, nil
}
// pick3 computes 8 trilinear interpolation weights for (hour, lat, lng).
func pick3(hour, lat, lng float64) ([8]lerp3, error) {
lhour, err := pick(0, 3, 65, hour, "hour")
if err != nil {
return [8]lerp3{}, err
}
llat, err := pick(-90, 0.5, 361, lat, "lat")
if err != nil {
return [8]lerp3{}, err
}
// Longitude wraps: tell pick the axis is one larger, then wrap index 720 → 0
llng, err := pick(0, 0.5, 720+1, lng, "lng")
if err != nil {
return [8]lerp3{}, err
}
if llng[1].index == 720 {
llng[1].index = 0
}
var out [8]lerp3
i := 0
for _, a := range lhour {
for _, b := range llat {
for _, c := range llng {
out[i] = lerp3{
hour: a.index,
lat: b.index,
lng: c.index,
lerp: a.lerp * b.lerp * c.lerp,
}
i++
}
}
}
return out, nil
}
// interp3 performs 8-point weighted interpolation at a given variable and pressure level.
func interp3(ds *dataset.File, lerps [8]lerp3, variable, level int) float64 {
var r float64
for i := 0; i < 8; i++ {
v := ds.Val(lerps[i].hour, level, variable, lerps[i].lat, lerps[i].lng)
r += float64(v) * lerps[i].lerp
}
return r
}
// search finds the largest pressure level index where interpolated geopotential
// height is less than the target altitude. Searches levels 0..45 (excludes topmost).
func search(ds *dataset.File, lerps [8]lerp3, target float64) int {
lower, upper := 0, 45
for lower < upper {
mid := (lower + upper + 1) / 2
test := interp3(ds, lerps, dataset.VarHeight, mid)
if target <= test {
upper = mid - 1
} else {
lower = mid
}
}
return lower
}
// interp4 performs altitude-interpolated wind lookup using two bracketing levels.
func interp4(ds *dataset.File, lerps [8]lerp3, altLerp lerp1, variable int) float64 {
lower := interp3(ds, lerps, variable, altLerp.index)
upper := interp3(ds, lerps, variable, altLerp.index+1)
return lower*altLerp.lerp + upper*(1-altLerp.lerp)
}
// GetWind returns interpolated (u, v) wind components for the given position.
// hour: fractional hours since dataset start.
// lat: latitude in degrees (-90 to +90).
// lng: longitude in degrees (0 to 360).
// alt: altitude in metres above sea level.
func GetWind(ds *dataset.File, warnings *Warnings, hour, lat, lng, alt float64) (u, v float64, err error) {
lerps, err := pick3(hour, lat, lng)
if err != nil {
return 0, 0, err
}
altidx := search(ds, lerps, alt)
lower := interp3(ds, lerps, dataset.VarHeight, altidx)
upper := interp3(ds, lerps, dataset.VarHeight, altidx+1)
var altLerp float64
if lower != upper {
altLerp = (upper - alt) / (upper - lower)
} else {
altLerp = 0.5
}
if altLerp < 0 {
warnings.AltitudeTooHigh.Add(1)
}
alt1 := lerp1{index: altidx, lerp: altLerp}
u = interp4(ds, lerps, alt1, dataset.VarWindU)
v = interp4(ds, lerps, alt1, dataset.VarWindV)
return u, v, nil
}

View file

@ -1,188 +0,0 @@
package prediction
import (
"math"
"time"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/elevation"
)
// Exact port of the reference flight models (models.py).
const (
pi180 = math.Pi / 180.0
_180pi = 180.0 / math.Pi
)
// --- Up/Down Models ---
// ConstantAscent returns a model with constant vertical velocity (m/s).
func ConstantAscent(ascentRate float64) Model {
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
return 0, 0, ascentRate
}
}
// DragDescent returns a descent-under-parachute model.
// seaLevelDescentRate is the descent rate at sea level (m/s, positive value).
// Uses the NASA atmosphere model for density at altitude.
func DragDescent(seaLevelDescentRate float64) Model {
dragCoefficient := seaLevelDescentRate * 1.1045
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
return 0, 0, -dragCoefficient / math.Sqrt(nasaDensity(alt))
}
}
// nasaDensity computes air density using the NASA atmosphere model.
// Reference: http://www.grc.nasa.gov/WWW/K-12/airplane/atmosmet.html
func nasaDensity(alt float64) float64 {
var temp, pressure float64
switch {
case alt > 25000:
temp = -131.21 + 0.00299*alt
pressure = 2.488 * math.Pow((temp+273.1)/216.6, -11.388)
case alt > 11000:
temp = -56.46
pressure = 22.65 * math.Exp(1.73-0.000157*alt)
default:
temp = 15.04 - 0.00649*alt
pressure = 101.29 * math.Pow((temp+273.1)/288.08, 5.256)
}
return pressure / (0.2869 * (temp + 273.1))
}
// --- Sideways Models ---
// WindVelocity returns a model that gives lateral movement at the wind velocity.
// ds is the wind dataset, dsEpoch is the dataset start time as UNIX timestamp.
func WindVelocity(ds *dataset.File, dsEpoch float64, warnings *Warnings) Model {
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
tHours := (t - dsEpoch) / 3600.0
u, v, err := GetWind(ds, warnings, tHours, lat, lng, alt)
if err != nil {
return 0, 0, 0
}
R := 6371009.0 + alt
dlat = _180pi * v / R
dlng = _180pi * u / (R * math.Cos(lat*pi180))
return dlat, dlng, 0
}
}
// --- Model Combinations ---
// LinearModel returns a model that sums all component models.
func LinearModel(models ...Model) Model {
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
for _, m := range models {
d1, d2, d3 := m(t, lat, lng, alt)
dlat += d1
dlng += d2
dalt += d3
}
return
}
}
// --- Termination Criteria ---
// BurstTermination returns a terminator that fires when altitude >= burstAltitude.
func BurstTermination(burstAltitude float64) Terminator {
return func(t, lat, lng, alt float64) bool {
return alt >= burstAltitude
}
}
// SeaLevelTermination fires when altitude <= 0.
func SeaLevelTermination(t, lat, lng, alt float64) bool {
return alt <= 0
}
// TimeTermination returns a terminator that fires when t > maxTime.
func TimeTermination(maxTime float64) Terminator {
return func(t, lat, lng, alt float64) bool {
return t > maxTime
}
}
// ElevationTermination returns a terminator that fires when alt < ground level.
// Uses ruaumoko-compatible elevation data. Longitude is normalised internally.
func ElevationTermination(elev *elevation.Dataset) Terminator {
return func(t, lat, lng, alt float64) bool {
return elev.Get(lat, lng) > alt
}
}
// --- Pre-Defined Profiles ---
// Stage pairs a model with its termination criterion.
type Stage struct {
Model Model
Terminator Terminator
}
// StandardProfile creates the chain for a standard high-altitude balloon flight:
// ascent at constant rate → burst → descent under parachute.
// If elev is non-nil, descent terminates at ground level; otherwise at sea level.
func StandardProfile(ascentRate, burstAltitude, descentRate float64,
ds *dataset.File, dsEpoch float64, warnings *Warnings,
elev *elevation.Dataset) []Stage {
wind := WindVelocity(ds, dsEpoch, warnings)
modelUp := LinearModel(ConstantAscent(ascentRate), wind)
termUp := BurstTermination(burstAltitude)
modelDown := LinearModel(DragDescent(descentRate), wind)
var termDown Terminator
if elev != nil {
termDown = ElevationTermination(elev)
} else {
termDown = Terminator(SeaLevelTermination)
}
return []Stage{
{Model: modelUp, Terminator: termUp},
{Model: modelDown, Terminator: termDown},
}
}
// FloatProfile creates the chain for a floating balloon flight:
// ascent to float altitude → float until stop time.
func FloatProfile(ascentRate, floatAltitude float64, stopTime time.Time,
ds *dataset.File, dsEpoch float64, warnings *Warnings) []Stage {
wind := WindVelocity(ds, dsEpoch, warnings)
modelUp := LinearModel(ConstantAscent(ascentRate), wind)
termUp := BurstTermination(floatAltitude)
modelFloat := wind
termFloat := TimeTermination(float64(stopTime.Unix()))
return []Stage{
{Model: modelUp, Terminator: termUp},
{Model: modelFloat, Terminator: termFloat},
}
}
// RunPrediction runs a prediction with the given profile stages.
// launchTime is a UNIX timestamp.
func RunPrediction(launchTime float64, lat, lng, alt float64, stages []Stage) []StageResult {
chain := make([]struct {
Model Model
Terminator Terminator
}, len(stages))
for i, s := range stages {
chain[i].Model = s.Model
chain[i].Terminator = s.Terminator
}
return Solve(launchTime, lat, lng, alt, chain)
}

View file

@ -1,180 +0,0 @@
package prediction
import "math"
// Exact port of the reference RK4 solver (solver.pyx).
// Integrates balloon state using RK4 with dt=60 seconds.
// Termination uses binary search refinement (tolerance 0.01).
// Vec holds the balloon state: latitude, longitude, altitude.
type Vec struct {
Lat float64
Lng float64
Alt float64
}
// Model is a function that returns (dlat/dt, dlng/dt, dalt/dt) given state.
// t is UNIX timestamp, lat/lng in degrees, alt in metres.
type Model func(t float64, lat, lng, alt float64) (dlat, dlng, dalt float64)
// Terminator returns true when integration should stop.
type Terminator func(t float64, lat, lng, alt float64) bool
// StageResult holds the trajectory points for one flight stage.
type StageResult struct {
Points []TrajectoryPoint
}
// TrajectoryPoint is a single point in a trajectory (used by solver).
type TrajectoryPoint struct {
T float64 // UNIX timestamp
Lat float64
Lng float64
Alt float64
}
// pymod returns a % b with Python semantics (always non-negative when b > 0).
func pymod(a, b float64) float64 {
r := math.Mod(a, b)
if r < 0 {
r += b
}
return r
}
// vecadd returns a + k*b, with lng wrapped to [0, 360).
func vecadd(a Vec, k float64, b Vec) Vec {
return Vec{
Lat: a.Lat + k*b.Lat,
Lng: pymod(a.Lng+k*b.Lng, 360.0),
Alt: a.Alt + k*b.Alt,
}
}
// scalarLerp returns (1-l)*a + l*b.
func scalarLerp(a, b, l float64) float64 {
return (1-l)*a + l*b
}
// lngLerp interpolates longitude handling the 0/360 wrap-around.
func lngLerp(a, b, l float64) float64 {
l2 := 1 - l
if a > b {
a, b = b, a
l, l2 = l2, l
}
// distance round one way: b - a
// distance around other: (a + 360) - b
if b-a < 180.0 {
return l2*a + l*b
}
return pymod(l2*(a+360)+l*b, 360.0)
}
// vecLerp returns (1-l)*a + l*b with proper longitude wrapping.
func vecLerp(a, b Vec, l float64) Vec {
return Vec{
Lat: scalarLerp(a.Lat, b.Lat, l),
Lng: lngLerp(a.Lng, b.Lng, l),
Alt: scalarLerp(a.Alt, b.Alt, l),
}
}
// rk4 integrates from initial conditions using RK4.
// dt=60.0 seconds, terminationTolerance=0.01.
func rk4(t float64, lat, lng, alt float64, model Model, terminator Terminator) []TrajectoryPoint {
const dt = 60.0
const terminationTolerance = 0.01
y := Vec{Lat: lat, Lng: lng, Alt: alt}
result := []TrajectoryPoint{{T: t, Lat: y.Lat, Lng: y.Lng, Alt: y.Alt}}
for {
// Evaluate model at 4 points (standard RK4)
k1lat, k1lng, k1alt := model(t, y.Lat, y.Lng, y.Alt)
k1 := Vec{Lat: k1lat, Lng: k1lng, Alt: k1alt}
mid1 := vecadd(y, dt/2, k1)
k2lat, k2lng, k2alt := model(t+dt/2, mid1.Lat, mid1.Lng, mid1.Alt)
k2 := Vec{Lat: k2lat, Lng: k2lng, Alt: k2alt}
mid2 := vecadd(y, dt/2, k2)
k3lat, k3lng, k3alt := model(t+dt/2, mid2.Lat, mid2.Lng, mid2.Alt)
k3 := Vec{Lat: k3lat, Lng: k3lng, Alt: k3alt}
end := vecadd(y, dt, k3)
k4lat, k4lng, k4alt := model(t+dt, end.Lat, end.Lng, end.Alt)
k4 := Vec{Lat: k4lat, Lng: k4lng, Alt: k4alt}
// y2 = y + dt/6*k1 + dt/3*k2 + dt/3*k3 + dt/6*k4
y2 := y
y2 = vecadd(y2, dt/6, k1)
y2 = vecadd(y2, dt/3, k2)
y2 = vecadd(y2, dt/3, k3)
y2 = vecadd(y2, dt/6, k4)
t2 := t + dt
if terminator(t2, y2.Lat, y2.Lng, y2.Alt) {
// Binary search to refine the termination point.
// Find l in [0, 1] such that (t3, y3) = lerp((t, y), (t2, y2), l)
// is near where the terminator fires.
left := 0.0
right := 1.0
var t3 float64
var y3 Vec
t3 = t2
y3 = y2
for right-left > terminationTolerance {
mid := (left + right) / 2
t3 = scalarLerp(t, t2, mid)
y3 = vecLerp(y, y2, mid)
if terminator(t3, y3.Lat, y3.Lng, y3.Alt) {
right = mid
} else {
left = mid
}
}
result = append(result, TrajectoryPoint{T: t3, Lat: y3.Lat, Lng: y3.Lng, Alt: y3.Alt})
break
}
// Update current state
t = t2
y = y2
result = append(result, TrajectoryPoint{T: t, Lat: y.Lat, Lng: y.Lng, Alt: y.Alt})
}
return result
}
// Solve runs through a chain of (model, terminator) stages.
// Returns one StageResult per stage.
func Solve(t, lat, lng, alt float64, chain []struct {
Model Model
Terminator Terminator
}) []StageResult {
var results []StageResult
for _, stage := range chain {
points := rk4(t, lat, lng, alt, stage.Model, stage.Terminator)
results = append(results, StageResult{Points: points})
// Next stage starts where this one ended
if len(points) > 0 {
last := points[len(points)-1]
t = last.T
lat = last.Lat
lng = last.Lng
alt = last.Alt
}
}
return results
}

View file

@ -1,21 +0,0 @@
package prediction
import "sync/atomic"
// Warnings tracks warning conditions during a prediction run.
type Warnings struct {
AltitudeTooHigh atomic.Int64
}
// ToMap returns warnings as a map suitable for JSON serialization.
// Only includes warnings that have fired.
func (w *Warnings) ToMap() map[string]any {
result := make(map[string]any)
if n := w.AltitudeTooHigh.Load(); n > 0 {
result["altitude_too_high"] = map[string]any{
"count": n,
"description": "The altitude went too high, above the max forecast wind. Wind data will be unreliable",
}
}
return result
}

View file

@ -1,245 +0,0 @@
package service
import (
"context"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/downloader"
"predictor-refactored/internal/elevation"
"go.uber.org/zap"
)
// Service orchestrates the dataset lifecycle and provides prediction capabilities.
type Service struct {
mu sync.RWMutex
ds *dataset.File
elev *elevation.Dataset
cfg *downloader.Config
dl *downloader.Downloader
log *zap.Logger
updating sync.Mutex // prevents concurrent downloads
}
// New creates a new Service.
func New(cfg *downloader.Config, log *zap.Logger) *Service {
return &Service{
cfg: cfg,
dl: downloader.NewDownloader(cfg, log),
log: log,
}
}
// LoadElevation loads the ruaumoko-compatible elevation dataset from path.
// If the file doesn't exist, elevation termination is disabled (falls back to sea level).
func (s *Service) LoadElevation(path string) {
ds, err := elevation.Open(path)
if err != nil {
s.log.Warn("elevation dataset not available, using sea-level termination",
zap.String("path", path), zap.Error(err))
return
}
s.elev = ds
s.log.Info("elevation dataset loaded", zap.String("path", path))
}
// Elevation returns the elevation dataset (may be nil).
func (s *Service) Elevation() *elevation.Dataset {
return s.elev
}
// Ready returns true if the service has a loaded dataset.
func (s *Service) Ready() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.ds != nil
}
// DatasetTime returns the forecast time of the currently loaded dataset.
func (s *Service) DatasetTime() (time.Time, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
if s.ds == nil {
return time.Time{}, false
}
return s.ds.DSTime, true
}
// Dataset returns the current dataset for reading.
func (s *Service) Dataset() *dataset.File {
s.mu.RLock()
defer s.mu.RUnlock()
return s.ds
}
// Update checks for and downloads new forecast data if needed.
func (s *Service) Update(ctx context.Context) error {
if !s.updating.TryLock() {
s.log.Info("update already in progress, skipping")
return nil
}
defer s.updating.Unlock()
// Check if current dataset is still fresh
if dsTime, ok := s.DatasetTime(); ok {
if time.Since(dsTime) < s.cfg.DatasetTTL {
s.log.Info("dataset still fresh, skipping update",
zap.Time("dataset_time", dsTime),
zap.Duration("age", time.Since(dsTime)))
return nil
}
}
// Try loading an existing dataset from disk first
if err := s.loadExistingDataset(); err == nil {
return nil
}
// Find latest available model run
run, err := s.dl.FindLatestRun(ctx)
if err != nil {
return err
}
// Download and assemble
path, err := s.dl.Download(ctx, run)
if err != nil {
return err
}
// Open the new dataset
ds, err := dataset.Open(path, run)
if err != nil {
return err
}
// Swap in the new dataset
s.setDataset(ds)
s.log.Info("dataset loaded", zap.Time("run", run), zap.String("path", path))
// Clean old datasets
s.cleanOldDatasets(path)
return nil
}
// loadExistingDataset tries to find and load an existing dataset from the data directory.
func (s *Service) loadExistingDataset() error {
entries, err := os.ReadDir(s.cfg.DataDir)
if err != nil {
return err
}
// Collect valid dataset files (name is YYYYMMDDHH, no extension, correct size)
type candidate struct {
name string
path string
run time.Time
}
var candidates []candidate
for _, e := range entries {
if e.IsDir() || strings.Contains(e.Name(), ".") {
continue
}
if len(e.Name()) != 10 {
continue
}
run, err := time.Parse("2006010215", e.Name())
if err != nil {
continue
}
path := filepath.Join(s.cfg.DataDir, e.Name())
info, err := os.Stat(path)
if err != nil || info.Size() != dataset.DatasetSize {
continue
}
if time.Since(run) > s.cfg.DatasetTTL {
continue
}
candidates = append(candidates, candidate{name: e.Name(), path: path, run: run})
}
if len(candidates) == 0 {
return os.ErrNotExist
}
// Pick the newest
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].run.After(candidates[j].run)
})
best := candidates[0]
ds, err := dataset.Open(best.path, best.run)
if err != nil {
return err
}
s.setDataset(ds)
s.log.Info("loaded existing dataset",
zap.Time("run", best.run),
zap.String("path", best.path))
return nil
}
// setDataset swaps the current dataset with a new one, closing the old one.
func (s *Service) setDataset(ds *dataset.File) {
s.mu.Lock()
old := s.ds
s.ds = ds
s.mu.Unlock()
if old != nil {
if err := old.Close(); err != nil {
s.log.Error("failed to close old dataset", zap.Error(err))
}
}
}
// cleanOldDatasets removes dataset files other than the one at keepPath.
func (s *Service) cleanOldDatasets(keepPath string) {
entries, err := os.ReadDir(s.cfg.DataDir)
if err != nil {
return
}
for _, e := range entries {
if e.IsDir() {
continue
}
path := filepath.Join(s.cfg.DataDir, e.Name())
if path == keepPath {
continue
}
// Remove old datasets and temp files
if len(e.Name()) == 10 || strings.HasSuffix(e.Name(), ".downloading") {
if err := os.Remove(path); err != nil {
s.log.Warn("failed to remove old file", zap.String("path", path), zap.Error(err))
} else {
s.log.Info("removed old dataset", zap.String("path", path))
}
}
}
}
// Close releases all resources.
func (s *Service) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.ds != nil {
err := s.ds.Close()
s.ds = nil
return err
}
return nil
}

View file

@ -1,30 +0,0 @@
package middleware
import (
"time"
"github.com/ogen-go/ogen/middleware"
"go.uber.org/zap"
)
// Logging returns an ogen middleware that logs request duration.
func Logging(log *zap.Logger) middleware.Middleware {
return func(req middleware.Request, next func(req middleware.Request) (middleware.Response, error)) (middleware.Response, error) {
lg := log.With(zap.String("operation", req.OperationID))
start := time.Now()
resp, err := next(req)
dur := time.Since(start)
if err != nil {
lg.Error("request failed",
zap.Duration("duration", dur),
zap.Error(err))
} else {
lg.Info("request completed",
zap.Duration("duration", dur))
}
return resp, err
}
}

View file

@ -1,16 +0,0 @@
package handler
import (
"time"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/elevation"
)
// Service defines the interface the handler needs from the service layer.
type Service interface {
Ready() bool
DatasetTime() (time.Time, bool)
Dataset() *dataset.File
Elevation() *elevation.Dataset
}

View file

@ -1,216 +0,0 @@
package handler
import (
"context"
"net/http"
"time"
"predictor-refactored/internal/prediction"
api "predictor-refactored/pkg/rest"
"go.uber.org/zap"
)
var _ api.Handler = (*Handler)(nil)
// Handler implements the ogen-generated api.Handler interface.
type Handler struct {
svc Service
log *zap.Logger
}
// New creates a new Handler.
func New(svc Service, log *zap.Logger) *Handler {
return &Handler{svc: svc, log: log}
}
// PerformPrediction implements the prediction endpoint.
func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredictionParams) (*api.PredictionResponse, error) {
if !h.svc.Ready() {
return nil, newError(http.StatusServiceUnavailable, "no dataset loaded, service is starting up")
}
ds := h.svc.Dataset()
if ds == nil {
return nil, newError(http.StatusServiceUnavailable, "dataset unavailable")
}
dsEpoch := float64(ds.DSTime.Unix())
// Parse parameters with defaults
profile := "standard_profile"
if p, ok := params.Profile.Get(); ok {
profile = string(p)
}
ascentRate := 5.0
if v, ok := params.AscentRate.Get(); ok {
ascentRate = v
}
burstAltitude := 28000.0
if v, ok := params.BurstAltitude.Get(); ok {
burstAltitude = v
}
descentRate := 5.0
if v, ok := params.DescentRate.Get(); ok {
descentRate = v
}
launchAlt := 0.0
if v, ok := params.LaunchAltitude.Get(); ok {
launchAlt = v
}
// Normalize longitude to [0, 360)
lng := params.LaunchLongitude
if lng < 0 {
lng += 360.0
}
launchTime := float64(params.LaunchDatetime.Unix())
warnings := &prediction.Warnings{}
// Build profile chain
elev := h.svc.Elevation()
var stages []prediction.Stage
switch profile {
case "standard_profile":
stages = prediction.StandardProfile(
ascentRate, burstAltitude, descentRate,
ds, dsEpoch, warnings, elev)
case "float_profile":
floatAlt := 25000.0
if v, ok := params.FloatAltitude.Get(); ok {
floatAlt = v
}
stopTime := params.LaunchDatetime.Add(24 * time.Hour)
if v, ok := params.StopDatetime.Get(); ok {
stopTime = v
}
stages = prediction.FloatProfile(
ascentRate, floatAlt, stopTime,
ds, dsEpoch, warnings)
default:
return nil, newError(http.StatusBadRequest, "unknown profile: "+profile)
}
// Run prediction
startTime := time.Now().UTC()
results := prediction.RunPrediction(launchTime, params.LaunchLatitude, lng, launchAlt, stages)
completeTime := time.Now().UTC()
// Build response
stageNames := []string{"ascent", "descent"}
if profile == "float_profile" {
stageNames = []string{"ascent", "float"}
}
var predItems []api.PredictionResponsePredictionItem
for i, sr := range results {
stageName := "ascent"
if i < len(stageNames) {
stageName = stageNames[i]
}
var stageEnum api.PredictionResponsePredictionItemStage
switch stageName {
case "ascent":
stageEnum = api.PredictionResponsePredictionItemStageAscent
case "descent":
stageEnum = api.PredictionResponsePredictionItemStageDescent
case "float":
stageEnum = api.PredictionResponsePredictionItemStageFloat
}
var traj []api.PredictionResponsePredictionItemTrajectoryItem
for _, pt := range sr.Points {
ptLng := pt.Lng
if ptLng > 180 {
ptLng -= 360
}
traj = append(traj, api.PredictionResponsePredictionItemTrajectoryItem{
Datetime: time.Unix(int64(pt.T), 0).UTC(),
Latitude: pt.Lat,
Longitude: ptLng,
Altitude: pt.Alt,
})
}
predItems = append(predItems, api.PredictionResponsePredictionItem{
Stage: stageEnum,
Trajectory: traj,
})
}
resp := &api.PredictionResponse{
Prediction: predItems,
Metadata: api.PredictionResponseMetadata{
StartDatetime: startTime,
CompleteDatetime: completeTime,
},
}
// Echo request
resp.Request = api.NewOptPredictionResponseRequest(api.PredictionResponseRequest{
Dataset: api.NewOptString(ds.DSTime.Format("2006-01-02T15:04:05Z")),
LaunchLatitude: api.NewOptFloat64(params.LaunchLatitude),
LaunchLongitude: api.NewOptFloat64(params.LaunchLongitude),
LaunchDatetime: api.NewOptString(params.LaunchDatetime.Format(time.RFC3339)),
LaunchAltitude: params.LaunchAltitude,
})
// Warnings
warnMap := warnings.ToMap()
if len(warnMap) > 0 {
resp.Warnings = api.NewOptPredictionResponseWarnings(api.PredictionResponseWarnings{})
}
h.log.Info("prediction complete",
zap.String("profile", profile),
zap.Int("stages", len(results)),
zap.Duration("elapsed", completeTime.Sub(startTime)))
return resp, nil
}
// ReadinessCheck implements the health check endpoint.
func (h *Handler) ReadinessCheck(ctx context.Context) (*api.ReadinessResponse, error) {
resp := &api.ReadinessResponse{}
if h.svc.Ready() {
resp.Status = api.ReadinessResponseStatusOk
if dsTime, ok := h.svc.DatasetTime(); ok {
resp.DatasetTime = api.NewOptDateTime(dsTime)
}
} else {
resp.Status = api.ReadinessResponseStatusNotReady
resp.ErrorMessage = api.NewOptString("no dataset loaded")
}
return resp, nil
}
// NewError creates an ErrorStatusCode from an error returned by a handler.
func (h *Handler) NewError(ctx context.Context, err error) *api.ErrorStatusCode {
if statusErr, ok := err.(*api.ErrorStatusCode); ok {
return statusErr
}
h.log.Error("unhandled error", zap.Error(err))
return newError(http.StatusInternalServerError, err.Error())
}
func newError(status int, description string) *api.ErrorStatusCode {
return &api.ErrorStatusCode{
StatusCode: status,
Response: api.Error{
Error: api.ErrorError{
Type: http.StatusText(status),
Description: description,
},
},
}
}

View file

@ -1,75 +0,0 @@
package rest
import (
"context"
"fmt"
"net/http"
"predictor-refactored/internal/transport/middleware"
"predictor-refactored/internal/transport/rest/handler"
api "predictor-refactored/pkg/rest"
"go.uber.org/zap"
)
// Transport wraps the ogen HTTP server.
type Transport struct {
srv *api.Server
handler *handler.Handler
port int
log *zap.Logger
}
// New creates a new REST transport.
func New(h *handler.Handler, port int, log *zap.Logger) (*Transport, error) {
srv, err := api.NewServer(
h,
api.WithMiddleware(middleware.Logging(log)),
)
if err != nil {
return nil, fmt.Errorf("create ogen server: %w", err)
}
return &Transport{
srv: srv,
handler: h,
port: port,
log: log,
}, nil
}
// Run starts the HTTP server. Blocks until the server stops.
func (t *Transport) Run() error {
mux := http.NewServeMux()
mux.Handle("/", t.srv)
httpSrv := &http.Server{
Addr: fmt.Sprintf(":%d", t.port),
Handler: corsMiddleware(mux),
}
t.log.Info("starting HTTP server", zap.Int("port", t.port))
return httpSrv.ListenAndServe()
}
// Shutdown gracefully stops the HTTP server.
func (t *Transport) Shutdown(ctx context.Context) error {
// The ogen server doesn't have a shutdown method;
// shutdown is handled by the http.Server in main.go
return nil
}
func corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}

View file

@ -0,0 +1,141 @@
package gfs
import "fmt"
// Dataset shape: (hour, pressure_level, variable, latitude, longitude).
// Matches the cube layout used by the reference Tawhiri implementation.
const (
NumHours = 65 // 0, 3, 6, ..., 192 hours forecast
NumLevels = 47 // pressure levels
NumVariables = 3 // geopotential height, U-wind, V-wind
NumLatitudes = 361 // -90.0 to +90.0 inclusive in 0.5° steps
NumLongitudes = 720 // 0.0 to 359.5 in 0.5° steps
HourStep = 3
MaxHour = 192
Resolution = 0.5
LatStart = -90.0
LonStart = 0.0
VarHeight = 0
VarWindU = 1
VarWindV = 2
ElementSize = 4 // float32
// DatasetSize is the canonical file size: every grid cell × element size.
DatasetSize int64 = int64(NumHours) * int64(NumLevels) * int64(NumVariables) *
int64(NumLatitudes) * int64(NumLongitudes) * int64(ElementSize)
)
// LevelSet identifies which GRIB file (primary/secondary) carries a level.
type LevelSet int
const (
LevelSetA LevelSet = iota // pgrb2 — primary file
LevelSetB // pgrb2b — secondary file
)
// Pressures lists the 47 pressure levels (hPa) in dataset index order,
// descending from surface to top of atmosphere.
var Pressures = [NumLevels]int{
1000, 975, 950, 925, 900, 875, 850, 825, 800, 775,
750, 725, 700, 675, 650, 625, 600, 575, 550, 525,
500, 475, 450, 425, 400, 375, 350, 325, 300, 275,
250, 225, 200, 175, 150, 125, 100, 70, 50, 30,
20, 10, 7, 5, 3, 2, 1,
}
// PressuresPgrb2 lists the levels carried by the primary GRIB file.
var PressuresPgrb2 = []int{
10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 350, 400,
450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 925,
950, 975, 1000,
}
// PressuresPgrb2b lists the levels carried by the secondary GRIB file.
var PressuresPgrb2b = []int{
1, 2, 3, 5, 7, 125, 175, 225, 275, 325, 375, 425,
475, 525, 575, 625, 675, 725, 775, 825, 875,
}
var pressureIndex map[int]int
var pressureLevelSet map[int]LevelSet
func init() {
pressureIndex = make(map[int]int, NumLevels)
for i, p := range Pressures {
pressureIndex[p] = i
}
pressureLevelSet = make(map[int]LevelSet, NumLevels)
for _, p := range PressuresPgrb2 {
pressureLevelSet[p] = LevelSetA
}
for _, p := range PressuresPgrb2b {
pressureLevelSet[p] = LevelSetB
}
}
// PressureIndex returns the dataset index for a pressure level in hPa,
// or -1 when the level is unknown.
func PressureIndex(hPa int) int {
idx, ok := pressureIndex[hPa]
if !ok {
return -1
}
return idx
}
// PressureLevelSet returns the GRIB file set carrying a pressure level.
func PressureLevelSet(hPa int) (LevelSet, bool) {
ls, ok := pressureLevelSet[hPa]
return ls, ok
}
// HourIndex returns the dataset time index for a forecast hour, or -1 when
// the hour is outside the range or not a multiple of HourStep.
func HourIndex(hour int) int {
if hour < 0 || hour > MaxHour || hour%HourStep != 0 {
return -1
}
return hour / HourStep
}
// Hours returns the full list of forecast hours, [0, 3, 6, ..., MaxHour].
func Hours() []int {
out := make([]int, 0, NumHours)
for h := 0; h <= MaxHour; h += HourStep {
out = append(out, h)
}
return out
}
// VariableIndex maps a GRIB (category, number) pair to a dataset variable
// index, returning -1 for parameters this dataset does not store.
func VariableIndex(parameterCategory, parameterNumber int) int {
switch {
case parameterCategory == 3 && parameterNumber == 5:
return VarHeight
case parameterCategory == 2 && parameterNumber == 2:
return VarWindU
case parameterCategory == 2 && parameterNumber == 3:
return VarWindV
default:
return -1
}
}
// S3 URL configuration for NOAA GFS data on the public S3 mirror.
const S3BaseURL = "https://noaa-gfs-bdp-pds.s3.amazonaws.com"
// GribURL returns the S3 URL for a primary (pgrb2) GRIB file.
func GribURL(date string, runHour, forecastStep int) string {
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.0p50.f%03d",
S3BaseURL, date, runHour, runHour, forecastStep)
}
// GribURLB returns the S3 URL for a secondary (pgrb2b) GRIB file.
func GribURLB(date string, runHour, forecastStep int) string {
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2b.0p50.f%03d",
S3BaseURL, date, runHour, runHour, forecastStep)
}

View file

@ -0,0 +1,150 @@
package gfs
import (
"encoding/binary"
"fmt"
"math"
"os"
"time"
mmap "github.com/edsrzf/mmap-go"
)
// File is an mmap-backed wind dataset file. The layout is a flat C-order
// row-major array of float32 values, shape (hour, level, variable, lat, lng).
type File struct {
mm mmap.MMap
file *os.File
writable bool
// Epoch is the forecast run time (UTC) the file represents.
Epoch time.Time
}
// Open opens an existing dataset file for reading.
func Open(path string, epoch time.Time) (*File, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open dataset: %w", err)
}
info, err := f.Stat()
if err != nil {
f.Close()
return nil, fmt.Errorf("stat dataset: %w", err)
}
if info.Size() != DatasetSize {
f.Close()
return nil, fmt.Errorf("dataset should be %d bytes (was %d)", DatasetSize, info.Size())
}
mm, err := mmap.Map(f, mmap.RDONLY, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap dataset: %w", err)
}
return &File{mm: mm, file: f, writable: false, Epoch: epoch}, nil
}
// Create creates a new dataset file of the canonical size, mmap'd read-write.
func Create(path string) (*File, error) {
f, err := os.Create(path)
if err != nil {
return nil, fmt.Errorf("create dataset: %w", err)
}
if err := f.Truncate(DatasetSize); err != nil {
f.Close()
return nil, fmt.Errorf("truncate dataset: %w", err)
}
mm, err := mmap.MapRegion(f, int(DatasetSize), mmap.RDWR, 0, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap dataset: %w", err)
}
return &File{mm: mm, file: f, writable: true}, nil
}
// OpenWritable opens an existing dataset file for read-write access.
// Used when resuming a partial download.
func OpenWritable(path string) (*File, error) {
f, err := os.OpenFile(path, os.O_RDWR, 0o644)
if err != nil {
return nil, fmt.Errorf("open dataset rw: %w", err)
}
info, err := f.Stat()
if err != nil {
f.Close()
return nil, fmt.Errorf("stat dataset: %w", err)
}
if info.Size() != DatasetSize {
f.Close()
return nil, fmt.Errorf("dataset should be %d bytes (was %d)", DatasetSize, info.Size())
}
mm, err := mmap.MapRegion(f, int(DatasetSize), mmap.RDWR, 0, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap dataset: %w", err)
}
return &File{mm: mm, file: f, writable: true}, nil
}
// offset returns the byte offset of the [hour][level][variable][lat][lng] cell.
func offset(hour, level, variable, lat, lng int) int64 {
idx := int64(hour)
idx = idx*int64(NumLevels) + int64(level)
idx = idx*int64(NumVariables) + int64(variable)
idx = idx*int64(NumLatitudes) + int64(lat)
idx = idx*int64(NumLongitudes) + int64(lng)
return idx * int64(ElementSize)
}
// Val reads one cell as a float32.
func (d *File) Val(hour, level, variable, lat, lng int) float32 {
off := offset(hour, level, variable, lat, lng)
return math.Float32frombits(binary.LittleEndian.Uint32(d.mm[off : off+4]))
}
// SetVal writes one cell. Only valid on writable files.
func (d *File) SetVal(hour, level, variable, lat, lng int, val float32) {
off := offset(hour, level, variable, lat, lng)
binary.LittleEndian.PutUint32(d.mm[off:off+4], math.Float32bits(val))
}
// BlitGribData copies one decoded GRIB grid into the dataset, flipping the
// latitude axis from GRIB's north-to-south scan order to our south-to-north
// storage order. gribData must be 361*720 = 259920 float64 values.
func (d *File) BlitGribData(hourIdx, levelIdx, varIdx int, gribData []float64) error {
expected := NumLatitudes * NumLongitudes
if len(gribData) != expected {
return fmt.Errorf("grib data has %d values, expected %d", len(gribData), expected)
}
for lat := range NumLatitudes {
for lng := range NumLongitudes {
gribIdx := (360-lat)*NumLongitudes + lng
d.SetVal(hourIdx, levelIdx, varIdx, lat, lng, float32(gribData[gribIdx]))
}
}
return nil
}
// Flush flushes the mmap to disk.
func (d *File) Flush() error {
if d.mm != nil {
return d.mm.Flush()
}
return nil
}
// Close unmaps and closes the file.
func (d *File) Close() error {
if d.mm != nil {
if err := d.mm.Unmap(); err != nil {
d.file.Close()
return fmt.Errorf("unmap: %w", err)
}
d.mm = nil
}
if d.file != nil {
err := d.file.Close()
d.file = nil
return err
}
return nil
}

View file

@ -0,0 +1,109 @@
package gfs
import (
"time"
"predictor-refactored/internal/numerics"
"predictor-refactored/internal/weather"
)
// Wind is a WindField backed by a GFS dataset file.
type Wind struct {
file *File
}
// NewWind returns a Wind backed by file.
func NewWind(file *File) *Wind {
return &Wind{file: file}
}
// Epoch returns the forecast run time of the underlying file.
func (w *Wind) Epoch() time.Time { return w.file.Epoch }
// Source returns the source identifier "noaa-gfs-0p50".
func (w *Wind) Source() string { return "noaa-gfs-0p50" }
// Close releases the underlying file's resources.
func (w *Wind) Close() error { return w.file.Close() }
// Grid axes for the GFS 0.5-degree dataset.
var (
hourAxis = numerics.Axis{
Left: 0,
Step: float64(HourStep),
N: NumHours,
Name: "hour",
}
latAxis = numerics.Axis{
Left: LatStart,
Step: Resolution,
N: NumLatitudes,
Name: "lat",
}
lngAxis = numerics.Axis{
Left: LonStart,
Step: Resolution,
N: NumLongitudes,
Wrap: true,
Name: "lng",
}
)
// Wind samples the field at the given UNIX time, geographic coordinate, and
// altitude. Vertical interpolation matches Tawhiri: locate the two pressure
// levels whose interpolated geopotential heights bracket alt, then linearly
// interpolate U and V between them.
func (w *Wind) Wind(t, lat, lng, alt float64) (weather.Sample, error) {
hours := (t - float64(w.file.Epoch.Unix())) / 3600.0
bh, err := hourAxis.Locate(hours)
if err != nil {
return weather.Sample{}, err
}
bla, err := latAxis.Locate(lat)
if err != nil {
return weather.Sample{}, err
}
bln, err := lngAxis.Locate(lng)
if err != nil {
return weather.Sample{}, err
}
bs := [3]numerics.Bracket{bh, bla, bln}
height := func(level int) func(i, j, k int) float64 {
return func(i, j, k int) float64 {
return float64(w.file.Val(i, level, VarHeight, j, k))
}
}
levelIdx := numerics.Bisect(0, NumLevels-2, alt, func(level int) float64 {
return numerics.EvalTrilinear(bs, height(level))
})
lowerHGT := numerics.EvalTrilinear(bs, height(levelIdx))
upperHGT := numerics.EvalTrilinear(bs, height(levelIdx+1))
var altFrac float64
if lowerHGT != upperHGT {
altFrac = (upperHGT - alt) / (upperHGT - lowerHGT)
} else {
altFrac = 0.5
}
component := func(level, variable int) float64 {
return numerics.EvalTrilinear(bs, func(i, j, k int) float64 {
return float64(w.file.Val(i, level, variable, j, k))
})
}
lowerU := component(levelIdx, VarWindU)
upperU := component(levelIdx+1, VarWindU)
lowerV := component(levelIdx, VarWindV)
upperV := component(levelIdx+1, VarWindV)
return weather.Sample{
U: lowerU*altFrac + upperU*(1-altFrac),
V: lowerV*altFrac + upperV*(1-altFrac),
AboveModel: altFrac < 0,
}, nil
}

37
internal/weather/types.go Normal file
View file

@ -0,0 +1,37 @@
// Package weather defines the abstract interface trajectory engines use
// to sample atmospheric data, and contains source-specific implementations
// in its subpackages.
package weather
import "time"
// Sample is the result of sampling a wind field at one point.
type Sample struct {
// U is the eastward wind component in m/s.
U float64
// V is the northward wind component in m/s.
V float64
// AboveModel is set when the query altitude was above the highest
// pressure level represented in the underlying dataset. The returned
// U/V values are linear extrapolations and should be treated as unreliable.
AboveModel bool
}
// WindField provides 3D wind data interpolated at arbitrary points.
//
// Implementations must be safe for concurrent use.
type WindField interface {
// Wind samples the field at (t, lat, lng, alt).
//
// t is UNIX seconds. lat is in degrees, -90 to +90. lng is in degrees,
// 0 to 360 (callers must normalize). alt is metres above mean sea level.
//
// Returns an error if any coordinate is outside the field's domain.
Wind(t, lat, lng, alt float64) (Sample, error)
// Epoch returns the time the field is anchored to (forecast run time).
Epoch() time.Time
// Source identifies the underlying dataset for logs and metrics.
Source() string
}