From 9e663db9dc7d77fb10ef065e3a3231860f3f161f Mon Sep 17 00:00:00 2001 From: "a.antonov" Date: Mon, 18 May 2026 03:17:17 +0900 Subject: [PATCH] step one --- Makefile | 45 ++- README.md | 415 +++++++++---------- cmd/api/main.go | 98 ----- cmd/compare-tawhiri/main.go | 153 +++++++ cmd/compare_prediction/main.go | 195 --------- cmd/compare_step0/main.go | 104 ----- cmd/predictor-cli/main.go | 216 ++++++++++ cmd/predictor/main.go | 181 +++++++++ docs/numerics.tex | 160 ++++++++ go.mod | 4 +- internal/api/admin/datasets.go | 206 ++++++++++ internal/api/middleware/cors.go | 20 + internal/api/middleware/log.go | 51 +++ internal/api/tawhiri/handler.go | 252 ++++++++++++ internal/api/transport.go | 109 +++++ internal/api/v2/handler.go | 173 ++++++++ internal/api/v2/profile.go | 145 +++++++ internal/api/v2/types.go | 114 ++++++ internal/config/config.go | 252 ++++++++++++ internal/config/config_test.go | 76 ++++ internal/dataset/dataset.go | 158 -------- internal/dataset/dataset_test.go | 152 ------- internal/dataset/file.go | 140 ------- internal/datasets/doc.go | 11 + internal/datasets/gfs/idx.go | 125 ++++++ internal/datasets/gfs/idx_test.go | 70 ++++ internal/datasets/gfs/source.go | 430 ++++++++++++++++++++ internal/datasets/manager.go | 383 ++++++++++++++++++ internal/datasets/manifest.go | 118 ++++++ internal/datasets/store_local.go | 167 ++++++++ internal/datasets/store_test.go | 82 ++++ internal/datasets/throttle.go | 63 +++ internal/datasets/types.go | 97 +++++ internal/downloader/config.go | 58 --- internal/downloader/downloader.go | 441 --------------------- internal/downloader/idx.go | 157 -------- internal/downloader/idx_test.go | 110 ----- internal/elevation/elevation.go | 7 +- internal/engine/constraints.go | 47 +++ internal/engine/engine_test.go | 176 ++++++++ internal/engine/models.go | 151 +++++++ internal/engine/profile.go | 55 +++ internal/engine/propagator.go | 156 ++++++++ internal/engine/state.go | 50 +++ internal/engine/types.go | 80 ++++ internal/metrics/prom.go | 146 +++++++ internal/metrics/prom_test.go | 49 +++ internal/metrics/types.go | 36 ++ internal/numerics/doc.go | 11 + internal/numerics/grid.go | 86 ++++ internal/numerics/grid_test.go | 94 +++++ internal/numerics/ode.go | 61 +++ internal/numerics/ode_test.go | 61 +++ internal/numerics/search.go | 19 + internal/numerics/search_test.go | 28 ++ internal/prediction/interpolate.go | 153 ------- internal/prediction/models.go | 188 --------- internal/prediction/solver.go | 180 --------- internal/prediction/warnings.go | 21 - internal/service/service.go | 245 ------------ internal/transport/middleware/log.go | 30 -- internal/transport/rest/handler/deps.go | 16 - internal/transport/rest/handler/handler.go | 216 ---------- internal/transport/rest/transport.go | 75 ---- internal/weather/gfs/constants.go | 141 +++++++ internal/weather/gfs/file.go | 150 +++++++ internal/weather/gfs/wind.go | 109 +++++ internal/weather/types.go | 37 ++ 68 files changed, 5647 insertions(+), 2958 deletions(-) delete mode 100644 cmd/api/main.go create mode 100644 cmd/compare-tawhiri/main.go delete mode 100644 cmd/compare_prediction/main.go delete mode 100644 cmd/compare_step0/main.go create mode 100644 cmd/predictor-cli/main.go create mode 100644 cmd/predictor/main.go create mode 100644 docs/numerics.tex create mode 100644 internal/api/admin/datasets.go create mode 100644 internal/api/middleware/cors.go create mode 100644 internal/api/middleware/log.go create mode 100644 internal/api/tawhiri/handler.go create mode 100644 internal/api/transport.go create mode 100644 internal/api/v2/handler.go create mode 100644 internal/api/v2/profile.go create mode 100644 internal/api/v2/types.go create mode 100644 internal/config/config.go create mode 100644 internal/config/config_test.go delete mode 100644 internal/dataset/dataset.go delete mode 100644 internal/dataset/dataset_test.go delete mode 100644 internal/dataset/file.go create mode 100644 internal/datasets/doc.go create mode 100644 internal/datasets/gfs/idx.go create mode 100644 internal/datasets/gfs/idx_test.go create mode 100644 internal/datasets/gfs/source.go create mode 100644 internal/datasets/manager.go create mode 100644 internal/datasets/manifest.go create mode 100644 internal/datasets/store_local.go create mode 100644 internal/datasets/store_test.go create mode 100644 internal/datasets/throttle.go create mode 100644 internal/datasets/types.go delete mode 100644 internal/downloader/config.go delete mode 100644 internal/downloader/downloader.go delete mode 100644 internal/downloader/idx.go delete mode 100644 internal/downloader/idx_test.go create mode 100644 internal/engine/constraints.go create mode 100644 internal/engine/engine_test.go create mode 100644 internal/engine/models.go create mode 100644 internal/engine/profile.go create mode 100644 internal/engine/propagator.go create mode 100644 internal/engine/state.go create mode 100644 internal/engine/types.go create mode 100644 internal/metrics/prom.go create mode 100644 internal/metrics/prom_test.go create mode 100644 internal/metrics/types.go create mode 100644 internal/numerics/doc.go create mode 100644 internal/numerics/grid.go create mode 100644 internal/numerics/grid_test.go create mode 100644 internal/numerics/ode.go create mode 100644 internal/numerics/ode_test.go create mode 100644 internal/numerics/search.go create mode 100644 internal/numerics/search_test.go delete mode 100644 internal/prediction/interpolate.go delete mode 100644 internal/prediction/models.go delete mode 100644 internal/prediction/solver.go delete mode 100644 internal/prediction/warnings.go delete mode 100644 internal/service/service.go delete mode 100644 internal/transport/middleware/log.go delete mode 100644 internal/transport/rest/handler/deps.go delete mode 100644 internal/transport/rest/handler/handler.go delete mode 100644 internal/transport/rest/transport.go create mode 100644 internal/weather/gfs/constants.go create mode 100644 internal/weather/gfs/file.go create mode 100644 internal/weather/gfs/wind.go create mode 100644 internal/weather/types.go diff --git a/Makefile b/Makefile index 7c14792..1409e86 100644 --- a/Makefile +++ b/Makefile @@ -1,12 +1,20 @@ -.PHONY: build run test fmt lint clean generate-ogen help +.PHONY: build server cli compare test fmt lint clean generate-ogen docs help -# Build the application -build: - go build -o predictor ./cmd/api +# Build all binaries +build: server cli compare + +server: + go build -o bin/predictor ./cmd/predictor + +cli: + go build -o bin/predictor-cli ./cmd/predictor-cli + +compare: + go build -o bin/compare-tawhiri ./cmd/compare-tawhiri # Run locally run: - go run ./cmd/api + go run ./cmd/predictor # Run tests test: @@ -20,21 +28,28 @@ fmt: lint: golangci-lint run -# Generate ogen API code from swagger spec +# Build the numerics LaTeX doc (requires pdflatex) +docs: + cd docs && pdflatex numerics.tex + +# Regenerate ogen API code from the OpenAPI spec generate-ogen: go run github.com/ogen-go/ogen/cmd/ogen@latest --target pkg/rest --package rest --clean api/rest/predictor.swagger.yml # Clean build artifacts clean: - rm -f predictor + rm -rf bin/ docs/numerics.pdf docs/numerics.aux docs/numerics.log -# Show help help: @echo "Available commands:" - @echo " build - Build binary" - @echo " run - Run locally" - @echo " test - Run tests" - @echo " fmt - Format code" - @echo " lint - Lint code" - @echo " generate-ogen - Generate API code from swagger spec" - @echo " clean - Remove build artifacts" + @echo " build - Build all binaries to bin/" + @echo " server - Build the HTTP server (cmd/predictor)" + @echo " cli - Build the CLI client (cmd/predictor-cli)" + @echo " compare - Build the validation tool (cmd/compare-tawhiri)" + @echo " run - Run the server with default config" + @echo " test - Run unit tests" + @echo " fmt - Format code" + @echo " lint - Lint code (golangci-lint)" + @echo " docs - Build the numerics LaTeX doc (requires pdflatex)" + @echo " generate-ogen - Regenerate ogen code from the OpenAPI spec" + @echo " clean - Remove build artifacts" diff --git a/README.md b/README.md index 579a255..967fca2 100644 --- a/README.md +++ b/README.md @@ -1,261 +1,274 @@ -# Balloon Trajectory Predictor +# stratoflights-predictor -High-altitude balloon trajectory prediction service. Predicts ascent, burst, and descent trajectories using GFS wind forecast data from NOAA. +High-altitude balloon trajectory prediction service. Forecasts ascent, descent, +and float trajectories from NOAA GFS wind data, exposed as a REST API. -The prediction algorithms are an exact port of [Tawhiri](https://github.com/cuspaceflight/tawhiri) (Cambridge University Spaceflight) to Go, verified to produce identical results. +The trajectory engine is a propagator-and-constraint system: any flight +profile can be expressed as a chain of propagators (constant-rate ascent, +parachute descent, piecewise rates, wind drift) with attached constraints +(altitude, time, terrain contact). The legacy Tawhiri request shape is kept +as a compatibility endpoint so existing clients work unchanged. -## Quick Start +## Quick start ```bash -# Build +# Build all three binaries (server, CLI, validation tool) make build -# Run (downloads ~9 GB of GFS data on first start, takes 30-60 min) -PREDICTOR_DATA_DIR=/tmp/predictor-data go run ./cmd/api +# Run the server (first start downloads ~9 GB of GFS data over 30-60 min) +./bin/predictor # Check readiness -curl http://localhost:8080/ready +./bin/predictor-cli ready -# Run a prediction -curl 'http://localhost:8080/api/v1/prediction?launch_latitude=52.2&launch_longitude=0.1&launch_datetime=2026-03-28T12:00:00Z&launch_altitude=0&ascent_rate=5&burst_altitude=30000&descent_rate=5' +# Run a Tawhiri-style prediction +./bin/predictor-cli predict \ + launch_latitude=52.2 launch_longitude=0.1 \ + launch_datetime=2026-03-28T12:00:00Z \ + ascent_rate=5 burst_altitude=30000 descent_rate=5 ``` ## Configuration -All configuration is via environment variables. +Configuration is layered: built-in defaults, then a YAML file +(`--config path.yml` or `PREDICTOR_CONFIG_FILE=path.yml`), then env vars, +then CLI flags. Flags override env vars override file values override defaults. -| Variable | Default | Description | -|---|---|---| -| `PREDICTOR_PORT` | `8080` | HTTP server port | -| `PREDICTOR_DATA_DIR` | `/tmp/predictor-data` | Directory for wind datasets and temp files | -| `PREDICTOR_DOWNLOAD_PARALLEL` | `8` | Max concurrent GRIB download goroutines | -| `PREDICTOR_UPDATE_INTERVAL` | `6h` | How often to check for new forecasts | -| `PREDICTOR_DATASET_TTL` | `48h` | Max age before a dataset is considered stale | -| `PREDICTOR_ELEVATION_DATASET` | `/srv/ruaumoko-dataset` | Path to elevation dataset (optional) | +| Setting | Env var | CLI flag | Default | +|---|---|---|---| +| HTTP port | `PREDICTOR_PORT` | `-port` | `8080` | +| Data directory | `PREDICTOR_DATA_DIR` | `-data-dir` | `/tmp/predictor-data` | +| Elevation dataset | `PREDICTOR_ELEVATION_DATASET` | `-elevation` | `/srv/ruaumoko-dataset` | +| Source | `PREDICTOR_SOURCE` | — | `noaa-gfs-0p50` | +| Download parallelism | `PREDICTOR_DOWNLOAD_PARALLEL` | `-download-parallel` | `8` | +| Download bandwidth (bytes/s; 0 = unlimited) | `PREDICTOR_DOWNLOAD_BANDWIDTH` | `-download-bandwidth` | `0` | +| Scheduler interval | `PREDICTOR_UPDATE_INTERVAL` | `-update-interval` | `6h` | +| Dataset freshness TTL | `PREDICTOR_DATASET_TTL` | `-freshness-ttl` | `48h` | +| Metrics enabled | `PREDICTOR_METRICS_ENABLED` | `-metrics` | `true` | +| Metrics HTTP path | `PREDICTOR_METRICS_PATH` | `-metrics-path` | `/metrics` | +| Log level | `PREDICTOR_LOG_LEVEL` | `-log-level` | `info` | -## API +A YAML config file mirrors the same structure: -### `GET /api/v1/prediction` +```yaml +http: + port: 8080 +data: + dir: /var/lib/predictor + elevation_path: /var/lib/predictor/elevation + source: noaa-gfs-0p50 +download: + parallel: 8 + bandwidth_bytes_per_second: 0 + update_interval: 6h + freshness_ttl: 48h +metrics: + enabled: true + path: /metrics +log: + level: info +``` -Run a balloon trajectory prediction. +## REST API -**Parameters** (query string): +### Tawhiri-compatible + +`GET /api/v1/prediction` — preserves the exact request and response shape of +the upstream Cambridge University Spaceflight predictor. Query parameters: | Parameter | Required | Description | |---|---|---| -| `launch_latitude` | yes | Launch latitude in degrees (-90 to 90) | -| `launch_longitude` | yes | Launch longitude in degrees (-180 to 180 or 0 to 360) | -| `launch_datetime` | yes | Launch time in RFC 3339 format | -| `launch_altitude` | no | Launch altitude in metres ASL (default: 0) | +| `launch_latitude` | yes | Degrees, -90 to 90 | +| `launch_longitude` | yes | Degrees, -180 to 180 or 0 to 360 | +| `launch_datetime` | yes | RFC 3339 | +| `launch_altitude` | no | Metres ASL (default 0) | | `profile` | no | `standard_profile` (default) or `float_profile` | -| `ascent_rate` | no | Ascent rate in m/s (default: 5) | -| `burst_altitude` | no | Burst altitude in metres (default: 28000) | -| `descent_rate` | no | Sea-level descent rate in m/s (default: 5) | -| `float_altitude` | no | Float altitude in metres (float_profile only) | -| `stop_datetime` | no | Float end time (float_profile only, default: +24h) | +| `ascent_rate` | no | m/s (default 5) | +| `burst_altitude` | no | Metres (default 28000) | +| `descent_rate` | no | m/s (default 5) | +| `float_altitude` | no | Metres (float profile only) | +| `stop_datetime` | no | Float-profile end time | -**Response** (Tawhiri-compatible): +`GET /ready` — returns `{"status": "ok", "dataset_time": "..."}` once a +dataset is loaded; `{"status": "not_ready", ...}` before then. + +### Profile-driven (new primary) + +`POST /api/v2/prediction` — accepts an arbitrary chain of propagators with +optional constraints. Useful when the frontend wants flight profiles the +Tawhiri shape can't express (e.g. piecewise rates, fallback on constraint +violation). ```json { - "prediction": [ + "launch": { + "time": "2026-03-28T12:00:00Z", + "latitude": 52.2, + "longitude": 0.1, + "altitude": 0 + }, + "profile": [ { - "stage": "ascent", - "trajectory": [ - {"datetime": "2026-03-28T12:00:00Z", "latitude": 52.2, "longitude": 0.1, "altitude": 0}, - ... - ] + "name": "ascent", + "model": {"type": "constant_rate", "rate": 5, "include_wind": true}, + "constraints": [{"type": "max_altitude", "limit": 30000}] }, { - "stage": "descent", - "trajectory": [...] + "name": "descent", + "model": {"type": "parachute_descent", "sea_level_rate": 5, "include_wind": true}, + "constraints": [{"type": "terrain_contact"}] } - ], - "metadata": { - "start_datetime": "...", - "complete_datetime": "..." - }, - "request": { - "dataset": "2026-03-28T06:00:00Z", - "launch_latitude": 52.2, - ... - } + ] } ``` -### `GET /ready` +Model types: `constant_rate`, `parachute_descent`, `piecewise`, `wind`. +Constraint types: `max_altitude`, `min_altitude`, `max_time`, +`terrain_contact`. Constraint actions: `stop` (default), `fallback`, `clip`. +Set `"direction": "reverse"` to integrate backward from a known landing. -Health check. Returns `{"status": "ok"}` when a dataset is loaded. +### Dataset admin -## Elevation Dataset - -Without elevation data, descent terminates at sea level (altitude <= 0). With elevation data, descent terminates at ground level, matching Tawhiri's behaviour. - -### Building the elevation dataset - -The elevation dataset uses ETOPO 2022 at 30 arc-second resolution, converted to a ruaumoko-compatible binary format (21601 x 43200 grid of int16 little-endian elevation values in metres). - -**Requirements**: Python 3, xarray, netcdf4, numpy. - -```bash -pip install xarray netcdf4 numpy - -# Downloads ~1.1 GB from NOAA, produces ~1.74 GB binary file -python3 scripts/build_elevation.py /tmp/predictor-data/ruaumoko-dataset +``` +GET /api/v1/admin/datasets list stored epochs +POST /api/v1/admin/datasets {epoch | latest} trigger a download +DELETE /api/v1/admin/datasets/{epoch} delete a stored dataset +GET /api/v1/admin/jobs list every job +GET /api/v1/admin/jobs/{id} fetch one job +DELETE /api/v1/admin/jobs/{id} cancel a running job ``` -To skip the download if you already have the ETOPO NetCDF file: +Returns `JobInfo`: -```bash -ETOPO_NC_PATH=/path/to/ETOPO_2022_v1_30s_N90W180_surface.nc \ - python3 scripts/build_elevation.py /tmp/predictor-data/ruaumoko-dataset +```json +{"id":"…","source":"noaa-gfs-0p50","epoch":"…","status":"running", + "started_at":"…","total_units":130,"done_units":47,"bytes":510000000} ``` -The ETOPO 2022 NetCDF can be manually downloaded from: -https://www.ncei.noaa.gov/products/etopo-global-relief-model +### Metrics -### Using the elevation dataset - -```bash -PREDICTOR_ELEVATION_DATASET=/tmp/predictor-data/ruaumoko-dataset go run ./cmd/api -``` - -If the file doesn't exist or can't be read, the service starts normally with a warning and falls back to sea-level termination. +`GET /metrics` — Prometheus text exposition. Counters: +`predictor_predictions_total{profile,status}`, +`predictor_downloads_total{source,status}`, +`predictor_download_bytes_total{source}`, +and a gauge `predictor_active_dataset_epoch_seconds`. ## Architecture ``` -cmd/api/main.go Entry point, config, scheduler, HTTP server +cmd/ + predictor/main.go main server entry point + predictor-cli/main.go HTTP client + compare-tawhiri/main.go end-to-end validation against the public Tawhiri instance internal/ - dataset/ - dataset.go Shape constants, pressure levels, S3 URLs - file.go mmap-backed dataset file (read/write/blit) - downloader/ - downloader.go S3 partial GRIB download (idx + range requests) - idx.go NOAA .idx file parser - config.go Environment-based configuration - elevation/ - elevation.go Ruaumoko-compatible elevation dataset (mmap int16) - prediction/ - interpolate.go 4D wind interpolation (time, lat, lon, altitude) - solver.go RK4 integrator with binary search termination - models.go Ascent, descent, wind models; flight profiles - warnings.go Prediction warning counters - service/ - service.go Dataset lifecycle, concurrent-safe access - transport/ - middleware/log.go Request logging middleware - rest/ - handler/handler.go ogen API handler implementation - handler/deps.go Service interface - transport.go ogen HTTP server, CORS -api/rest/predictor.swagger.yml OpenAPI 3.0 spec -pkg/rest/ Generated ogen code (17 files) -scripts/ - build_elevation.py ETOPO 2022 to ruaumoko converter + numerics/ pure numerical primitives (interp, bisect, RK4, refinement) + engine/ propagator + constraint system + concrete models + weather/ WindField interface; gfs/ — NOAA GFS file format + impl + datasets/ Source/Storage/Manager + transactional, resumable downloads + gfs/ — NOAA GFS source impl + elevation/ ruaumoko-format ground elevation reader + config/ layered file+env+CLI config + metrics/ Sink interface + Prometheus text impl + api/ HTTP transport + tawhiri/ — legacy v1 endpoint via ogen + v2/ — profile-driven endpoint + admin/ — dataset/job admin endpoints + middleware/ +api/rest/predictor.swagger.yml OpenAPI 3 spec for v1 + /ready +pkg/rest/ ogen-generated code (regenerate via `make generate-ogen`) +docs/numerics.tex LaTeX math reference for the numerics package +scripts/build_elevation.py ETOPO 2022 → ruaumoko converter ``` -## Wind Dataset +## Deployment -The service downloads GFS 0.5-degree forecast data from NOAA S3: +### Local single instance + +```bash +./bin/predictor --data-dir /var/lib/predictor +``` + +No external dependencies beyond the NOAA S3 mirror. + +### Docker single container + +```dockerfile +FROM golang:1.25 AS build +WORKDIR /src +COPY . . +RUN go build -o /predictor ./cmd/predictor + +FROM gcr.io/distroless/base +COPY --from=build /predictor /predictor +EXPOSE 8080 +ENTRYPOINT ["/predictor"] +``` + +Mount a volume at `/data` and set `PREDICTOR_DATA_DIR=/data`. + +### Load-balanced cluster + +The server is stateless apart from the on-disk dataset cache and in-memory +job table. For multiple replicas, point all replicas at a shared filesystem +(NFS or similar) for `data_dir`; each replica reads-only its own mmap. Active +download coordination across replicas is not implemented — run downloads on +one node, or accept that two nodes may download the same epoch concurrently +(only one Commit wins via atomic rename). + +## Elevation dataset + +Without elevation data, descent terminates at sea level. With elevation, +descent terminates at ground level, matching upstream Tawhiri. + +```bash +pip install xarray netcdf4 numpy +python3 scripts/build_elevation.py /var/lib/predictor/elevation +``` + +`PREDICTOR_ELEVATION_DATASET=/var/lib/predictor/elevation ./bin/predictor` + +## Numerical methods + +The numerics package (`internal/numerics`) provides: + +- regular-grid multilinear interpolation, +- monotone bisection, +- classical RK4 (forward and reverse time), +- binary-search refinement of a termination point. + +Detailed math reference: `docs/numerics.tex`. The package has no +domain dependencies and is small enough for manual verification (~300 +lines of Go), enabling a future C or Rust port without changes to the +trajectory engine. + +## Wind data | Property | Value | |---|---| -| Source | `noaa-gfs-bdp-pds.s3.amazonaws.com` | -| Resolution | 0.5 degrees | -| Grid | 361 lat x 720 lon | -| Time steps | 65 (every 3 hours, 0-192h) | -| Pressure levels | 47 (1000 to 1 hPa) | +| Source | NOAA GFS, S3 mirror (`noaa-gfs-bdp-pds.s3.amazonaws.com`) | +| Resolution | 0.5° | +| Grid | 361 × 720 (lat × lng) | +| Forecast steps | 65 (every 3 hours, 0–192h) | +| Pressure levels | 47 (1000 → 1 hPa) | | Variables | Geopotential height, U-wind, V-wind | -| Dataset size | 9,528,667,200 bytes (~8.87 GiB) | -| Update cadence | Every 6 hours (GFS runs at 00, 06, 12, 18 UTC) | +| File size | ~8.87 GiB (float32 flat binary, mmap-backed) | +| Update cadence | every 6 hours | -Data is downloaded using HTTP Range requests against `.idx` index files, fetching only the needed GRIB messages (HGT, UGRD, VGRD at 47 pressure levels). Full download takes 30-60 minutes depending on bandwidth. +Downloads use HTTP Range requests against `.idx` index files to fetch only +the needed GRIB messages. Downloads are transactional (temp file, manifest, +atomic rename on commit) and resumable: interrupted downloads pick up where +they left off via the manifest. -The dataset is stored as a memory-mapped flat binary file of float32 values in C-order with shape `(65, 47, 3, 361, 720)`. +## Validation -## Prediction Algorithms - -All algorithms are exact ports of the reference implementations in Tawhiri. The following sections describe the key components. - -### Interpolation (`internal/prediction/interpolate.go`) - -4D wind interpolation from the dataset grid to arbitrary coordinates. - -1. **Trilinear weights** (`pick3`): compute 8 interpolation weights for the (hour, lat, lon) cube corners. -2. **Altitude search** (`search`): binary search on interpolated geopotential height to find the two pressure levels bracketing the target altitude. -3. **Wind extraction** (`interp4`): 8-point weighted sum at each bracket level, then linear interpolation between levels. - -Reference: `tawhiri/interpolate.pyx` - -### Solver (`internal/prediction/solver.go`) - -4th-order Runge-Kutta integrator with dt = 60 seconds. - -- State vector: (latitude, longitude, altitude) in degrees and metres. -- Time: UNIX timestamp in seconds. -- Longitude is kept in [0, 360) via Python-style modulo after each `vecadd`. -- When a terminator fires, binary search refinement (tolerance 0.01) finds the precise termination point between the last good step and the first terminated step. -- Longitude interpolation (`lngLerp`) handles the 0/360 wrap-around. - -Reference: `tawhiri/solver.pyx` - -### Models (`internal/prediction/models.go`) - -- **Constant ascent**: vertical velocity = ascent_rate m/s. -- **Drag descent**: NASA atmosphere density model with drag coefficient = sea_level_rate * 1.1045. Descent rate increases with altitude due to thinner air. -- **Wind velocity**: u, v components from interpolation converted to degrees/second: `dlat = (180/pi) * v / (R)`, `dlng = (180/pi) * u / (R * cos(lat))` where R = 6371009 + altitude. -- **Linear model**: sum of component models (e.g., wind + ascent). -- **Elevation termination**: `ground_elevation > altitude` using ruaumoko dataset. - -Reference: `tawhiri/models.py` - -### Profiles - -- **standard_profile**: ascent (constant rate + wind) until burst altitude, then descent (drag + wind) until ground level. -- **float_profile**: ascent to float altitude, then drift at constant altitude until stop time. - -## Verification - -The predictor has been verified against the reference Tawhiri implementation: - -| Test | Result | -|---|---| -| Dataset (step 0): 36.6M float32 values vs Python/cfgrib | 0 mismatches, max diff = 0.0 | -| Prediction burst point vs public Tawhiri API | Identical (lat, lon, alt all match) | -| Prediction landing point vs public Tawhiri API | Identical lat/lon, 5m altitude diff (different elevation datasets) | -| Descent point count | Identical (46 points) | -| Ascent point count | Identical (101 points) | - -## Development - -```bash -# Regenerate ogen API code after modifying the swagger spec -make generate-ogen - -# Run tests -make test - -# Format -make fmt -``` - -### Comparison tools - -```bash -# Compare single dataset step against Python/cfgrib reference -go run ./cmd/compare_step0 - -# Run prediction and compare against public Tawhiri API -go run ./cmd/compare_prediction -``` +`./bin/compare-tawhiri --server http://localhost:8080` runs an identical +prediction against the local server and against the public SondeHub Tawhiri +instance, reporting the great-circle distance between landing points. ## References -- [Tawhiri](https://github.com/cuspaceflight/tawhiri) — Reference Python/Cython predictor (Cambridge University Spaceflight) -- [tawhiri-downloader](https://github.com/cuspaceflight/tawhiri-downloader) — OCaml dataset downloader -- [ruaumoko](https://github.com/cuspaceflight/ruaumoko) — Global elevation dataset -- [NOAA GFS](https://www.ncei.noaa.gov/products/weather-climate-models/global-forecast) — Global Forecast System -- [NOAA GFS on S3](https://noaa-gfs-bdp-pds.s3.amazonaws.com/index.html) — Public S3 bucket -- [ETOPO 2022](https://www.ncei.noaa.gov/products/etopo-global-relief-model) — Global relief model for elevation data -- [SondeHub Tawhiri API](https://api.v2.sondehub.org/tawhiri) — Public Tawhiri instance for comparison +- [Tawhiri](https://github.com/cuspaceflight/tawhiri) — reference Python/Cython predictor +- [ruaumoko](https://github.com/cuspaceflight/ruaumoko) — global elevation dataset format +- [NOAA GFS](https://www.ncei.noaa.gov/products/weather-climate-models/global-forecast) +- [ETOPO 2022](https://www.ncei.noaa.gov/products/etopo-global-relief-model) +- [SondeHub Tawhiri API](https://api.v2.sondehub.org/tawhiri) — public Tawhiri instance diff --git a/cmd/api/main.go b/cmd/api/main.go deleted file mode 100644 index 08d12e3..0000000 --- a/cmd/api/main.go +++ /dev/null @@ -1,98 +0,0 @@ -package main - -import ( - "context" - "fmt" - "os" - "os/signal" - "syscall" - "time" - - "predictor-refactored/internal/downloader" - "predictor-refactored/internal/service" - "predictor-refactored/internal/transport/rest" - "predictor-refactored/internal/transport/rest/handler" - - "github.com/go-co-op/gocron" - "go.uber.org/zap" -) - -func main() { - log, err := zap.NewProduction() - if err != nil { - panic(err) - } - defer log.Sync() - - cfg := downloader.LoadConfig() - log.Info("configuration loaded", - zap.String("data_dir", cfg.DataDir), - zap.Int("parallel", cfg.Parallel), - zap.Duration("update_interval", cfg.UpdateInterval), - zap.Duration("dataset_ttl", cfg.DatasetTTL)) - - if err := os.MkdirAll(cfg.DataDir, 0o755); err != nil { - log.Fatal("failed to create data directory", zap.Error(err)) - } - - svc := service.New(cfg, log) - defer svc.Close() - - // Load elevation dataset (optional — falls back to sea-level termination) - elevPath := "/srv/ruaumoko-dataset" - if v := os.Getenv("PREDICTOR_ELEVATION_DATASET"); v != "" { - elevPath = v - } - svc.LoadElevation(elevPath) - - // Initial dataset load (async so the server starts immediately) - go func() { - log.Info("performing initial dataset update...") - if err := svc.Update(context.Background()); err != nil { - log.Error("initial dataset update failed", zap.Error(err)) - } else { - log.Info("initial dataset update complete") - } - }() - - // Scheduler for periodic dataset updates - scheduler := gocron.NewScheduler(time.UTC) - scheduler.Every(cfg.UpdateInterval).Do(func() { - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute) - defer cancel() - log.Info("scheduled dataset update starting") - if err := svc.Update(ctx); err != nil { - log.Error("scheduled dataset update failed", zap.Error(err)) - } else { - log.Info("scheduled dataset update complete") - } - }) - scheduler.StartAsync() - defer scheduler.Stop() - - // HTTP transport (ogen) - port := 8080 - if p := os.Getenv("PREDICTOR_PORT"); p != "" { - fmt.Sscanf(p, "%d", &port) - } - - h := handler.New(svc, log) - transport, err := rest.New(h, port, log) - if err != nil { - log.Fatal("failed to create transport", zap.Error(err)) - } - - go func() { - if err := transport.Run(); err != nil { - log.Fatal("HTTP server error", zap.Error(err)) - } - }() - - log.Info("service started") - - // Graceful shutdown - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - sig := <-sigChan - log.Info("received shutdown signal", zap.String("signal", sig.String())) -} diff --git a/cmd/compare-tawhiri/main.go b/cmd/compare-tawhiri/main.go new file mode 100644 index 0000000..39a1f08 --- /dev/null +++ b/cmd/compare-tawhiri/main.go @@ -0,0 +1,153 @@ +// Command compare-tawhiri runs the same prediction against a local predictor +// instance and against the public SondeHub Tawhiri instance, reporting the +// distance between the two predicted landing points. +// +// Intended use: +// +// ./compare-tawhiri --server http://localhost:8080 +package main + +import ( + "encoding/json" + "flag" + "fmt" + "io" + "math" + "net/http" + "net/url" + "os" + "time" +) + +const tawhiriPublicURL = "https://api.v2.sondehub.org/tawhiri" + +func main() { + server := flag.String("server", "http://localhost:8080", "local predictor server URL") + lat := flag.Float64("lat", 52.2135, "launch latitude") + lng := flag.Float64("lng", 0.0964, "launch longitude") + alt := flag.Float64("alt", 0, "launch altitude") + rate := flag.Float64("ascent-rate", 5, "ascent rate m/s") + burst := flag.Float64("burst", 30000, "burst altitude m") + descent := flag.Float64("descent-rate", 5, "descent rate m/s") + launch := flag.String("launch", "", "launch time RFC3339; default: 3 hours after the active dataset epoch") + flag.Parse() + + // Discover the active dataset epoch from /ready. + epoch, err := fetchActiveEpoch(*server) + if err != nil { + fmt.Fprintln(os.Stderr, "ready:", err) + os.Exit(1) + } + + launchTime := epoch.Add(3 * time.Hour) + if *launch != "" { + t, err := time.Parse(time.RFC3339, *launch) + if err != nil { + fmt.Fprintln(os.Stderr, "invalid launch time:", err) + os.Exit(1) + } + launchTime = t + } + + ourLat, ourLng, err := runPrediction(*server+"/api/v1/prediction", *lat, *lng, *alt, launchTime, *rate, *burst, *descent) + if err != nil { + fmt.Fprintln(os.Stderr, "local prediction:", err) + os.Exit(1) + } + fmt.Printf("local landing: lat=%.4f, lng=%.4f\n", ourLat, ourLng) + + tawLat, tawLng, err := runPrediction(tawhiriPublicURL, *lat, *lng, *alt, launchTime, *rate, *burst, *descent) + if err != nil { + fmt.Fprintln(os.Stderr, "tawhiri prediction:", err) + os.Exit(1) + } + fmt.Printf("tawhiri landing: lat=%.4f, lng=%.4f\n", tawLat, tawLng) + + d := haversine(ourLat, ourLng, tawLat, tawLng) + fmt.Printf("distance: %.2f km\n", d/1000) + + switch { + case d < 1000: + fmt.Println("MATCH (< 1 km)") + case d < 50000: + fmt.Printf("MODERATE (%.1f km) — likely different forecast runs\n", d/1000) + default: + fmt.Printf("LARGE (%.1f km) — investigate\n", d/1000) + } +} + +type readinessResp struct { + Status string `json:"status"` + DatasetTime string `json:"dataset_time"` +} + +func fetchActiveEpoch(base string) (time.Time, error) { + resp, err := http.Get(base + "/ready") + if err != nil { + return time.Time{}, err + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != 200 { + return time.Time{}, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + var r readinessResp + if err := json.Unmarshal(body, &r); err != nil { + return time.Time{}, err + } + if r.Status != "ok" { + return time.Time{}, fmt.Errorf("server status %q", r.Status) + } + return time.Parse(time.RFC3339, r.DatasetTime) +} + +func runPrediction(endpoint string, lat, lng, alt float64, launch time.Time, rate, burst, descent float64) (float64, float64, error) { + q := url.Values{} + q.Set("launch_latitude", fmt.Sprintf("%.4f", lat)) + q.Set("launch_longitude", fmt.Sprintf("%.4f", lng)) + q.Set("launch_altitude", fmt.Sprintf("%.0f", alt)) + q.Set("launch_datetime", launch.Format(time.RFC3339)) + q.Set("ascent_rate", fmt.Sprintf("%.1f", rate)) + q.Set("burst_altitude", fmt.Sprintf("%.0f", burst)) + q.Set("descent_rate", fmt.Sprintf("%.1f", descent)) + + resp, err := http.Get(endpoint + "?" + q.Encode()) + if err != nil { + return 0, 0, err + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != 200 { + return 0, 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) + } + + var result struct { + Prediction []struct { + Stage string `json:"stage"` + Trajectory []struct { + Latitude float64 `json:"latitude"` + Longitude float64 `json:"longitude"` + } `json:"trajectory"` + } `json:"prediction"` + } + if err := json.Unmarshal(body, &result); err != nil { + return 0, 0, err + } + for _, stage := range result.Prediction { + if stage.Stage == "descent" && len(stage.Trajectory) > 0 { + last := stage.Trajectory[len(stage.Trajectory)-1] + return last.Latitude, last.Longitude, nil + } + } + return 0, 0, fmt.Errorf("no descent stage in response") +} + +func haversine(lat1, lng1, lat2, lng2 float64) float64 { + const R = 6371000.0 + phi1 := lat1 * math.Pi / 180 + phi2 := lat2 * math.Pi / 180 + dphi := (lat2 - lat1) * math.Pi / 180 + dlam := (lng2 - lng1) * math.Pi / 180 + a := math.Sin(dphi/2)*math.Sin(dphi/2) + math.Cos(phi1)*math.Cos(phi2)*math.Sin(dlam/2)*math.Sin(dlam/2) + return R * 2 * math.Atan2(math.Sqrt(a), math.Sqrt(1-a)) +} diff --git a/cmd/compare_prediction/main.go b/cmd/compare_prediction/main.go deleted file mode 100644 index 13ae2a6..0000000 --- a/cmd/compare_prediction/main.go +++ /dev/null @@ -1,195 +0,0 @@ -package main - -import ( - "context" - "encoding/json" - "fmt" - "io" - "math" - "net/http" - "os" - "time" - - "predictor-refactored/internal/dataset" - "predictor-refactored/internal/downloader" - "predictor-refactored/internal/prediction" - - "go.uber.org/zap" -) - -// Downloads a few forecast steps and runs a prediction, then compares -// against the public Tawhiri API. -func main() { - log, _ := zap.NewDevelopment() - - cfg := &downloader.Config{ - DataDir: os.TempDir(), - Parallel: 4, - } - dl := downloader.NewDownloader(cfg, log) - - ctx := context.Background() - - // Find latest run - run, err := dl.FindLatestRun(ctx) - if err != nil { - fmt.Fprintf(os.Stderr, "FindLatestRun: %v\n", err) - os.Exit(1) - } - fmt.Printf("Using run: %s\n", run.Format("2006010215")) - - // Create dataset and download first 10 steps (0-27 hours, enough for a prediction) - dsPath := fmt.Sprintf("/tmp/pred_test_%s.bin", run.Format("2006010215")) - defer os.Remove(dsPath) - - ds, err := dataset.Create(dsPath) - if err != nil { - fmt.Fprintf(os.Stderr, "Create: %v\n", err) - os.Exit(1) - } - - date := run.Format("20060102") - runHour := run.Hour() - stepsToDownload := []int{0, 3, 6, 9, 12, 15, 18, 21, 24, 27} - - fmt.Printf("Downloading %d steps...\n", len(stepsToDownload)) - for _, step := range stepsToDownload { - hourIdx := dataset.HourIndex(step) - fmt.Printf(" step %d (hour idx %d)...\n", step, hourIdx) - - urlA := dataset.GribURL(date, runHour, step) - if err := dl.DownloadAndBlit(ctx, ds, urlA, hourIdx, dataset.LevelSetA); err != nil { - fmt.Fprintf(os.Stderr, " pgrb2 step %d: %v\n", step, err) - os.Exit(1) - } - - urlB := dataset.GribURLB(date, runHour, step) - if err := dl.DownloadAndBlit(ctx, ds, urlB, hourIdx, dataset.LevelSetB); err != nil { - fmt.Fprintf(os.Stderr, " pgrb2b step %d: %v\n", step, err) - os.Exit(1) - } - } - ds.Flush() - fmt.Println("Download complete") - - // Set dataset time - ds.DSTime = run - - // Run our prediction - launchLat := 52.2135 - launchLon := 0.0964 // already in [0, 360) - launchAlt := 0.0 - ascentRate := 5.0 - burstAlt := 30000.0 - descentRate := 5.0 - - // Launch 3 hours into the forecast - launchTime := run.Add(3 * time.Hour) - launchTimestamp := float64(launchTime.Unix()) - dsEpoch := float64(run.Unix()) - - warnings := &prediction.Warnings{} - stages := prediction.StandardProfile(ascentRate, burstAlt, descentRate, ds, dsEpoch, warnings, nil) - results := prediction.RunPrediction(launchTimestamp, launchLat, launchLon, launchAlt, stages) - - fmt.Printf("\n=== Our prediction ===\n") - for i, sr := range results { - stage := "ascent" - if i == 1 { - stage = "descent" - } - first := sr.Points[0] - last := sr.Points[len(sr.Points)-1] - fmt.Printf(" %s: %d points, start=(%.4f, %.4f, %.0fm) end=(%.4f, %.4f, %.0fm)\n", - stage, len(sr.Points), - first.Lat, first.Lng, first.Alt, - last.Lat, last.Lng, last.Alt) - } - - // Get landing point - var ourLandLat, ourLandLon float64 - if len(results) >= 2 { - last := results[1].Points[len(results[1].Points)-1] - ourLandLat = last.Lat - ourLandLon = last.Lng - if ourLandLon > 180 { - ourLandLon -= 360 - } - } - fmt.Printf(" Landing: lat=%.4f, lon=%.4f\n", ourLandLat, ourLandLon) - - // Compare against public Tawhiri API - fmt.Printf("\n=== Tawhiri API comparison ===\n") - tawhiriLandLat, tawhiriLandLon, err := queryTawhiri(launchLat, launchLon, launchAlt, launchTime, ascentRate, burstAlt, descentRate) - if err != nil { - fmt.Printf(" Tawhiri API error: %v\n", err) - fmt.Println(" (Cannot compare — Tawhiri may use a different dataset)") - ds.Close() - return - } - fmt.Printf(" Tawhiri landing: lat=%.4f, lon=%.4f\n", tawhiriLandLat, tawhiriLandLon) - - dist := haversine(ourLandLat, ourLandLon, tawhiriLandLat, tawhiriLandLon) - fmt.Printf(" Distance between landing points: %.2f km\n", dist/1000) - - if dist < 1000 { - fmt.Println(" CLOSE MATCH (< 1 km)") - } else if dist < 50000 { - fmt.Printf(" MODERATE DIFFERENCE (%.1f km) — likely different datasets\n", dist/1000) - } else { - fmt.Printf(" LARGE DIFFERENCE (%.1f km) — possible bug\n", dist/1000) - } - - ds.Close() -} - -func queryTawhiri(lat, lon, alt float64, launchTime time.Time, ascentRate, burstAlt, descentRate float64) (landLat, landLon float64, err error) { - url := fmt.Sprintf( - "https://api.v2.sondehub.org/tawhiri?launch_latitude=%.4f&launch_longitude=%.4f&launch_altitude=%.0f&launch_datetime=%s&ascent_rate=%.1f&burst_altitude=%.0f&descent_rate=%.1f", - lat, lon, alt, launchTime.Format(time.RFC3339), ascentRate, burstAlt, descentRate) - - resp, err := http.Get(url) - if err != nil { - return 0, 0, err - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != 200 { - return 0, 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) - } - - var result struct { - Prediction []struct { - Stage string `json:"stage"` - Trajectory []struct { - Latitude float64 `json:"latitude"` - Longitude float64 `json:"longitude"` - Altitude float64 `json:"altitude"` - } `json:"trajectory"` - } `json:"prediction"` - } - - if err := json.Unmarshal(body, &result); err != nil { - return 0, 0, err - } - - for _, stage := range result.Prediction { - if stage.Stage == "descent" && len(stage.Trajectory) > 0 { - last := stage.Trajectory[len(stage.Trajectory)-1] - return last.Latitude, last.Longitude, nil - } - } - - return 0, 0, fmt.Errorf("no descent stage found") -} - -func haversine(lat1, lon1, lat2, lon2 float64) float64 { - const R = 6371000.0 - phi1 := lat1 * math.Pi / 180 - phi2 := lat2 * math.Pi / 180 - dphi := (lat2 - lat1) * math.Pi / 180 - dlam := (lon2 - lon1) * math.Pi / 180 - a := math.Sin(dphi/2)*math.Sin(dphi/2) + math.Cos(phi1)*math.Cos(phi2)*math.Sin(dlam/2)*math.Sin(dlam/2) - return R * 2 * math.Atan2(math.Sqrt(a), math.Sqrt(1-a)) -} diff --git a/cmd/compare_step0/main.go b/cmd/compare_step0/main.go deleted file mode 100644 index d0d697d..0000000 --- a/cmd/compare_step0/main.go +++ /dev/null @@ -1,104 +0,0 @@ -package main - -import ( - "context" - "fmt" - "os" - "time" - - "predictor-refactored/internal/dataset" - "predictor-refactored/internal/downloader" - - "go.uber.org/zap" -) - -// Downloads step 0 of a given run and writes a minimal dataset for comparison. -// Usage: go run ./cmd/compare_step0 -func main() { - if len(os.Args) < 3 { - fmt.Fprintf(os.Stderr, "Usage: %s \n", os.Args[0]) - os.Exit(1) - } - - runStr := os.Args[1] - outPath := os.Args[2] - - run, err := time.Parse("2006010215", runStr) - if err != nil { - fmt.Fprintf(os.Stderr, "Invalid run time %q: %v\n", runStr, err) - os.Exit(1) - } - - log, _ := zap.NewDevelopment() - - // Create a full-size dataset (we only fill step 0) - fmt.Printf("Creating dataset at %s (%d bytes)...\n", outPath, dataset.DatasetSize) - ds, err := dataset.Create(outPath) - if err != nil { - fmt.Fprintf(os.Stderr, "Create dataset: %v\n", err) - os.Exit(1) - } - defer ds.Close() - - cfg := &downloader.Config{ - DataDir: os.TempDir(), - Parallel: 4, - } - dl := downloader.NewDownloader(cfg, log) - - ctx := context.Background() - date := run.Format("20060102") - runHour := run.Hour() - - // Download and blit step 0 from pgrb2 - fmt.Println("Downloading pgrb2 step 0...") - urlA := dataset.GribURL(date, runHour, 0) - if err := dl.DownloadAndBlit(ctx, ds, urlA, 0, dataset.LevelSetA); err != nil { - fmt.Fprintf(os.Stderr, "pgrb2: %v\n", err) - os.Exit(1) - } - fmt.Println(" done") - - // Download and blit step 0 from pgrb2b - fmt.Println("Downloading pgrb2b step 0...") - urlB := dataset.GribURLB(date, runHour, 0) - if err := dl.DownloadAndBlit(ctx, ds, urlB, 0, dataset.LevelSetB); err != nil { - fmt.Fprintf(os.Stderr, "pgrb2b: %v\n", err) - os.Exit(1) - } - fmt.Println(" done") - - if err := ds.Flush(); err != nil { - fmt.Fprintf(os.Stderr, "Flush: %v\n", err) - os.Exit(1) - } - - // Spot-check: print same values as the Python script - fmt.Println("\n=== Go dataset values (spot check) ===") - type testPoint struct { - varName string - varIdx int - levelIdx int - lat, lon int - } - - points := []testPoint{ - {"HGT", 0, 0, 0, 0}, // HGT @ 1000mb, lat=-90, lon=0 - {"HGT", 0, 0, 180, 0}, // HGT @ 1000mb, lat=0, lon=0 - {"HGT", 0, 0, 360, 0}, // HGT @ 1000mb, lat=+90, lon=0 - {"HGT", 0, 20, 180, 360}, // HGT @ 500mb, lat=0, lon=180 - {"UGRD", 1, 0, 180, 0}, // UGRD @ 1000mb, lat=0, lon=0 - {"VGRD", 2, 0, 180, 0}, // VGRD @ 1000mb, lat=0, lon=0 - {"UGRD", 1, 20, 284, 0}, // UGRD @ 500mb, lat=52N, lon=0 - } - - for _, p := range points { - val := ds.Val(0, p.levelIdx, p.varIdx, p.lat, p.lon) - actualLat := -90.0 + float64(p.lat)*0.5 - actualLon := float64(p.lon) * 0.5 - fmt.Printf(" %-4s %4dmb lat=%+7.1f lon=%6.1f: %12.4f\n", - p.varName, dataset.Pressures[p.levelIdx], actualLat, actualLon, val) - } - - fmt.Printf("\nDataset written to %s\n", outPath) -} diff --git a/cmd/predictor-cli/main.go b/cmd/predictor-cli/main.go new file mode 100644 index 0000000..9b16562 --- /dev/null +++ b/cmd/predictor-cli/main.go @@ -0,0 +1,216 @@ +// Command predictor-cli is a small HTTP client for stratoflights-predictor. +// +// It is intended for operations and development; production callers should +// use the REST API directly. +package main + +import ( + "bytes" + "encoding/json" + "flag" + "fmt" + "io" + "net/http" + "net/url" + "os" + "strings" + "time" +) + +const usage = `predictor-cli — HTTP client for stratoflights-predictor + +USAGE + predictor-cli [--server URL] [args...] + +COMMANDS + ready Check service health + predict ... Run a Tawhiri-compat prediction (key=value pairs) + datasets list List stored dataset epochs + datasets download [--latest|--epoch RFC3339] + Trigger a dataset download + datasets delete Delete a stored dataset + jobs list List download jobs + jobs get Show one job + jobs cancel Cancel a running job + +ENVIRONMENT + PREDICTOR_SERVER Default --server (overridden by the flag) +` + +func main() { + fs := flag.NewFlagSet("predictor-cli", flag.ContinueOnError) + fs.Usage = func() { fmt.Fprint(os.Stderr, usage) } + server := fs.String("server", envDefault("PREDICTOR_SERVER", "http://localhost:8080"), "predictor server URL") + if err := fs.Parse(os.Args[1:]); err != nil { + os.Exit(2) + } + args := fs.Args() + if len(args) == 0 { + fs.Usage() + os.Exit(2) + } + + c := &client{base: strings.TrimRight(*server, "/"), http: &http.Client{Timeout: 30 * time.Second}} + if err := dispatch(c, args); err != nil { + fmt.Fprintln(os.Stderr, "error:", err) + os.Exit(1) + } +} + +func envDefault(name, fallback string) string { + if v := os.Getenv(name); v != "" { + return v + } + return fallback +} + +func dispatch(c *client, args []string) error { + switch args[0] { + case "ready": + return c.ready() + case "predict": + return c.predict(args[1:]) + case "datasets": + if len(args) < 2 { + return fmt.Errorf("usage: datasets {list|download|delete}") + } + switch args[1] { + case "list": + return c.datasetsList() + case "download": + return c.datasetsDownload(args[2:]) + case "delete": + if len(args) < 3 { + return fmt.Errorf("usage: datasets delete ") + } + return c.datasetsDelete(args[2]) + } + case "jobs": + if len(args) < 2 { + return fmt.Errorf("usage: jobs {list|get|cancel}") + } + switch args[1] { + case "list": + return c.jobsList() + case "get": + if len(args) < 3 { + return fmt.Errorf("usage: jobs get ") + } + return c.jobsGet(args[2]) + case "cancel": + if len(args) < 3 { + return fmt.Errorf("usage: jobs cancel ") + } + return c.jobsCancel(args[2]) + } + } + return fmt.Errorf("unknown command %q", args[0]) +} + +type client struct { + base string + http *http.Client +} + +func (c *client) ready() error { + return c.getPrint("/ready") +} + +func (c *client) predict(kv []string) error { + q := url.Values{} + for _, p := range kv { + idx := strings.IndexByte(p, '=') + if idx <= 0 { + return fmt.Errorf("expected key=value, got %q", p) + } + q.Set(p[:idx], p[idx+1:]) + } + return c.getPrint("/api/v1/prediction?" + q.Encode()) +} + +func (c *client) datasetsList() error { + return c.getPrint("/api/v1/admin/datasets") +} + +func (c *client) datasetsDownload(args []string) error { + fs := flag.NewFlagSet("datasets download", flag.ContinueOnError) + latest := fs.Bool("latest", false, "download the latest available run") + epoch := fs.String("epoch", "", "RFC3339 epoch to download") + if err := fs.Parse(args); err != nil { + return err + } + body := map[string]any{} + if *latest { + body["latest"] = true + } + if *epoch != "" { + body["epoch"] = *epoch + } + return c.postPrint("/api/v1/admin/datasets", body) +} + +func (c *client) datasetsDelete(epoch string) error { + return c.deletePrint("/api/v1/admin/datasets/" + url.PathEscape(epoch)) +} + +func (c *client) jobsList() error { return c.getPrint("/api/v1/admin/jobs") } +func (c *client) jobsGet(id string) error { + return c.getPrint("/api/v1/admin/jobs/" + url.PathEscape(id)) +} +func (c *client) jobsCancel(id string) error { + return c.deletePrint("/api/v1/admin/jobs/" + url.PathEscape(id)) +} + +func (c *client) getPrint(path string) error { + resp, err := c.http.Get(c.base + path) + if err != nil { + return err + } + return printResp(resp) +} + +func (c *client) postPrint(path string, body any) error { + buf, err := json.Marshal(body) + if err != nil { + return err + } + resp, err := c.http.Post(c.base+path, "application/json", bytes.NewReader(buf)) + if err != nil { + return err + } + return printResp(resp) +} + +func (c *client) deletePrint(path string) error { + req, err := http.NewRequest(http.MethodDelete, c.base+path, nil) + if err != nil { + return err + } + resp, err := c.http.Do(req) + if err != nil { + return err + } + return printResp(resp) +} + +func printResp(resp *http.Response) error { + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode >= 400 { + return fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body))) + } + // Pretty-print JSON when possible; raw bytes otherwise. + if strings.Contains(resp.Header.Get("Content-Type"), "json") && len(body) > 0 { + var any any + if err := json.Unmarshal(body, &any); err == nil { + pretty, _ := json.MarshalIndent(any, "", " ") + fmt.Println(string(pretty)) + return nil + } + } + if len(body) > 0 { + fmt.Println(strings.TrimSpace(string(body))) + } + return nil +} + diff --git a/cmd/predictor/main.go b/cmd/predictor/main.go new file mode 100644 index 0000000..13391ce --- /dev/null +++ b/cmd/predictor/main.go @@ -0,0 +1,181 @@ +// Command predictor is the stratoflights-predictor HTTP server. +// +// It wires the configuration, dataset manager, scheduler, and API layer +// into a single process and exits cleanly on SIGINT/SIGTERM. +package main + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "github.com/go-co-op/gocron" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "predictor-refactored/internal/api" + "predictor-refactored/internal/config" + "predictor-refactored/internal/datasets" + "predictor-refactored/internal/datasets/gfs" + "predictor-refactored/internal/elevation" + "predictor-refactored/internal/metrics" +) + +func main() { + if err := run(os.Args[1:]); err != nil { + fmt.Fprintln(os.Stderr, "fatal:", err) + os.Exit(1) + } +} + +func run(args []string) error { + cfg, err := config.Load(args) + if err != nil { + return fmt.Errorf("load config: %w", err) + } + + log, err := newLogger(cfg.Log.Level) + if err != nil { + return fmt.Errorf("init logger: %w", err) + } + defer log.Sync() + + log.Info("configuration loaded", + zap.Int("port", cfg.HTTP.Port), + zap.String("data_dir", cfg.Data.Dir), + zap.String("source", cfg.Data.Source), + zap.Int("download_parallel", cfg.Download.Parallel), + zap.Duration("update_interval", cfg.Download.UpdateInterval), + zap.Duration("freshness_ttl", cfg.Download.FreshnessTTL), + zap.Bool("metrics_enabled", cfg.Metrics.Enabled), + ) + + store, err := datasets.NewLocalStore(cfg.Data.Dir, cfg.Data.Source) + if err != nil { + return fmt.Errorf("init store: %w", err) + } + + // Source is GFS today; the spec leaves room for ECMWF later via the + // same datasets.Source interface. + if cfg.Data.Source != "noaa-gfs-0p50" { + return fmt.Errorf("source %q not supported", cfg.Data.Source) + } + source := gfs.NewSource(log) + source.Parallel = cfg.Download.Parallel + + var throttle datasets.Throttle + if cfg.Download.BandwidthBytesPerSecond > 0 { + throttle = datasets.NewTokenBucket(cfg.Download.BandwidthBytesPerSecond) + } + + // Metrics (optional). + var sink metrics.Sink = metrics.Noop() + var metricsHandler http.Handler + if cfg.Metrics.Enabled { + prom := metrics.NewProm() + sink = prom + metricsHandler = prom + } + + mgr := datasets.New(source, store, throttle, log) + defer mgr.Close() + + // Optional elevation dataset. Missing or unreadable elevation is logged + // but non-fatal; descent terminates at sea level instead. + var elev *elevation.Dataset + if cfg.Data.ElevationPath != "" { + if d, err := elevation.Open(cfg.Data.ElevationPath); err == nil { + elev = d + log.Info("elevation dataset loaded", zap.String("path", cfg.Data.ElevationPath)) + defer elev.Close() + } else { + log.Warn("elevation dataset not available, using sea-level termination", + zap.String("path", cfg.Data.ElevationPath), + zap.Error(err)) + } + } + + // Kick off the initial refresh in the background so the server can start + // answering /ready immediately. + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute) + defer cancel() + if _, err := mgr.Refresh(ctx, cfg.Download.FreshnessTTL); err != nil { + log.Error("initial dataset refresh failed", zap.Error(err)) + } + if a := mgr.Active(); a != nil { + sink.ActiveEpoch(a.Epoch()) + } + }() + + scheduler := gocron.NewScheduler(time.UTC) + scheduler.Every(cfg.Download.UpdateInterval).Do(func() { + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute) + defer cancel() + log.Info("scheduled dataset refresh starting") + if _, err := mgr.Refresh(ctx, cfg.Download.FreshnessTTL); err != nil { + log.Error("scheduled dataset refresh failed", zap.Error(err)) + } + if a := mgr.Active(); a != nil { + sink.ActiveEpoch(a.Epoch()) + } + }) + scheduler.StartAsync() + defer scheduler.Stop() + + server, err := api.New(cfg.HTTP.Port, api.Deps{ + Manager: mgr, + Elevation: elev, + Metrics: sink, + MetricsHandler: metricsHandler, + MetricsPath: cfg.Metrics.Path, + Log: log, + }) + if err != nil { + return fmt.Errorf("init server: %w", err) + } + + // Graceful shutdown + ctx, cancel := signalContext() + defer cancel() + + log.Info("service started") + if err := server.Run(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("http server: %w", err) + } + log.Info("service stopped") + return nil +} + +func signalContext() (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(context.Background()) + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) + go func() { + <-sigCh + cancel() + }() + return ctx, cancel +} + +func newLogger(level string) (*zap.Logger, error) { + cfg := zap.NewProductionConfig() + switch level { + case "debug": + cfg.Level = zap.NewAtomicLevelAt(zapcore.DebugLevel) + case "info": + cfg.Level = zap.NewAtomicLevelAt(zapcore.InfoLevel) + case "warn": + cfg.Level = zap.NewAtomicLevelAt(zapcore.WarnLevel) + case "error": + cfg.Level = zap.NewAtomicLevelAt(zapcore.ErrorLevel) + default: + cfg.Level = zap.NewAtomicLevelAt(zapcore.InfoLevel) + } + return cfg.Build() +} diff --git a/docs/numerics.tex b/docs/numerics.tex new file mode 100644 index 0000000..3706f45 --- /dev/null +++ b/docs/numerics.tex @@ -0,0 +1,160 @@ +\documentclass[a4paper,11pt]{article} +\usepackage[margin=1in]{geometry} +\usepackage{amsmath, amssymb} +\usepackage{algorithm, algpseudocode} +\usepackage{hyperref} + +\title{Numerics Library: Mathematical Reference} +\author{stratoflights-predictor} +\date{} + +\begin{document} +\maketitle + +This document describes every numerical primitive in the +\verb|internal/numerics| package. Each section pairs the mathematical +definition with a pointer to the Go implementation and at least one +worked example that can be reproduced manually. + +\section{Regular-grid bracketing} + +\paragraph{Definition.} An \emph{axis} is the regularly-spaced sequence +$x_i = \ell + i \cdot s$ for $i = 0, 1, \ldots, N - 1$, parameterised by +the left edge $\ell$, the step $s > 0$, and the point count $N$. + +Given a query $v$, the \emph{bracket} is the pair $(i_0, i_1)$ with +$x_{i_0} \le v < x_{i_1}$ and the dimensionless position +\[ + f = \frac{v - x_{i_0}}{s} \in [0, 1). +\] +Implemented as \verb|Axis.Locate| (\verb|internal/numerics/grid.go|). + +\paragraph{Wrapping axes.} For periodic axes (e.g.\ longitude), the +sequence is extended by the convention $x_N = x_0$ so a value approaching +$x_N$ from below brackets $(N{-}1, 0)$ with fraction +$f = (v - x_{N-1})/s$. + +\paragraph{Domain.} The bracket is undefined when $v$ falls outside the +half-open interval $[\ell, \ell + (N{-}1)\,s)$ (for non-wrapping axes) or +$[\ell, \ell + N\,s)$ (for wrapping axes); the implementation returns +an \verb|AxisError| in those cases. + +\paragraph{Worked example.} Latitude axis $\ell = -90$, $s = 0{.}5$, +$N = 361$. Query $v = -89{.}75$ yields +$p = (-89{.}75 - (-90))/0{.}5 = 0{.}5$, so $i_0 = 0$, $i_1 = 1$, $f = 0{.}5$. + +\section{Multilinear interpolation} + +\paragraph{Definition.} For a scalar field $u$ defined at the grid nodes +of three axes, the trilinear interpolant at brackets $b_a, b_b, b_c$ is +\[ + \tilde u = \sum_{i, j, k \in \{0, 1\}} w_{a,i} \, w_{b,j} \, w_{c,k} + \; u\bigl(b_a^i, b_b^j, b_c^k\bigr), +\] +where $w_{\bullet, 0} = 1 - f_\bullet$ and $w_{\bullet, 1} = f_\bullet$. +Implemented as \verb|EvalTrilinear|. + +\paragraph{Linear exactness.} For any affine field +$u(i, j, k) = \alpha i + \beta j + \gamma k + \delta$, the formula returns +$\alpha \cdot p_a + \beta \cdot p_b + \gamma \cdot p_c + \delta$ exactly +(modulo floating-point rounding), where $p_\bullet = b_\bullet^0 + f_\bullet$. + +\paragraph{Evaluation order.} The eight corner terms are accumulated in +the order $(0,0,0), (0,0,1), \ldots, (1,1,1)$, matching the reference +Tawhiri implementation \emph{exactly} so that double-precision results +agree bit-for-bit. + +\section{Monotone bisection} + +\paragraph{Definition.} For an integer-indexed monotone non-decreasing +sequence $f : \{i_{\min}, \ldots, i_{\max}\} \to \mathbb{R}$ and a target +$t$, $\mathrm{Bisect}$ returns the largest index $i^\star$ with +$f(i^\star) < t$. The implementation evaluates $f$ on a midpoint +$m = \lceil(i_{\min} + i_{\max})/2\rceil$ each iteration and halves the +interval, taking $\mathcal{O}(\log(i_{\max} - i_{\min}))$ evaluations. + +\paragraph{Boundary behaviour.} If $t \le f(i_{\min})$, the function +returns $i_{\min}$; if $t > f(i_{\max})$, it returns $i_{\max}$. + +\paragraph{Usage in this codebase.} The pressure-level search in the GFS +wind field locates the largest level whose interpolated geopotential +height is below the query altitude; vertical interpolation then runs +between that level and its successor. + +\section{Classical Runge--Kutta--4 integrator} + +\paragraph{Definition.} For a state $y$, derivative $\dot y = f(t, y)$, +and step $\Delta t$, \verb|RK4Step| applies the classical RK4 update +\[ +\begin{aligned} + k_1 &= f(t, y), \\ + k_2 &= f\bigl(t + \tfrac{\Delta t}{2}, \; y + \tfrac{\Delta t}{2} k_1\bigr), \\ + k_3 &= f\bigl(t + \tfrac{\Delta t}{2}, \; y + \tfrac{\Delta t}{2} k_2\bigr), \\ + k_4 &= f\bigl(t + \Delta t, \; y + \Delta t \cdot k_3\bigr), \\ + y(t + \Delta t) &= y + \tfrac{\Delta t}{6}\bigl(k_1 + 2 k_2 + 2 k_3 + k_4\bigr). +\end{aligned} +\] + +\paragraph{Reverse-time integration.} Passing $\Delta t < 0$ integrates +backwards in time. The derivative $f$ is treated as direction-independent: +all sign accounting lives in the integrator. The implementation contains +no explicit branch on the sign of $\Delta t$. + +\paragraph{Vector state.} \verb|RK4Step| is generic on the state type. +Domain-specific vector arithmetic (in particular longitude wrap on the +$0\!:\!360$ degree circle) is injected via the \verb|VecAdd| operation +$\mathrm{add}(y, k, \delta y) = y + k \cdot \delta y$. + +\section{Termination-point refinement} + +After each integration step the propagator checks one or more +constraints. When a constraint reports a violation between $(t_1, y_1)$ +(not violated) and $(t_2, y_2)$ (violated), \verb|RefineTrigger| +locates the crossing within tolerance $\tau \in (0, 1)$ by binary +search in the linear interpolation parameter $\lambda$: + +\begin{algorithm}[H] +\caption{RefineTrigger}\label{alg:refine} +\begin{algorithmic}[1] + \State $L \gets 0,\; R \gets 1$ + \State $t_3 \gets t_2,\; y_3 \gets y_2$ + \While{$R - L > \tau$} + \State $m \gets (L + R)/2$ + \State $t_3 \gets (1 - m)\,t_1 + m\,t_2$ + \State $y_3 \gets \mathrm{lerp}(y_1, y_2, m)$ + \If{constraint violated at $(t_3, y_3)$} + \State $R \gets m$ + \Else + \State $L \gets m$ + \EndIf + \EndWhile + \State \Return $(t_3, y_3)$ +\end{algorithmic} +\end{algorithm} + +\paragraph{Termination guarantee.} After $\lceil \log_2 \tau^{-1} \rceil$ +iterations, $R - L \le \tau$. With $\tau = 0{.}01$ and $\Delta t = 60$~s, +the returned point is within $0{.}6$~s of the true crossing in parameter +space; the corresponding altitude error is bounded by $0{.}6\,|\dot y|$, +which for typical balloon ascent and parachute descent rates is at most +$\sim 3$~m. + +\paragraph{Quirk.} The returned $(t_3, y_3)$ is the \emph{last midpoint +sampled} rather than guaranteed to lie on the triggered side; this +matches the reference Tawhiri implementation byte-for-byte. + +\paragraph{Vector lerp.} As with \verb|RK4Step|, the per-coordinate +linear interpolation is delegated to the caller's \verb|VecLerp| to keep +the integrator agnostic of state semantics. The engine package's +\verb|lerpState| applies the shorter-arc convention for longitudes +crossing the $0\!:\!360$ boundary. + +\section{Implementation notes} + +The library is intentionally small (under 300 lines of Go) and uses no +runtime allocations on the hot path. The type-generic \verb|RK4Step| and +\verb|RefineTrigger| compile to per-type specialisations under Go's +generics, so a future C or Rust port can mirror the implementation +verbatim without changing the call sites in the trajectory engine. + +\end{document} diff --git a/go.mod b/go.mod index 35b9486..744ff76 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/go-co-op/gocron v1.37.0 github.com/go-faster/errors v0.7.1 github.com/go-faster/jx v1.2.0 + github.com/google/uuid v1.6.0 github.com/nilsmagnus/grib v1.2.8 github.com/ogen-go/ogen v1.20.2 go.opentelemetry.io/otel v1.42.0 @@ -14,6 +15,7 @@ require ( go.opentelemetry.io/otel/trace v1.42.0 go.uber.org/zap v1.27.1 golang.org/x/sync v0.20.0 + gopkg.in/yaml.v2 v2.4.0 ) require ( @@ -24,7 +26,6 @@ require ( github.com/go-faster/yaml v0.4.6 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect @@ -37,5 +38,4 @@ require ( golang.org/x/net v0.52.0 // indirect golang.org/x/sys v0.42.0 // indirect golang.org/x/text v0.35.0 // indirect - gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/internal/api/admin/datasets.go b/internal/api/admin/datasets.go new file mode 100644 index 0000000..7d2c624 --- /dev/null +++ b/internal/api/admin/datasets.go @@ -0,0 +1,206 @@ +// Package admin implements dataset-management HTTP endpoints used by the +// stratoflights operator console. +// +// Endpoints: +// +// GET /api/v1/admin/datasets list stored epochs +// POST /api/v1/admin/datasets trigger a download +// DELETE /api/v1/admin/datasets/{epoch} delete a stored epoch +// GET /api/v1/admin/jobs list all jobs +// GET /api/v1/admin/jobs/{id} fetch one job +// DELETE /api/v1/admin/jobs/{id} cancel a running job +package admin + +import ( + "context" + "encoding/json" + "net/http" + "time" + + "go.uber.org/zap" + + "predictor-refactored/internal/datasets" +) + +// Handler serves all /api/v1/admin/* endpoints. +type Handler struct { + mgr *datasets.Manager + log *zap.Logger +} + +// New wires an admin handler. +func New(mgr *datasets.Manager, log *zap.Logger) *Handler { + if log == nil { + log = zap.NewNop() + } + return &Handler{mgr: mgr, log: log} +} + +// Register installs admin routes on mux. Routes are mounted under +// /api/v1/admin/... +func (h *Handler) Register(mux *http.ServeMux) { + mux.HandleFunc("GET /api/v1/admin/datasets", h.listDatasets) + mux.HandleFunc("POST /api/v1/admin/datasets", h.triggerDownload) + mux.HandleFunc("DELETE /api/v1/admin/datasets/{epoch}", h.deleteDataset) + mux.HandleFunc("GET /api/v1/admin/jobs", h.listJobs) + mux.HandleFunc("GET /api/v1/admin/jobs/{id}", h.getJob) + mux.HandleFunc("DELETE /api/v1/admin/jobs/{id}", h.cancelJob) +} + +// listDatasets handles GET /api/v1/admin/datasets. +func (h *Handler) listDatasets(w http.ResponseWriter, _ *http.Request) { + epochs, err := h.mgr.ListEpochs() + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + active := "" + if a := h.mgr.Active(); a != nil { + active = a.Epoch().UTC().Format(time.RFC3339) + } + out := struct { + Source string `json:"source"` + Active string `json:"active,omitempty"` + Epochs []string `json:"epochs"` + }{ + Source: h.mgr.Source(), + Active: active, + } + for _, e := range epochs { + out.Epochs = append(out.Epochs, e.UTC().Format(time.RFC3339)) + } + writeJSON(w, http.StatusOK, out) +} + +// triggerDownload handles POST /api/v1/admin/datasets. +// +// Body: {"epoch": "2026-03-28T06:00:00Z"} OR {"latest": true}. +func (h *Handler) triggerDownload(w http.ResponseWriter, r *http.Request) { + var body struct { + Epoch string `json:"epoch,omitempty"` + Latest bool `json:"latest,omitempty"` + } + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + writeError(w, http.StatusBadRequest, "invalid body: "+err.Error()) + return + } + if !body.Latest && body.Epoch == "" { + writeError(w, http.StatusBadRequest, "specify either epoch or latest=true") + return + } + + var epoch time.Time + if body.Latest { + ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) + defer cancel() + jobID, err := h.mgr.Refresh(ctx, 0) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + writeJSON(w, http.StatusAccepted, map[string]string{"job_id": jobID}) + return + } + + var err error + epoch, err = time.Parse(time.RFC3339, body.Epoch) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid epoch: "+err.Error()) + return + } + jobID := h.mgr.Download(epoch) + writeJSON(w, http.StatusAccepted, map[string]string{"job_id": jobID}) +} + +// deleteDataset handles DELETE /api/v1/admin/datasets/{epoch}. +func (h *Handler) deleteDataset(w http.ResponseWriter, r *http.Request) { + rawEpoch := r.PathValue("epoch") + epoch, err := time.Parse(time.RFC3339, rawEpoch) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid epoch: "+err.Error()) + return + } + if err := h.mgr.RemoveEpoch(epoch); err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + w.WriteHeader(http.StatusNoContent) +} + +// listJobs handles GET /api/v1/admin/jobs. +func (h *Handler) listJobs(w http.ResponseWriter, _ *http.Request) { + jobs := h.mgr.ListJobs() + out := make([]jobDTO, 0, len(jobs)) + for _, j := range jobs { + out = append(out, toDTO(j)) + } + writeJSON(w, http.StatusOK, out) +} + +// getJob handles GET /api/v1/admin/jobs/{id}. +func (h *Handler) getJob(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + job, ok := h.mgr.GetJob(id) + if !ok { + writeError(w, http.StatusNotFound, "job not found") + return + } + writeJSON(w, http.StatusOK, toDTO(job)) +} + +// cancelJob handles DELETE /api/v1/admin/jobs/{id}. +func (h *Handler) cancelJob(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if !h.mgr.CancelJob(id) { + writeError(w, http.StatusConflict, "job not found or already terminal") + return + } + w.WriteHeader(http.StatusNoContent) +} + +type jobDTO struct { + ID string `json:"id"` + Source string `json:"source"` + Epoch string `json:"epoch"` + Status string `json:"status"` + StartedAt string `json:"started_at"` + EndedAt string `json:"ended_at,omitempty"` + Err string `json:"error,omitempty"` + Total int `json:"total_units"` + Done int `json:"done_units"` + Bytes int64 `json:"bytes"` +} + +func toDTO(j datasets.JobInfo) jobDTO { + dto := jobDTO{ + ID: j.ID, + Source: j.Source, + Epoch: j.Epoch.UTC().Format(time.RFC3339), + Status: string(j.Status), + StartedAt: j.StartedAt.UTC().Format(time.RFC3339), + Err: j.Err, + Total: j.Total, + Done: j.Done, + Bytes: j.Bytes, + } + if j.EndedAt != nil { + dto.EndedAt = j.EndedAt.UTC().Format(time.RFC3339) + } + return dto +} + +func writeJSON(w http.ResponseWriter, status int, body any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(body) +} + +func writeError(w http.ResponseWriter, status int, description string) { + writeJSON(w, status, map[string]any{ + "error": map[string]string{ + "type": http.StatusText(status), + "description": description, + }, + }) +} + diff --git a/internal/api/middleware/cors.go b/internal/api/middleware/cors.go new file mode 100644 index 0000000..30a322c --- /dev/null +++ b/internal/api/middleware/cors.go @@ -0,0 +1,20 @@ +package middleware + +import "net/http" + +// CORS wraps next with permissive CORS headers and short-circuits OPTIONS preflight. +// +// This service is meant to sit behind an authenticated gateway, so we set +// "Access-Control-Allow-Origin: *". Tighten this if you deploy elsewhere. +func CORS(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + if r.Method == http.MethodOptions { + w.WriteHeader(http.StatusNoContent) + return + } + next.ServeHTTP(w, r) + }) +} diff --git a/internal/api/middleware/log.go b/internal/api/middleware/log.go new file mode 100644 index 0000000..bb76eb5 --- /dev/null +++ b/internal/api/middleware/log.go @@ -0,0 +1,51 @@ +// Package middleware contains HTTP and ogen middleware used by the API layer. +package middleware + +import ( + "net/http" + "time" + + "github.com/ogen-go/ogen/middleware" + "go.uber.org/zap" +) + +// OgenLogging is an ogen middleware that logs request duration and outcome. +func OgenLogging(log *zap.Logger) middleware.Middleware { + return func(req middleware.Request, next func(req middleware.Request) (middleware.Response, error)) (middleware.Response, error) { + lg := log.With(zap.String("op", req.OperationID)) + start := time.Now() + resp, err := next(req) + dur := time.Since(start) + if err != nil { + lg.Error("request failed", zap.Duration("duration", dur), zap.Error(err)) + } else { + lg.Info("request completed", zap.Duration("duration", dur)) + } + return resp, err + } +} + +// statusRecorder captures the response status for HTTPLogging. +type statusRecorder struct { + http.ResponseWriter + status int +} + +func (r *statusRecorder) WriteHeader(code int) { + r.status = code + r.ResponseWriter.WriteHeader(code) +} + +// HTTPLogging wraps the given http.Handler with a per-request log line. +func HTTPLogging(log *zap.Logger, next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + rec := &statusRecorder{ResponseWriter: w, status: 200} + next.ServeHTTP(rec, r) + log.Info("http", + zap.String("method", r.Method), + zap.String("path", r.URL.Path), + zap.Int("status", rec.status), + zap.Duration("duration", time.Since(start))) + }) +} diff --git a/internal/api/tawhiri/handler.go b/internal/api/tawhiri/handler.go new file mode 100644 index 0000000..bf9a103 --- /dev/null +++ b/internal/api/tawhiri/handler.go @@ -0,0 +1,252 @@ +// Package tawhiri implements the legacy Tawhiri-compatible HTTP endpoint +// (GET /api/v1/prediction). The request/response shapes match the original +// Cambridge University Spaceflight predictor for drop-in compatibility. +// +// Internally the handler builds an engine.Profile from query parameters and +// dispatches it through the same engine path as the new v2 endpoint. +package tawhiri + +import ( + "context" + "errors" + "net/http" + "time" + + "go.uber.org/zap" + + "predictor-refactored/internal/datasets" + "predictor-refactored/internal/elevation" + "predictor-refactored/internal/engine" + "predictor-refactored/internal/metrics" + api "predictor-refactored/pkg/rest" +) + +// Handler implements api.Handler (the ogen-generated interface for +// performPrediction and readinessCheck). +type Handler struct { + mgr *datasets.Manager + elev *elevation.Dataset + metrics metrics.Sink + log *zap.Logger +} + +// New wires a Handler. +func New(mgr *datasets.Manager, elev *elevation.Dataset, sink metrics.Sink, log *zap.Logger) *Handler { + if log == nil { + log = zap.NewNop() + } + if sink == nil { + sink = metrics.Noop() + } + return &Handler{mgr: mgr, elev: elev, metrics: sink, log: log} +} + +// Compile-time check that Handler satisfies api.Handler. +var _ api.Handler = (*Handler)(nil) + +// PerformPrediction runs the Tawhiri-style prediction. +func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredictionParams) (*api.PredictionResponse, error) { + field := h.mgr.Active() + if field == nil { + return nil, newError(http.StatusServiceUnavailable, "no dataset loaded, service is starting up") + } + + // Parameters with Tawhiri defaults. + profileKind := "standard_profile" + if v, ok := params.Profile.Get(); ok { + profileKind = string(v) + } + ascentRate := 5.0 + if v, ok := params.AscentRate.Get(); ok { + ascentRate = v + } + burstAltitude := 28000.0 + if v, ok := params.BurstAltitude.Get(); ok { + burstAltitude = v + } + descentRate := 5.0 + if v, ok := params.DescentRate.Get(); ok { + descentRate = v + } + launchAlt := 0.0 + if v, ok := params.LaunchAltitude.Get(); ok { + launchAlt = v + } + + lng := params.LaunchLongitude + if lng < 0 { + lng += 360 + } + + launchTime := float64(params.LaunchDatetime.Unix()) + warnings := &engine.Warnings{} + + // Build the profile. + var stageNames []string + var prof engine.Profile + switch profileKind { + case "standard_profile": + stageNames = []string{"ascent", "descent"} + prof = engine.Profile{ + Direction: engine.Forward, + Stages: []*engine.Propagator{ + { + Name: "ascent", + Step: 60, + Model: engine.Sum( + engine.ConstantRate(ascentRate), + engine.WindTransport(field, warnings), + ), + Constraints: []engine.Constraint{engine.MaxAltitude{Limit: burstAltitude, On: engine.ActionStop}}, + }, + { + Name: "descent", + Step: 60, + Model: engine.Sum( + engine.ParachuteDescent(descentRate), + engine.WindTransport(field, warnings), + ), + Constraints: descentConstraints(h.elev), + }, + }, + } + case "float_profile": + floatAlt := 25000.0 + if v, ok := params.FloatAltitude.Get(); ok { + floatAlt = v + } + stopTime := params.LaunchDatetime.Add(24 * time.Hour) + if v, ok := params.StopDatetime.Get(); ok { + stopTime = v + } + stageNames = []string{"ascent", "float"} + prof = engine.Profile{ + Direction: engine.Forward, + Stages: []*engine.Propagator{ + { + Name: "ascent", + Step: 60, + Model: engine.Sum( + engine.ConstantRate(ascentRate), + engine.WindTransport(field, warnings), + ), + Constraints: []engine.Constraint{engine.MaxAltitude{Limit: floatAlt, On: engine.ActionStop}}, + }, + { + Name: "float", + Step: 60, + Model: engine.WindTransport(field, warnings), + Constraints: []engine.Constraint{engine.MaxTime{Limit: float64(stopTime.Unix()), On: engine.ActionStop}}, + }, + }, + } + default: + return nil, newError(http.StatusBadRequest, "unknown profile: "+profileKind) + } + + started := time.Now().UTC() + results := prof.Run(launchTime, engine.State{Lat: params.LaunchLatitude, Lng: lng, Altitude: launchAlt}) + completed := time.Now().UTC() + h.metrics.Prediction(profileKind, completed.Sub(started), nil) + + resp := &api.PredictionResponse{ + Metadata: api.PredictionResponseMetadata{ + StartDatetime: started, + CompleteDatetime: completed, + }, + } + + for i, r := range results { + stageName := "ascent" + if i < len(stageNames) { + stageName = stageNames[i] + } + stageEnum := api.PredictionResponsePredictionItemStageAscent + switch stageName { + case "descent": + stageEnum = api.PredictionResponsePredictionItemStageDescent + case "float": + stageEnum = api.PredictionResponsePredictionItemStageFloat + } + traj := make([]api.PredictionResponsePredictionItemTrajectoryItem, 0, len(r.Points)) + for _, pt := range r.Points { + ptLng := pt.Lng + if ptLng > 180 { + ptLng -= 360 + } + traj = append(traj, api.PredictionResponsePredictionItemTrajectoryItem{ + Datetime: time.Unix(int64(pt.Time), 0).UTC(), + Latitude: pt.Lat, + Longitude: ptLng, + Altitude: pt.Altitude, + }) + } + resp.Prediction = append(resp.Prediction, api.PredictionResponsePredictionItem{ + Stage: stageEnum, + Trajectory: traj, + }) + } + + resp.Request = api.NewOptPredictionResponseRequest(api.PredictionResponseRequest{ + Dataset: api.NewOptString(field.Epoch().Format("2006-01-02T15:04:05Z")), + LaunchLatitude: api.NewOptFloat64(params.LaunchLatitude), + LaunchLongitude: api.NewOptFloat64(params.LaunchLongitude), + LaunchDatetime: api.NewOptString(params.LaunchDatetime.Format(time.RFC3339)), + LaunchAltitude: params.LaunchAltitude, + }) + + if warns := warnings.ToMap(); len(warns) > 0 { + resp.Warnings = api.NewOptPredictionResponseWarnings(api.PredictionResponseWarnings{}) + } + + h.log.Info("prediction complete", + zap.String("profile", profileKind), + zap.Int("stages", len(results)), + zap.Duration("elapsed", completed.Sub(started))) + + return resp, nil +} + +// descentConstraints returns the descent termination set: TerrainContact if an +// elevation dataset is loaded, MinAltitude(0) otherwise. +func descentConstraints(elev *elevation.Dataset) []engine.Constraint { + if elev != nil { + return []engine.Constraint{engine.TerrainContact{Provider: elev, On: engine.ActionStop}} + } + return []engine.Constraint{engine.MinAltitude{Limit: 0, On: engine.ActionStop}} +} + +// ReadinessCheck reports whether a dataset is currently loaded. +func (h *Handler) ReadinessCheck(_ context.Context) (*api.ReadinessResponse, error) { + resp := &api.ReadinessResponse{} + if field := h.mgr.Active(); field != nil { + resp.Status = api.ReadinessResponseStatusOk + resp.DatasetTime = api.NewOptDateTime(field.Epoch()) + } else { + resp.Status = api.ReadinessResponseStatusNotReady + resp.ErrorMessage = api.NewOptString("no dataset loaded") + } + return resp, nil +} + +// NewError implements the ogen Handler interface for unhandled errors. +func (h *Handler) NewError(_ context.Context, err error) *api.ErrorStatusCode { + var statusErr *api.ErrorStatusCode + if errors.As(err, &statusErr) { + return statusErr + } + h.log.Error("unhandled error", zap.Error(err)) + return newError(http.StatusInternalServerError, err.Error()) +} + +func newError(status int, description string) *api.ErrorStatusCode { + return &api.ErrorStatusCode{ + StatusCode: status, + Response: api.Error{ + Error: api.ErrorError{ + Type: http.StatusText(status), + Description: description, + }, + }, + } +} diff --git a/internal/api/transport.go b/internal/api/transport.go new file mode 100644 index 0000000..50d2185 --- /dev/null +++ b/internal/api/transport.go @@ -0,0 +1,109 @@ +// Package api wires together every HTTP-facing component of the service: +// +// - Tawhiri-compatible v1 endpoints generated from the OpenAPI spec (ogen); +// - The new v2 prediction endpoint; +// - Dataset and job admin endpoints under /api/v1/admin/; +// - Optional Prometheus-format metrics endpoint. +package api + +import ( + "context" + "fmt" + "net/http" + "time" + + "go.uber.org/zap" + + "predictor-refactored/internal/api/admin" + "predictor-refactored/internal/api/middleware" + "predictor-refactored/internal/api/tawhiri" + v2 "predictor-refactored/internal/api/v2" + "predictor-refactored/internal/datasets" + "predictor-refactored/internal/elevation" + "predictor-refactored/internal/metrics" + apirest "predictor-refactored/pkg/rest" +) + +// Server is the top-level HTTP server. +type Server struct { + port int + mux *http.ServeMux + log *zap.Logger +} + +// Deps are the runtime dependencies the API layer needs. +type Deps struct { + Manager *datasets.Manager + Elevation *elevation.Dataset + Metrics metrics.Sink + MetricsHandler http.Handler // optional; mounted at MetricsPath when non-nil + MetricsPath string + Log *zap.Logger +} + +// New wires the HTTP server. The returned Server is not yet started. +func New(port int, d Deps) (*Server, error) { + if d.Log == nil { + d.Log = zap.NewNop() + } + if d.Metrics == nil { + d.Metrics = metrics.Noop() + } + + mux := http.NewServeMux() + + // ogen-generated server handles the Tawhiri-compat surface + // (GET /api/v1/prediction and GET /ready). + tw := tawhiri.New(d.Manager, d.Elevation, d.Metrics, d.Log) + ogenSrv, err := apirest.NewServer(tw, apirest.WithMiddleware(middleware.OgenLogging(d.Log))) + if err != nil { + return nil, fmt.Errorf("create ogen server: %w", err) + } + + // New primary prediction endpoint. + v2h := v2.New(d.Manager, d.Elevation, d.Metrics, d.Log) + mux.Handle("/api/v2/prediction", v2h) + + // Admin endpoints. + adminH := admin.New(d.Manager, d.Log) + adminH.Register(mux) + + // Metrics endpoint. + if d.MetricsHandler != nil && d.MetricsPath != "" { + mux.Handle(d.MetricsPath, d.MetricsHandler) + } + + // Fallback to the ogen-generated routes (v1 + ready) for anything else. + mux.Handle("/", ogenSrv) + + return &Server{ + port: port, + mux: mux, + log: d.Log, + }, nil +} + +// Run starts the HTTP server and blocks until it returns. +// +// The handler chain is: CORS → request logger → mux. +func (s *Server) Run(ctx context.Context) error { + srv := &http.Server{ + Addr: fmt.Sprintf(":%d", s.port), + Handler: middleware.CORS(middleware.HTTPLogging(s.log, s.mux)), + } + + errCh := make(chan error, 1) + go func() { + s.log.Info("HTTP server starting", zap.Int("port", s.port)) + errCh <- srv.ListenAndServe() + }() + + select { + case err := <-errCh: + return err + case <-ctx.Done(): + shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + return srv.Shutdown(shutdownCtx) + } +} diff --git a/internal/api/v2/handler.go b/internal/api/v2/handler.go new file mode 100644 index 0000000..23dd886 --- /dev/null +++ b/internal/api/v2/handler.go @@ -0,0 +1,173 @@ +package v2 + +import ( + "encoding/json" + "fmt" + "net/http" + "time" + + "go.uber.org/zap" + + "predictor-refactored/internal/datasets" + "predictor-refactored/internal/elevation" + "predictor-refactored/internal/engine" + "predictor-refactored/internal/metrics" +) + +// Handler serves POST /api/v2/prediction. +type Handler struct { + mgr *datasets.Manager + elev *elevation.Dataset + metrics metrics.Sink + log *zap.Logger +} + +// New wires a v2 Handler. +func New(mgr *datasets.Manager, elev *elevation.Dataset, sink metrics.Sink, log *zap.Logger) *Handler { + if log == nil { + log = zap.NewNop() + } + if sink == nil { + sink = metrics.Noop() + } + return &Handler{mgr: mgr, elev: elev, metrics: sink, log: log} +} + +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + writeError(w, http.StatusMethodNotAllowed, "use POST") + return + } + + var req PredictionRequest + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + if err := dec.Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid request body: "+err.Error()) + return + } + + if err := validateRequest(req); err != nil { + writeError(w, http.StatusBadRequest, err.Error()) + return + } + + field := h.mgr.Active() + if field == nil { + writeError(w, http.StatusServiceUnavailable, "no dataset loaded, service is starting up") + return + } + + // Normalize longitude to [0, 360) for internal use. + lng := req.Launch.Longitude + if lng < 0 { + lng += 360 + } + + warnings := &engine.Warnings{} + var terrain engine.TerrainProvider + if h.elev != nil { + terrain = h.elev + } + + prof, err := buildProfile(req, field, terrain, warnings) + if err != nil { + writeError(w, http.StatusBadRequest, err.Error()) + return + } + + started := time.Now().UTC() + results := prof.Run(float64(req.Launch.Time.Unix()), engine.State{ + Lat: req.Launch.Latitude, + Lng: lng, + Altitude: req.Launch.Altitude, + }) + completed := time.Now().UTC() + h.metrics.Prediction("v2", completed.Sub(started), nil) + + resp := PredictionResponse{ + Stages: make([]StageResult, 0, len(results)), + StartedAt: started, + CompletedAt: completed, + Dataset: DatasetInfo{ + Source: field.Source(), + Epoch: field.Epoch(), + }, + } + for _, r := range results { + stage := StageResult{ + Name: r.Propagator, + Outcome: outcomeString(r.Outcome), + } + if r.Constraint != nil { + stage.Constraint = r.Constraint.Name() + } + stage.Trajectory = make([]TrajectoryPoint, len(r.Points)) + for i, pt := range r.Points { + ptLng := pt.Lng + if ptLng > 180 { + ptLng -= 360 + } + stage.Trajectory[i] = TrajectoryPoint{ + Time: time.Unix(int64(pt.Time), 0).UTC(), + Latitude: pt.Lat, + Longitude: ptLng, + Altitude: pt.Altitude, + } + } + resp.Stages = append(resp.Stages, stage) + } + if warns := warnings.ToMap(); len(warns) > 0 { + resp.Warnings = warns + } + + h.log.Info("v2 prediction complete", + zap.Int("stages", len(results)), + zap.Duration("elapsed", completed.Sub(started))) + writeJSON(w, http.StatusOK, resp) +} + +func validateRequest(req PredictionRequest) error { + if req.Launch.Latitude < -90 || req.Launch.Latitude > 90 { + return fmt.Errorf("launch.latitude must be in [-90, 90]") + } + if req.Launch.Longitude < -180 || req.Launch.Longitude >= 360 { + return fmt.Errorf("launch.longitude must be in [-180, 360)") + } + if len(req.Profile) == 0 { + return fmt.Errorf("profile must contain at least one stage") + } + for i, s := range req.Profile { + if s.Name == "" { + return fmt.Errorf("profile[%d].name is required", i) + } + if s.Model.Type == "" { + return fmt.Errorf("profile[%d].model.type is required", i) + } + } + return nil +} + +func outcomeString(o engine.Outcome) string { + switch o { + case engine.OutcomeStopped: + return "stopped" + case engine.OutcomeFallback: + return "fallback" + default: + return "continued" + } +} + +func writeError(w http.ResponseWriter, status int, description string) { + writeJSON(w, status, ErrorResponse{Error: ErrorBody{ + Type: http.StatusText(status), + Description: description, + }}) +} + +func writeJSON(w http.ResponseWriter, status int, body any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(body) +} diff --git a/internal/api/v2/profile.go b/internal/api/v2/profile.go new file mode 100644 index 0000000..36f11f8 --- /dev/null +++ b/internal/api/v2/profile.go @@ -0,0 +1,145 @@ +package v2 + +import ( + "fmt" + + "predictor-refactored/internal/engine" + "predictor-refactored/internal/weather" +) + +// buildProfile translates a PredictionRequest into an engine.Profile. +// +// elev may be nil when no terrain dataset is loaded; TerrainContact constraints +// will return an error in that case. +func buildProfile(req PredictionRequest, field weather.WindField, elev engine.TerrainProvider, warnings *engine.Warnings) (engine.Profile, error) { + if len(req.Profile) == 0 { + return engine.Profile{}, fmt.Errorf("profile must contain at least one stage") + } + + step := req.Options.StepSeconds + if step == 0 { + step = 60 + } + tol := req.Options.Tolerance + if tol == 0 { + tol = 0.01 + } + + dir := engine.Forward + switch req.Direction { + case "", "forward": + dir = engine.Forward + case "reverse": + dir = engine.Reverse + default: + return engine.Profile{}, fmt.Errorf("unknown direction %q", req.Direction) + } + + props := make([]*engine.Propagator, len(req.Profile)) + for i, stage := range req.Profile { + model, err := buildModel(stage.Model, field, warnings) + if err != nil { + return engine.Profile{}, fmt.Errorf("stage %q: %w", stage.Name, err) + } + constraints, err := buildConstraints(stage.Constraints, elev) + if err != nil { + return engine.Profile{}, fmt.Errorf("stage %q: %w", stage.Name, err) + } + props[i] = &engine.Propagator{ + Name: stage.Name, + Step: step, + Model: model, + Constraints: constraints, + Tolerance: tol, + } + } + + // Wire fallbacks once all stages exist. + for i, stage := range req.Profile { + if stage.FallbackIndex == nil { + continue + } + idx := *stage.FallbackIndex + if idx < 0 || idx >= len(props) { + return engine.Profile{}, fmt.Errorf("stage %q: fallback_index %d out of range", stage.Name, idx) + } + props[i].Fallback = props[idx] + } + + return engine.Profile{Stages: props, Direction: dir}, nil +} + +func buildModel(spec ModelSpec, field weather.WindField, warnings *engine.Warnings) (engine.Model, error) { + var base engine.Model + switch spec.Type { + case "constant_rate": + base = engine.ConstantRate(spec.Rate) + case "parachute_descent": + if spec.SeaLevelRate <= 0 { + return nil, fmt.Errorf("parachute_descent requires positive sea_level_rate") + } + base = engine.ParachuteDescent(spec.SeaLevelRate) + case "piecewise": + segs := make([]engine.RateSegment, len(spec.Segments)) + for i, s := range spec.Segments { + segs[i] = engine.RateSegment{Until: s.Until, Rate: s.Rate} + } + base = engine.Piecewise(segs) + case "wind": + if field == nil { + return nil, fmt.Errorf("wind model requires a loaded dataset") + } + return engine.WindTransport(field, warnings), nil + default: + return nil, fmt.Errorf("unknown model type %q", spec.Type) + } + + if spec.IncludeWind { + if field == nil { + return nil, fmt.Errorf("include_wind requires a loaded dataset") + } + return engine.Sum(base, engine.WindTransport(field, warnings)), nil + } + return base, nil +} + +func buildConstraints(specs []ConstraintSpec, elev engine.TerrainProvider) ([]engine.Constraint, error) { + out := make([]engine.Constraint, 0, len(specs)) + for _, spec := range specs { + action, err := parseAction(spec.Action) + if err != nil { + return nil, err + } + var c engine.Constraint + switch spec.Type { + case "max_altitude": + c = engine.MaxAltitude{Limit: spec.Limit, On: action} + case "min_altitude": + c = engine.MinAltitude{Limit: spec.Limit, On: action} + case "max_time": + c = engine.MaxTime{Limit: spec.Limit, On: action} + case "terrain_contact": + if elev == nil { + return nil, fmt.Errorf("terrain_contact requires an elevation dataset") + } + c = engine.TerrainContact{Provider: elev, On: action} + default: + return nil, fmt.Errorf("unknown constraint type %q", spec.Type) + } + out = append(out, c) + } + return out, nil +} + +func parseAction(s string) (engine.Action, error) { + switch s { + case "", "stop": + return engine.ActionStop, nil + case "fallback": + return engine.ActionFallback, nil + case "clip": + return engine.ActionClip, nil + default: + return 0, fmt.Errorf("unknown constraint action %q", s) + } +} diff --git a/internal/api/v2/types.go b/internal/api/v2/types.go new file mode 100644 index 0000000..7d76dd1 --- /dev/null +++ b/internal/api/v2/types.go @@ -0,0 +1,114 @@ +// Package v2 implements the new primary prediction endpoint, which accepts a +// user-defined profile (chain of propagators with optional constraints) and +// returns the resulting trajectory. +// +// Endpoint: POST /api/v2/prediction +package v2 + +import "time" + +// PredictionRequest is the request body for POST /api/v2/prediction. +type PredictionRequest struct { + Launch Launch `json:"launch"` + Profile []Stage `json:"profile"` + Options Options `json:"options,omitempty"` + Direction string `json:"direction,omitempty"` // "forward" (default) or "reverse" +} + +// Launch is the initial state of the balloon (or, for reverse predictions, +// the known landing point). +type Launch struct { + Time time.Time `json:"time"` + Latitude float64 `json:"latitude"` + Longitude float64 `json:"longitude"` + Altitude float64 `json:"altitude"` +} + +// Stage is one entry in the propagator chain. +type Stage struct { + Name string `json:"name"` + Model ModelSpec `json:"model"` + Constraints []ConstraintSpec `json:"constraints,omitempty"` + // FallbackIndex, when set, points to another stage in the same profile to + // transfer to on ActionFallback constraints. Optional. + FallbackIndex *int `json:"fallback_index,omitempty"` +} + +// ModelSpec describes the per-stage propagation model. +type ModelSpec struct { + // Type selects the model: "constant_rate", "parachute_descent", "piecewise", "wind". + Type string `json:"type"` + // Rate (m/s) for constant_rate. + Rate float64 `json:"rate,omitempty"` + // SeaLevelRate (m/s, positive) for parachute_descent. + SeaLevelRate float64 `json:"sea_level_rate,omitempty"` + // Segments for piecewise. + Segments []PiecewiseSegment `json:"segments,omitempty"` + // IncludeWind sums a WindTransport model into the resulting derivative, + // allowing the same stage to model both vertical motion and wind drift. + IncludeWind bool `json:"include_wind"` +} + +// PiecewiseSegment is one entry in a piecewise rate schedule. +type PiecewiseSegment struct { + Until float64 `json:"until"` // UNIX seconds; segment applies for t < Until + Rate float64 `json:"rate"` // m/s +} + +// ConstraintSpec describes one constraint attached to a stage. +type ConstraintSpec struct { + // Type: "max_altitude", "min_altitude", "max_time", "terrain_contact". + Type string `json:"type"` + // Limit is interpreted per Type: metres for altitude, UNIX seconds for time. + Limit float64 `json:"limit,omitempty"` + // Action: "stop" (default), "fallback", "clip". + Action string `json:"action,omitempty"` +} + +// Options tweaks the integrator behaviour. +type Options struct { + StepSeconds float64 `json:"step_seconds,omitempty"` + Tolerance float64 `json:"tolerance,omitempty"` +} + +// PredictionResponse is the response body for POST /api/v2/prediction. +type PredictionResponse struct { + Stages []StageResult `json:"stages"` + Warnings map[string]any `json:"warnings,omitempty"` + Dataset DatasetInfo `json:"dataset"` + StartedAt time.Time `json:"started_at"` + CompletedAt time.Time `json:"completed_at"` +} + +// StageResult is the outcome of one stage. +type StageResult struct { + Name string `json:"name"` + Outcome string `json:"outcome"` // "stopped" | "fallback" | "continued" + Constraint string `json:"constraint,omitempty"` + Trajectory []TrajectoryPoint `json:"trajectory"` +} + +// TrajectoryPoint is one sampled point of the trajectory. +type TrajectoryPoint struct { + Time time.Time `json:"time"` + Latitude float64 `json:"latitude"` + Longitude float64 `json:"longitude"` + Altitude float64 `json:"altitude"` +} + +// DatasetInfo identifies the dataset the prediction was computed against. +type DatasetInfo struct { + Source string `json:"source"` + Epoch time.Time `json:"epoch"` +} + +// ErrorResponse is the JSON error shape used by both v2 and admin endpoints. +type ErrorResponse struct { + Error ErrorBody `json:"error"` +} + +// ErrorBody is the error detail. +type ErrorBody struct { + Type string `json:"type"` + Description string `json:"description"` +} diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..6ce40ab --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,252 @@ +// Package config holds the service's runtime configuration, loaded by +// merging (in order of increasing precedence): built-in defaults, a YAML +// config file, environment variables, and command-line flags. +// +// Validation is performed once on load; downstream consumers receive an +// immutable struct. +package config + +import ( + "flag" + "fmt" + "os" + "strconv" + "time" + + "gopkg.in/yaml.v2" +) + +// Config is the top-level configuration tree. +type Config struct { + HTTP HTTPConfig `yaml:"http"` + Data DataConfig `yaml:"data"` + Download DownloadConfig `yaml:"download"` + Metrics MetricsConfig `yaml:"metrics"` + Log LogConfig `yaml:"log"` +} + +// HTTPConfig configures the HTTP server. +type HTTPConfig struct { + Port int `yaml:"port"` +} + +// DataConfig configures dataset and elevation storage. +type DataConfig struct { + Dir string `yaml:"dir"` + ElevationPath string `yaml:"elevation_path"` + // Source is the dataset source identifier; only "noaa-gfs-0p50" is supported today. + Source string `yaml:"source"` +} + +// DownloadConfig configures the dataset downloader. +type DownloadConfig struct { + Parallel int `yaml:"parallel"` + BandwidthBytesPerSecond int64 `yaml:"bandwidth_bytes_per_second"` + UpdateInterval time.Duration `yaml:"update_interval"` + FreshnessTTL time.Duration `yaml:"freshness_ttl"` +} + +// MetricsConfig configures the metrics endpoint. +type MetricsConfig struct { + Enabled bool `yaml:"enabled"` + Path string `yaml:"path"` +} + +// LogConfig configures logging. +type LogConfig struct { + Level string `yaml:"level"` // "debug", "info", "warn", "error" +} + +// Defaults returns a Config with reasonable default values. +func Defaults() Config { + return Config{ + HTTP: HTTPConfig{Port: 8080}, + Data: DataConfig{ + Dir: "/tmp/predictor-data", + ElevationPath: "/srv/ruaumoko-dataset", + Source: "noaa-gfs-0p50", + }, + Download: DownloadConfig{ + Parallel: 8, + BandwidthBytesPerSecond: 0, + UpdateInterval: 6 * time.Hour, + FreshnessTTL: 48 * time.Hour, + }, + Metrics: MetricsConfig{ + Enabled: true, + Path: "/metrics", + }, + Log: LogConfig{Level: "info"}, + } +} + +// Load resolves the configuration by merging built-in defaults with +// (in increasing precedence): a YAML file (path from PREDICTOR_CONFIG_FILE +// env var or --config flag), environment variables, and command-line flags. +// +// args is os.Args[1:] in production code; tests pass a custom slice. +func Load(args []string) (Config, error) { + cfg := Defaults() + + fs := flag.NewFlagSet("predictor", flag.ContinueOnError) + // Surface a deterministic usage by suppressing the default output: + fs.SetOutput(os.Stderr) + + var ( + configPath = fs.String("config", os.Getenv("PREDICTOR_CONFIG_FILE"), "path to YAML config file") + // Flag-driven overrides. Empty / -1 means "not specified". + flagPort = fs.Int("port", -1, "HTTP listen port") + flagDataDir = fs.String("data-dir", "", "directory for dataset files") + flagElevation = fs.String("elevation", "", "path to ruaumoko elevation dataset") + flagParallel = fs.Int("download-parallel", -1, "max concurrent GRIB downloads") + flagBandwidth = fs.Int64("download-bandwidth", -1, "download bandwidth limit in bytes/sec (0 = unlimited)") + flagInterval = fs.Duration("update-interval", 0, "scheduler refresh interval") + flagTTL = fs.Duration("freshness-ttl", 0, "max age before a dataset is considered stale") + flagMetricsEnabled = fs.Bool("metrics", true, "enable Prometheus-compatible metrics endpoint") + flagMetricsPath = fs.String("metrics-path", "", "HTTP path for the metrics endpoint") + flagLogLevel = fs.String("log-level", "", "log level: debug|info|warn|error") + ) + if err := fs.Parse(args); err != nil { + return Config{}, fmt.Errorf("parse flags: %w", err) + } + + // 1. File. + if *configPath != "" { + data, err := os.ReadFile(*configPath) + if err != nil { + return Config{}, fmt.Errorf("read config %s: %w", *configPath, err) + } + if err := yaml.UnmarshalStrict(data, &cfg); err != nil { + return Config{}, fmt.Errorf("parse config %s: %w", *configPath, err) + } + } + + // 2. Env vars. + applyEnv(&cfg) + + // 3. Flags (only when explicitly set). + if *flagPort >= 0 { + cfg.HTTP.Port = *flagPort + } + if *flagDataDir != "" { + cfg.Data.Dir = *flagDataDir + } + if *flagElevation != "" { + cfg.Data.ElevationPath = *flagElevation + } + if *flagParallel >= 0 { + cfg.Download.Parallel = *flagParallel + } + if *flagBandwidth >= 0 { + cfg.Download.BandwidthBytesPerSecond = *flagBandwidth + } + if *flagInterval != 0 { + cfg.Download.UpdateInterval = *flagInterval + } + if *flagTTL != 0 { + cfg.Download.FreshnessTTL = *flagTTL + } + // flag.Bool defaults to true here so we only override if user explicitly disables it. + if isFlagSet(fs, "metrics") { + cfg.Metrics.Enabled = *flagMetricsEnabled + } + if *flagMetricsPath != "" { + cfg.Metrics.Path = *flagMetricsPath + } + if *flagLogLevel != "" { + cfg.Log.Level = *flagLogLevel + } + + if err := cfg.Validate(); err != nil { + return Config{}, err + } + return cfg, nil +} + +func isFlagSet(fs *flag.FlagSet, name string) bool { + set := false + fs.Visit(func(f *flag.Flag) { + if f.Name == name { + set = true + } + }) + return set +} + +// applyEnv overlays PREDICTOR_* environment variables onto cfg. +func applyEnv(cfg *Config) { + if v := os.Getenv("PREDICTOR_PORT"); v != "" { + if n, err := strconv.Atoi(v); err == nil { + cfg.HTTP.Port = n + } + } + if v := os.Getenv("PREDICTOR_DATA_DIR"); v != "" { + cfg.Data.Dir = v + } + if v := os.Getenv("PREDICTOR_ELEVATION_DATASET"); v != "" { + cfg.Data.ElevationPath = v + } + if v := os.Getenv("PREDICTOR_SOURCE"); v != "" { + cfg.Data.Source = v + } + if v := os.Getenv("PREDICTOR_DOWNLOAD_PARALLEL"); v != "" { + if n, err := strconv.Atoi(v); err == nil { + cfg.Download.Parallel = n + } + } + if v := os.Getenv("PREDICTOR_DOWNLOAD_BANDWIDTH"); v != "" { + if n, err := strconv.ParseInt(v, 10, 64); err == nil { + cfg.Download.BandwidthBytesPerSecond = n + } + } + if v := os.Getenv("PREDICTOR_UPDATE_INTERVAL"); v != "" { + if d, err := time.ParseDuration(v); err == nil { + cfg.Download.UpdateInterval = d + } + } + if v := os.Getenv("PREDICTOR_DATASET_TTL"); v != "" { + if d, err := time.ParseDuration(v); err == nil { + cfg.Download.FreshnessTTL = d + } + } + if v := os.Getenv("PREDICTOR_METRICS_ENABLED"); v != "" { + cfg.Metrics.Enabled = v == "1" || v == "true" || v == "yes" + } + if v := os.Getenv("PREDICTOR_METRICS_PATH"); v != "" { + cfg.Metrics.Path = v + } + if v := os.Getenv("PREDICTOR_LOG_LEVEL"); v != "" { + cfg.Log.Level = v + } +} + +// Validate reports configuration errors. +func (c Config) Validate() error { + if c.HTTP.Port < 0 || c.HTTP.Port > 65535 { + return fmt.Errorf("http.port %d outside [0, 65535]", c.HTTP.Port) + } + if c.Data.Dir == "" { + return fmt.Errorf("data.dir is required") + } + if c.Data.Source == "" { + return fmt.Errorf("data.source is required") + } + if c.Download.Parallel <= 0 { + return fmt.Errorf("download.parallel must be > 0") + } + if c.Download.UpdateInterval <= 0 { + return fmt.Errorf("download.update_interval must be > 0") + } + if c.Download.FreshnessTTL <= 0 { + return fmt.Errorf("download.freshness_ttl must be > 0") + } + if c.Metrics.Enabled && c.Metrics.Path == "" { + return fmt.Errorf("metrics.path is required when metrics enabled") + } + switch c.Log.Level { + case "debug", "info", "warn", "error": + default: + return fmt.Errorf("log.level %q is not one of debug|info|warn|error", c.Log.Level) + } + return nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..8978357 --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,76 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestLoadDefaults(t *testing.T) { + t.Setenv("PREDICTOR_DATA_DIR", "") + t.Setenv("PREDICTOR_PORT", "") + t.Setenv("PREDICTOR_CONFIG_FILE", "") + + cfg, err := Load(nil) + if err != nil { + t.Fatalf("Load: %v", err) + } + if cfg.HTTP.Port != 8080 { + t.Errorf("default port = %d, want 8080", cfg.HTTP.Port) + } + if cfg.Download.Parallel != 8 { + t.Errorf("default parallel = %d, want 8", cfg.Download.Parallel) + } +} + +func TestLoadEnvOverridesDefaults(t *testing.T) { + t.Setenv("PREDICTOR_PORT", "9090") + t.Setenv("PREDICTOR_UPDATE_INTERVAL", "30m") + + cfg, err := Load(nil) + if err != nil { + t.Fatalf("Load: %v", err) + } + if cfg.HTTP.Port != 9090 { + t.Errorf("env port = %d, want 9090", cfg.HTTP.Port) + } + if cfg.Download.UpdateInterval != 30*time.Minute { + t.Errorf("env update interval = %v, want 30m", cfg.Download.UpdateInterval) + } +} + +func TestLoadFlagsOverrideEnv(t *testing.T) { + t.Setenv("PREDICTOR_PORT", "9090") + cfg, err := Load([]string{"-port", "7777"}) + if err != nil { + t.Fatalf("Load: %v", err) + } + if cfg.HTTP.Port != 7777 { + t.Errorf("flag should override env: port = %d, want 7777", cfg.HTTP.Port) + } +} + +func TestLoadFileOverridesDefaults(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "predictor.yml") + if err := os.WriteFile(path, []byte("http:\n port: 12345\n"), 0o644); err != nil { + t.Fatal(err) + } + + cfg, err := Load([]string{"-config", path}) + if err != nil { + t.Fatalf("Load: %v", err) + } + if cfg.HTTP.Port != 12345 { + t.Errorf("file port = %d, want 12345", cfg.HTTP.Port) + } +} + +func TestValidate(t *testing.T) { + cfg := Defaults() + cfg.Data.Dir = "" + if err := cfg.Validate(); err == nil { + t.Error("expected validation error for empty data dir") + } +} diff --git a/internal/dataset/dataset.go b/internal/dataset/dataset.go deleted file mode 100644 index 53367db..0000000 --- a/internal/dataset/dataset.go +++ /dev/null @@ -1,158 +0,0 @@ -package dataset - -import "fmt" - -// Dataset shape constants. -// Shape: (65, 47, 3, 361, 720) = (hour, pressure_level, variable, latitude, longitude) -// This matches the reference predictor exactly. -const ( - NumHours = 65 // 0, 3, 6, ..., 192 - NumLevels = 47 // pressure levels - NumVariables = 3 // height, wind_u, wind_v - NumLatitudes = 361 // -90.0 to +90.0 in 0.5 degree steps - NumLongitudes = 720 // 0.0 to 359.5 in 0.5 degree steps - - HourStep = 3 // hours between forecast time steps - MaxHour = 192 // maximum forecast hour - Resolution = 0.5 // grid resolution in degrees - LatStart = -90.0 // first latitude in the dataset - LonStart = 0.0 // first longitude in the dataset - - // Variable indices within the dataset. - VarHeight = 0 - VarWindU = 1 - VarWindV = 2 - - ElementSize = 4 // float32 = 4 bytes -) - -// DatasetSize is the total size of the dataset file in bytes. -// 65 * 47 * 3 * 361 * 720 * 4 = 9,528,667,200 -const DatasetSize int64 = int64(NumHours) * int64(NumLevels) * int64(NumVariables) * - int64(NumLatitudes) * int64(NumLongitudes) * int64(ElementSize) - -// LevelSet identifies which GRIB file set a pressure level belongs to. -type LevelSet int - -const ( - LevelSetA LevelSet = iota // pgrb2 (primary) - LevelSetB // pgrb2b (secondary) -) - -// Pressures contains the 47 pressure levels in hPa, sorted descending. -// Index 0 = 1000 hPa (near surface), Index 46 = 1 hPa (high atmosphere). -var Pressures = [NumLevels]int{ - 1000, 975, 950, 925, 900, 875, 850, 825, 800, 775, - 750, 725, 700, 675, 650, 625, 600, 575, 550, 525, - 500, 475, 450, 425, 400, 375, 350, 325, 300, 275, - 250, 225, 200, 175, 150, 125, 100, 70, 50, 30, - 20, 10, 7, 5, 3, 2, 1, -} - -// pressureIndex maps pressure in hPa to its index in the Pressures array. -var pressureIndex map[int]int - -// pressureLevelSet maps pressure in hPa to its GRIB file set. -var pressureLevelSet map[int]LevelSet - -func init() { - pressureIndex = make(map[int]int, NumLevels) - for i, p := range Pressures { - pressureIndex[p] = i - } - - pressureLevelSet = make(map[int]LevelSet, NumLevels) - for _, p := range PressuresPgrb2 { - pressureLevelSet[p] = LevelSetA - } - for _, p := range PressuresPgrb2b { - pressureLevelSet[p] = LevelSetB - } -} - -// PressuresPgrb2 contains levels found in the primary pgrb2 file (26 levels). -var PressuresPgrb2 = []int{ - 10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 350, 400, - 450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 925, - 950, 975, 1000, -} - -// PressuresPgrb2b contains levels found in the secondary pgrb2b file (21 levels). -var PressuresPgrb2b = []int{ - 1, 2, 3, 5, 7, 125, 175, 225, 275, 325, 375, 425, - 475, 525, 575, 625, 675, 725, 775, 825, 875, -} - -// PressureIndex returns the dataset index for a given pressure level in hPa. -// Returns -1 if the level is not found. -func PressureIndex(hPa int) int { - idx, ok := pressureIndex[hPa] - if !ok { - return -1 - } - return idx -} - -// PressureLevelSet returns which GRIB file set a pressure level belongs to. -func PressureLevelSet(hPa int) (LevelSet, bool) { - ls, ok := pressureLevelSet[hPa] - return ls, ok -} - -// HourIndex returns the dataset time index for a forecast hour. -// Returns -1 if the hour is invalid (not a multiple of HourStep or out of range). -func HourIndex(hour int) int { - if hour < 0 || hour > MaxHour || hour%HourStep != 0 { - return -1 - } - return hour / HourStep -} - -// Hours returns all forecast hours as a slice: [0, 3, 6, ..., 192]. -func Hours() []int { - out := make([]int, 0, NumHours) - for h := 0; h <= MaxHour; h += HourStep { - out = append(out, h) - } - return out -} - -// S3 URL configuration for NOAA GFS data. -const S3BaseURL = "https://noaa-gfs-bdp-pds.s3.amazonaws.com" - -// GribURL returns the S3 URL for a primary (pgrb2) GRIB file. -func GribURL(date string, runHour, forecastStep int) string { - return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.0p50.f%03d", - S3BaseURL, date, runHour, runHour, forecastStep) -} - -// GribURLB returns the S3 URL for a secondary (pgrb2b) GRIB file. -func GribURLB(date string, runHour, forecastStep int) string { - return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2b.0p50.f%03d", - S3BaseURL, date, runHour, runHour, forecastStep) -} - -// GribFileName returns the local filename for a primary GRIB file. -func GribFileName(runHour, forecastStep int) string { - return fmt.Sprintf("gfs.t%02dz.pgrb2.0p50.f%03d", runHour, forecastStep) -} - -// GribFileNameB returns the local filename for a secondary GRIB file. -func GribFileNameB(runHour, forecastStep int) string { - return fmt.Sprintf("gfs.t%02dz.pgrb2b.0p50.f%03d", runHour, forecastStep) -} - -// VariableIndex returns the dataset variable index for a GRIB parameter. -// Returns -1 if the parameter is not recognized. -func VariableIndex(parameterCategory, parameterNumber int) int { - switch { - case parameterCategory == 3 && parameterNumber == 5: - return VarHeight // Geopotential Height - case parameterCategory == 2 && parameterNumber == 2: - return VarWindU // U-component of wind - case parameterCategory == 2 && parameterNumber == 3: - return VarWindV // V-component of wind - default: - return -1 - } -} diff --git a/internal/dataset/dataset_test.go b/internal/dataset/dataset_test.go deleted file mode 100644 index 14b36ef..0000000 --- a/internal/dataset/dataset_test.go +++ /dev/null @@ -1,152 +0,0 @@ -package dataset - -import ( - "testing" -) - -func TestDatasetShape(t *testing.T) { - if NumHours != 65 { - t.Errorf("NumHours = %d, want 65", NumHours) - } - if NumLevels != 47 { - t.Errorf("NumLevels = %d, want 47", NumLevels) - } - if NumVariables != 3 { - t.Errorf("NumVariables = %d, want 3", NumVariables) - } - if NumLatitudes != 361 { - t.Errorf("NumLatitudes = %d, want 361", NumLatitudes) - } - if NumLongitudes != 720 { - t.Errorf("NumLongitudes = %d, want 720", NumLongitudes) - } -} - -func TestDatasetSize(t *testing.T) { - // 65 * 47 * 3 * 361 * 720 * 4 = 9,528,667,200 - want := int64(9_528_667_200) - if DatasetSize != want { - t.Errorf("DatasetSize = %d, want %d", DatasetSize, want) - } -} - -func TestPressureLevels(t *testing.T) { - if len(Pressures) != 47 { - t.Fatalf("len(Pressures) = %d, want 47", len(Pressures)) - } - - // First should be 1000 (highest pressure, near surface) - if Pressures[0] != 1000 { - t.Errorf("Pressures[0] = %d, want 1000", Pressures[0]) - } - // Last should be 1 (lowest pressure, high atmosphere) - if Pressures[46] != 1 { - t.Errorf("Pressures[46] = %d, want 1", Pressures[46]) - } - - // Should be sorted descending - for i := 1; i < len(Pressures); i++ { - if Pressures[i] >= Pressures[i-1] { - t.Errorf("Pressures not descending at [%d]: %d >= %d", i, Pressures[i], Pressures[i-1]) - } - } - - // Total levels: 26 from pgrb2 + 21 from pgrb2b = 47 - if len(PressuresPgrb2) != 26 { - t.Errorf("len(PressuresPgrb2) = %d, want 26", len(PressuresPgrb2)) - } - if len(PressuresPgrb2b) != 21 { - t.Errorf("len(PressuresPgrb2b) = %d, want 21", len(PressuresPgrb2b)) - } -} - -func TestPressureIndex(t *testing.T) { - if PressureIndex(1000) != 0 { - t.Errorf("PressureIndex(1000) = %d, want 0", PressureIndex(1000)) - } - if PressureIndex(1) != 46 { - t.Errorf("PressureIndex(1) = %d, want 46", PressureIndex(1)) - } - if PressureIndex(500) != 20 { - t.Errorf("PressureIndex(500) = %d, want 20", PressureIndex(500)) - } - if PressureIndex(9999) != -1 { - t.Errorf("PressureIndex(9999) = %d, want -1", PressureIndex(9999)) - } -} - -func TestPressureLevelSet(t *testing.T) { - // 1000 mb should be in pgrb2 (A) - ls, ok := PressureLevelSet(1000) - if !ok || ls != LevelSetA { - t.Errorf("PressureLevelSet(1000) = %v, %v; want A, true", ls, ok) - } - - // 125 mb should be in pgrb2b (B) - ls, ok = PressureLevelSet(125) - if !ok || ls != LevelSetB { - t.Errorf("PressureLevelSet(125) = %v, %v; want B, true", ls, ok) - } - - // 1, 2, 3, 5, 7 should be in pgrb2b (B) - for _, p := range []int{1, 2, 3, 5, 7} { - ls, ok := PressureLevelSet(p) - if !ok || ls != LevelSetB { - t.Errorf("PressureLevelSet(%d) = %v, %v; want B, true", p, ls, ok) - } - } - - // Every pressure level should have a level set assignment - for _, p := range Pressures { - _, ok := PressureLevelSet(p) - if !ok { - t.Errorf("PressureLevelSet(%d) not found", p) - } - } -} - -func TestHourIndex(t *testing.T) { - if HourIndex(0) != 0 { - t.Errorf("HourIndex(0) = %d, want 0", HourIndex(0)) - } - if HourIndex(3) != 1 { - t.Errorf("HourIndex(3) = %d, want 1", HourIndex(3)) - } - if HourIndex(192) != 64 { - t.Errorf("HourIndex(192) = %d, want 64", HourIndex(192)) - } - if HourIndex(1) != -1 { - t.Errorf("HourIndex(1) = %d, want -1 (not multiple of 3)", HourIndex(1)) - } - if HourIndex(195) != -1 { - t.Errorf("HourIndex(195) = %d, want -1 (out of range)", HourIndex(195)) - } -} - -func TestHours(t *testing.T) { - hours := Hours() - if len(hours) != NumHours { - t.Fatalf("len(Hours()) = %d, want %d", len(hours), NumHours) - } - if hours[0] != 0 { - t.Errorf("Hours()[0] = %d, want 0", hours[0]) - } - if hours[len(hours)-1] != MaxHour { - t.Errorf("Hours()[last] = %d, want %d", hours[len(hours)-1], MaxHour) - } -} - -func TestVariableIndex(t *testing.T) { - if VariableIndex(3, 5) != VarHeight { - t.Errorf("HGT: got %d, want %d", VariableIndex(3, 5), VarHeight) - } - if VariableIndex(2, 2) != VarWindU { - t.Errorf("UGRD: got %d, want %d", VariableIndex(2, 2), VarWindU) - } - if VariableIndex(2, 3) != VarWindV { - t.Errorf("VGRD: got %d, want %d", VariableIndex(2, 3), VarWindV) - } - if VariableIndex(0, 0) != -1 { - t.Errorf("unknown: got %d, want -1", VariableIndex(0, 0)) - } -} diff --git a/internal/dataset/file.go b/internal/dataset/file.go deleted file mode 100644 index 96f14c2..0000000 --- a/internal/dataset/file.go +++ /dev/null @@ -1,140 +0,0 @@ -package dataset - -import ( - "encoding/binary" - "fmt" - "math" - "os" - "time" - - mmap "github.com/edsrzf/mmap-go" -) - -// File represents an mmap-backed wind dataset file. -type File struct { - mm mmap.MMap - file *os.File - writable bool - DSTime time.Time // forecast run time (UTC) -} - -// Open opens an existing dataset file for reading. -func Open(path string, dsTime time.Time) (*File, error) { - f, err := os.Open(path) - if err != nil { - return nil, fmt.Errorf("open dataset: %w", err) - } - - info, err := f.Stat() - if err != nil { - f.Close() - return nil, fmt.Errorf("stat dataset: %w", err) - } - if info.Size() != DatasetSize { - f.Close() - return nil, fmt.Errorf("dataset should be %d bytes (was %d)", DatasetSize, info.Size()) - } - - mm, err := mmap.Map(f, mmap.RDONLY, 0) - if err != nil { - f.Close() - return nil, fmt.Errorf("mmap dataset: %w", err) - } - - return &File{mm: mm, file: f, writable: false, DSTime: dsTime}, nil -} - -// Create creates a new dataset file for writing. -// The file is truncated to the correct size and mmap'd read-write. -func Create(path string) (*File, error) { - f, err := os.Create(path) - if err != nil { - return nil, fmt.Errorf("create dataset: %w", err) - } - - if err := f.Truncate(DatasetSize); err != nil { - f.Close() - return nil, fmt.Errorf("truncate dataset: %w", err) - } - - mm, err := mmap.MapRegion(f, int(DatasetSize), mmap.RDWR, 0, 0) - if err != nil { - f.Close() - return nil, fmt.Errorf("mmap dataset: %w", err) - } - - return &File{mm: mm, file: f, writable: true}, nil -} - -// offset computes the byte offset for element [hour][level][variable][lat][lon]. -// Row-major C-order indexing matching the reference implementation: -// shape = (65, 47, 3, 361, 720) -func offset(hour, level, variable, lat, lon int) int64 { - idx := int64(hour) - idx = idx*int64(NumLevels) + int64(level) - idx = idx*int64(NumVariables) + int64(variable) - idx = idx*int64(NumLatitudes) + int64(lat) - idx = idx*int64(NumLongitudes) + int64(lon) - return idx * int64(ElementSize) -} - -// Val reads a float32 value from the dataset at [hour][level][variable][lat][lon]. -func (d *File) Val(hour, level, variable, lat, lon int) float32 { - off := offset(hour, level, variable, lat, lon) - bits := binary.LittleEndian.Uint32(d.mm[off : off+4]) - return math.Float32frombits(bits) -} - -// SetVal writes a float32 value to the dataset at [hour][level][variable][lat][lon]. -// Only valid on writable (created) datasets. -func (d *File) SetVal(hour, level, variable, lat, lon int, val float32) { - off := offset(hour, level, variable, lat, lon) - binary.LittleEndian.PutUint32(d.mm[off:off+4], math.Float32bits(val)) -} - -// BlitGribData copies decoded GRIB grid data into the dataset at the given position. -// gribData is 361*720 float64 values in GRIB scan order (north-to-south, west-to-east). -// This function flips the latitude so that dataset index 0 = -90 (south) and 360 = +90 (north). -func (d *File) BlitGribData(hourIdx, levelIdx, varIdx int, gribData []float64) error { - expected := NumLatitudes * NumLongitudes - if len(gribData) != expected { - return fmt.Errorf("grib data has %d values, expected %d", len(gribData), expected) - } - - for lat := 0; lat < NumLatitudes; lat++ { - for lon := 0; lon < NumLongitudes; lon++ { - // GRIB scans north-to-south: row 0 = 90N, row 360 = 90S - // Dataset stores south-to-north: index 0 = -90 (90S), index 360 = +90 (90N) - gribIdx := (360-lat)*NumLongitudes + lon - val := float32(gribData[gribIdx]) - d.SetVal(hourIdx, levelIdx, varIdx, lat, lon, val) - } - } - - return nil -} - -// Flush flushes the mmap to disk. -func (d *File) Flush() error { - if d.mm != nil { - return d.mm.Flush() - } - return nil -} - -// Close unmaps and closes the dataset file. -func (d *File) Close() error { - if d.mm != nil { - if err := d.mm.Unmap(); err != nil { - d.file.Close() - return fmt.Errorf("unmap: %w", err) - } - d.mm = nil - } - if d.file != nil { - err := d.file.Close() - d.file = nil - return err - } - return nil -} diff --git a/internal/datasets/doc.go b/internal/datasets/doc.go new file mode 100644 index 0000000..645fc90 --- /dev/null +++ b/internal/datasets/doc.go @@ -0,0 +1,11 @@ +// Package datasets manages the lifecycle of atmospheric datasets. It exposes: +// +// - A Source interface for pluggable dataset origins (GFS now, ECMWF later). +// - A Storage interface for transactional, resumable on-disk persistence. +// - A Manager that coordinates downloads, tracks job state, and owns the +// currently-active weather.WindField. +// +// The package is the only one in the service that knows about download +// scheduling, manifests, or bandwidth throttling — engine and API layers +// only see WindField + Manager-as-admin. +package datasets diff --git a/internal/datasets/gfs/idx.go b/internal/datasets/gfs/idx.go new file mode 100644 index 0000000..093e8b4 --- /dev/null +++ b/internal/datasets/gfs/idx.go @@ -0,0 +1,125 @@ +package gfs + +import ( + "fmt" + "strconv" + "strings" +) + +// IdxEntry is one parsed line from a NOAA GRIB .idx file. +// +// Example line: "15:1207405:d=2024010100:HGT:1000 mb:0 hour fcst:" +type IdxEntry struct { + Index int + Offset int64 + Variable string + LevelMB int // 0 when the level is not isobaric + Hour int // forecast hour; 0 for analysis ("anl"); -1 if unparseable + EndOffset int64 // computed from the next entry's Offset; -1 for the final entry +} + +// Length returns the byte length of this GRIB message, or -1 if unknown +// (the final entry in an idx file). +func (e *IdxEntry) Length() int64 { + if e.EndOffset <= 0 { + return -1 + } + return e.EndOffset - e.Offset +} + +// ParseIdx parses a .idx file body. Unparseable lines are silently skipped. +func ParseIdx(body []byte) []IdxEntry { + lines := strings.Split(string(body), "\n") + var entries []IdxEntry + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + parts := strings.Split(line, ":") + if len(parts) < 7 { + continue + } + idx, err := strconv.Atoi(parts[0]) + if err != nil { + continue + } + off, err := strconv.ParseInt(parts[1], 10, 64) + if err != nil { + continue + } + entries = append(entries, IdxEntry{ + Index: idx, + Offset: off, + Variable: parts[3], + LevelMB: parseLevelMB(parts[4]), + Hour: parseHour(parts[5]), + EndOffset: -1, + }) + } + for i := 0; i < len(entries)-1; i++ { + entries[i].EndOffset = entries[i+1].Offset + } + return entries +} + +// FilterIdx returns entries matching one of the wanted variables at a known +// pressure level with a computable byte length. +func FilterIdx(entries []IdxEntry, wanted map[string]bool) []IdxEntry { + var out []IdxEntry + for _, e := range entries { + if !wanted[e.Variable] || e.LevelMB <= 0 || e.Length() <= 0 { + continue + } + out = append(out, e) + } + return out +} + +func parseLevelMB(s string) int { + s = strings.TrimSpace(s) + if !strings.HasSuffix(s, " mb") { + return 0 + } + n, err := strconv.Atoi(strings.TrimSuffix(s, " mb")) + if err != nil { + return 0 + } + return n +} + +func parseHour(s string) int { + s = strings.TrimSpace(s) + if s == "anl" { + return 0 + } + n, err := strconv.Atoi(strings.TrimSuffix(s, " hour fcst")) + if err != nil { + return -1 + } + return n +} + +// ByteRange is one HTTP range download corresponding to one GRIB message. +type ByteRange struct { + Start int64 + End int64 // inclusive + Entry IdxEntry +} + +// EntriesToRanges converts idx entries to inclusive HTTP byte ranges. +func EntriesToRanges(entries []IdxEntry) []ByteRange { + out := make([]ByteRange, 0, len(entries)) + for _, e := range entries { + if e.Length() <= 0 { + continue + } + out = append(out, ByteRange{Start: e.Offset, End: e.EndOffset - 1, Entry: e}) + } + return out +} + +// FormatRange returns an HTTP Range header value for the byte range. +func (r ByteRange) FormatRange() string { + return fmt.Sprintf("bytes=%d-%d", r.Start, r.End) +} diff --git a/internal/datasets/gfs/idx_test.go b/internal/datasets/gfs/idx_test.go new file mode 100644 index 0000000..ab04710 --- /dev/null +++ b/internal/datasets/gfs/idx_test.go @@ -0,0 +1,70 @@ +package gfs + +import "testing" + +const sampleIdx = `1:0:d=2024010100:HGT:1000 mb:0 hour fcst: +2:289012:d=2024010100:HGT:975 mb:0 hour fcst: +3:541876:d=2024010100:TMP:1000 mb:0 hour fcst: +4:789012:d=2024010100:UGRD:1000 mb:0 hour fcst: +5:1045678:d=2024010100:VGRD:1000 mb:0 hour fcst: +6:1298765:d=2024010100:UGRD:975 mb:0 hour fcst: +7:1567890:d=2024010100:UGRD:2 m above ground:0 hour fcst: +8:1812345:d=2024010100:VGRD:975 mb:0 hour fcst: +9:2098765:d=2024010100:HGT:500 mb:3 hour fcst: +` + +func TestParseIdx(t *testing.T) { + entries := ParseIdx([]byte(sampleIdx)) + if len(entries) != 9 { + t.Fatalf("expected 9 entries, got %d", len(entries)) + } + if e := entries[0]; e.Index != 1 || e.Offset != 0 || e.Variable != "HGT" || e.LevelMB != 1000 || e.Hour != 0 || e.EndOffset != 289012 { + t.Errorf("entry 0: %+v", e) + } + if e := entries[6]; e.LevelMB != 0 { + t.Errorf("non-pressure level should have LevelMB=0, got %d", e.LevelMB) + } + if e := entries[len(entries)-1]; e.EndOffset != -1 { + t.Errorf("last entry EndOffset: got %d, want -1", e.EndOffset) + } +} + +func TestFilterIdx(t *testing.T) { + entries := ParseIdx([]byte(sampleIdx)) + want := map[string]bool{"HGT": true, "UGRD": true, "VGRD": true} + filtered := FilterIdx(entries, want) + // HGT@1000, HGT@975, UGRD@1000, VGRD@1000, UGRD@975, VGRD@975 = 6 + // HGT@500 at 3hr is last entry (no EndOffset), so dropped. + if len(filtered) != 6 { + t.Errorf("expected 6, got %d", len(filtered)) + } +} + +func TestParseLevelMB(t *testing.T) { + cases := []struct { + in string + want int + }{ + {"1000 mb", 1000}, {"975 mb", 975}, {"1 mb", 1}, + {"2 m above ground", 0}, {"surface", 0}, {"tropopause", 0}, + } + for _, c := range cases { + if got := parseLevelMB(c.in); got != c.want { + t.Errorf("parseLevelMB(%q) = %d, want %d", c.in, got, c.want) + } + } +} + +func TestParseHour(t *testing.T) { + cases := []struct { + in string + want int + }{ + {"0 hour fcst", 0}, {"3 hour fcst", 3}, {"192 hour fcst", 192}, {"anl", 0}, + } + for _, c := range cases { + if got := parseHour(c.in); got != c.want { + t.Errorf("parseHour(%q) = %d, want %d", c.in, got, c.want) + } + } +} diff --git a/internal/datasets/gfs/source.go b/internal/datasets/gfs/source.go new file mode 100644 index 0000000..081525e --- /dev/null +++ b/internal/datasets/gfs/source.go @@ -0,0 +1,430 @@ +// Package gfs implements datasets.Source for NOAA GFS 0.5-degree forecasts. +package gfs + +import ( + "context" + "errors" + "fmt" + "io" + "math" + "net/http" + "os" + "sync" + "time" + + "github.com/nilsmagnus/grib/griblib" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" + + "predictor-refactored/internal/datasets" + "predictor-refactored/internal/weather" + wgfs "predictor-refactored/internal/weather/gfs" +) + +// Source is the GFS implementation of datasets.Source. +type Source struct { + Parallel int // max concurrent step downloads + Client *http.Client // optional; defaults to a 2-minute-timeout client + Log *zap.Logger +} + +// NewSource returns a default Source. +func NewSource(log *zap.Logger) *Source { + return &Source{ + Parallel: 8, + Client: &http.Client{Timeout: 2 * time.Minute}, + Log: log, + } +} + +// ID returns the source identifier. +func (s *Source) ID() string { return "noaa-gfs-0p50" } + +func (s *Source) log() *zap.Logger { + if s.Log == nil { + return zap.NewNop() + } + return s.Log +} + +func (s *Source) client() *http.Client { + if s.Client == nil { + return &http.Client{Timeout: 2 * time.Minute} + } + return s.Client +} + +func (s *Source) parallel() int { + if s.Parallel <= 0 { + return 8 + } + return s.Parallel +} + +// LatestEpoch returns the most recent run NOAA has finished publishing, +// determined by HEAD-ing the .idx for the final forecast hour. Walks back +// up to 8 runs (48 hours) before giving up. +func (s *Source) LatestEpoch(ctx context.Context) (time.Time, error) { + now := time.Now().UTC() + hour := now.Hour() - (now.Hour() % 6) + current := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC) + + for range 8 { + date := current.Format("20060102") + url := wgfs.GribURL(date, current.Hour(), wgfs.MaxHour) + ".idx" + + req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) + if err == nil { + resp, err := s.client().Do(req) + if err == nil { + resp.Body.Close() + if resp.StatusCode == http.StatusOK { + s.log().Info("latest GFS run discovered", + zap.Time("run", current), + zap.String("verified_url", url)) + return current, nil + } + } + } + current = current.Add(-6 * time.Hour) + } + return time.Time{}, fmt.Errorf("no recent GFS run found (checked 8 runs)") +} + +// Open loads a stored dataset as a WindField. +func (s *Source) Open(_ context.Context, epoch time.Time, store datasets.Storage) (weather.WindField, error) { + if !store.Exists(epoch) { + return nil, fmt.Errorf("epoch %s not found", epoch.Format(time.RFC3339)) + } + file, err := wgfs.Open(store.Path(epoch), epoch.UTC()) + if err != nil { + return nil, err + } + return wgfs.NewWind(file), nil +} + +// neededVariables is the GRIB variable set we extract. +var neededVariables = map[string]bool{"HGT": true, "UGRD": true, "VGRD": true} + +// Download fetches the full dataset for epoch in parallel, resuming any +// previously-completed work units. Honours ctx cancellation and prog +// (which may be nil). +func (s *Source) Download(ctx context.Context, epoch time.Time, store datasets.Storage, prog datasets.ProgressSink, throttle datasets.Throttle) error { + if prog == nil { + prog = noopSink{} + } + + handle, err := store.BeginWrite(epoch) + if err != nil { + return fmt.Errorf("begin write: %w", err) + } + manifest := handle.Manifest() + + // Open or create the temp file. If a previous attempt left a partial + // file of the right size, reuse it (resume); otherwise Create. + file, err := openOrCreateCube(handle.Path()) + if err != nil { + _ = handle.Abort() + return err + } + + date := epoch.UTC().Format("20060102") + runHour := epoch.UTC().Hour() + steps := wgfs.Hours() + totalUnits := len(steps) * 2 + + prog.SetTotal(totalUnits) + // Pre-count already-done units so progress is accurate on resume. + for _, u := range manifest.Units() { + _ = u + prog.StepComplete() + } + + start := time.Now() + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(s.parallel()) + + // fileMu serialises concurrent BlitGribData calls because the underlying + // mmap is shared and SetVal isn't atomic. + var fileMu sync.Mutex + + for _, step := range steps { + hourIdx := wgfs.HourIndex(step) + if hourIdx < 0 { + continue + } + + for _, ls := range []wgfs.LevelSet{wgfs.LevelSetA, wgfs.LevelSetB} { + unit := unitKey(step, ls) + if manifest.Has(unit) { + continue + } + + g.Go(func() error { + var url string + switch ls { + case wgfs.LevelSetA: + url = wgfs.GribURL(date, runHour, step) + case wgfs.LevelSetB: + url = wgfs.GribURLB(date, runHour, step) + } + if err := s.downloadAndBlit(ctx, file, &fileMu, url, hourIdx, ls, prog, throttle); err != nil { + return fmt.Errorf("step %d %s: %w", step, levelSetLabel(ls), err) + } + if err := manifest.Mark(unit); err != nil { + return fmt.Errorf("mark unit: %w", err) + } + prog.StepComplete() + return nil + }) + } + } + + if err := g.Wait(); err != nil { + _ = file.Close() + // Don't Abort on context cancellation — preserve progress for resume. + if errors.Is(err, context.Canceled) { + return err + } + // Other errors: abort if no progress was made; otherwise leave for resume. + if len(manifest.Units()) == 0 { + _ = handle.Abort() + } + return err + } + + if err := file.Flush(); err != nil { + _ = file.Close() + return fmt.Errorf("flush: %w", err) + } + if err := file.Close(); err != nil { + return fmt.Errorf("close: %w", err) + } + if err := handle.Commit(); err != nil { + return fmt.Errorf("commit: %w", err) + } + + s.log().Info("download complete", + zap.Time("epoch", epoch), + zap.Duration("elapsed", time.Since(start))) + return nil +} + +// openOrCreateCube returns a writable cube file at path, creating it if the +// file does not exist or has the wrong size. +func openOrCreateCube(path string) (*wgfs.File, error) { + info, err := os.Stat(path) + if err == nil && info.Size() == wgfs.DatasetSize { + return wgfs.OpenWritable(path) + } + if err != nil && !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("stat cube: %w", err) + } + // Wrong-size or missing — truncate-create. + return wgfs.Create(path) +} + +// downloadAndBlit fetches and decodes one (URL, level-set) chunk and writes +// it into the dataset. +func (s *Source) downloadAndBlit( + ctx context.Context, + file *wgfs.File, + fileMu *sync.Mutex, + baseURL string, + hourIdx int, + ls wgfs.LevelSet, + prog datasets.ProgressSink, + throttle datasets.Throttle, +) error { + idxBody, err := s.httpGet(ctx, baseURL+".idx", throttle, prog) + if err != nil { + return fmt.Errorf("idx: %w", err) + } + entries := ParseIdx(idxBody) + filtered := FilterIdx(entries, neededVariables) + + var relevant []IdxEntry + for _, e := range filtered { + set, ok := wgfs.PressureLevelSet(e.LevelMB) + if ok && set == ls { + relevant = append(relevant, e) + } + } + if len(relevant) == 0 { + return nil + } + + ranges := EntriesToRanges(relevant) + tmp, err := os.CreateTemp("", "gfs-msg-*.tmp") + if err != nil { + return fmt.Errorf("temp: %w", err) + } + tmpPath := tmp.Name() + defer os.Remove(tmpPath) + + for _, r := range ranges { + body, err := s.httpGetRange(ctx, baseURL, r.Start, r.End, throttle, prog) + if err != nil { + tmp.Close() + return fmt.Errorf("range %d-%d: %w", r.Start, r.End, err) + } + if _, err := tmp.Write(body); err != nil { + tmp.Close() + return fmt.Errorf("write tmp: %w", err) + } + } + if err := tmp.Close(); err != nil { + return err + } + + f, err := os.Open(tmpPath) + if err != nil { + return err + } + messages, err := griblib.ReadMessages(f) + f.Close() + if err != nil { + return fmt.Errorf("read grib: %w", err) + } + + for _, msg := range messages { + if msg.Section4.ProductDefinitionTemplateNumber != 0 { + continue + } + p := msg.Section4.ProductDefinitionTemplate + varIdx := wgfs.VariableIndex(int(p.ParameterCategory), int(p.ParameterNumber)) + if varIdx < 0 { + continue + } + if p.FirstSurface.Type != 100 { // isobaric only + continue + } + pressureMB := int(math.Round(float64(p.FirstSurface.Value) / 100.0)) + levelIdx := wgfs.PressureIndex(pressureMB) + if levelIdx < 0 { + continue + } + data := msg.Data() + fileMu.Lock() + err := file.BlitGribData(hourIdx, levelIdx, varIdx, data) + fileMu.Unlock() + if err != nil { + return fmt.Errorf("blit: %w", err) + } + } + return nil +} + +// httpGet downloads a URL body with 3 retries and optional throttling. +func (s *Source) httpGet(ctx context.Context, url string, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) { + var lastErr error + for attempt := range 3 { + if attempt > 0 { + select { + case <-time.After(time.Duration(attempt*2) * time.Second): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + resp, err := s.client().Do(req) + if err != nil { + lastErr = err + continue + } + body, err := readThrottled(ctx, resp.Body, throttle, prog) + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url) + continue + } + if err != nil { + lastErr = err + continue + } + return body, nil + } + return nil, fmt.Errorf("after 3 attempts: %w", lastErr) +} + +// httpGetRange downloads an inclusive byte range with 3 retries and throttling. +func (s *Source) httpGetRange(ctx context.Context, url string, start, end int64, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) { + var lastErr error + for attempt := range 3 { + if attempt > 0 { + select { + case <-time.After(time.Duration(attempt*2) * time.Second): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) + resp, err := s.client().Do(req) + if err != nil { + lastErr = err + continue + } + body, err := readThrottled(ctx, resp.Body, throttle, prog) + resp.Body.Close() + if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK { + lastErr = fmt.Errorf("HTTP %d for range %d-%d of %s", resp.StatusCode, start, end, url) + continue + } + if err != nil { + lastErr = err + continue + } + return body, nil + } + return nil, fmt.Errorf("after 3 attempts: %w", lastErr) +} + +// readThrottled reads r into memory, consulting throttle (if non-nil) before +// each chunk and reporting bytes to prog. +func readThrottled(ctx context.Context, r io.Reader, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) { + buf := make([]byte, 0, 64*1024) + chunk := make([]byte, 32*1024) + for { + if throttle != nil { + if err := throttle.Wait(ctx, len(chunk)); err != nil { + return nil, err + } + } + n, err := r.Read(chunk) + if n > 0 { + buf = append(buf, chunk[:n]...) + prog.Bytes(int64(n)) + } + if errors.Is(err, io.EOF) { + return buf, nil + } + if err != nil { + return nil, err + } + } +} + +func unitKey(step int, ls wgfs.LevelSet) string { + return fmt.Sprintf("step%03d-%s", step, levelSetLabel(ls)) +} + +func levelSetLabel(ls wgfs.LevelSet) string { + if ls == wgfs.LevelSetB { + return "B" + } + return "A" +} + +// noopSink discards progress events. +type noopSink struct{} + +func (noopSink) SetTotal(int) {} +func (noopSink) StepComplete() {} +func (noopSink) Bytes(int64) {} diff --git a/internal/datasets/manager.go b/internal/datasets/manager.go new file mode 100644 index 0000000..a7584c1 --- /dev/null +++ b/internal/datasets/manager.go @@ -0,0 +1,383 @@ +package datasets + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" + + "predictor-refactored/internal/weather" +) + +// JobStatus is the lifecycle state of a download job. +type JobStatus string + +const ( + JobPending JobStatus = "pending" + JobRunning JobStatus = "running" + JobComplete JobStatus = "complete" + JobFailed JobStatus = "failed" + JobCancelled JobStatus = "cancelled" +) + +// JobInfo is the externally-visible snapshot of a download job. +type JobInfo struct { + ID string + Source string + Epoch time.Time + Status JobStatus + StartedAt time.Time + EndedAt *time.Time + Err string + Total int + Done int + Bytes int64 +} + +// jobEntry is the Manager's mutable record for one job. +type jobEntry struct { + id string + source string + epoch time.Time + startedAt time.Time + cancel context.CancelFunc + + mu sync.Mutex + status JobStatus + endedAt time.Time + errStr string + + total atomic.Int64 + done atomic.Int64 + bytes atomic.Int64 +} + +func (e *jobEntry) snapshot() JobInfo { + e.mu.Lock() + info := JobInfo{ + ID: e.id, Source: e.source, Epoch: e.epoch, + StartedAt: e.startedAt, Status: e.status, Err: e.errStr, + } + if !e.endedAt.IsZero() { + ts := e.endedAt + info.EndedAt = &ts + } + e.mu.Unlock() + info.Total = int(e.total.Load()) + info.Done = int(e.done.Load()) + info.Bytes = e.bytes.Load() + return info +} + +// jobProgress is the ProgressSink wired into a jobEntry. +type jobProgress struct{ e *jobEntry } + +func (p jobProgress) SetTotal(n int) { p.e.total.Store(int64(n)) } +func (p jobProgress) StepComplete() { p.e.done.Add(1) } +func (p jobProgress) Bytes(n int64) { p.e.bytes.Add(n) } + +// Manager coordinates dataset downloads and exposes the active WindField. +type Manager struct { + src Source + store Storage + throttle Throttle + log *zap.Logger + + activeMu sync.RWMutex + active weather.WindField + + jobsMu sync.RWMutex + jobs map[string]*jobEntry + + // inFlight maps an epoch's RFC3339 representation to its jobID, enforcing + // single-flight per epoch. + inFlight sync.Map +} + +// New returns a Manager wiring source, store, and an optional throttle. +// A nil log uses zap.NewNop(). +func New(src Source, store Storage, throttle Throttle, log *zap.Logger) *Manager { + if log == nil { + log = zap.NewNop() + } + if src.ID() != store.SourceID() { + log.Warn("source/store ID mismatch", + zap.String("src", src.ID()), + zap.String("store", store.SourceID())) + } + return &Manager{ + src: src, store: store, throttle: throttle, log: log, + jobs: make(map[string]*jobEntry), + } +} + +// Source returns the underlying source ID. +func (m *Manager) Source() string { return m.src.ID() } + +// Active returns the currently-loaded WindField, or nil. +func (m *Manager) Active() weather.WindField { + m.activeMu.RLock() + defer m.activeMu.RUnlock() + return m.active +} + +// Ready reports whether a dataset is currently loaded. +func (m *Manager) Ready() bool { return m.Active() != nil } + +// ListEpochs returns all stored dataset epochs, newest first. +func (m *Manager) ListEpochs() ([]time.Time, error) { return m.store.List() } + +// ListJobs returns snapshots of every job recorded since startup. +func (m *Manager) ListJobs() []JobInfo { + m.jobsMu.RLock() + defer m.jobsMu.RUnlock() + out := make([]JobInfo, 0, len(m.jobs)) + for _, e := range m.jobs { + out = append(out, e.snapshot()) + } + return out +} + +// GetJob returns the snapshot for a job, or false if id is unknown. +func (m *Manager) GetJob(id string) (JobInfo, bool) { + m.jobsMu.RLock() + e, ok := m.jobs[id] + m.jobsMu.RUnlock() + if !ok { + return JobInfo{}, false + } + return e.snapshot(), true +} + +// CancelJob cancels a running job. Returns false if id is unknown or the +// job is already terminal. +func (m *Manager) CancelJob(id string) bool { + m.jobsMu.RLock() + e, ok := m.jobs[id] + m.jobsMu.RUnlock() + if !ok { + return false + } + e.mu.Lock() + terminal := e.status == JobComplete || e.status == JobFailed || e.status == JobCancelled + e.mu.Unlock() + if terminal { + return false + } + e.cancel() + return true +} + +// RemoveEpoch deletes a stored dataset. If epoch is currently active, the +// active field is cleared. +func (m *Manager) RemoveEpoch(epoch time.Time) error { + epoch = epoch.UTC() + if active := m.Active(); active != nil && active.Epoch().Equal(epoch) { + m.activeMu.Lock() + m.active = nil + m.activeMu.Unlock() + } + return m.store.Remove(epoch) +} + +// Download starts (or resumes) a download job for epoch in the background. +// Returns the JobID. If a job for the same epoch is already running, its +// existing JobID is returned. +// +// If the dataset is already present on disk, a synthetic completed JobInfo +// is recorded and its JobID returned. +func (m *Manager) Download(epoch time.Time) string { + epoch = epoch.UTC() + key := epoch.Format(time.RFC3339) + + if existing, ok := m.inFlight.Load(key); ok { + return existing.(string) + } + + jobID := uuid.New().String() + if other, loaded := m.inFlight.LoadOrStore(key, jobID); loaded { + return other.(string) + } + + ctx, cancel := context.WithCancel(context.Background()) + now := time.Now().UTC() + e := &jobEntry{ + id: jobID, + source: m.src.ID(), + epoch: epoch, + startedAt: now, + status: JobPending, + cancel: cancel, + } + m.jobsMu.Lock() + m.jobs[jobID] = e + m.jobsMu.Unlock() + + if m.store.Exists(epoch) { + // Skip the download but still record the job for traceability. + go m.completeShortCircuit(ctx, e) + return jobID + } + go m.runDownload(ctx, e) + return jobID +} + +// LoadEpoch swaps the active WindField to epoch's stored dataset. +func (m *Manager) LoadEpoch(ctx context.Context, epoch time.Time) error { + epoch = epoch.UTC() + if !m.store.Exists(epoch) { + return fmt.Errorf("epoch %s not present on disk", epoch.Format(time.RFC3339)) + } + field, err := m.src.Open(ctx, epoch, m.store) + if err != nil { + return fmt.Errorf("open epoch: %w", err) + } + m.swapActive(field) + m.log.Info("loaded dataset", + zap.Time("epoch", epoch), + zap.String("source", m.src.ID())) + return nil +} + +// Refresh ensures the most recent upstream dataset is downloaded and active. +// +// If the freshest stored dataset is newer than retentionTTL old, no upstream +// check is performed. Otherwise the source's LatestEpoch is consulted; if it +// is newer than the active dataset, a download is started and on completion +// the new dataset becomes active. +// +// Returns the JobID started, or empty string when nothing was scheduled. +func (m *Manager) Refresh(ctx context.Context, freshnessTTL time.Duration) (string, error) { + if active := m.Active(); active != nil && time.Since(active.Epoch()) < freshnessTTL { + return "", nil + } + + // Try loading the freshest existing dataset before going to the network. + if epochs, err := m.store.List(); err == nil { + for _, e := range epochs { + if time.Since(e) > freshnessTTL { + continue + } + if active := m.Active(); active != nil && active.Epoch().Equal(e) { + return "", nil + } + if err := m.LoadEpoch(ctx, e); err == nil { + return "", nil + } + } + } + + latest, err := m.src.LatestEpoch(ctx) + if err != nil { + return "", fmt.Errorf("latest epoch: %w", err) + } + if active := m.Active(); active != nil && !latest.After(active.Epoch()) { + return "", nil + } + + jobID := m.Download(latest) + + // Spawn a watcher that loads the dataset on successful completion. + go func() { + for { + info, ok := m.GetJob(jobID) + if !ok { + return + } + switch info.Status { + case JobComplete: + if err := m.LoadEpoch(context.Background(), latest); err != nil { + m.log.Error("load after download", zap.Error(err)) + } + return + case JobFailed, JobCancelled: + return + } + time.Sleep(2 * time.Second) + } + }() + return jobID, nil +} + +// runDownload executes one Source.Download invocation and records its outcome. +func (m *Manager) runDownload(ctx context.Context, e *jobEntry) { + defer m.inFlight.Delete(e.epoch.Format(time.RFC3339)) + + e.mu.Lock() + e.status = JobRunning + e.mu.Unlock() + + m.log.Info("download started", + zap.String("job", e.id), + zap.Time("epoch", e.epoch)) + + err := m.src.Download(ctx, e.epoch, m.store, jobProgress{e: e}, m.throttle) + now := time.Now().UTC() + + e.mu.Lock() + e.endedAt = now + switch { + case errors.Is(err, context.Canceled): + e.status = JobCancelled + case err != nil: + e.status = JobFailed + e.errStr = err.Error() + default: + e.status = JobComplete + } + finalStatus := e.status + e.mu.Unlock() + + m.log.Info("download finished", + zap.String("job", e.id), + zap.String("status", string(finalStatus)), + zap.NamedError("err", err)) +} + +// completeShortCircuit records a job as complete without performing any work. +func (m *Manager) completeShortCircuit(ctx context.Context, e *jobEntry) { + _ = ctx + defer m.inFlight.Delete(e.epoch.Format(time.RFC3339)) + now := time.Now().UTC() + e.mu.Lock() + e.status = JobComplete + e.endedAt = now + e.mu.Unlock() +} + +// swapActive replaces the active field and closes the previous one if it +// implements io.Closer. +func (m *Manager) swapActive(f weather.WindField) { + m.activeMu.Lock() + old := m.active + m.active = f + m.activeMu.Unlock() + if c, ok := old.(interface{ Close() error }); ok && c != nil { + if err := c.Close(); err != nil { + m.log.Warn("close old dataset", zap.Error(err)) + } + } +} + +// Close releases all resources, cancelling any in-flight jobs. +func (m *Manager) Close() error { + m.jobsMu.Lock() + for _, e := range m.jobs { + e.cancel() + } + m.jobsMu.Unlock() + + m.activeMu.Lock() + active := m.active + m.active = nil + m.activeMu.Unlock() + if c, ok := active.(interface{ Close() error }); ok && c != nil { + return c.Close() + } + return nil +} diff --git a/internal/datasets/manifest.go b/internal/datasets/manifest.go new file mode 100644 index 0000000..15c3038 --- /dev/null +++ b/internal/datasets/manifest.go @@ -0,0 +1,118 @@ +package datasets + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "sort" + "sync" +) + +// Manifest tracks completed work units for a partial dataset download. +// Units are arbitrary opaque strings; sources choose the format +// (e.g. "step12-A" for "forecast step 12, level set A"). +// +// A Manifest is persisted as a JSON object: {"units": ["step0-A", "step0-B", ...]}. +type Manifest struct { + path string + + mu sync.Mutex + units map[string]struct{} +} + +// LoadManifest opens or creates the manifest at path. Missing or unreadable +// files are treated as empty; a corrupt file returns an error. +func LoadManifest(path string) (*Manifest, error) { + m := &Manifest{path: path, units: make(map[string]struct{})} + + data, err := os.ReadFile(path) + if errors.Is(err, os.ErrNotExist) { + return m, nil + } + if err != nil { + return nil, fmt.Errorf("read manifest %s: %w", path, err) + } + if len(data) == 0 { + return m, nil + } + + var doc struct { + Units []string `json:"units"` + } + if err := json.Unmarshal(data, &doc); err != nil { + return nil, fmt.Errorf("parse manifest %s: %w", path, err) + } + for _, u := range doc.Units { + m.units[u] = struct{}{} + } + return m, nil +} + +// Has reports whether unit has been recorded as completed. +func (m *Manifest) Has(unit string) bool { + m.mu.Lock() + defer m.mu.Unlock() + _, ok := m.units[unit] + return ok +} + +// Mark records unit as completed and persists the manifest to disk. +func (m *Manifest) Mark(unit string) error { + m.mu.Lock() + defer m.mu.Unlock() + if _, ok := m.units[unit]; ok { + return nil + } + m.units[unit] = struct{}{} + return m.persistLocked() +} + +// Units returns the completed units in sorted order. +func (m *Manifest) Units() []string { + m.mu.Lock() + defer m.mu.Unlock() + out := make([]string, 0, len(m.units)) + for u := range m.units { + out = append(out, u) + } + sort.Strings(out) + return out +} + +// Reset clears all recorded units and removes the manifest file. +func (m *Manifest) Reset() error { + m.mu.Lock() + defer m.mu.Unlock() + m.units = make(map[string]struct{}) + if err := os.Remove(m.path); err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("remove manifest %s: %w", m.path, err) + } + return nil +} + +// persistLocked writes the manifest to disk via temp+rename. +// The caller must hold m.mu. +func (m *Manifest) persistLocked() error { + units := make([]string, 0, len(m.units)) + for u := range m.units { + units = append(units, u) + } + sort.Strings(units) + data, err := json.Marshal(struct { + Units []string `json:"units"` + }{Units: units}) + if err != nil { + return err + } + + tmp := m.path + ".new" + if err := os.WriteFile(tmp, data, 0o644); err != nil { + return fmt.Errorf("write manifest temp: %w", err) + } + if err := os.Rename(tmp, m.path); err != nil { + os.Remove(tmp) + return fmt.Errorf("rename manifest: %w", err) + } + return nil +} diff --git a/internal/datasets/store_local.go b/internal/datasets/store_local.go new file mode 100644 index 0000000..5dece03 --- /dev/null +++ b/internal/datasets/store_local.go @@ -0,0 +1,167 @@ +package datasets + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "time" +) + +// LocalStore stores dataset files on the local filesystem. +// +// Layout under Root: +// +// .bin — committed dataset (binary cube) +// .bin.downloading — in-progress dataset +// .bin.manifest.json — manifest of completed work units +// +// The .bin suffix exists to differentiate from sidecars in directory listings; +// epoch is formatted as "20060102T150405Z" (UTC). +type LocalStore struct { + Root string + Source string // source ID, recorded for safety but currently advisory + Extension string // default ".bin" +} + +// NewLocalStore returns a LocalStore at root. The directory is created if missing. +func NewLocalStore(root, sourceID string) (*LocalStore, error) { + if err := os.MkdirAll(root, 0o755); err != nil { + return nil, fmt.Errorf("create store root %s: %w", root, err) + } + return &LocalStore{Root: root, Source: sourceID, Extension: ".bin"}, nil +} + +// SourceID returns the source ID this store is configured for. +func (s *LocalStore) SourceID() string { return s.Source } + +const epochFormat = "20060102T150405Z" + +func (s *LocalStore) ext() string { + if s.Extension == "" { + return ".bin" + } + return s.Extension +} + +// Path returns the canonical path for an epoch's committed dataset file. +func (s *LocalStore) Path(epoch time.Time) string { + return filepath.Join(s.Root, epoch.UTC().Format(epochFormat)+s.ext()) +} + +func (s *LocalStore) tempPath(epoch time.Time) string { + return s.Path(epoch) + ".downloading" +} + +func (s *LocalStore) manifestPath(epoch time.Time) string { + return s.Path(epoch) + ".manifest.json" +} + +// Exists reports whether a committed dataset for epoch is present. +func (s *LocalStore) Exists(epoch time.Time) bool { + info, err := os.Stat(s.Path(epoch)) + return err == nil && !info.IsDir() +} + +// List returns all committed epochs, newest first. +func (s *LocalStore) List() ([]time.Time, error) { + entries, err := os.ReadDir(s.Root) + if err != nil { + return nil, fmt.Errorf("read store: %w", err) + } + var out []time.Time + ext := s.ext() + for _, e := range entries { + if e.IsDir() { + continue + } + name := e.Name() + if !strings.HasSuffix(name, ext) { + continue + } + stem := strings.TrimSuffix(name, ext) + // skip in-progress files (their stem already has .bin.downloading...) + if strings.Contains(stem, ".") { + continue + } + t, err := time.Parse(epochFormat, stem) + if err != nil { + continue + } + out = append(out, t.UTC()) + } + sort.Slice(out, func(i, j int) bool { return out[i].After(out[j]) }) + return out, nil +} + +// Remove deletes the committed dataset and any sidecar files for epoch. +func (s *LocalStore) Remove(epoch time.Time) error { + var errs []error + for _, p := range []string{s.Path(epoch), s.tempPath(epoch), s.manifestPath(epoch)} { + if err := os.Remove(p); err != nil && !errors.Is(err, os.ErrNotExist) { + errs = append(errs, err) + } + } + if len(errs) > 0 { + return fmt.Errorf("remove dataset: %v", errs) + } + return nil +} + +// BeginWrite opens or resumes a TempHandle for epoch. +// +// If a partial download is already present, its file and manifest are reused +// so the new download picks up where the previous one stopped. +func (s *LocalStore) BeginWrite(epoch time.Time) (TempHandle, error) { + man, err := LoadManifest(s.manifestPath(epoch)) + if err != nil { + return nil, err + } + return &localHandle{ + store: s, + epoch: epoch, + manifest: man, + }, nil +} + +type localHandle struct { + store *LocalStore + epoch time.Time + manifest *Manifest + closed bool +} + +func (h *localHandle) Path() string { return h.store.tempPath(h.epoch) } +func (h *localHandle) Manifest() *Manifest { return h.manifest } + +// Commit promotes the temp file to its final path and removes the manifest. +func (h *localHandle) Commit() error { + if h.closed { + return nil + } + h.closed = true + if err := os.Rename(h.store.tempPath(h.epoch), h.store.Path(h.epoch)); err != nil { + return fmt.Errorf("commit rename: %w", err) + } + if err := os.Remove(h.store.manifestPath(h.epoch)); err != nil && !errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("commit remove manifest: %w", err) + } + return nil +} + +// Abort removes the in-progress file and manifest. +func (h *localHandle) Abort() error { + if h.closed { + return nil + } + h.closed = true + var firstErr error + for _, p := range []string{h.store.tempPath(h.epoch), h.store.manifestPath(h.epoch)} { + if err := os.Remove(p); err != nil && !errors.Is(err, os.ErrNotExist) && firstErr == nil { + firstErr = err + } + } + return firstErr +} diff --git a/internal/datasets/store_test.go b/internal/datasets/store_test.go new file mode 100644 index 0000000..3ee5082 --- /dev/null +++ b/internal/datasets/store_test.go @@ -0,0 +1,82 @@ +package datasets + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestLocalStoreBeginWriteResume(t *testing.T) { + dir := t.TempDir() + store, err := NewLocalStore(dir, "gfs-test") + if err != nil { + t.Fatalf("NewLocalStore: %v", err) + } + + epoch := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + h, err := store.BeginWrite(epoch) + if err != nil { + t.Fatalf("BeginWrite: %v", err) + } + if err := os.WriteFile(h.Path(), []byte("partial"), 0o644); err != nil { + t.Fatalf("write partial: %v", err) + } + if err := h.Manifest().Mark("step000-A"); err != nil { + t.Fatalf("mark: %v", err) + } + + // Re-open should see the previous manifest entry. + h2, err := store.BeginWrite(epoch) + if err != nil { + t.Fatalf("BeginWrite resume: %v", err) + } + if !h2.Manifest().Has("step000-A") { + t.Errorf("resumed manifest missing step000-A; units = %v", h2.Manifest().Units()) + } + + // Commit promotes the temp file and removes the manifest. + if err := h2.Commit(); err != nil { + t.Fatalf("Commit: %v", err) + } + if !store.Exists(epoch) { + t.Errorf("Exists after commit returned false") + } + if _, err := os.Stat(filepath.Join(dir, store.manifestPath(epoch))); !os.IsNotExist(err) { + t.Errorf("manifest should be removed, got err=%v", err) + } + + // Listing finds the committed epoch. + epochs, err := store.List() + if err != nil { + t.Fatalf("List: %v", err) + } + if len(epochs) != 1 || !epochs[0].Equal(epoch) { + t.Errorf("List = %v, want [%v]", epochs, epoch) + } + + // Remove cleans up. + if err := store.Remove(epoch); err != nil { + t.Fatalf("Remove: %v", err) + } + if store.Exists(epoch) { + t.Errorf("Exists after remove returned true") + } +} + +func TestLocalStoreAbort(t *testing.T) { + dir := t.TempDir() + store, _ := NewLocalStore(dir, "gfs-test") + epoch := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) + + h, _ := store.BeginWrite(epoch) + os.WriteFile(h.Path(), []byte("x"), 0o644) + h.Manifest().Mark("step000-A") + + if err := h.Abort(); err != nil { + t.Fatalf("Abort: %v", err) + } + if _, err := os.Stat(h.Path()); !os.IsNotExist(err) { + t.Errorf("temp file should be removed after abort, got %v", err) + } +} diff --git a/internal/datasets/throttle.go b/internal/datasets/throttle.go new file mode 100644 index 0000000..980a005 --- /dev/null +++ b/internal/datasets/throttle.go @@ -0,0 +1,63 @@ +package datasets + +import ( + "context" + "sync" + "time" +) + +// TokenBucket is a simple bytes-per-second rate limiter. +// +// The bucket is initialised full (capacity = rate × 1 second). Calls to Wait +// block until enough tokens have accumulated. +type TokenBucket struct { + mu sync.Mutex + rate float64 // tokens per second + tokens float64 + cap float64 + last time.Time +} + +// NewTokenBucket returns a TokenBucket emitting at most bytesPerSecond. +// A non-positive rate disables throttling (Wait becomes a no-op). +func NewTokenBucket(bytesPerSecond int64) *TokenBucket { + if bytesPerSecond <= 0 { + return &TokenBucket{rate: 0} + } + r := float64(bytesPerSecond) + return &TokenBucket{rate: r, tokens: r, cap: r, last: time.Now()} +} + +// Wait blocks until n tokens are available or ctx is cancelled. +func (t *TokenBucket) Wait(ctx context.Context, n int) error { + if t.rate <= 0 { + return nil + } + want := float64(n) + + for { + t.mu.Lock() + now := time.Now() + elapsed := now.Sub(t.last).Seconds() + t.last = now + t.tokens += elapsed * t.rate + if t.tokens > t.cap { + t.tokens = t.cap + } + if t.tokens >= want { + t.tokens -= want + t.mu.Unlock() + return nil + } + // Sleep until we expect enough tokens. + need := want - t.tokens + sleep := time.Duration(need / t.rate * float64(time.Second)) + t.mu.Unlock() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(sleep): + } + } +} diff --git a/internal/datasets/types.go b/internal/datasets/types.go new file mode 100644 index 0000000..f2cb82e --- /dev/null +++ b/internal/datasets/types.go @@ -0,0 +1,97 @@ +package datasets + +import ( + "context" + "time" + + "predictor-refactored/internal/weather" +) + +// Source is a pluggable origin for atmospheric datasets. +// +// Implementations download dataset files in a transactional, resumable +// manner and load them as weather.WindField. A Source must be safe for +// concurrent use across multiple Manager calls. +type Source interface { + // ID is a stable identifier, e.g. "noaa-gfs-0p50". + ID() string + + // LatestEpoch returns the most recent dataset epoch this source can provide. + LatestEpoch(ctx context.Context) (time.Time, error) + + // Download fetches the dataset for epoch into store. Sources must honour + // any partial progress recorded in store's manifest and skip + // already-completed work, so re-invocation after a crash resumes cleanly. + // + // prog receives progress events; nil is acceptable. + // throttle, if non-nil, is consulted before each network read for + // bandwidth limiting; nil means no throttling. + Download(ctx context.Context, epoch time.Time, store Storage, prog ProgressSink, throttle Throttle) error + + // Open loads epoch's stored dataset and returns it as a WindField. + Open(ctx context.Context, epoch time.Time, store Storage) (weather.WindField, error) +} + +// Storage abstracts the on-disk location of dataset files and their manifests. +// +// Atomicity: only datasets promoted via TempHandle.Commit appear in Exists or +// List. Aborted or in-progress downloads are invisible to readers. +type Storage interface { + // SourceID identifies the data source these files belong to. Mixing + // sources in one Storage is not supported. + SourceID() string + + // Path returns the canonical local path for epoch's dataset. The path + // is valid even when the dataset has not been written. + Path(epoch time.Time) string + + // Exists reports whether a committed dataset for epoch is present. + Exists(epoch time.Time) bool + + // List returns all committed epochs available, newest first. + List() ([]time.Time, error) + + // Remove deletes the dataset and any sidecar manifest for epoch. + Remove(epoch time.Time) error + + // BeginWrite opens (or resumes) a transactional handle for downloading + // epoch's dataset. Callers must Commit or Abort the returned handle. + BeginWrite(epoch time.Time) (TempHandle, error) +} + +// TempHandle is the storage state for one in-progress download. +type TempHandle interface { + // Path returns the path of the in-progress file. Sources write directly here. + Path() string + + // Manifest is the tracker of completed work units for resume support. + Manifest() *Manifest + + // Commit promotes the temp file to its canonical location and removes + // the manifest. Subsequent calls are no-ops. + Commit() error + + // Abort discards the temp file and manifest. Subsequent calls are no-ops. + Abort() error +} + +// ProgressSink receives progress events during a download. +// +// All methods are safe to call concurrently. +type ProgressSink interface { + // SetTotal sets the total number of work units this download expects. + // May be called multiple times if discovery happens incrementally. + SetTotal(n int) + // StepComplete records one work unit as completed. + StepComplete() + // Bytes records n bytes received from the network. + Bytes(n int64) +} + +// Throttle is an optional bandwidth limiter consulted by sources before +// each network read. +type Throttle interface { + // Wait blocks until n bytes can be consumed from the budget, + // or returns ctx's error if the context is cancelled while waiting. + Wait(ctx context.Context, n int) error +} diff --git a/internal/downloader/config.go b/internal/downloader/config.go deleted file mode 100644 index 91575e1..0000000 --- a/internal/downloader/config.go +++ /dev/null @@ -1,58 +0,0 @@ -package downloader - -import ( - "os" - "strconv" - "time" -) - -// Config holds downloader configuration, loaded from environment variables. -type Config struct { - // DataDir is the directory for storing dataset files and temporary GRIB data. - DataDir string - - // Parallel is the maximum number of concurrent GRIB downloads. - Parallel int - - // UpdateInterval is how often the scheduler checks for new forecast data. - UpdateInterval time.Duration - - // DatasetTTL is how long a dataset is considered fresh before a new one is needed. - DatasetTTL time.Duration -} - -// DefaultConfig returns the default configuration. -func DefaultConfig() *Config { - return &Config{ - DataDir: "/tmp/predictor-data", - Parallel: 8, - UpdateInterval: 6 * time.Hour, - DatasetTTL: 48 * time.Hour, - } -} - -// LoadConfig loads configuration from environment variables, falling back to defaults. -func LoadConfig() *Config { - cfg := DefaultConfig() - - if v := os.Getenv("PREDICTOR_DATA_DIR"); v != "" { - cfg.DataDir = v - } - if v := os.Getenv("PREDICTOR_DOWNLOAD_PARALLEL"); v != "" { - if n, err := strconv.Atoi(v); err == nil && n > 0 { - cfg.Parallel = n - } - } - if v := os.Getenv("PREDICTOR_UPDATE_INTERVAL"); v != "" { - if d, err := time.ParseDuration(v); err == nil { - cfg.UpdateInterval = d - } - } - if v := os.Getenv("PREDICTOR_DATASET_TTL"); v != "" { - if d, err := time.ParseDuration(v); err == nil { - cfg.DatasetTTL = d - } - } - - return cfg -} diff --git a/internal/downloader/downloader.go b/internal/downloader/downloader.go deleted file mode 100644 index 32208a0..0000000 --- a/internal/downloader/downloader.go +++ /dev/null @@ -1,441 +0,0 @@ -package downloader - -import ( - "context" - "fmt" - "io" - "math" - "net/http" - "os" - "path/filepath" - "sync/atomic" - "time" - - "predictor-refactored/internal/dataset" - - "github.com/nilsmagnus/grib/griblib" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" -) - -// Downloader handles fetching GFS forecast data from S3 and assembling dataset files. -type Downloader struct { - cfg *Config - client *http.Client - log *zap.Logger -} - -// NewDownloader creates a new Downloader. -func NewDownloader(cfg *Config, log *zap.Logger) *Downloader { - return &Downloader{ - cfg: cfg, - client: &http.Client{ - Timeout: 2 * time.Minute, - }, - log: log, - } -} - -// neededVariables is the set of GRIB variable names we need. -var neededVariables = map[string]bool{ - "HGT": true, - "UGRD": true, - "VGRD": true, -} - -// FindLatestRun finds the most recent available GFS model run on S3. -// It checks the last forecast step of each run to confirm availability. -func (d *Downloader) FindLatestRun(ctx context.Context) (time.Time, error) { - now := time.Now().UTC() - hour := now.Hour() - (now.Hour() % 6) - current := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC) - - for i := 0; i < 8; i++ { - date := current.Format("20060102") - url := dataset.GribURL(date, current.Hour(), dataset.MaxHour) + ".idx" - - req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) - if err != nil { - current = current.Add(-6 * time.Hour) - continue - } - - resp, err := d.client.Do(req) - if err == nil { - resp.Body.Close() - if resp.StatusCode == http.StatusOK { - d.log.Info("found latest model run", - zap.Time("run", current), - zap.String("verified_url", url)) - return current, nil - } - } - - current = current.Add(-6 * time.Hour) - } - - return time.Time{}, fmt.Errorf("no recent GFS forecast found (checked 8 runs)") -} - -// progress tracks download progress across concurrent goroutines. -type progress struct { - bytesDownloaded atomic.Int64 - stepsCompleted atomic.Int64 - totalSteps int64 - startTime time.Time - log *zap.Logger -} - -func newProgress(totalSteps int, log *zap.Logger) *progress { - return &progress{ - totalSteps: int64(totalSteps), - startTime: time.Now(), - log: log, - } -} - -func (p *progress) addBytes(n int64) { - p.bytesDownloaded.Add(n) -} - -func (p *progress) completeStep() { - done := p.stepsCompleted.Add(1) - total := p.totalSteps - bytes := p.bytesDownloaded.Load() - elapsed := time.Since(p.startTime).Seconds() - - pct := float64(done) / float64(total) * 100 - mbDownloaded := float64(bytes) / (1024 * 1024) - mbPerSec := 0.0 - if elapsed > 0 { - mbPerSec = mbDownloaded / elapsed - } - - // Estimate remaining - eta := "" - if done > 0 && done < total { - secsPerStep := elapsed / float64(done) - remaining := secsPerStep * float64(total-done) - if remaining > 60 { - eta = fmt.Sprintf("%.0fm%02.0fs", math.Floor(remaining/60), math.Mod(remaining, 60)) - } else { - eta = fmt.Sprintf("%.0fs", remaining) - } - } - - p.log.Info("download progress", - zap.String("progress", fmt.Sprintf("%d/%d", done, total)), - zap.String("percent", fmt.Sprintf("%.1f%%", pct)), - zap.String("downloaded", fmt.Sprintf("%.1f MB", mbDownloaded)), - zap.String("speed", fmt.Sprintf("%.1f MB/s", mbPerSec)), - zap.String("eta", eta)) -} - -// Download downloads a complete forecast and assembles a dataset file. -// Returns the path to the completed dataset file. -func (d *Downloader) Download(ctx context.Context, run time.Time) (string, error) { - date := run.Format("20060102") - runHour := run.Hour() - - finalPath := filepath.Join(d.cfg.DataDir, run.Format("2006010215")) - tempPath := finalPath + ".downloading" - - // Check if final dataset already exists - if info, err := os.Stat(finalPath); err == nil && info.Size() == dataset.DatasetSize { - d.log.Info("dataset already exists", zap.String("path", finalPath)) - return finalPath, nil - } - - steps := dataset.Hours() - totalSteps := len(steps) * 2 // pgrb2 + pgrb2b per step - prog := newProgress(totalSteps, d.log) - - d.log.Info("starting dataset download", - zap.Time("run", run), - zap.Int("total_steps", totalSteps), - zap.String("temp_path", tempPath)) - - // Create the dataset file - ds, err := dataset.Create(tempPath) - if err != nil { - return "", fmt.Errorf("create dataset: %w", err) - } - defer ds.Close() - - // Process each forecast step with bounded concurrency - g, ctx := errgroup.WithContext(ctx) - sem := make(chan struct{}, d.cfg.Parallel) - - for _, step := range steps { - step := step - hourIdx := dataset.HourIndex(step) - if hourIdx < 0 { - continue - } - - // Download pgrb2 (level set A) - sem <- struct{}{} - g.Go(func() error { - defer func() { <-sem }() - url := dataset.GribURL(date, runHour, step) - err := d.downloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetA, prog) - if err != nil { - return fmt.Errorf("step %d pgrb2: %w", step, err) - } - prog.completeStep() - return nil - }) - - // Download pgrb2b (level set B) - sem <- struct{}{} - g.Go(func() error { - defer func() { <-sem }() - url := dataset.GribURLB(date, runHour, step) - err := d.downloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetB, prog) - if err != nil { - return fmt.Errorf("step %d pgrb2b: %w", step, err) - } - prog.completeStep() - return nil - }) - } - - if err := g.Wait(); err != nil { - os.Remove(tempPath) - return "", err - } - - elapsed := time.Since(prog.startTime) - totalMB := float64(prog.bytesDownloaded.Load()) / (1024 * 1024) - d.log.Info("download complete, flushing to disk", - zap.String("downloaded", fmt.Sprintf("%.1f MB", totalMB)), - zap.Duration("elapsed", elapsed), - zap.String("avg_speed", fmt.Sprintf("%.1f MB/s", totalMB/elapsed.Seconds()))) - - // Flush to disk - if err := ds.Flush(); err != nil { - os.Remove(tempPath) - return "", fmt.Errorf("flush dataset: %w", err) - } - - // Close before rename - ds.Close() - - // Atomic rename - if err := os.Rename(tempPath, finalPath); err != nil { - os.Remove(tempPath) - return "", fmt.Errorf("rename dataset: %w", err) - } - - d.log.Info("dataset ready", zap.String("path", finalPath)) - return finalPath, nil -} - -// DownloadAndBlit downloads needed GRIB fields from a URL and writes them into the dataset. -func (d *Downloader) DownloadAndBlit(ctx context.Context, ds *dataset.File, baseURL string, hourIdx int, levelSet dataset.LevelSet) error { - return d.downloadAndBlit(ctx, ds, baseURL, hourIdx, levelSet, nil) -} - -// downloadAndBlit is the internal implementation with optional progress tracking. -func (d *Downloader) downloadAndBlit(ctx context.Context, ds *dataset.File, baseURL string, hourIdx int, levelSet dataset.LevelSet, prog *progress) error { - // 1. Download .idx - idxURL := baseURL + ".idx" - idxBody, err := d.httpGet(ctx, idxURL) - if err != nil { - return fmt.Errorf("download idx: %w", err) - } - - // 2. Parse and filter - entries := ParseIdx(idxBody) - filtered := FilterIdx(entries, neededVariables) - - // Further filter to only levels in this level set - var relevant []IdxEntry - for _, e := range filtered { - ls, ok := dataset.PressureLevelSet(e.LevelMB) - if ok && ls == levelSet { - relevant = append(relevant, e) - } - } - - if len(relevant) == 0 { - d.log.Warn("no relevant entries found in idx", - zap.String("url", idxURL), - zap.Int("total_entries", len(entries)), - zap.Int("filtered", len(filtered))) - return nil - } - - // 3. Download byte ranges and write to temp file - ranges := EntriesToRanges(relevant) - tmpFile, err := d.downloadRangesToTempFile(ctx, baseURL, ranges, prog) - if err != nil { - return fmt.Errorf("download ranges: %w", err) - } - defer os.Remove(tmpFile) - - // 4. Read GRIB messages from temp file - f, err := os.Open(tmpFile) - if err != nil { - return fmt.Errorf("open temp grib: %w", err) - } - - messages, err := griblib.ReadMessages(f) - f.Close() - if err != nil { - return fmt.Errorf("read grib messages: %w", err) - } - - // 5. Decode and blit each message into the dataset - for _, msg := range messages { - if msg.Section4.ProductDefinitionTemplateNumber != 0 { - continue - } - - product := msg.Section4.ProductDefinitionTemplate - - varIdx := dataset.VariableIndex(int(product.ParameterCategory), int(product.ParameterNumber)) - if varIdx < 0 { - continue - } - - if product.FirstSurface.Type != 100 { // isobaric surface - continue - } - - pressurePa := float64(product.FirstSurface.Value) - pressureMB := int(math.Round(pressurePa / 100.0)) - levelIdx := dataset.PressureIndex(pressureMB) - if levelIdx < 0 { - continue - } - - data := msg.Data() - if err := ds.BlitGribData(hourIdx, levelIdx, varIdx, data); err != nil { - d.log.Warn("blit failed", - zap.Int("var", varIdx), - zap.Int("level_mb", pressureMB), - zap.Error(err)) - continue - } - } - - return nil -} - -// downloadRangesToTempFile downloads multiple byte ranges from a URL, -// concatenating them into a single temp file (valid concatenated GRIB messages). -func (d *Downloader) downloadRangesToTempFile(ctx context.Context, baseURL string, ranges []ByteRange, prog *progress) (string, error) { - tmpFile, err := os.CreateTemp(d.cfg.DataDir, "grib-*.tmp") - if err != nil { - return "", fmt.Errorf("create temp file: %w", err) - } - tmpPath := tmpFile.Name() - - for _, r := range ranges { - data, err := d.httpGetRange(ctx, baseURL, r.Start, r.End) - if err != nil { - tmpFile.Close() - os.Remove(tmpPath) - return "", fmt.Errorf("download range %d-%d: %w", r.Start, r.End, err) - } - if _, err := tmpFile.Write(data); err != nil { - tmpFile.Close() - os.Remove(tmpPath) - return "", fmt.Errorf("write temp: %w", err) - } - if prog != nil { - prog.addBytes(int64(len(data))) - } - } - - if err := tmpFile.Close(); err != nil { - os.Remove(tmpPath) - return "", err - } - - return tmpPath, nil -} - -// httpGet downloads a URL and returns the body bytes. -func (d *Downloader) httpGet(ctx context.Context, url string) ([]byte, error) { - var lastErr error - for attempt := 0; attempt < 3; attempt++ { - if attempt > 0 { - select { - case <-time.After(time.Duration(attempt*2) * time.Second): - case <-ctx.Done(): - return nil, ctx.Err() - } - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return nil, err - } - - resp, err := d.client.Do(req) - if err != nil { - lastErr = err - continue - } - - body, err := io.ReadAll(resp.Body) - resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url) - continue - } - if err != nil { - lastErr = err - continue - } - - return body, nil - } - - return nil, fmt.Errorf("after 3 attempts: %w", lastErr) -} - -// httpGetRange downloads a byte range from a URL. -func (d *Downloader) httpGetRange(ctx context.Context, url string, start, end int64) ([]byte, error) { - var lastErr error - for attempt := 0; attempt < 3; attempt++ { - if attempt > 0 { - select { - case <-time.After(time.Duration(attempt*2) * time.Second): - case <-ctx.Done(): - return nil, ctx.Err() - } - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return nil, err - } - req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) - - resp, err := d.client.Do(req) - if err != nil { - lastErr = err - continue - } - - body, err := io.ReadAll(resp.Body) - resp.Body.Close() - - if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK { - lastErr = fmt.Errorf("HTTP %d for range %d-%d of %s", resp.StatusCode, start, end, url) - continue - } - if err != nil { - lastErr = err - continue - } - - return body, nil - } - - return nil, fmt.Errorf("after 3 attempts: %w", lastErr) -} diff --git a/internal/downloader/idx.go b/internal/downloader/idx.go deleted file mode 100644 index 2e09bc4..0000000 --- a/internal/downloader/idx.go +++ /dev/null @@ -1,157 +0,0 @@ -package downloader - -import ( - "fmt" - "strconv" - "strings" -) - -// IdxEntry represents a single parsed line from a GRIB .idx file. -// Example line: "15:1207405:d=2024010100:HGT:1000 mb:0 hour fcst:" -type IdxEntry struct { - Index int - Offset int64 - Variable string // "HGT", "UGRD", "VGRD", etc. - LevelMB int // pressure level in mb (0 if not a pressure level) - Hour int // forecast hour - EndOffset int64 // byte after this message (from next entry's offset, or -1 if last) -} - -// Length returns the byte length of this GRIB message, or -1 if unknown. -func (e *IdxEntry) Length() int64 { - if e.EndOffset <= 0 { - return -1 - } - return e.EndOffset - e.Offset -} - -// ParseIdx parses a .idx file body and returns all entries. -// Lines that can't be parsed are silently skipped. -func ParseIdx(body []byte) []IdxEntry { - lines := strings.Split(string(body), "\n") - var entries []IdxEntry - - for _, line := range lines { - line = strings.TrimSpace(line) - if line == "" { - continue - } - - parts := strings.Split(line, ":") - if len(parts) < 7 { - continue - } - - idx, err := strconv.Atoi(parts[0]) - if err != nil { - continue - } - - offset, err := strconv.ParseInt(parts[1], 10, 64) - if err != nil { - continue - } - - variable := parts[3] - levelStr := parts[4] - hourStr := parts[5] - - levelMB := parseLevelMB(levelStr) - hour := parseHour(hourStr) - - entries = append(entries, IdxEntry{ - Index: idx, - Offset: offset, - Variable: variable, - LevelMB: levelMB, - Hour: hour, - EndOffset: -1, // filled in below - }) - } - - // Fill in EndOffset from the next entry's Offset. - for i := 0; i < len(entries)-1; i++ { - entries[i].EndOffset = entries[i+1].Offset - } - - return entries -} - -// FilterIdx returns entries matching the given variables at pressure levels. -// Only entries with a recognized pressure level (levelMB > 0) are returned. -func FilterIdx(entries []IdxEntry, variables map[string]bool) []IdxEntry { - var filtered []IdxEntry - for _, e := range entries { - if !variables[e.Variable] { - continue - } - if e.LevelMB <= 0 { - continue - } - // Must have a known length (not the last entry) or be handled specially - if e.Length() <= 0 { - continue - } - filtered = append(filtered, e) - } - return filtered -} - -// parseLevelMB parses a level string like "1000 mb" and returns the pressure in mb. -// Returns 0 if not a pressure level. -func parseLevelMB(s string) int { - s = strings.TrimSpace(s) - if !strings.HasSuffix(s, " mb") { - return 0 - } - numStr := strings.TrimSuffix(s, " mb") - n, err := strconv.Atoi(numStr) - if err != nil { - return 0 - } - return n -} - -// parseHour parses a forecast hour string like "0 hour fcst" or "anl". -// Returns -1 if it can't be parsed. -func parseHour(s string) int { - s = strings.TrimSpace(s) - if s == "anl" { - return 0 - } - s = strings.TrimSuffix(s, " hour fcst") - n, err := strconv.Atoi(s) - if err != nil { - return -1 - } - return n -} - -// GroupByRange groups idx entries into byte ranges suitable for HTTP Range downloads. -// Each range covers one contiguous GRIB message. -type ByteRange struct { - Start int64 - End int64 // inclusive - Entry IdxEntry -} - -// EntriesToRanges converts filtered idx entries to byte ranges. -func EntriesToRanges(entries []IdxEntry) []ByteRange { - ranges := make([]ByteRange, 0, len(entries)) - for _, e := range entries { - if e.Length() <= 0 { - continue - } - ranges = append(ranges, ByteRange{ - Start: e.Offset, - End: e.EndOffset - 1, // inclusive - Entry: e, - }) - } - return ranges -} - -// FormatRange returns an HTTP Range header value for a byte range. -func (r ByteRange) FormatRange() string { - return fmt.Sprintf("bytes=%d-%d", r.Start, r.End) -} diff --git a/internal/downloader/idx_test.go b/internal/downloader/idx_test.go deleted file mode 100644 index 71a7224..0000000 --- a/internal/downloader/idx_test.go +++ /dev/null @@ -1,110 +0,0 @@ -package downloader - -import ( - "testing" -) - -const sampleIdx = `1:0:d=2024010100:HGT:1000 mb:0 hour fcst: -2:289012:d=2024010100:HGT:975 mb:0 hour fcst: -3:541876:d=2024010100:TMP:1000 mb:0 hour fcst: -4:789012:d=2024010100:UGRD:1000 mb:0 hour fcst: -5:1045678:d=2024010100:VGRD:1000 mb:0 hour fcst: -6:1298765:d=2024010100:UGRD:975 mb:0 hour fcst: -7:1567890:d=2024010100:UGRD:2 m above ground:0 hour fcst: -8:1812345:d=2024010100:VGRD:975 mb:0 hour fcst: -9:2098765:d=2024010100:HGT:500 mb:3 hour fcst: -` - -func TestParseIdx(t *testing.T) { - entries := ParseIdx([]byte(sampleIdx)) - if len(entries) != 9 { - t.Fatalf("expected 9 entries, got %d", len(entries)) - } - - // Check first entry - e := entries[0] - if e.Index != 1 || e.Offset != 0 || e.Variable != "HGT" || e.LevelMB != 1000 || e.Hour != 0 { - t.Errorf("entry 0: got %+v", e) - } - if e.EndOffset != 289012 { - t.Errorf("entry 0 EndOffset: got %d, want 289012", e.EndOffset) - } - - // Check "2 m above ground" is not a pressure level - e = entries[6] // UGRD at "2 m above ground" - if e.LevelMB != 0 { - t.Errorf("non-pressure level should have LevelMB=0, got %d", e.LevelMB) - } - - // Last entry should have EndOffset = -1 - last := entries[len(entries)-1] - if last.EndOffset != -1 { - t.Errorf("last entry EndOffset: got %d, want -1", last.EndOffset) - } -} - -func TestFilterIdx(t *testing.T) { - entries := ParseIdx([]byte(sampleIdx)) - filtered := FilterIdx(entries, neededVariables) - - // Should include HGT/UGRD/VGRD at pressure levels, exclude TMP and "above ground" - // And exclude last entry (no EndOffset) - for _, e := range filtered { - if !neededVariables[e.Variable] { - t.Errorf("unexpected variable %s", e.Variable) - } - if e.LevelMB <= 0 { - t.Errorf("non-pressure level included: %+v", e) - } - if e.Length() <= 0 { - t.Errorf("entry with unknown length included: %+v", e) - } - } - - // Count expected: HGT@1000, HGT@975, UGRD@1000, VGRD@1000, UGRD@975, VGRD@975 = 6 - // But HGT@500 at 3hr fcst is the last entry (no EndOffset), so excluded - if len(filtered) != 6 { - t.Errorf("expected 6 filtered entries, got %d", len(filtered)) - for _, e := range filtered { - t.Logf(" %s %d mb (offset %d, len %d)", e.Variable, e.LevelMB, e.Offset, e.Length()) - } - } -} - -func TestParseLevelMB(t *testing.T) { - tests := []struct { - input string - want int - }{ - {"1000 mb", 1000}, - {"975 mb", 975}, - {"1 mb", 1}, - {"2 m above ground", 0}, - {"surface", 0}, - {"tropopause", 0}, - } - for _, tt := range tests { - got := parseLevelMB(tt.input) - if got != tt.want { - t.Errorf("parseLevelMB(%q) = %d, want %d", tt.input, got, tt.want) - } - } -} - -func TestParseHour(t *testing.T) { - tests := []struct { - input string - want int - }{ - {"0 hour fcst", 0}, - {"3 hour fcst", 3}, - {"192 hour fcst", 192}, - {"anl", 0}, - } - for _, tt := range tests { - got := parseHour(tt.input) - if got != tt.want { - t.Errorf("parseHour(%q) = %d, want %d", tt.input, got, tt.want) - } - } -} diff --git a/internal/elevation/elevation.go b/internal/elevation/elevation.go index 9fde295..4bcfce9 100644 --- a/internal/elevation/elevation.go +++ b/internal/elevation/elevation.go @@ -71,9 +71,10 @@ func (d *Dataset) getCell(latIdx, lngIdx int) int16 { return int16(binary.LittleEndian.Uint16(d.mm[off : off+2])) } -// Get returns the interpolated elevation in metres at the given coordinates. -// lat: -90 to +90, lng: 0 to 360 (or -180 to 180, will be normalised). -func (d *Dataset) Get(lat, lng float64) float64 { +// Elevation returns the bilinearly-interpolated ground elevation in metres at +// the given coordinates. lat is in [-90, +90]; lng accepts either [0, 360) or +// [-180, 180) and is normalised internally. +func (d *Dataset) Elevation(lat, lng float64) float64 { // Normalise longitude to [0, 360) if lng < 0 { lng += 360 diff --git a/internal/engine/constraints.go b/internal/engine/constraints.go new file mode 100644 index 0000000..f2f8d08 --- /dev/null +++ b/internal/engine/constraints.go @@ -0,0 +1,47 @@ +package engine + +// MaxAltitude triggers when altitude rises above Limit (in metres). +// Used as the burst condition for ascent stages. +type MaxAltitude struct { + Limit float64 + On Action +} + +func (c MaxAltitude) Name() string { return "max_altitude" } +func (c MaxAltitude) Violated(_ float64, s State) bool { return s.Altitude >= c.Limit } +func (c MaxAltitude) Action() Action { return c.On } + +// MinAltitude triggers when altitude falls at or below Limit (in metres). +// With Limit=0 this is the "sea level" terminator. +type MinAltitude struct { + Limit float64 + On Action +} + +func (c MinAltitude) Name() string { return "min_altitude" } +func (c MinAltitude) Violated(_ float64, s State) bool { return s.Altitude <= c.Limit } +func (c MinAltitude) Action() Action { return c.On } + +// MaxTime triggers when t exceeds Limit (UNIX seconds). Used as a stop +// condition for float profiles. +type MaxTime struct { + Limit float64 + On Action +} + +func (c MaxTime) Name() string { return "max_time" } +func (c MaxTime) Violated(t float64, _ State) bool { return t > c.Limit } +func (c MaxTime) Action() Action { return c.On } + +// TerrainContact triggers when altitude has dropped at or below ground level. +// Equivalent to Tawhiri's elevation termination. +type TerrainContact struct { + Provider TerrainProvider + On Action +} + +func (c TerrainContact) Name() string { return "terrain_contact" } +func (c TerrainContact) Violated(_ float64, s State) bool { + return c.Provider.Elevation(s.Lat, s.Lng) > s.Altitude +} +func (c TerrainContact) Action() Action { return c.On } diff --git a/internal/engine/engine_test.go b/internal/engine/engine_test.go new file mode 100644 index 0000000..fd55e38 --- /dev/null +++ b/internal/engine/engine_test.go @@ -0,0 +1,176 @@ +package engine + +import ( + "math" + "testing" + "time" + + "predictor-refactored/internal/weather" +) + +// noWind is a WindField that always returns zero wind. Lets us test +// integration of vertical-only profiles deterministically. +type noWind struct{ epoch time.Time } + +func (n noWind) Wind(_ float64, _, _, _ float64) (weather.Sample, error) { + return weather.Sample{}, nil +} +func (n noWind) Epoch() time.Time { return n.epoch } +func (n noWind) Source() string { return "test" } + +// flatGround returns 0 metres everywhere. +type flatGround struct{} + +func (flatGround) Elevation(_, _ float64) float64 { return 0 } + +func TestConstantAscentToBurst(t *testing.T) { + burst := 30000.0 + rate := 5.0 + + ascend := &Propagator{ + Name: "ascent", + Step: 60, + Model: Sum(ConstantRate(rate), WindTransport(noWind{}, nil)), + Constraints: []Constraint{MaxAltitude{Limit: burst, On: ActionStop}}, + } + + prof := Profile{Stages: []*Propagator{ascend}, Direction: Forward} + results := prof.Run(0, State{Lat: 0, Lng: 0, Altitude: 0}) + + if len(results) != 1 || results[0].Outcome != OutcomeStopped { + t.Fatalf("expected one stopped stage, got %+v", results) + } + + last := results[0].Points[len(results[0].Points)-1] + // Refinement tolerance is 0.01 in parameter space over a 60s step, so the + // returned point sits within ±0.6s × rate ≈ ±3m of the boundary. + if math.Abs(last.Altitude-burst) > 5 { + t.Errorf("burst altitude = %v, want within 5m of %v", last.Altitude, burst) + } + wantTime := burst / rate + if math.Abs(last.Time-wantTime) > 1 { + t.Errorf("burst time = %v, want within 1s of %v", last.Time, wantTime) + } +} + +func TestProfileWithFallback(t *testing.T) { + burst := 1000.0 + rate := 5.0 + + descent := &Propagator{ + Name: "descent", + Step: 60, + Model: ParachuteDescent(rate), + Constraints: []Constraint{TerrainContact{Provider: flatGround{}, On: ActionStop}}, + } + ascend := &Propagator{ + Name: "ascent", + Step: 60, + Model: ConstantRate(rate), + Constraints: []Constraint{MaxAltitude{Limit: burst, On: ActionFallback}}, + Fallback: descent, + } + + prof := Profile{Stages: []*Propagator{ascend}, Direction: Forward} + results := prof.Run(0, State{Altitude: 0}) + + if len(results) != 2 { + t.Fatalf("expected 2 results (ascent then descent fallback), got %d", len(results)) + } + if results[0].Outcome != OutcomeFallback { + t.Errorf("first outcome = %v, want OutcomeFallback", results[0].Outcome) + } + if results[1].Outcome != OutcomeStopped { + t.Errorf("second outcome = %v, want OutcomeStopped", results[1].Outcome) + } + + last := results[1].Points[len(results[1].Points)-1] + if math.Abs(last.Altitude) > 5 { + t.Errorf("final altitude = %v, want within 5m of 0", last.Altitude) + } +} + +func TestReverseDirection(t *testing.T) { + // Start at altitude 100m with downward rate; integrating reverse should + // give increasing altitude. + desc := &Propagator{ + Name: "rewind", + Step: 1, + Model: ConstantRate(-1), // forward: alt decreases at 1 m/s + Constraints: []Constraint{MaxAltitude{Limit: 200, On: ActionStop}}, + } + prof := Profile{Stages: []*Propagator{desc}, Direction: Reverse} + results := prof.Run(0, State{Altitude: 100}) + + last := results[0].Points[len(results[0].Points)-1] + if math.Abs(last.Altitude-200) > 1 { + t.Errorf("reverse final altitude = %v, want ~200", last.Altitude) + } + if last.Time >= 0 { + t.Errorf("reverse final time = %v, want < 0", last.Time) + } +} + +func TestPiecewiseRate(t *testing.T) { + m := Piecewise([]RateSegment{ + {Until: 100, Rate: 5}, + {Until: 200, Rate: 3}, + {Until: math.Inf(1), Rate: 0}, + }) + + if r := m(50, State{}); r.Altitude != 5 { + t.Errorf("rate at t=50 = %v, want 5", r.Altitude) + } + if r := m(150, State{}); r.Altitude != 3 { + t.Errorf("rate at t=150 = %v, want 3", r.Altitude) + } + if r := m(300, State{}); r.Altitude != 0 { + t.Errorf("rate at t=300 = %v, want 0", r.Altitude) + } +} + +// fixedWind returns a constant wind sample. +type fixedWind struct{ u, v float64 } + +func (w fixedWind) Wind(_ float64, _, _, _ float64) (weather.Sample, error) { + return weather.Sample{U: w.u, V: w.v}, nil +} +func (fixedWind) Epoch() time.Time { return time.Unix(0, 0) } +func (fixedWind) Source() string { return "test-fixed" } + +func TestWindTransportUnitConversion(t *testing.T) { + // Pure eastward wind of 10 m/s at the equator at sea level. + // Expected dlng/dt = (180/pi) * 10 / (6371009 * cos(0)) ≈ 0.00008991 deg/s. + // Expected dlat/dt = 0. + wind := WindTransport(fixedWind{u: 10, v: 0}, nil) + d := wind(0, State{Lat: 0, Lng: 0, Altitude: 0}) + + wantLng := (180.0 / math.Pi) * 10.0 / 6371009.0 + if math.Abs(d.Lng-wantLng) > 1e-12 { + t.Errorf("dlng = %v, want %v", d.Lng, wantLng) + } + if math.Abs(d.Lat) > 1e-12 { + t.Errorf("dlat = %v, want 0 for u=10 v=0", d.Lat) + } + + // Pure northward at 60° latitude: dlat = (180/pi) * v / R, dlng = 0. + wind2 := WindTransport(fixedWind{u: 0, v: 5}, nil) + d = wind2(0, State{Lat: 60, Lng: 0, Altitude: 0}) + wantLat := (180.0 / math.Pi) * 5.0 / 6371009.0 + if math.Abs(d.Lat-wantLat) > 1e-12 { + t.Errorf("dlat at lat=60 = %v, want %v", d.Lat, wantLat) + } +} + +func TestStateAddWrapsLongitude(t *testing.T) { + // Demonstrates state algebra used by the integrator and refinement. + s := stateAdd(State{Lat: 0, Lng: 350, Altitude: 0}, 1, State{Lng: 20}) + if math.Abs(s.Lng-10) > 1e-9 { + t.Errorf("addState wrap: lng = %v, want 10", s.Lng) + } + + mid := stateLerp(State{Lng: 350}, State{Lng: 10}, 0.5) + if math.Abs(mid.Lng-0) > 1e-9 && math.Abs(mid.Lng-360) > 1e-9 { + t.Errorf("lerpState lng wrap: %v, want 0 or 360", mid.Lng) + } +} diff --git a/internal/engine/models.go b/internal/engine/models.go new file mode 100644 index 0000000..6431b60 --- /dev/null +++ b/internal/engine/models.go @@ -0,0 +1,151 @@ +package engine + +import ( + "math" + "sort" + "sync/atomic" + + "predictor-refactored/internal/weather" +) + +// Sum composes models by summing their derivatives at each evaluation point. +// +// Useful for combining e.g. a vertical-rate model with a horizontal wind model +// into a single propagator. Equivalent to Tawhiri's LinearModel. +func Sum(models ...Model) Model { + if len(models) == 1 { + return models[0] + } + return func(t float64, s State) State { + var sum State + for _, m := range models { + d := m(t, s) + sum.Lat += d.Lat + sum.Lng += d.Lng + sum.Altitude += d.Altitude + } + return sum + } +} + +// ConstantRate returns a model with a constant vertical velocity (m/s). +// A positive rate is upward (ascent); a negative rate is downward. +func ConstantRate(rate float64) Model { + return func(_ float64, _ State) State { + return State{Altitude: rate} + } +} + +// ParachuteDescent returns a model where vertical velocity grows with altitude +// because thinner air provides less drag. +// +// seaLevelRate is the descent speed at sea level (m/s, positive number). +// The terminal velocity at altitude is computed as +// +// v = -k / sqrt(rho(alt)), k = seaLevelRate * 1.1045, +// +// using the NASA atmosphere model for rho. Equivalent to Tawhiri's drag_descent. +func ParachuteDescent(seaLevelRate float64) Model { + k := seaLevelRate * 1.1045 + return func(_ float64, s State) State { + return State{Altitude: -k / math.Sqrt(nasaDensity(s.Altitude))} + } +} + +// nasaDensity returns air density (kg/m^3) for the given altitude in metres, +// using the NASA simple atmosphere model. See +// https://www.grc.nasa.gov/WWW/K-12/airplane/atmosmet.html. +func nasaDensity(alt float64) float64 { + var temp, pressure float64 + switch { + case alt > 25000: + temp = -131.21 + 0.00299*alt + pressure = 2.488 * math.Pow((temp+273.1)/216.6, -11.388) + case alt > 11000: + temp = -56.46 + pressure = 22.65 * math.Exp(1.73-0.000157*alt) + default: + temp = 15.04 - 0.00649*alt + pressure = 101.29 * math.Pow((temp+273.1)/288.08, 5.256) + } + return pressure / (0.2869 * (temp + 273.1)) +} + +// RateSegment is one entry in a Piecewise rate schedule. +type RateSegment struct { + // Until is the UNIX timestamp at which this segment ends. + // The model applies the segment's Rate for all t < Until. + Until float64 + // Rate is the vertical velocity (m/s) during the segment. Positive is up. + Rate float64 +} + +// Piecewise returns a model that produces a piecewise-constant vertical rate +// over a sequence of time intervals. +// +// Segments are searched by their Until field; the first segment whose Until +// exceeds t supplies the active rate. For t at or after the last Until, the +// final segment's Rate is held indefinitely. Input is sorted ascending by +// Until on construction. +func Piecewise(segments []RateSegment) Model { + if len(segments) == 0 { + return ConstantRate(0) + } + sorted := append([]RateSegment(nil), segments...) + sort.Slice(sorted, func(i, j int) bool { return sorted[i].Until < sorted[j].Until }) + finalRate := sorted[len(sorted)-1].Rate + + return func(t float64, _ State) State { + idx := sort.Search(len(sorted), func(i int) bool { return sorted[i].Until > t }) + if idx == len(sorted) { + return State{Altitude: finalRate} + } + return State{Altitude: sorted[idx].Rate} + } +} + +// Warnings aggregates non-fatal conditions encountered during integration. +type Warnings struct { + // AltitudeTooHigh counts evaluations where the wind sampler reported + // that altitude was above the highest pressure level of the dataset. + AltitudeTooHigh atomic.Int64 +} + +// ToMap returns warnings as a map suitable for JSON output. Only counters +// that have fired are included. +func (w *Warnings) ToMap() map[string]any { + out := make(map[string]any) + if n := w.AltitudeTooHigh.Load(); n > 0 { + out["altitude_too_high"] = map[string]any{ + "count": n, + "description": "altitude exceeded the highest pressure level of the wind dataset; samples were extrapolated", + } + } + return out +} + +// WindTransport returns a model that moves laterally at the wind velocity +// sampled from field. The vertical component of the returned derivative is +// zero. Wind units are converted from m/s to deg/s on Earth's surface. +// +// If warnings is non-nil, the AltitudeTooHigh counter is incremented for any +// sample where the wind field reported altitude above the model top. +func WindTransport(field weather.WindField, warnings *Warnings) Model { + const earthR = 6371009.0 + const piOver180 = math.Pi / 180.0 + const degPerRad = 180.0 / math.Pi + return func(t float64, s State) State { + sample, err := field.Wind(t, s.Lat, s.Lng, s.Altitude) + if err != nil { + return State{} + } + if sample.AboveModel && warnings != nil { + warnings.AltitudeTooHigh.Add(1) + } + r := earthR + s.Altitude + return State{ + Lat: degPerRad * sample.V / r, + Lng: degPerRad * sample.U / (r * math.Cos(s.Lat*piOver180)), + } + } +} diff --git a/internal/engine/profile.go b/internal/engine/profile.go new file mode 100644 index 0000000..57460de --- /dev/null +++ b/internal/engine/profile.go @@ -0,0 +1,55 @@ +package engine + +// Profile is an ordered chain of propagators executed sequentially. Each +// propagator picks up where the previous one finished. +type Profile struct { + // Stages are run in order. For Direction=Reverse they are still iterated + // from index 0 onwards, but each propagator integrates with negative dt. + Stages []*Propagator + + // Direction controls the sign of dt across the whole profile. + Direction Direction + + // Globals are constraints evaluated alongside each stage's local Constraints. + // Useful for profile-wide bounds like "stop after N hours total". + Globals []Constraint +} + +// Run executes the profile from the given launch point. Returns one Result +// per executed stage, including any Fallback chains that were activated. +func (p *Profile) Run(t0 float64, launch State) []Result { + if p.Direction == 0 { + p.Direction = Forward + } + + results := make([]Result, 0, len(p.Stages)) + t, s := t0, launch + + for i := 0; i < len(p.Stages); i++ { + stage := p.Stages[i] + res := stage.run(t, s, p.Direction, p.Globals) + results = append(results, res) + + last := res.Points[len(res.Points)-1] + t = last.Time + s = State{Lat: last.Lat, Lng: last.Lng, Altitude: last.Altitude} + + // Follow Fallback chains until none remains. Each fallback consumes + // from the same point the previous stage stopped at. + for res.Outcome == OutcomeFallback && stage.Fallback != nil { + stage = stage.Fallback + res = stage.run(t, s, p.Direction, p.Globals) + results = append(results, res) + last = res.Points[len(res.Points)-1] + t = last.Time + s = State{Lat: last.Lat, Lng: last.Lng, Altitude: last.Altitude} + } + + // If a propagator's stop fired (not a fallback), end the profile. + if res.Outcome == OutcomeStopped { + continue + } + } + + return results +} diff --git a/internal/engine/propagator.go b/internal/engine/propagator.go new file mode 100644 index 0000000..c653218 --- /dev/null +++ b/internal/engine/propagator.go @@ -0,0 +1,156 @@ +package engine + +import ( + "predictor-refactored/internal/numerics" +) + +// Propagator advances state under one Model, checking a set of Constraints +// after every integration step. +// +// When a constraint fires, the propagator binary-search refines the violation +// point and emits it as its final trajectory point. The Action of the +// triggering constraint controls what the surrounding Profile does next: +// stop the profile, transfer to Fallback, or clip and continue. +type Propagator struct { + // Name identifies the propagator in trajectory metadata. + Name string + + // Step is the magnitude of the integration step in seconds (always positive). + // The Profile flips its sign for Reverse direction. + Step float64 + + // Model produces the per-second time derivative of state. + Model Model + + // Constraints are evaluated after each step. Any fired constraint stops + // the propagator at the refined point; the first one in this slice wins + // on ties. + Constraints []Constraint + + // Fallback is the propagator to switch to when a constraint with + // ActionFallback fires. Optional. + Fallback *Propagator + + // Tolerance is the binary-search refinement tolerance in parameter space + // (default 0.01, matching Tawhiri). + Tolerance float64 +} + +// Outcome describes how a propagator's run ended. +type Outcome int + +const ( + // OutcomeStopped means a Constraint with ActionStop fired and the profile + // should end here. + OutcomeStopped Outcome = iota + // OutcomeFallback means a Constraint with ActionFallback fired and the + // profile should transfer to the propagator's Fallback chain. + OutcomeFallback + // OutcomeContinued means no constraint fired before the time horizon was + // reached. In practice this is only seen when a propagator runs unbounded, + // which means the profile is misconfigured. + OutcomeContinued +) + +// Result is the output of running one propagator. +type Result struct { + Propagator string + Points []TrajectoryPoint + Outcome Outcome + // Constraint is the constraint that fired, or nil if Outcome == OutcomeContinued. + Constraint Constraint +} + +// run integrates the model from (t0, s0) in direction dir, returning a Result. +// globals are constraints injected by the Profile and checked alongside the +// propagator's local Constraints. +func (p *Propagator) run(t0 float64, s0 State, dir Direction, globals []Constraint) Result { + dt := p.Step * float64(dir) + tol := p.Tolerance + if tol == 0 { + tol = 0.01 + } + + deriv := numerics.Deriv[State](func(t float64, s State) State { return p.Model(t, s) }) + add := numerics.VecAdd[State](stateAdd) + lerp := numerics.VecLerp[State](stateLerp) + + out := Result{ + Propagator: p.Name, + Outcome: OutcomeContinued, + Points: []TrajectoryPoint{{ + Time: t0, Lat: s0.Lat, Lng: s0.Lng, Altitude: s0.Altitude, + }}, + } + + t := t0 + s := s0 + + for { + s2 := numerics.RK4Step(t, s, dt, deriv, add) + t2 := t + dt + + if c, fired := firstFiring(p.Constraints, globals, t2, s2); fired { + trig := numerics.Trigger[State](func(tt float64, ss State) bool { return c.Violated(tt, ss) }) + t3, s3 := numerics.RefineTrigger(t, s, t2, s2, trig, lerp, tol) + + switch c.Action() { + case ActionClip: + s3 = clipToConstraint(c, s3) + out.Points = append(out.Points, TrajectoryPoint{ + Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude, + }) + t, s = t3, s3 + continue + case ActionFallback: + out.Points = append(out.Points, TrajectoryPoint{ + Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude, + }) + out.Outcome = OutcomeFallback + out.Constraint = c + return out + default: // ActionStop + out.Points = append(out.Points, TrajectoryPoint{ + Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude, + }) + out.Outcome = OutcomeStopped + out.Constraint = c + return out + } + } + + t, s = t2, s2 + out.Points = append(out.Points, TrajectoryPoint{ + Time: t, Lat: s.Lat, Lng: s.Lng, Altitude: s.Altitude, + }) + } +} + +// firstFiring scans local then global constraints for the first one whose +// Violated returns true at (t, s). +func firstFiring(local, globals []Constraint, t float64, s State) (Constraint, bool) { + for _, c := range local { + if c.Violated(t, s) { + return c, true + } + } + for _, c := range globals { + if c.Violated(t, s) { + return c, true + } + } + return nil, false +} + +// clipToConstraint adjusts s so that the given constraint is exactly satisfied +// (not violated). Implemented for constraints with a well-defined boundary; +// others fall through unchanged. +func clipToConstraint(c Constraint, s State) State { + switch v := c.(type) { + case MaxAltitude: + s.Altitude = v.Limit + case MinAltitude: + s.Altitude = v.Limit + } + return s +} diff --git a/internal/engine/state.go b/internal/engine/state.go new file mode 100644 index 0000000..f989374 --- /dev/null +++ b/internal/engine/state.go @@ -0,0 +1,50 @@ +package engine + +import "math" + +// pymod returns a % b with Python semantics: the result has the sign of b, +// so for b > 0 the result is always in [0, b). +func pymod(a, b float64) float64 { + r := math.Mod(a, b) + if r < 0 { + r += b + } + return r +} + +// stateAdd is the RK4 integrator's update operation y + k*dy, with longitude +// kept wrapped to [0, 360). +// +// Time is not stored in State — it is tracked separately by the integrator +// and passed to Model. +func stateAdd(y State, k float64, dy State) State { + return State{ + Lat: y.Lat + k*dy.Lat, + Lng: pymod(y.Lng+k*dy.Lng, 360), + Altitude: y.Altitude + k*dy.Altitude, + } +} + +// stateLerp computes the linear interpolation of two states by parameter l +// in [0, 1]. Longitude uses lngLerp so that wrap-around is handled. +func stateLerp(a, b State, l float64) State { + return State{ + Lat: (1-l)*a.Lat + l*b.Lat, + Lng: lngLerp(a.Lng, b.Lng, l), + Altitude: (1-l)*a.Altitude + l*b.Altitude, + } +} + +// lngLerp interpolates between two longitudes in [0, 360), choosing the +// shorter great-circle arc. +func lngLerp(a, b, l float64) float64 { + l2 := 1 - l + if a > b { + a, b = b, a + l, l2 = l2, l + } + if b-a < 180 { + return l2*a + l*b + } + return pymod(l2*(a+360)+l*b, 360) +} diff --git a/internal/engine/types.go b/internal/engine/types.go new file mode 100644 index 0000000..59504e4 --- /dev/null +++ b/internal/engine/types.go @@ -0,0 +1,80 @@ +// Package engine is the trajectory calculation engine. It composes +// propagators (model-driven integrators) into profiles (ordered chains) and +// runs them over a wind field. +// +// The engine has no direct dependency on any specific data source: wind data +// is consumed through weather.WindField and terrain data through any type +// satisfying TerrainProvider. +package engine + +// State holds the spatial state of the balloon. When returned by a Model +// the same struct is interpreted as the per-second time derivative of state. +type State struct { + // Lat is degrees latitude in [-90, 90] (or deg/s when returned as a derivative). + Lat float64 + // Lng is degrees longitude in [0, 360) (or deg/s as a derivative). + Lng float64 + // Altitude is metres above mean sea level (or m/s as a derivative). + Altitude float64 +} + +// Model returns the time derivative of state at (t, s). +// +// The derivative is direction-independent; the integrator applies the sign +// of dt for reverse propagation. +type Model func(t float64, s State) State + +// TrajectoryPoint is one sampled point of an integration result. +type TrajectoryPoint struct { + Time float64 // UNIX seconds + Lat float64 + Lng float64 + Altitude float64 +} + +// Direction is the time direction of integration. Forward (+1) integrates +// from launch to landing; Reverse (-1) integrates from a known landing back +// to a candidate launch point. +type Direction int8 + +const ( + Forward Direction = +1 + Reverse Direction = -1 +) + +// Action describes what the profile runner should do when a Constraint +// reports a violation. +type Action int + +const ( + // ActionStop ends the current propagator at the (refined) violation point. + // This matches the only behaviour available in the reference Tawhiri solver. + ActionStop Action = iota + // ActionFallback ends the current propagator and starts its Fallback + // propagator from the violation point. Useful for "if max altitude is + // reached during ascent, switch to descent" profiles. + ActionFallback + // ActionClip clips the violated coordinate to the boundary and continues + // integration. Useful for soft constraints such as "max altitude floor". + ActionClip +) + +// Constraint reports when integration should stop, branch, or clip. +// +// A constraint is direction-agnostic: it reads state and decides. The profile +// runner is responsible for refining the trigger point via binary search and +// dispatching the configured Action. +type Constraint interface { + // Name identifies the constraint in logs and result metadata. + Name() string + // Violated reports whether the constraint is breached at (t, s). + Violated(t float64, s State) bool + // Action is the behaviour to take on violation. + Action() Action +} + +// TerrainProvider returns ground elevation in metres at a coordinate. +// Implementations must be safe for concurrent use. +type TerrainProvider interface { + Elevation(lat, lng float64) float64 +} diff --git a/internal/metrics/prom.go b/internal/metrics/prom.go new file mode 100644 index 0000000..395205e --- /dev/null +++ b/internal/metrics/prom.go @@ -0,0 +1,146 @@ +package metrics + +import ( + "fmt" + "io" + "net/http" + "sort" + "strings" + "sync" + "time" +) + +// Prom is a minimal Sink that exposes counters and gauges in Prometheus's +// text exposition format. No external dependencies. +// +// The Prom sink supports labelled counters, sums (for durations and byte +// counts), and labelled gauges. Histograms are intentionally omitted; if +// they are needed later, swap Prom for an OTel-based sink. +type Prom struct { + mu sync.Mutex + counters map[string]map[string]float64 // name → label-key → value + gauges map[string]map[string]float64 // name → label-key → value +} + +// NewProm returns an empty Prom sink. +func NewProm() *Prom { + return &Prom{ + counters: make(map[string]map[string]float64), + gauges: make(map[string]map[string]float64), + } +} + +// Prediction implements Sink. +func (p *Prom) Prediction(profile string, d time.Duration, err error) { + status := "ok" + if err != nil { + status = "error" + } + labels := map[string]string{"profile": profile, "status": status} + p.incCounter("predictor_predictions_total", labels, 1) + p.incCounter("predictor_prediction_duration_seconds_sum", labels, d.Seconds()) + p.incCounter("predictor_prediction_duration_seconds_count", labels, 1) +} + +// Download implements Sink. +func (p *Prom) Download(source string, d time.Duration, status string, bytes int64) { + labels := map[string]string{"source": source, "status": status} + p.incCounter("predictor_downloads_total", labels, 1) + p.incCounter("predictor_download_duration_seconds_sum", labels, d.Seconds()) + p.incCounter("predictor_download_bytes_total", map[string]string{"source": source}, float64(bytes)) +} + +// ActiveEpoch implements Sink. +func (p *Prom) ActiveEpoch(t time.Time) { + var v float64 + if !t.IsZero() { + v = float64(t.Unix()) + } + p.setGauge("predictor_active_dataset_epoch_seconds", map[string]string{}, v) +} + +// ServeHTTP writes the metrics in Prometheus text exposition format. +func (p *Prom) ServeHTTP(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/plain; version=0.0.4") + p.Write(w) +} + +// Write writes the metrics in Prometheus exposition format to w. +func (p *Prom) Write(w io.Writer) { + p.mu.Lock() + defer p.mu.Unlock() + + names := make([]string, 0, len(p.counters)+len(p.gauges)) + for n := range p.counters { + names = append(names, n) + } + for n := range p.gauges { + names = append(names, n) + } + sort.Strings(names) + + for _, name := range names { + if labels, ok := p.counters[name]; ok { + fmt.Fprintf(w, "# TYPE %s counter\n", name) + writeMetricFamily(w, name, labels) + } + if labels, ok := p.gauges[name]; ok { + fmt.Fprintf(w, "# TYPE %s gauge\n", name) + writeMetricFamily(w, name, labels) + } + } +} + +func writeMetricFamily(w io.Writer, name string, labels map[string]float64) { + keys := make([]string, 0, len(labels)) + for k := range labels { + keys = append(keys, k) + } + sort.Strings(keys) + for _, key := range keys { + fmt.Fprintf(w, "%s%s %g\n", name, key, labels[key]) + } +} + +func (p *Prom) incCounter(name string, labels map[string]string, n float64) { + key := labelKey(labels) + p.mu.Lock() + defer p.mu.Unlock() + if p.counters[name] == nil { + p.counters[name] = make(map[string]float64) + } + p.counters[name][key] += n +} + +func (p *Prom) setGauge(name string, labels map[string]string, v float64) { + key := labelKey(labels) + p.mu.Lock() + defer p.mu.Unlock() + if p.gauges[name] == nil { + p.gauges[name] = make(map[string]float64) + } + p.gauges[name][key] = v +} + +// labelKey renders the labels into a Prometheus-format "{k1="v1",k2="v2"}" +// suffix, empty if no labels. +func labelKey(labels map[string]string) string { + if len(labels) == 0 { + return "" + } + keys := make([]string, 0, len(labels)) + for k := range labels { + keys = append(keys, k) + } + sort.Strings(keys) + var sb strings.Builder + sb.WriteByte('{') + for i, k := range keys { + if i > 0 { + sb.WriteByte(',') + } + fmt.Fprintf(&sb, "%s=%q", k, labels[k]) + } + sb.WriteByte('}') + return sb.String() +} diff --git a/internal/metrics/prom_test.go b/internal/metrics/prom_test.go new file mode 100644 index 0000000..4bcce14 --- /dev/null +++ b/internal/metrics/prom_test.go @@ -0,0 +1,49 @@ +package metrics + +import ( + "bytes" + "strings" + "testing" + "time" +) + +func TestPromCounters(t *testing.T) { + p := NewProm() + p.Prediction("standard_profile", 100*time.Millisecond, nil) + p.Prediction("standard_profile", 200*time.Millisecond, nil) + p.Prediction("float_profile", 50*time.Millisecond, nil) + + var buf bytes.Buffer + p.Write(&buf) + out := buf.String() + + if !strings.Contains(out, `predictor_predictions_total{profile="standard_profile",status="ok"} 2`) { + t.Errorf("expected count=2 for standard_profile, got: %s", out) + } + if !strings.Contains(out, `predictor_predictions_total{profile="float_profile",status="ok"} 1`) { + t.Errorf("expected count=1 for float_profile, got: %s", out) + } + // Sum of durations: 0.1 + 0.2 = 0.3 seconds. + if !strings.Contains(out, "predictor_prediction_duration_seconds_sum") { + t.Errorf("expected sum present, got: %s", out) + } +} + +func TestPromGauge(t *testing.T) { + p := NewProm() + p.ActiveEpoch(time.Unix(1700000000, 0)) + + var buf bytes.Buffer + p.Write(&buf) + out := buf.String() + if !strings.Contains(out, "predictor_active_dataset_epoch_seconds 1.7e+09") { + t.Errorf("expected gauge with epoch 1700000000, got: %s", out) + } +} + +func TestNoop(t *testing.T) { + sink := Noop() + sink.Prediction("any", time.Second, nil) + sink.Download("any", time.Second, "complete", 0) + sink.ActiveEpoch(time.Now()) +} diff --git a/internal/metrics/types.go b/internal/metrics/types.go new file mode 100644 index 0000000..3e0f622 --- /dev/null +++ b/internal/metrics/types.go @@ -0,0 +1,36 @@ +// Package metrics defines the Sink interface used to record service metrics +// and ships two implementations: a Noop sink (default, zero-cost) and a Prom +// sink that exposes counters in the Prometheus text exposition format. +// +// The metrics layer is optional: if no Sink is wired (or Noop is wired), the +// service runs unchanged. +package metrics + +import "time" + +// Sink collects observations from the rest of the service. +// +// Implementations must be safe for concurrent use across many goroutines. +// All methods are advisory; implementations may ignore any observation. +type Sink interface { + // Prediction records the duration and outcome of one prediction. + // err is nil on success; otherwise the error's class is used as a label. + Prediction(profile string, duration time.Duration, err error) + + // Download records the outcome of one dataset download job. + // status is "complete", "failed", or "cancelled". + Download(source string, duration time.Duration, status string, bytes int64) + + // ActiveEpoch reports the forecast time of the currently-loaded dataset. + // Pass time.Time{} when no dataset is loaded. + ActiveEpoch(t time.Time) +} + +// Noop returns a Sink that discards every observation. +func Noop() Sink { return noop{} } + +type noop struct{} + +func (noop) Prediction(string, time.Duration, error) {} +func (noop) Download(string, time.Duration, string, int64) {} +func (noop) ActiveEpoch(time.Time) {} diff --git a/internal/numerics/doc.go b/internal/numerics/doc.go new file mode 100644 index 0000000..807ba3d --- /dev/null +++ b/internal/numerics/doc.go @@ -0,0 +1,11 @@ +// Package numerics provides the numerical primitives used by the trajectory +// engine: regular-grid multilinear interpolation, monotone bisection, and +// a generic explicit Runge-Kutta-4 integrator with binary-search refinement +// of a termination point. +// +// The package has no dependencies on any domain type. State and derivative +// types are generic, and all coordinate-wrap or unit-conversion semantics +// live in the caller. +// +// All algorithms are documented in docs/numerics.tex. +package numerics diff --git a/internal/numerics/grid.go b/internal/numerics/grid.go new file mode 100644 index 0000000..40ea24e --- /dev/null +++ b/internal/numerics/grid.go @@ -0,0 +1,86 @@ +package numerics + +import "fmt" + +// Axis describes a regularly-spaced grid axis with N grid points, +// values left, left+step, left+2*step, ..., left+(N-1)*step. +// +// If Wrap is true, the axis is periodic with period N*step (e.g. longitude). +// A query value at left+N*step wraps to the value at left+0*step. Locate +// returns Hi = 0 in that case. +type Axis struct { + Left float64 + Step float64 + N int + Wrap bool + Name string +} + +// AxisError is returned by Axis.Locate when value lies outside a non-wrapping axis. +type AxisError struct { + Axis string + Value float64 +} + +func (e *AxisError) Error() string { + return fmt.Sprintf("%s=%v out of range", e.Axis, e.Value) +} + +// Bracket holds the two surrounding grid indices and the fractional position +// of a value within an axis. The weight at Lo is (1 - Frac); the weight at Hi +// is Frac. Frac lies in [0, 1). +type Bracket struct { + Lo, Hi int + Frac float64 +} + +// Locate returns the bracket containing value within the axis. +// For a non-wrapping axis, value must lie in [Left, Left + (N-1)*Step); +// for a wrapping axis, value must lie in [Left, Left + N*Step). +func (a Axis) Locate(value float64) (Bracket, error) { + pos := (value - a.Left) / a.Step + lo := int(pos) // truncates toward zero; pos is non-negative for valid inputs + + maxLo := a.N - 2 + if a.Wrap { + maxLo = a.N - 1 + } + if lo < 0 || lo > maxLo { + return Bracket{}, &AxisError{Axis: a.Name, Value: value} + } + + hi := lo + 1 + if a.Wrap && hi == a.N { + hi = 0 + } + return Bracket{Lo: lo, Hi: hi, Frac: pos - float64(lo)}, nil +} + +// EvalTrilinear samples a 3D field via f at the eight corners defined by b3 +// and returns the trilinearly interpolated value. +// +// The corners are visited in the order (axis0 outer, axis2 inner), matching +// the Cython reference. With f(i,j,k) = a*i + b*j + c*k + d this returns +// a*pos0 + b*pos1 + c*pos2 + d exactly, modulo floating-point rounding. +func EvalTrilinear(b3 [3]Bracket, f func(i, j, k int) float64) float64 { + wa0, wa1 := 1-b3[0].Frac, b3[0].Frac + wb0, wb1 := 1-b3[1].Frac, b3[1].Frac + wc0, wc1 := 1-b3[2].Frac, b3[2].Frac + a0, a1 := b3[0].Lo, b3[0].Hi + bb0, bb1 := b3[1].Lo, b3[1].Hi + c0, c1 := b3[2].Lo, b3[2].Hi + + return wa0*wb0*wc0*f(a0, bb0, c0) + + wa0*wb0*wc1*f(a0, bb0, c1) + + wa0*wb1*wc0*f(a0, bb1, c0) + + wa0*wb1*wc1*f(a0, bb1, c1) + + wa1*wb0*wc0*f(a1, bb0, c0) + + wa1*wb0*wc1*f(a1, bb0, c1) + + wa1*wb1*wc0*f(a1, bb1, c0) + + wa1*wb1*wc1*f(a1, bb1, c1) +} + +// Lerp returns (1-l)*a + l*b. +func Lerp(a, b, l float64) float64 { + return (1-l)*a + l*b +} diff --git a/internal/numerics/grid_test.go b/internal/numerics/grid_test.go new file mode 100644 index 0000000..342d39c --- /dev/null +++ b/internal/numerics/grid_test.go @@ -0,0 +1,94 @@ +package numerics + +import ( + "math" + "testing" +) + +func TestAxisLocate(t *testing.T) { + a := Axis{Left: -90, Step: 0.5, N: 361, Name: "lat"} + + b, err := a.Locate(-90) + if err != nil || b.Lo != 0 || b.Hi != 1 || b.Frac != 0 { + t.Errorf("Locate(-90) = %+v, %v; want {0 1 0}, nil", b, err) + } + + b, err = a.Locate(0) + if err != nil || b.Lo != 180 || b.Hi != 181 || b.Frac != 0 { + t.Errorf("Locate(0) = %+v, %v; want {180 181 0}, nil", b, err) + } + + b, err = a.Locate(-89.75) + if err != nil || b.Lo != 0 || b.Hi != 1 || math.Abs(b.Frac-0.5) > 1e-12 { + t.Errorf("Locate(-89.75) = %+v, %v; want frac=0.5", b, err) + } + + // 90 is exactly on the upper boundary — there's no Hi above it + if _, err := a.Locate(90); err == nil { + t.Errorf("Locate(90) should error, got nil") + } + + if _, err := a.Locate(-91); err == nil { + t.Errorf("Locate(-91) should error, got nil") + } +} + +func TestAxisLocateWrap(t *testing.T) { + a := Axis{Left: 0, Step: 0.5, N: 720, Wrap: true, Name: "lng"} + + b, err := a.Locate(0) + if err != nil || b.Lo != 0 || b.Hi != 1 || b.Frac != 0 { + t.Errorf("Locate(0) = %+v, %v", b, err) + } + + // Right up against the wrap boundary + b, err = a.Locate(359.75) + if err != nil || b.Lo != 719 || b.Hi != 0 || math.Abs(b.Frac-0.5) > 1e-12 { + t.Errorf("Locate(359.75) = %+v, %v; want {719 0 0.5}", b, err) + } + + // 360 is outside the half-open interval + if _, err := a.Locate(360); err == nil { + t.Errorf("Locate(360) should error, got nil") + } +} + +func TestEvalTrilinear(t *testing.T) { + // Field f(i,j,k) = 100*i + 10*j + k. + f := func(i, j, k int) float64 { return 100*float64(i) + 10*float64(j) + float64(k) } + + // At all fractions = 0.5, expected value is the mean of the 8 corners. + bs := [3]Bracket{{Lo: 0, Hi: 1, Frac: 0.5}, {Lo: 0, Hi: 1, Frac: 0.5}, {Lo: 0, Hi: 1, Frac: 0.5}} + got := EvalTrilinear(bs, f) + want := (0 + 1 + 10 + 11 + 100 + 101 + 110 + 111) / 8.0 + if math.Abs(got-want) > 1e-12 { + t.Errorf("EvalTrilinear at center = %v, want %v", got, want) + } + + // At all fractions = 0, expected value is f(lo, lo, lo) = 0. + bs = [3]Bracket{{Lo: 0, Hi: 1, Frac: 0}, {Lo: 0, Hi: 1, Frac: 0}, {Lo: 0, Hi: 1, Frac: 0}} + got = EvalTrilinear(bs, f) + if got != 0 { + t.Errorf("EvalTrilinear at (lo,lo,lo) = %v, want 0", got) + } + + // Asymmetric: linear field f(i,j,k) = i should give frac of axis 0 exactly. + f2 := func(i, _, _ int) float64 { return float64(i) } + bs = [3]Bracket{{Lo: 0, Hi: 1, Frac: 0.3}, {Lo: 0, Hi: 1, Frac: 0.7}, {Lo: 0, Hi: 1, Frac: 0.9}} + got = EvalTrilinear(bs, f2) + if math.Abs(got-0.3) > 1e-12 { + t.Errorf("EvalTrilinear of i-field = %v, want 0.3", got) + } +} + +func TestLerp(t *testing.T) { + if Lerp(10, 20, 0) != 10 { + t.Errorf("Lerp(10, 20, 0) != 10") + } + if Lerp(10, 20, 1) != 20 { + t.Errorf("Lerp(10, 20, 1) != 20") + } + if math.Abs(Lerp(10, 20, 0.25)-12.5) > 1e-12 { + t.Errorf("Lerp(10, 20, 0.25) != 12.5") + } +} diff --git a/internal/numerics/ode.go b/internal/numerics/ode.go new file mode 100644 index 0000000..86fdc4c --- /dev/null +++ b/internal/numerics/ode.go @@ -0,0 +1,61 @@ +package numerics + +// VecAdd computes y + k*dy on the domain state type S. +// Any coordinate-wrap or other domain-specific operation lives here. +type VecAdd[S any] func(y S, k float64, dy S) S + +// VecLerp computes (1-l)*a + l*b on the domain state type S. +type VecLerp[S any] func(a, b S, l float64) S + +// Deriv computes the time derivative of state. +type Deriv[S any] func(t float64, y S) S + +// Trigger reports whether a termination condition holds at (t, y). +type Trigger[S any] func(t float64, y S) bool + +// RK4Step performs one classical Runge-Kutta-4 step from (t, y) with step dt. +// dt may be negative to integrate backwards in time. +func RK4Step[S any](t float64, y S, dt float64, deriv Deriv[S], add VecAdd[S]) S { + k1 := deriv(t, y) + k2 := deriv(t+dt/2, add(y, dt/2, k1)) + k3 := deriv(t+dt/2, add(y, dt/2, k2)) + k4 := deriv(t+dt, add(y, dt, k3)) + + y2 := y + y2 = add(y2, dt/6, k1) + y2 = add(y2, dt/3, k2) + y2 = add(y2, dt/3, k3) + y2 = add(y2, dt/6, k4) + return y2 +} + +// RefineTrigger locates the trigger point between (t1, y1) (trigger not fired) +// and (t2, y2) (trigger fired) via binary search in the linear-interpolation +// parameter space, stopping when the parameter interval is narrower than tol. +// +// Returns the final midpoint sampled, matching the behavior of Tawhiri's +// solver.pyx (the returned point is *not* guaranteed to satisfy the trigger; +// for tol << 1 the difference is at most one tolerance-width either side). +func RefineTrigger[S any]( + t1 float64, y1 S, + t2 float64, y2 S, + trigger Trigger[S], + lerp VecLerp[S], + tol float64, +) (float64, S) { + left, right := 0.0, 1.0 + t3 := t2 + y3 := y2 + + for right-left > tol { + mid := (left + right) / 2 + t3 = Lerp(t1, t2, mid) + y3 = lerp(y1, y2, mid) + if trigger(t3, y3) { + right = mid + } else { + left = mid + } + } + return t3, y3 +} diff --git a/internal/numerics/ode_test.go b/internal/numerics/ode_test.go new file mode 100644 index 0000000..8a9dba5 --- /dev/null +++ b/internal/numerics/ode_test.go @@ -0,0 +1,61 @@ +package numerics + +import ( + "math" + "testing" +) + +// scalarAdd / scalarLerp let us drive RK4 on a plain float64. +func scalarAdd(y float64, k float64, dy float64) float64 { return y + k*dy } +func scalarLerpF(a, b float64, l float64) float64 { return Lerp(a, b, l) } + +func TestRK4ExponentialDecay(t *testing.T) { + // dy/dt = -y → exact: y(t) = y0 * exp(-t). + deriv := func(_ float64, y float64) float64 { return -y } + + y := 1.0 + tnow := 0.0 + dt := 0.01 + for range 100 { + y = RK4Step(tnow, y, dt, deriv, scalarAdd) + tnow += dt + } + want := math.Exp(-1.0) + if math.Abs(y-want) > 1e-8 { + t.Errorf("RK4 exp decay at t=1: got %v, want %v (diff %v)", y, want, y-want) + } +} + +func TestRK4ReverseTime(t *testing.T) { + // dy/dt = y → exact: y(t) = y0 * exp(t). + // Integrating from t=1 backwards with dt=-0.01 over 100 steps should give y0. + deriv := func(_ float64, y float64) float64 { return y } + + y := math.E + tnow := 1.0 + dt := -0.01 + for range 100 { + y = RK4Step(tnow, y, dt, deriv, scalarAdd) + tnow += dt + } + if math.Abs(y-1.0) > 1e-8 { + t.Errorf("RK4 reverse: got %v, want 1.0 (diff %v)", y, y-1.0) + } +} + +func TestRefineTrigger(t *testing.T) { + // y crosses 0 at l=0.4 between y1=1 and y2=-1.5. + y1, y2 := 1.0, -1.5 + t1, t2 := 0.0, 1.0 + trig := func(_ float64, y float64) bool { return y <= 0 } + + tr, yr := RefineTrigger(t1, y1, t2, y2, trig, scalarLerpF, 0.001) + + // The exact crossing is at l = 1/(1+1.5) = 0.4 → t = 0.4, y = 0. + if math.Abs(tr-0.4) > 0.01 { + t.Errorf("Refined t = %v, want ~0.4", tr) + } + if math.Abs(yr) > 0.01 { + t.Errorf("Refined y = %v, want ~0", yr) + } +} diff --git a/internal/numerics/search.go b/internal/numerics/search.go new file mode 100644 index 0000000..71836b0 --- /dev/null +++ b/internal/numerics/search.go @@ -0,0 +1,19 @@ +package numerics + +// Bisect returns the largest index i in [imin, imax] such that f(i) < target, +// assuming f is monotonically nondecreasing on that range. +// +// If target <= f(imin), returns imin. If target > f(imax), returns imax. +// Performs O(log(imax-imin)) evaluations of f. +func Bisect(imin, imax int, target float64, f func(i int) float64) int { + lo, hi := imin, imax + for lo < hi { + mid := (lo + hi + 1) / 2 + if target <= f(mid) { + hi = mid - 1 + } else { + lo = mid + } + } + return lo +} diff --git a/internal/numerics/search_test.go b/internal/numerics/search_test.go new file mode 100644 index 0000000..c84f847 --- /dev/null +++ b/internal/numerics/search_test.go @@ -0,0 +1,28 @@ +package numerics + +import "testing" + +func TestBisect(t *testing.T) { + // f(i) = 10*i, monotone increasing. + f := func(i int) float64 { return 10 * float64(i) } + + // target = 25 → largest i with 10i < 25 is i=2 + if got := Bisect(0, 10, 25, f); got != 2 { + t.Errorf("Bisect target=25 = %d, want 2", got) + } + + // target on boundary: target = 30, condition is target <= f(mid) so f(3)=30 → not less; want 2 + if got := Bisect(0, 10, 30, f); got != 2 { + t.Errorf("Bisect target=30 = %d, want 2", got) + } + + // target below all values + if got := Bisect(0, 10, -5, f); got != 0 { + t.Errorf("Bisect target=-5 = %d, want 0", got) + } + + // target above all values + if got := Bisect(0, 10, 1000, f); got != 10 { + t.Errorf("Bisect target=1000 = %d, want 10", got) + } +} diff --git a/internal/prediction/interpolate.go b/internal/prediction/interpolate.go deleted file mode 100644 index 5ef0d14..0000000 --- a/internal/prediction/interpolate.go +++ /dev/null @@ -1,153 +0,0 @@ -package prediction - -import ( - "fmt" - - "predictor-refactored/internal/dataset" -) - -// Exact port of the reference interpolation logic (interpolate.pyx). -// 4D interpolation: time, latitude, longitude, altitude (via geopotential height). - -// lerp1 holds an index and interpolation weight for one axis. -type lerp1 struct { - index int - lerp float64 -} - -// lerp3 holds indices and a combined weight for the (hour, lat, lon) axes. -type lerp3 struct { - hour, lat, lng int - lerp float64 -} - -// RangeError indicates a coordinate is outside the dataset bounds. -type RangeError struct { - Variable string - Value float64 -} - -func (e *RangeError) Error() string { - return fmt.Sprintf("%s=%f out of range", e.Variable, e.Value) -} - -// pick computes interpolation indices and weights for a single axis. -// left: axis start, step: axis spacing, n: number of points, value: query value. -// Returns two lerp1 values (lower and upper bracket). -func pick(left, step float64, n int, value float64, variableName string) ([2]lerp1, error) { - a := (value - left) / step - b := int(a) // truncation toward zero, same as Cython cast - if b < 0 || b >= n-1 { - return [2]lerp1{}, &RangeError{Variable: variableName, Value: value} - } - l := a - float64(b) - return [2]lerp1{ - {index: b, lerp: 1 - l}, - {index: b + 1, lerp: l}, - }, nil -} - -// pick3 computes 8 trilinear interpolation weights for (hour, lat, lng). -func pick3(hour, lat, lng float64) ([8]lerp3, error) { - lhour, err := pick(0, 3, 65, hour, "hour") - if err != nil { - return [8]lerp3{}, err - } - llat, err := pick(-90, 0.5, 361, lat, "lat") - if err != nil { - return [8]lerp3{}, err - } - // Longitude wraps: tell pick the axis is one larger, then wrap index 720 → 0 - llng, err := pick(0, 0.5, 720+1, lng, "lng") - if err != nil { - return [8]lerp3{}, err - } - if llng[1].index == 720 { - llng[1].index = 0 - } - - var out [8]lerp3 - i := 0 - for _, a := range lhour { - for _, b := range llat { - for _, c := range llng { - out[i] = lerp3{ - hour: a.index, - lat: b.index, - lng: c.index, - lerp: a.lerp * b.lerp * c.lerp, - } - i++ - } - } - } - return out, nil -} - -// interp3 performs 8-point weighted interpolation at a given variable and pressure level. -func interp3(ds *dataset.File, lerps [8]lerp3, variable, level int) float64 { - var r float64 - for i := 0; i < 8; i++ { - v := ds.Val(lerps[i].hour, level, variable, lerps[i].lat, lerps[i].lng) - r += float64(v) * lerps[i].lerp - } - return r -} - -// search finds the largest pressure level index where interpolated geopotential -// height is less than the target altitude. Searches levels 0..45 (excludes topmost). -func search(ds *dataset.File, lerps [8]lerp3, target float64) int { - lower, upper := 0, 45 - - for lower < upper { - mid := (lower + upper + 1) / 2 - test := interp3(ds, lerps, dataset.VarHeight, mid) - if target <= test { - upper = mid - 1 - } else { - lower = mid - } - } - - return lower -} - -// interp4 performs altitude-interpolated wind lookup using two bracketing levels. -func interp4(ds *dataset.File, lerps [8]lerp3, altLerp lerp1, variable int) float64 { - lower := interp3(ds, lerps, variable, altLerp.index) - upper := interp3(ds, lerps, variable, altLerp.index+1) - return lower*altLerp.lerp + upper*(1-altLerp.lerp) -} - -// GetWind returns interpolated (u, v) wind components for the given position. -// hour: fractional hours since dataset start. -// lat: latitude in degrees (-90 to +90). -// lng: longitude in degrees (0 to 360). -// alt: altitude in metres above sea level. -func GetWind(ds *dataset.File, warnings *Warnings, hour, lat, lng, alt float64) (u, v float64, err error) { - lerps, err := pick3(hour, lat, lng) - if err != nil { - return 0, 0, err - } - - altidx := search(ds, lerps, alt) - lower := interp3(ds, lerps, dataset.VarHeight, altidx) - upper := interp3(ds, lerps, dataset.VarHeight, altidx+1) - - var altLerp float64 - if lower != upper { - altLerp = (upper - alt) / (upper - lower) - } else { - altLerp = 0.5 - } - - if altLerp < 0 { - warnings.AltitudeTooHigh.Add(1) - } - - alt1 := lerp1{index: altidx, lerp: altLerp} - u = interp4(ds, lerps, alt1, dataset.VarWindU) - v = interp4(ds, lerps, alt1, dataset.VarWindV) - - return u, v, nil -} diff --git a/internal/prediction/models.go b/internal/prediction/models.go deleted file mode 100644 index 8048c46..0000000 --- a/internal/prediction/models.go +++ /dev/null @@ -1,188 +0,0 @@ -package prediction - -import ( - "math" - "time" - - "predictor-refactored/internal/dataset" - "predictor-refactored/internal/elevation" -) - -// Exact port of the reference flight models (models.py). - -const ( - pi180 = math.Pi / 180.0 - _180pi = 180.0 / math.Pi -) - -// --- Up/Down Models --- - -// ConstantAscent returns a model with constant vertical velocity (m/s). -func ConstantAscent(ascentRate float64) Model { - return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) { - return 0, 0, ascentRate - } -} - -// DragDescent returns a descent-under-parachute model. -// seaLevelDescentRate is the descent rate at sea level (m/s, positive value). -// Uses the NASA atmosphere model for density at altitude. -func DragDescent(seaLevelDescentRate float64) Model { - dragCoefficient := seaLevelDescentRate * 1.1045 - - return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) { - return 0, 0, -dragCoefficient / math.Sqrt(nasaDensity(alt)) - } -} - -// nasaDensity computes air density using the NASA atmosphere model. -// Reference: http://www.grc.nasa.gov/WWW/K-12/airplane/atmosmet.html -func nasaDensity(alt float64) float64 { - var temp, pressure float64 - - switch { - case alt > 25000: - temp = -131.21 + 0.00299*alt - pressure = 2.488 * math.Pow((temp+273.1)/216.6, -11.388) - case alt > 11000: - temp = -56.46 - pressure = 22.65 * math.Exp(1.73-0.000157*alt) - default: - temp = 15.04 - 0.00649*alt - pressure = 101.29 * math.Pow((temp+273.1)/288.08, 5.256) - } - - return pressure / (0.2869 * (temp + 273.1)) -} - -// --- Sideways Models --- - -// WindVelocity returns a model that gives lateral movement at the wind velocity. -// ds is the wind dataset, dsEpoch is the dataset start time as UNIX timestamp. -func WindVelocity(ds *dataset.File, dsEpoch float64, warnings *Warnings) Model { - return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) { - tHours := (t - dsEpoch) / 3600.0 - u, v, err := GetWind(ds, warnings, tHours, lat, lng, alt) - if err != nil { - return 0, 0, 0 - } - - R := 6371009.0 + alt - dlat = _180pi * v / R - dlng = _180pi * u / (R * math.Cos(lat*pi180)) - return dlat, dlng, 0 - } -} - -// --- Model Combinations --- - -// LinearModel returns a model that sums all component models. -func LinearModel(models ...Model) Model { - return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) { - for _, m := range models { - d1, d2, d3 := m(t, lat, lng, alt) - dlat += d1 - dlng += d2 - dalt += d3 - } - return - } -} - -// --- Termination Criteria --- - -// BurstTermination returns a terminator that fires when altitude >= burstAltitude. -func BurstTermination(burstAltitude float64) Terminator { - return func(t, lat, lng, alt float64) bool { - return alt >= burstAltitude - } -} - -// SeaLevelTermination fires when altitude <= 0. -func SeaLevelTermination(t, lat, lng, alt float64) bool { - return alt <= 0 -} - -// TimeTermination returns a terminator that fires when t > maxTime. -func TimeTermination(maxTime float64) Terminator { - return func(t, lat, lng, alt float64) bool { - return t > maxTime - } -} - -// ElevationTermination returns a terminator that fires when alt < ground level. -// Uses ruaumoko-compatible elevation data. Longitude is normalised internally. -func ElevationTermination(elev *elevation.Dataset) Terminator { - return func(t, lat, lng, alt float64) bool { - return elev.Get(lat, lng) > alt - } -} - -// --- Pre-Defined Profiles --- - -// Stage pairs a model with its termination criterion. -type Stage struct { - Model Model - Terminator Terminator -} - -// StandardProfile creates the chain for a standard high-altitude balloon flight: -// ascent at constant rate → burst → descent under parachute. -// If elev is non-nil, descent terminates at ground level; otherwise at sea level. -func StandardProfile(ascentRate, burstAltitude, descentRate float64, - ds *dataset.File, dsEpoch float64, warnings *Warnings, - elev *elevation.Dataset) []Stage { - - wind := WindVelocity(ds, dsEpoch, warnings) - - modelUp := LinearModel(ConstantAscent(ascentRate), wind) - termUp := BurstTermination(burstAltitude) - - modelDown := LinearModel(DragDescent(descentRate), wind) - var termDown Terminator - if elev != nil { - termDown = ElevationTermination(elev) - } else { - termDown = Terminator(SeaLevelTermination) - } - - return []Stage{ - {Model: modelUp, Terminator: termUp}, - {Model: modelDown, Terminator: termDown}, - } -} - -// FloatProfile creates the chain for a floating balloon flight: -// ascent to float altitude → float until stop time. -func FloatProfile(ascentRate, floatAltitude float64, stopTime time.Time, - ds *dataset.File, dsEpoch float64, warnings *Warnings) []Stage { - - wind := WindVelocity(ds, dsEpoch, warnings) - - modelUp := LinearModel(ConstantAscent(ascentRate), wind) - termUp := BurstTermination(floatAltitude) - - modelFloat := wind - termFloat := TimeTermination(float64(stopTime.Unix())) - - return []Stage{ - {Model: modelUp, Terminator: termUp}, - {Model: modelFloat, Terminator: termFloat}, - } -} - -// RunPrediction runs a prediction with the given profile stages. -// launchTime is a UNIX timestamp. -func RunPrediction(launchTime float64, lat, lng, alt float64, stages []Stage) []StageResult { - chain := make([]struct { - Model Model - Terminator Terminator - }, len(stages)) - - for i, s := range stages { - chain[i].Model = s.Model - chain[i].Terminator = s.Terminator - } - - return Solve(launchTime, lat, lng, alt, chain) -} diff --git a/internal/prediction/solver.go b/internal/prediction/solver.go deleted file mode 100644 index 62e29a7..0000000 --- a/internal/prediction/solver.go +++ /dev/null @@ -1,180 +0,0 @@ -package prediction - -import "math" - -// Exact port of the reference RK4 solver (solver.pyx). -// Integrates balloon state using RK4 with dt=60 seconds. -// Termination uses binary search refinement (tolerance 0.01). - -// Vec holds the balloon state: latitude, longitude, altitude. -type Vec struct { - Lat float64 - Lng float64 - Alt float64 -} - -// Model is a function that returns (dlat/dt, dlng/dt, dalt/dt) given state. -// t is UNIX timestamp, lat/lng in degrees, alt in metres. -type Model func(t float64, lat, lng, alt float64) (dlat, dlng, dalt float64) - -// Terminator returns true when integration should stop. -type Terminator func(t float64, lat, lng, alt float64) bool - -// StageResult holds the trajectory points for one flight stage. -type StageResult struct { - Points []TrajectoryPoint -} - -// TrajectoryPoint is a single point in a trajectory (used by solver). -type TrajectoryPoint struct { - T float64 // UNIX timestamp - Lat float64 - Lng float64 - Alt float64 -} - -// pymod returns a % b with Python semantics (always non-negative when b > 0). -func pymod(a, b float64) float64 { - r := math.Mod(a, b) - if r < 0 { - r += b - } - return r -} - -// vecadd returns a + k*b, with lng wrapped to [0, 360). -func vecadd(a Vec, k float64, b Vec) Vec { - return Vec{ - Lat: a.Lat + k*b.Lat, - Lng: pymod(a.Lng+k*b.Lng, 360.0), - Alt: a.Alt + k*b.Alt, - } -} - -// scalarLerp returns (1-l)*a + l*b. -func scalarLerp(a, b, l float64) float64 { - return (1-l)*a + l*b -} - -// lngLerp interpolates longitude handling the 0/360 wrap-around. -func lngLerp(a, b, l float64) float64 { - l2 := 1 - l - - if a > b { - a, b = b, a - l, l2 = l2, l - } - - // distance round one way: b - a - // distance around other: (a + 360) - b - if b-a < 180.0 { - return l2*a + l*b - } - return pymod(l2*(a+360)+l*b, 360.0) -} - -// vecLerp returns (1-l)*a + l*b with proper longitude wrapping. -func vecLerp(a, b Vec, l float64) Vec { - return Vec{ - Lat: scalarLerp(a.Lat, b.Lat, l), - Lng: lngLerp(a.Lng, b.Lng, l), - Alt: scalarLerp(a.Alt, b.Alt, l), - } -} - -// rk4 integrates from initial conditions using RK4. -// dt=60.0 seconds, terminationTolerance=0.01. -func rk4(t float64, lat, lng, alt float64, model Model, terminator Terminator) []TrajectoryPoint { - const dt = 60.0 - const terminationTolerance = 0.01 - - y := Vec{Lat: lat, Lng: lng, Alt: alt} - result := []TrajectoryPoint{{T: t, Lat: y.Lat, Lng: y.Lng, Alt: y.Alt}} - - for { - // Evaluate model at 4 points (standard RK4) - k1lat, k1lng, k1alt := model(t, y.Lat, y.Lng, y.Alt) - k1 := Vec{Lat: k1lat, Lng: k1lng, Alt: k1alt} - - mid1 := vecadd(y, dt/2, k1) - k2lat, k2lng, k2alt := model(t+dt/2, mid1.Lat, mid1.Lng, mid1.Alt) - k2 := Vec{Lat: k2lat, Lng: k2lng, Alt: k2alt} - - mid2 := vecadd(y, dt/2, k2) - k3lat, k3lng, k3alt := model(t+dt/2, mid2.Lat, mid2.Lng, mid2.Alt) - k3 := Vec{Lat: k3lat, Lng: k3lng, Alt: k3alt} - - end := vecadd(y, dt, k3) - k4lat, k4lng, k4alt := model(t+dt, end.Lat, end.Lng, end.Alt) - k4 := Vec{Lat: k4lat, Lng: k4lng, Alt: k4alt} - - // y2 = y + dt/6*k1 + dt/3*k2 + dt/3*k3 + dt/6*k4 - y2 := y - y2 = vecadd(y2, dt/6, k1) - y2 = vecadd(y2, dt/3, k2) - y2 = vecadd(y2, dt/3, k3) - y2 = vecadd(y2, dt/6, k4) - - t2 := t + dt - - if terminator(t2, y2.Lat, y2.Lng, y2.Alt) { - // Binary search to refine the termination point. - // Find l in [0, 1] such that (t3, y3) = lerp((t, y), (t2, y2), l) - // is near where the terminator fires. - left := 0.0 - right := 1.0 - - var t3 float64 - var y3 Vec - t3 = t2 - y3 = y2 - - for right-left > terminationTolerance { - mid := (left + right) / 2 - t3 = scalarLerp(t, t2, mid) - y3 = vecLerp(y, y2, mid) - - if terminator(t3, y3.Lat, y3.Lng, y3.Alt) { - right = mid - } else { - left = mid - } - } - - result = append(result, TrajectoryPoint{T: t3, Lat: y3.Lat, Lng: y3.Lng, Alt: y3.Alt}) - break - } - - // Update current state - t = t2 - y = y2 - result = append(result, TrajectoryPoint{T: t, Lat: y.Lat, Lng: y.Lng, Alt: y.Alt}) - } - - return result -} - -// Solve runs through a chain of (model, terminator) stages. -// Returns one StageResult per stage. -func Solve(t, lat, lng, alt float64, chain []struct { - Model Model - Terminator Terminator -}) []StageResult { - var results []StageResult - - for _, stage := range chain { - points := rk4(t, lat, lng, alt, stage.Model, stage.Terminator) - results = append(results, StageResult{Points: points}) - - // Next stage starts where this one ended - if len(points) > 0 { - last := points[len(points)-1] - t = last.T - lat = last.Lat - lng = last.Lng - alt = last.Alt - } - } - - return results -} diff --git a/internal/prediction/warnings.go b/internal/prediction/warnings.go deleted file mode 100644 index 1beeb1a..0000000 --- a/internal/prediction/warnings.go +++ /dev/null @@ -1,21 +0,0 @@ -package prediction - -import "sync/atomic" - -// Warnings tracks warning conditions during a prediction run. -type Warnings struct { - AltitudeTooHigh atomic.Int64 -} - -// ToMap returns warnings as a map suitable for JSON serialization. -// Only includes warnings that have fired. -func (w *Warnings) ToMap() map[string]any { - result := make(map[string]any) - if n := w.AltitudeTooHigh.Load(); n > 0 { - result["altitude_too_high"] = map[string]any{ - "count": n, - "description": "The altitude went too high, above the max forecast wind. Wind data will be unreliable", - } - } - return result -} diff --git a/internal/service/service.go b/internal/service/service.go deleted file mode 100644 index 4ccd1d4..0000000 --- a/internal/service/service.go +++ /dev/null @@ -1,245 +0,0 @@ -package service - -import ( - "context" - "os" - "path/filepath" - "sort" - "strings" - "sync" - "time" - - "predictor-refactored/internal/dataset" - "predictor-refactored/internal/downloader" - "predictor-refactored/internal/elevation" - - "go.uber.org/zap" -) - -// Service orchestrates the dataset lifecycle and provides prediction capabilities. -type Service struct { - mu sync.RWMutex - ds *dataset.File - elev *elevation.Dataset - cfg *downloader.Config - dl *downloader.Downloader - log *zap.Logger - updating sync.Mutex // prevents concurrent downloads -} - -// New creates a new Service. -func New(cfg *downloader.Config, log *zap.Logger) *Service { - return &Service{ - cfg: cfg, - dl: downloader.NewDownloader(cfg, log), - log: log, - } -} - -// LoadElevation loads the ruaumoko-compatible elevation dataset from path. -// If the file doesn't exist, elevation termination is disabled (falls back to sea level). -func (s *Service) LoadElevation(path string) { - ds, err := elevation.Open(path) - if err != nil { - s.log.Warn("elevation dataset not available, using sea-level termination", - zap.String("path", path), zap.Error(err)) - return - } - s.elev = ds - s.log.Info("elevation dataset loaded", zap.String("path", path)) -} - -// Elevation returns the elevation dataset (may be nil). -func (s *Service) Elevation() *elevation.Dataset { - return s.elev -} - -// Ready returns true if the service has a loaded dataset. -func (s *Service) Ready() bool { - s.mu.RLock() - defer s.mu.RUnlock() - return s.ds != nil -} - -// DatasetTime returns the forecast time of the currently loaded dataset. -func (s *Service) DatasetTime() (time.Time, bool) { - s.mu.RLock() - defer s.mu.RUnlock() - if s.ds == nil { - return time.Time{}, false - } - return s.ds.DSTime, true -} - -// Dataset returns the current dataset for reading. -func (s *Service) Dataset() *dataset.File { - s.mu.RLock() - defer s.mu.RUnlock() - return s.ds -} - -// Update checks for and downloads new forecast data if needed. -func (s *Service) Update(ctx context.Context) error { - if !s.updating.TryLock() { - s.log.Info("update already in progress, skipping") - return nil - } - defer s.updating.Unlock() - - // Check if current dataset is still fresh - if dsTime, ok := s.DatasetTime(); ok { - if time.Since(dsTime) < s.cfg.DatasetTTL { - s.log.Info("dataset still fresh, skipping update", - zap.Time("dataset_time", dsTime), - zap.Duration("age", time.Since(dsTime))) - return nil - } - } - - // Try loading an existing dataset from disk first - if err := s.loadExistingDataset(); err == nil { - return nil - } - - // Find latest available model run - run, err := s.dl.FindLatestRun(ctx) - if err != nil { - return err - } - - // Download and assemble - path, err := s.dl.Download(ctx, run) - if err != nil { - return err - } - - // Open the new dataset - ds, err := dataset.Open(path, run) - if err != nil { - return err - } - - // Swap in the new dataset - s.setDataset(ds) - s.log.Info("dataset loaded", zap.Time("run", run), zap.String("path", path)) - - // Clean old datasets - s.cleanOldDatasets(path) - - return nil -} - -// loadExistingDataset tries to find and load an existing dataset from the data directory. -func (s *Service) loadExistingDataset() error { - entries, err := os.ReadDir(s.cfg.DataDir) - if err != nil { - return err - } - - // Collect valid dataset files (name is YYYYMMDDHH, no extension, correct size) - type candidate struct { - name string - path string - run time.Time - } - var candidates []candidate - - for _, e := range entries { - if e.IsDir() || strings.Contains(e.Name(), ".") { - continue - } - if len(e.Name()) != 10 { - continue - } - - run, err := time.Parse("2006010215", e.Name()) - if err != nil { - continue - } - - path := filepath.Join(s.cfg.DataDir, e.Name()) - info, err := os.Stat(path) - if err != nil || info.Size() != dataset.DatasetSize { - continue - } - - if time.Since(run) > s.cfg.DatasetTTL { - continue - } - - candidates = append(candidates, candidate{name: e.Name(), path: path, run: run}) - } - - if len(candidates) == 0 { - return os.ErrNotExist - } - - // Pick the newest - sort.Slice(candidates, func(i, j int) bool { - return candidates[i].run.After(candidates[j].run) - }) - - best := candidates[0] - ds, err := dataset.Open(best.path, best.run) - if err != nil { - return err - } - - s.setDataset(ds) - s.log.Info("loaded existing dataset", - zap.Time("run", best.run), - zap.String("path", best.path)) - return nil -} - -// setDataset swaps the current dataset with a new one, closing the old one. -func (s *Service) setDataset(ds *dataset.File) { - s.mu.Lock() - old := s.ds - s.ds = ds - s.mu.Unlock() - - if old != nil { - if err := old.Close(); err != nil { - s.log.Error("failed to close old dataset", zap.Error(err)) - } - } -} - -// cleanOldDatasets removes dataset files other than the one at keepPath. -func (s *Service) cleanOldDatasets(keepPath string) { - entries, err := os.ReadDir(s.cfg.DataDir) - if err != nil { - return - } - - for _, e := range entries { - if e.IsDir() { - continue - } - path := filepath.Join(s.cfg.DataDir, e.Name()) - if path == keepPath { - continue - } - // Remove old datasets and temp files - if len(e.Name()) == 10 || strings.HasSuffix(e.Name(), ".downloading") { - if err := os.Remove(path); err != nil { - s.log.Warn("failed to remove old file", zap.String("path", path), zap.Error(err)) - } else { - s.log.Info("removed old dataset", zap.String("path", path)) - } - } - } -} - -// Close releases all resources. -func (s *Service) Close() error { - s.mu.Lock() - defer s.mu.Unlock() - if s.ds != nil { - err := s.ds.Close() - s.ds = nil - return err - } - return nil -} diff --git a/internal/transport/middleware/log.go b/internal/transport/middleware/log.go deleted file mode 100644 index fbbbbc1..0000000 --- a/internal/transport/middleware/log.go +++ /dev/null @@ -1,30 +0,0 @@ -package middleware - -import ( - "time" - - "github.com/ogen-go/ogen/middleware" - "go.uber.org/zap" -) - -// Logging returns an ogen middleware that logs request duration. -func Logging(log *zap.Logger) middleware.Middleware { - return func(req middleware.Request, next func(req middleware.Request) (middleware.Response, error)) (middleware.Response, error) { - lg := log.With(zap.String("operation", req.OperationID)) - - start := time.Now() - resp, err := next(req) - dur := time.Since(start) - - if err != nil { - lg.Error("request failed", - zap.Duration("duration", dur), - zap.Error(err)) - } else { - lg.Info("request completed", - zap.Duration("duration", dur)) - } - - return resp, err - } -} diff --git a/internal/transport/rest/handler/deps.go b/internal/transport/rest/handler/deps.go deleted file mode 100644 index f81a3b8..0000000 --- a/internal/transport/rest/handler/deps.go +++ /dev/null @@ -1,16 +0,0 @@ -package handler - -import ( - "time" - - "predictor-refactored/internal/dataset" - "predictor-refactored/internal/elevation" -) - -// Service defines the interface the handler needs from the service layer. -type Service interface { - Ready() bool - DatasetTime() (time.Time, bool) - Dataset() *dataset.File - Elevation() *elevation.Dataset -} diff --git a/internal/transport/rest/handler/handler.go b/internal/transport/rest/handler/handler.go deleted file mode 100644 index fc1f693..0000000 --- a/internal/transport/rest/handler/handler.go +++ /dev/null @@ -1,216 +0,0 @@ -package handler - -import ( - "context" - "net/http" - "time" - - "predictor-refactored/internal/prediction" - api "predictor-refactored/pkg/rest" - - "go.uber.org/zap" -) - -var _ api.Handler = (*Handler)(nil) - -// Handler implements the ogen-generated api.Handler interface. -type Handler struct { - svc Service - log *zap.Logger -} - -// New creates a new Handler. -func New(svc Service, log *zap.Logger) *Handler { - return &Handler{svc: svc, log: log} -} - -// PerformPrediction implements the prediction endpoint. -func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredictionParams) (*api.PredictionResponse, error) { - if !h.svc.Ready() { - return nil, newError(http.StatusServiceUnavailable, "no dataset loaded, service is starting up") - } - - ds := h.svc.Dataset() - if ds == nil { - return nil, newError(http.StatusServiceUnavailable, "dataset unavailable") - } - - dsEpoch := float64(ds.DSTime.Unix()) - - // Parse parameters with defaults - profile := "standard_profile" - if p, ok := params.Profile.Get(); ok { - profile = string(p) - } - - ascentRate := 5.0 - if v, ok := params.AscentRate.Get(); ok { - ascentRate = v - } - - burstAltitude := 28000.0 - if v, ok := params.BurstAltitude.Get(); ok { - burstAltitude = v - } - - descentRate := 5.0 - if v, ok := params.DescentRate.Get(); ok { - descentRate = v - } - - launchAlt := 0.0 - if v, ok := params.LaunchAltitude.Get(); ok { - launchAlt = v - } - - // Normalize longitude to [0, 360) - lng := params.LaunchLongitude - if lng < 0 { - lng += 360.0 - } - - launchTime := float64(params.LaunchDatetime.Unix()) - - warnings := &prediction.Warnings{} - - // Build profile chain - elev := h.svc.Elevation() - var stages []prediction.Stage - switch profile { - case "standard_profile": - stages = prediction.StandardProfile( - ascentRate, burstAltitude, descentRate, - ds, dsEpoch, warnings, elev) - case "float_profile": - floatAlt := 25000.0 - if v, ok := params.FloatAltitude.Get(); ok { - floatAlt = v - } - stopTime := params.LaunchDatetime.Add(24 * time.Hour) - if v, ok := params.StopDatetime.Get(); ok { - stopTime = v - } - stages = prediction.FloatProfile( - ascentRate, floatAlt, stopTime, - ds, dsEpoch, warnings) - default: - return nil, newError(http.StatusBadRequest, "unknown profile: "+profile) - } - - // Run prediction - startTime := time.Now().UTC() - results := prediction.RunPrediction(launchTime, params.LaunchLatitude, lng, launchAlt, stages) - completeTime := time.Now().UTC() - - // Build response - stageNames := []string{"ascent", "descent"} - if profile == "float_profile" { - stageNames = []string{"ascent", "float"} - } - - var predItems []api.PredictionResponsePredictionItem - for i, sr := range results { - stageName := "ascent" - if i < len(stageNames) { - stageName = stageNames[i] - } - - var stageEnum api.PredictionResponsePredictionItemStage - switch stageName { - case "ascent": - stageEnum = api.PredictionResponsePredictionItemStageAscent - case "descent": - stageEnum = api.PredictionResponsePredictionItemStageDescent - case "float": - stageEnum = api.PredictionResponsePredictionItemStageFloat - } - - var traj []api.PredictionResponsePredictionItemTrajectoryItem - for _, pt := range sr.Points { - ptLng := pt.Lng - if ptLng > 180 { - ptLng -= 360 - } - traj = append(traj, api.PredictionResponsePredictionItemTrajectoryItem{ - Datetime: time.Unix(int64(pt.T), 0).UTC(), - Latitude: pt.Lat, - Longitude: ptLng, - Altitude: pt.Alt, - }) - } - - predItems = append(predItems, api.PredictionResponsePredictionItem{ - Stage: stageEnum, - Trajectory: traj, - }) - } - - resp := &api.PredictionResponse{ - Prediction: predItems, - Metadata: api.PredictionResponseMetadata{ - StartDatetime: startTime, - CompleteDatetime: completeTime, - }, - } - - // Echo request - resp.Request = api.NewOptPredictionResponseRequest(api.PredictionResponseRequest{ - Dataset: api.NewOptString(ds.DSTime.Format("2006-01-02T15:04:05Z")), - LaunchLatitude: api.NewOptFloat64(params.LaunchLatitude), - LaunchLongitude: api.NewOptFloat64(params.LaunchLongitude), - LaunchDatetime: api.NewOptString(params.LaunchDatetime.Format(time.RFC3339)), - LaunchAltitude: params.LaunchAltitude, - }) - - // Warnings - warnMap := warnings.ToMap() - if len(warnMap) > 0 { - resp.Warnings = api.NewOptPredictionResponseWarnings(api.PredictionResponseWarnings{}) - } - - h.log.Info("prediction complete", - zap.String("profile", profile), - zap.Int("stages", len(results)), - zap.Duration("elapsed", completeTime.Sub(startTime))) - - return resp, nil -} - -// ReadinessCheck implements the health check endpoint. -func (h *Handler) ReadinessCheck(ctx context.Context) (*api.ReadinessResponse, error) { - resp := &api.ReadinessResponse{} - - if h.svc.Ready() { - resp.Status = api.ReadinessResponseStatusOk - if dsTime, ok := h.svc.DatasetTime(); ok { - resp.DatasetTime = api.NewOptDateTime(dsTime) - } - } else { - resp.Status = api.ReadinessResponseStatusNotReady - resp.ErrorMessage = api.NewOptString("no dataset loaded") - } - - return resp, nil -} - -// NewError creates an ErrorStatusCode from an error returned by a handler. -func (h *Handler) NewError(ctx context.Context, err error) *api.ErrorStatusCode { - if statusErr, ok := err.(*api.ErrorStatusCode); ok { - return statusErr - } - - h.log.Error("unhandled error", zap.Error(err)) - return newError(http.StatusInternalServerError, err.Error()) -} - -func newError(status int, description string) *api.ErrorStatusCode { - return &api.ErrorStatusCode{ - StatusCode: status, - Response: api.Error{ - Error: api.ErrorError{ - Type: http.StatusText(status), - Description: description, - }, - }, - } -} diff --git a/internal/transport/rest/transport.go b/internal/transport/rest/transport.go deleted file mode 100644 index 3744270..0000000 --- a/internal/transport/rest/transport.go +++ /dev/null @@ -1,75 +0,0 @@ -package rest - -import ( - "context" - "fmt" - "net/http" - - "predictor-refactored/internal/transport/middleware" - "predictor-refactored/internal/transport/rest/handler" - api "predictor-refactored/pkg/rest" - - "go.uber.org/zap" -) - -// Transport wraps the ogen HTTP server. -type Transport struct { - srv *api.Server - handler *handler.Handler - port int - log *zap.Logger -} - -// New creates a new REST transport. -func New(h *handler.Handler, port int, log *zap.Logger) (*Transport, error) { - srv, err := api.NewServer( - h, - api.WithMiddleware(middleware.Logging(log)), - ) - if err != nil { - return nil, fmt.Errorf("create ogen server: %w", err) - } - - return &Transport{ - srv: srv, - handler: h, - port: port, - log: log, - }, nil -} - -// Run starts the HTTP server. Blocks until the server stops. -func (t *Transport) Run() error { - mux := http.NewServeMux() - mux.Handle("/", t.srv) - - httpSrv := &http.Server{ - Addr: fmt.Sprintf(":%d", t.port), - Handler: corsMiddleware(mux), - } - - t.log.Info("starting HTTP server", zap.Int("port", t.port)) - return httpSrv.ListenAndServe() -} - -// Shutdown gracefully stops the HTTP server. -func (t *Transport) Shutdown(ctx context.Context) error { - // The ogen server doesn't have a shutdown method; - // shutdown is handled by the http.Server in main.go - return nil -} - -func corsMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type") - - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusNoContent) - return - } - - next.ServeHTTP(w, r) - }) -} diff --git a/internal/weather/gfs/constants.go b/internal/weather/gfs/constants.go new file mode 100644 index 0000000..77ed6ee --- /dev/null +++ b/internal/weather/gfs/constants.go @@ -0,0 +1,141 @@ +package gfs + +import "fmt" + +// Dataset shape: (hour, pressure_level, variable, latitude, longitude). +// Matches the cube layout used by the reference Tawhiri implementation. +const ( + NumHours = 65 // 0, 3, 6, ..., 192 hours forecast + NumLevels = 47 // pressure levels + NumVariables = 3 // geopotential height, U-wind, V-wind + NumLatitudes = 361 // -90.0 to +90.0 inclusive in 0.5° steps + NumLongitudes = 720 // 0.0 to 359.5 in 0.5° steps + + HourStep = 3 + MaxHour = 192 + Resolution = 0.5 + LatStart = -90.0 + LonStart = 0.0 + + VarHeight = 0 + VarWindU = 1 + VarWindV = 2 + + ElementSize = 4 // float32 + + // DatasetSize is the canonical file size: every grid cell × element size. + DatasetSize int64 = int64(NumHours) * int64(NumLevels) * int64(NumVariables) * + int64(NumLatitudes) * int64(NumLongitudes) * int64(ElementSize) +) + +// LevelSet identifies which GRIB file (primary/secondary) carries a level. +type LevelSet int + +const ( + LevelSetA LevelSet = iota // pgrb2 — primary file + LevelSetB // pgrb2b — secondary file +) + +// Pressures lists the 47 pressure levels (hPa) in dataset index order, +// descending from surface to top of atmosphere. +var Pressures = [NumLevels]int{ + 1000, 975, 950, 925, 900, 875, 850, 825, 800, 775, + 750, 725, 700, 675, 650, 625, 600, 575, 550, 525, + 500, 475, 450, 425, 400, 375, 350, 325, 300, 275, + 250, 225, 200, 175, 150, 125, 100, 70, 50, 30, + 20, 10, 7, 5, 3, 2, 1, +} + +// PressuresPgrb2 lists the levels carried by the primary GRIB file. +var PressuresPgrb2 = []int{ + 10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 350, 400, + 450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 925, + 950, 975, 1000, +} + +// PressuresPgrb2b lists the levels carried by the secondary GRIB file. +var PressuresPgrb2b = []int{ + 1, 2, 3, 5, 7, 125, 175, 225, 275, 325, 375, 425, + 475, 525, 575, 625, 675, 725, 775, 825, 875, +} + +var pressureIndex map[int]int +var pressureLevelSet map[int]LevelSet + +func init() { + pressureIndex = make(map[int]int, NumLevels) + for i, p := range Pressures { + pressureIndex[p] = i + } + pressureLevelSet = make(map[int]LevelSet, NumLevels) + for _, p := range PressuresPgrb2 { + pressureLevelSet[p] = LevelSetA + } + for _, p := range PressuresPgrb2b { + pressureLevelSet[p] = LevelSetB + } +} + +// PressureIndex returns the dataset index for a pressure level in hPa, +// or -1 when the level is unknown. +func PressureIndex(hPa int) int { + idx, ok := pressureIndex[hPa] + if !ok { + return -1 + } + return idx +} + +// PressureLevelSet returns the GRIB file set carrying a pressure level. +func PressureLevelSet(hPa int) (LevelSet, bool) { + ls, ok := pressureLevelSet[hPa] + return ls, ok +} + +// HourIndex returns the dataset time index for a forecast hour, or -1 when +// the hour is outside the range or not a multiple of HourStep. +func HourIndex(hour int) int { + if hour < 0 || hour > MaxHour || hour%HourStep != 0 { + return -1 + } + return hour / HourStep +} + +// Hours returns the full list of forecast hours, [0, 3, 6, ..., MaxHour]. +func Hours() []int { + out := make([]int, 0, NumHours) + for h := 0; h <= MaxHour; h += HourStep { + out = append(out, h) + } + return out +} + +// VariableIndex maps a GRIB (category, number) pair to a dataset variable +// index, returning -1 for parameters this dataset does not store. +func VariableIndex(parameterCategory, parameterNumber int) int { + switch { + case parameterCategory == 3 && parameterNumber == 5: + return VarHeight + case parameterCategory == 2 && parameterNumber == 2: + return VarWindU + case parameterCategory == 2 && parameterNumber == 3: + return VarWindV + default: + return -1 + } +} + +// S3 URL configuration for NOAA GFS data on the public S3 mirror. +const S3BaseURL = "https://noaa-gfs-bdp-pds.s3.amazonaws.com" + +// GribURL returns the S3 URL for a primary (pgrb2) GRIB file. +func GribURL(date string, runHour, forecastStep int) string { + return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.0p50.f%03d", + S3BaseURL, date, runHour, runHour, forecastStep) +} + +// GribURLB returns the S3 URL for a secondary (pgrb2b) GRIB file. +func GribURLB(date string, runHour, forecastStep int) string { + return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2b.0p50.f%03d", + S3BaseURL, date, runHour, runHour, forecastStep) +} diff --git a/internal/weather/gfs/file.go b/internal/weather/gfs/file.go new file mode 100644 index 0000000..720107b --- /dev/null +++ b/internal/weather/gfs/file.go @@ -0,0 +1,150 @@ +package gfs + +import ( + "encoding/binary" + "fmt" + "math" + "os" + "time" + + mmap "github.com/edsrzf/mmap-go" +) + +// File is an mmap-backed wind dataset file. The layout is a flat C-order +// row-major array of float32 values, shape (hour, level, variable, lat, lng). +type File struct { + mm mmap.MMap + file *os.File + writable bool + // Epoch is the forecast run time (UTC) the file represents. + Epoch time.Time +} + +// Open opens an existing dataset file for reading. +func Open(path string, epoch time.Time) (*File, error) { + f, err := os.Open(path) + if err != nil { + return nil, fmt.Errorf("open dataset: %w", err) + } + info, err := f.Stat() + if err != nil { + f.Close() + return nil, fmt.Errorf("stat dataset: %w", err) + } + if info.Size() != DatasetSize { + f.Close() + return nil, fmt.Errorf("dataset should be %d bytes (was %d)", DatasetSize, info.Size()) + } + mm, err := mmap.Map(f, mmap.RDONLY, 0) + if err != nil { + f.Close() + return nil, fmt.Errorf("mmap dataset: %w", err) + } + return &File{mm: mm, file: f, writable: false, Epoch: epoch}, nil +} + +// Create creates a new dataset file of the canonical size, mmap'd read-write. +func Create(path string) (*File, error) { + f, err := os.Create(path) + if err != nil { + return nil, fmt.Errorf("create dataset: %w", err) + } + if err := f.Truncate(DatasetSize); err != nil { + f.Close() + return nil, fmt.Errorf("truncate dataset: %w", err) + } + mm, err := mmap.MapRegion(f, int(DatasetSize), mmap.RDWR, 0, 0) + if err != nil { + f.Close() + return nil, fmt.Errorf("mmap dataset: %w", err) + } + return &File{mm: mm, file: f, writable: true}, nil +} + +// OpenWritable opens an existing dataset file for read-write access. +// Used when resuming a partial download. +func OpenWritable(path string) (*File, error) { + f, err := os.OpenFile(path, os.O_RDWR, 0o644) + if err != nil { + return nil, fmt.Errorf("open dataset rw: %w", err) + } + info, err := f.Stat() + if err != nil { + f.Close() + return nil, fmt.Errorf("stat dataset: %w", err) + } + if info.Size() != DatasetSize { + f.Close() + return nil, fmt.Errorf("dataset should be %d bytes (was %d)", DatasetSize, info.Size()) + } + mm, err := mmap.MapRegion(f, int(DatasetSize), mmap.RDWR, 0, 0) + if err != nil { + f.Close() + return nil, fmt.Errorf("mmap dataset: %w", err) + } + return &File{mm: mm, file: f, writable: true}, nil +} + +// offset returns the byte offset of the [hour][level][variable][lat][lng] cell. +func offset(hour, level, variable, lat, lng int) int64 { + idx := int64(hour) + idx = idx*int64(NumLevels) + int64(level) + idx = idx*int64(NumVariables) + int64(variable) + idx = idx*int64(NumLatitudes) + int64(lat) + idx = idx*int64(NumLongitudes) + int64(lng) + return idx * int64(ElementSize) +} + +// Val reads one cell as a float32. +func (d *File) Val(hour, level, variable, lat, lng int) float32 { + off := offset(hour, level, variable, lat, lng) + return math.Float32frombits(binary.LittleEndian.Uint32(d.mm[off : off+4])) +} + +// SetVal writes one cell. Only valid on writable files. +func (d *File) SetVal(hour, level, variable, lat, lng int, val float32) { + off := offset(hour, level, variable, lat, lng) + binary.LittleEndian.PutUint32(d.mm[off:off+4], math.Float32bits(val)) +} + +// BlitGribData copies one decoded GRIB grid into the dataset, flipping the +// latitude axis from GRIB's north-to-south scan order to our south-to-north +// storage order. gribData must be 361*720 = 259920 float64 values. +func (d *File) BlitGribData(hourIdx, levelIdx, varIdx int, gribData []float64) error { + expected := NumLatitudes * NumLongitudes + if len(gribData) != expected { + return fmt.Errorf("grib data has %d values, expected %d", len(gribData), expected) + } + for lat := range NumLatitudes { + for lng := range NumLongitudes { + gribIdx := (360-lat)*NumLongitudes + lng + d.SetVal(hourIdx, levelIdx, varIdx, lat, lng, float32(gribData[gribIdx])) + } + } + return nil +} + +// Flush flushes the mmap to disk. +func (d *File) Flush() error { + if d.mm != nil { + return d.mm.Flush() + } + return nil +} + +// Close unmaps and closes the file. +func (d *File) Close() error { + if d.mm != nil { + if err := d.mm.Unmap(); err != nil { + d.file.Close() + return fmt.Errorf("unmap: %w", err) + } + d.mm = nil + } + if d.file != nil { + err := d.file.Close() + d.file = nil + return err + } + return nil +} diff --git a/internal/weather/gfs/wind.go b/internal/weather/gfs/wind.go new file mode 100644 index 0000000..01329b4 --- /dev/null +++ b/internal/weather/gfs/wind.go @@ -0,0 +1,109 @@ +package gfs + +import ( + "time" + + "predictor-refactored/internal/numerics" + "predictor-refactored/internal/weather" +) + +// Wind is a WindField backed by a GFS dataset file. +type Wind struct { + file *File +} + +// NewWind returns a Wind backed by file. +func NewWind(file *File) *Wind { + return &Wind{file: file} +} + +// Epoch returns the forecast run time of the underlying file. +func (w *Wind) Epoch() time.Time { return w.file.Epoch } + +// Source returns the source identifier "noaa-gfs-0p50". +func (w *Wind) Source() string { return "noaa-gfs-0p50" } + +// Close releases the underlying file's resources. +func (w *Wind) Close() error { return w.file.Close() } + +// Grid axes for the GFS 0.5-degree dataset. +var ( + hourAxis = numerics.Axis{ + Left: 0, + Step: float64(HourStep), + N: NumHours, + Name: "hour", + } + latAxis = numerics.Axis{ + Left: LatStart, + Step: Resolution, + N: NumLatitudes, + Name: "lat", + } + lngAxis = numerics.Axis{ + Left: LonStart, + Step: Resolution, + N: NumLongitudes, + Wrap: true, + Name: "lng", + } +) + +// Wind samples the field at the given UNIX time, geographic coordinate, and +// altitude. Vertical interpolation matches Tawhiri: locate the two pressure +// levels whose interpolated geopotential heights bracket alt, then linearly +// interpolate U and V between them. +func (w *Wind) Wind(t, lat, lng, alt float64) (weather.Sample, error) { + hours := (t - float64(w.file.Epoch.Unix())) / 3600.0 + + bh, err := hourAxis.Locate(hours) + if err != nil { + return weather.Sample{}, err + } + bla, err := latAxis.Locate(lat) + if err != nil { + return weather.Sample{}, err + } + bln, err := lngAxis.Locate(lng) + if err != nil { + return weather.Sample{}, err + } + bs := [3]numerics.Bracket{bh, bla, bln} + + height := func(level int) func(i, j, k int) float64 { + return func(i, j, k int) float64 { + return float64(w.file.Val(i, level, VarHeight, j, k)) + } + } + + levelIdx := numerics.Bisect(0, NumLevels-2, alt, func(level int) float64 { + return numerics.EvalTrilinear(bs, height(level)) + }) + + lowerHGT := numerics.EvalTrilinear(bs, height(levelIdx)) + upperHGT := numerics.EvalTrilinear(bs, height(levelIdx+1)) + + var altFrac float64 + if lowerHGT != upperHGT { + altFrac = (upperHGT - alt) / (upperHGT - lowerHGT) + } else { + altFrac = 0.5 + } + + component := func(level, variable int) float64 { + return numerics.EvalTrilinear(bs, func(i, j, k int) float64 { + return float64(w.file.Val(i, level, variable, j, k)) + }) + } + + lowerU := component(levelIdx, VarWindU) + upperU := component(levelIdx+1, VarWindU) + lowerV := component(levelIdx, VarWindV) + upperV := component(levelIdx+1, VarWindV) + + return weather.Sample{ + U: lowerU*altFrac + upperU*(1-altFrac), + V: lowerV*altFrac + upperV*(1-altFrac), + AboveModel: altFrac < 0, + }, nil +} diff --git a/internal/weather/types.go b/internal/weather/types.go new file mode 100644 index 0000000..68ff4a6 --- /dev/null +++ b/internal/weather/types.go @@ -0,0 +1,37 @@ +// Package weather defines the abstract interface trajectory engines use +// to sample atmospheric data, and contains source-specific implementations +// in its subpackages. +package weather + +import "time" + +// Sample is the result of sampling a wind field at one point. +type Sample struct { + // U is the eastward wind component in m/s. + U float64 + // V is the northward wind component in m/s. + V float64 + // AboveModel is set when the query altitude was above the highest + // pressure level represented in the underlying dataset. The returned + // U/V values are linear extrapolations and should be treated as unreliable. + AboveModel bool +} + +// WindField provides 3D wind data interpolated at arbitrary points. +// +// Implementations must be safe for concurrent use. +type WindField interface { + // Wind samples the field at (t, lat, lng, alt). + // + // t is UNIX seconds. lat is in degrees, -90 to +90. lng is in degrees, + // 0 to 360 (callers must normalize). alt is metres above mean sea level. + // + // Returns an error if any coordinate is outside the field's domain. + Wind(t, lat, lng, alt float64) (Sample, error) + + // Epoch returns the time the field is anchored to (forecast run time). + Epoch() time.Time + + // Source identifies the underlying dataset for logs and metrics. + Source() string +}