From 81b8e763bd8c566a2f9b5a598c9dc0cb8748aa06 Mon Sep 17 00:00:00 2001 From: "a.antonov" Date: Sat, 23 May 2026 00:55:35 +0900 Subject: [PATCH] engine refactor --- README.md | 299 ++++++------- cmd/predictor/main.go | 33 +- docs/numerics.tex | 345 +++++++++++---- internal/api/admin/datasets.go | 210 +++++++--- internal/api/async/handler.go | 63 +++ internal/api/async/manager.go | 276 ++++++++++++ internal/api/httpjson/httpjson.go | 27 ++ internal/api/tawhiri/handler.go | 215 +++++----- internal/api/transport.go | 16 +- internal/api/v2/handler.go | 150 +++---- internal/api/v2/profile.go | 98 +---- internal/api/v2/types.go | 106 ++--- internal/config/config.go | 15 +- internal/datasets/gefs/source.go | 151 +++++++ internal/datasets/gfs/source.go | 438 ++++---------------- internal/datasets/grib/downloader.go | 369 +++++++++++++++++ internal/datasets/{gfs => grib}/idx.go | 6 +- internal/datasets/{gfs => grib}/idx_test.go | 2 +- internal/datasets/manager.go | 301 ++++++++------ internal/datasets/store_local.go | 107 ++--- internal/datasets/store_test.go | 58 +-- internal/datasets/subset.go | 156 +++++++ internal/datasets/types.go | 66 ++- internal/engine/constraints.go | 152 +++++-- internal/engine/engine_test.go | 117 +++++- internal/engine/events.go | 89 ++++ internal/engine/models.go | 77 ++-- internal/engine/operators.go | 69 +++ internal/engine/profile.go | 47 ++- internal/engine/propagator.go | 157 ++++--- internal/engine/registry.go | 287 +++++++++++++ internal/engine/types.go | 141 +++++-- internal/weather/gfs/constants.go | 137 +----- internal/weather/gfs/file.go | 66 +-- internal/weather/gfs/gefs_variants.go | 68 +++ internal/weather/gfs/variant.go | 191 +++++++++ internal/weather/gfs/wind.go | 66 +-- 37 files changed, 3532 insertions(+), 1639 deletions(-) create mode 100644 internal/api/async/handler.go create mode 100644 internal/api/async/manager.go create mode 100644 internal/api/httpjson/httpjson.go create mode 100644 internal/datasets/gefs/source.go create mode 100644 internal/datasets/grib/downloader.go rename internal/datasets/{gfs => grib}/idx.go (91%) rename internal/datasets/{gfs => grib}/idx_test.go (99%) create mode 100644 internal/datasets/subset.go create mode 100644 internal/engine/events.go create mode 100644 internal/engine/operators.go create mode 100644 internal/engine/registry.go create mode 100644 internal/weather/gfs/gefs_variants.go create mode 100644 internal/weather/gfs/variant.go diff --git a/README.md b/README.md index 967fca2..2426f67 100644 --- a/README.md +++ b/README.md @@ -1,27 +1,25 @@ # stratoflights-predictor High-altitude balloon trajectory prediction service. Forecasts ascent, descent, -and float trajectories from NOAA GFS wind data, exposed as a REST API. +and float trajectories from NOAA GFS and GEFS wind data, exposed as a REST API. The trajectory engine is a propagator-and-constraint system: any flight profile can be expressed as a chain of propagators (constant-rate ascent, -parachute descent, piecewise rates, wind drift) with attached constraints -(altitude, time, terrain contact). The legacy Tawhiri request shape is kept -as a compatibility endpoint so existing clients work unchanged. +parachute descent, piecewise rates with absolute / profile-relative / +propagator-relative timing, wind drift) with attached constraints +(scalar comparisons over altitude or time, terrain contact, geographic +polygons). Constraints can stop the profile, hand off to a fallback +propagator, or clip the violated coordinate to the boundary. The legacy +Tawhiri request shape is kept as a compatibility endpoint so existing +clients work unchanged. ## Quick start ```bash -# Build all three binaries (server, CLI, validation tool) -make build +make build # produces bin/{predictor,predictor-cli,compare-tawhiri} +./bin/predictor # downloads ~9 GB of GFS data on first start -# Run the server (first start downloads ~9 GB of GFS data over 30-60 min) -./bin/predictor - -# Check readiness ./bin/predictor-cli ready - -# Run a Tawhiri-style prediction ./bin/predictor-cli predict \ launch_latitude=52.2 launch_longitude=0.1 \ launch_datetime=2026-03-28T12:00:00Z \ @@ -30,16 +28,14 @@ make build ## Configuration -Configuration is layered: built-in defaults, then a YAML file -(`--config path.yml` or `PREDICTOR_CONFIG_FILE=path.yml`), then env vars, -then CLI flags. Flags override env vars override file values override defaults. +Layered configuration: built-in defaults < YAML file < env vars < CLI flags. | Setting | Env var | CLI flag | Default | |---|---|---|---| | HTTP port | `PREDICTOR_PORT` | `-port` | `8080` | | Data directory | `PREDICTOR_DATA_DIR` | `-data-dir` | `/tmp/predictor-data` | | Elevation dataset | `PREDICTOR_ELEVATION_DATASET` | `-elevation` | `/srv/ruaumoko-dataset` | -| Source | `PREDICTOR_SOURCE` | — | `noaa-gfs-0p50` | +| Source variant | `PREDICTOR_SOURCE` | — | `gfs-0p50-3h` | | Download parallelism | `PREDICTOR_DOWNLOAD_PARALLEL` | `-download-parallel` | `8` | | Download bandwidth (bytes/s; 0 = unlimited) | `PREDICTOR_DOWNLOAD_BANDWIDTH` | `-download-bandwidth` | `0` | | Scheduler interval | `PREDICTOR_UPDATE_INTERVAL` | `-update-interval` | `6h` | @@ -48,227 +44,188 @@ then CLI flags. Flags override env vars override file values override defaults. | Metrics HTTP path | `PREDICTOR_METRICS_PATH` | `-metrics-path` | `/metrics` | | Log level | `PREDICTOR_LOG_LEVEL` | `-log-level` | `info` | -A YAML config file mirrors the same structure: +YAML config mirrors the same structure; see `internal/config/config.go`. -```yaml -http: - port: 8080 -data: - dir: /var/lib/predictor - elevation_path: /var/lib/predictor/elevation - source: noaa-gfs-0p50 -download: - parallel: 8 - bandwidth_bytes_per_second: 0 - update_interval: 6h - freshness_ttl: 48h -metrics: - enabled: true - path: /metrics -log: - level: info -``` +Supported source variants: + +| `source` | Resolution | Cadence | Notes | +|---|---|---|---| +| `gfs-0p50-3h` | 0.5° | 3h to 192h | historical Tawhiri default | +| `gfs-0p25-3h` | 0.25° | 3h to 192h | | +| `gfs-0p25-1h` | 0.25° | 1h to 120h | | +| `gefs-0p50-3h` | 0.5° | 3h to 192h | 21-member ensemble; each member is a separate dataset | ## REST API -### Tawhiri-compatible +### Tawhiri-compatible (legacy) `GET /api/v1/prediction` — preserves the exact request and response shape of -the upstream Cambridge University Spaceflight predictor. Query parameters: +the upstream Cambridge University Spaceflight predictor. -| Parameter | Required | Description | -|---|---|---| -| `launch_latitude` | yes | Degrees, -90 to 90 | -| `launch_longitude` | yes | Degrees, -180 to 180 or 0 to 360 | -| `launch_datetime` | yes | RFC 3339 | -| `launch_altitude` | no | Metres ASL (default 0) | -| `profile` | no | `standard_profile` (default) or `float_profile` | -| `ascent_rate` | no | m/s (default 5) | -| `burst_altitude` | no | Metres (default 28000) | -| `descent_rate` | no | m/s (default 5) | -| `float_altitude` | no | Metres (float profile only) | -| `stop_datetime` | no | Float-profile end time | +`GET /ready` — returns `{"status":"ok", "dataset_time":"..."}` once a dataset +is loaded. -`GET /ready` — returns `{"status": "ok", "dataset_time": "..."}` once a -dataset is loaded; `{"status": "not_ready", ...}` before then. +### Profile-driven (synchronous) -### Profile-driven (new primary) - -`POST /api/v2/prediction` — accepts an arbitrary chain of propagators with -optional constraints. Useful when the frontend wants flight profiles the -Tawhiri shape can't express (e.g. piecewise rates, fallback on constraint -violation). +`POST /api/v2/prediction` — execute a profile synchronously and return the +trajectory. Request shape: ```json { - "launch": { - "time": "2026-03-28T12:00:00Z", - "latitude": 52.2, - "longitude": 0.1, - "altitude": 0 - }, + "launch": { "time": "2026-03-28T12:00:00Z", "latitude": 52.2, "longitude": 0.1, "altitude": 0 }, + "direction": "forward", "profile": [ { "name": "ascent", - "model": {"type": "constant_rate", "rate": 5, "include_wind": true}, - "constraints": [{"type": "max_altitude", "limit": 30000}] + "model": { "type": "constant_rate", "rate": 5, "include_wind": true }, + "constraints": [{ "type": "altitude", "op": ">=", "limit": 30000 }] }, { "name": "descent", - "model": {"type": "parachute_descent", "sea_level_rate": 5, "include_wind": true}, - "constraints": [{"type": "terrain_contact"}] + "model": { "type": "parachute_descent", "sea_level_rate": 5, "include_wind": true }, + "constraints": [{ "type": "terrain_contact" }] } - ] + ], + "globals": [{ "type": "time", "op": ">", "limit": 1799999999 }] } ``` Model types: `constant_rate`, `parachute_descent`, `piecewise`, `wind`. -Constraint types: `max_altitude`, `min_altitude`, `max_time`, -`terrain_contact`. Constraint actions: `stop` (default), `fallback`, `clip`. -Set `"direction": "reverse"` to integrate backward from a known landing. +Constraint types: `altitude`, `time`, `terrain_contact`, `polygon`. +Operators: `<`, `<=`, `>`, `>=`, `==`. Actions: `stop` (default), `fallback`, `clip`. +Direction: `forward` (default) or `reverse`. + +Piecewise segments support a `reference` field (`absolute`, `profile_start`, or +`propagator_start`) so a single rate schedule can be reused across profiles +with different launch times. + +The response includes per-stage trajectories, detailed termination info +(violation state + refined state + constraint name), an `events` array of +non-fatal observations (e.g. `above_model` when altitude exceeded the dataset's +highest pressure level), and dataset metadata. + +### Profile-driven (asynchronous) + +`POST /api/v1/predictions` — enqueue a prediction. Returns `202` with a job ID: + +```json +{"id":"842107d9-…","status":"pending","created_at":"…"} +``` + +`GET /api/v1/predictions/{id}` — poll status. When `status == "complete"`, +the response includes a `result` field with the full v2 PredictionResponse. + +`DELETE /api/v1/predictions/{id}` — cancel a queued job. + +A worker pool (`http.async_workers`, default 4) services the queue; completed +results are retained for `http.async_result_ttl` (default 1h). ### Dataset admin ``` -GET /api/v1/admin/datasets list stored epochs -POST /api/v1/admin/datasets {epoch | latest} trigger a download -DELETE /api/v1/admin/datasets/{epoch} delete a stored dataset -GET /api/v1/admin/jobs list every job +GET /api/v1/admin/datasets list stored datasets (epoch, subset, coverage, loaded?) +POST /api/v1/admin/datasets trigger a download +DELETE /api/v1/admin/datasets/{filename} delete by filename (DatasetID.Filename()) +GET /api/v1/admin/jobs list every download job GET /api/v1/admin/jobs/{id} fetch one job -DELETE /api/v1/admin/jobs/{id} cancel a running job +DELETE /api/v1/admin/jobs/{id} cancel a running download +GET /api/v1/admin/status consolidated status (uptime, mem, goroutines, jobs, datasets) ``` -Returns `JobInfo`: +Trigger-download body: ```json -{"id":"…","source":"noaa-gfs-0p50","epoch":"…","status":"running", - "started_at":"…","total_units":130,"done_units":47,"bytes":510000000} +{ + "epoch": "2026-03-28T06:00:00Z", + "subset": { + "region": { "min_lat": -10, "max_lat": 10, "min_lng": 0, "max_lng": 30 }, + "hour_range": { "min_hour": 0, "max_hour": 72 }, + "members": [5] + } +} ``` +`{"latest": true}` is a shortcut that refreshes the latest global dataset +for the configured source. Each `(epoch, subset)` combination is a +separate dataset; the loader auto-selects which loaded dataset covers a +given prediction query. + ### Metrics `GET /metrics` — Prometheus text exposition. Counters: -`predictor_predictions_total{profile,status}`, -`predictor_downloads_total{source,status}`, -`predictor_download_bytes_total{source}`, -and a gauge `predictor_active_dataset_epoch_seconds`. +`predictor_predictions_total{profile,status}`, `predictor_downloads_total`, +`predictor_download_bytes_total`, and a gauge +`predictor_active_dataset_epoch_seconds`. ## Architecture ``` cmd/ - predictor/main.go main server entry point - predictor-cli/main.go HTTP client - compare-tawhiri/main.go end-to-end validation against the public Tawhiri instance + predictor/ main server + predictor-cli/ HTTP client + compare-tawhiri/ end-to-end validation against the public Tawhiri instance internal/ numerics/ pure numerical primitives (interp, bisect, RK4, refinement) - engine/ propagator + constraint system + concrete models - weather/ WindField interface; gfs/ — NOAA GFS file format + impl - datasets/ Source/Storage/Manager + transactional, resumable downloads - gfs/ — NOAA GFS source impl + engine/ propagator + constraint system + concrete models + registry + weather/ WindField interface; gfs/ — variant-parameterized GFS file format + WindField + datasets/ Source / Storage / Manager + transactional, resumable, subsettable downloads + grib/ — shared GRIB downloader skeleton (idx parser, HTTP, parallel blit) + gfs/ — GFS Source (URL templating only) + gefs/ — GEFS Source (URL templating + member resolution) elevation/ ruaumoko-format ground elevation reader config/ layered file+env+CLI config metrics/ Sink interface + Prometheus text impl api/ HTTP transport - tawhiri/ — legacy v1 endpoint via ogen - v2/ — profile-driven endpoint - admin/ — dataset/job admin endpoints + tawhiri/ — legacy v1 endpoint via ogen + v2/ — synchronous profile-driven endpoint + async/ — asynchronous prediction jobs + admin/ — dataset + service-status endpoints + httpjson/ — tiny JSON response helpers middleware/ api/rest/predictor.swagger.yml OpenAPI 3 spec for v1 + /ready pkg/rest/ ogen-generated code (regenerate via `make generate-ogen`) -docs/numerics.tex LaTeX math reference for the numerics package +docs/numerics.tex end-to-end mathematical reference scripts/build_elevation.py ETOPO 2022 → ruaumoko converter ``` +## Subsetting and ensembles + +Each stored dataset is identified by `DatasetID = (epoch, subset)`. A subset +restricts the data fetched by region, forecast-hour range, or ensemble +member. The downloader honours the subset (skipping out-of-range +forecast steps; member-selecting URLs for GEFS), the storage tracks each +subset as a separate file (filename includes a deterministic subset key), +and the Manager exposes coverage so per-query dataset selection picks the +right one. + ## Deployment -### Local single instance - -```bash -./bin/predictor --data-dir /var/lib/predictor -``` - -No external dependencies beyond the NOAA S3 mirror. - -### Docker single container - -```dockerfile -FROM golang:1.25 AS build -WORKDIR /src -COPY . . -RUN go build -o /predictor ./cmd/predictor - -FROM gcr.io/distroless/base -COPY --from=build /predictor /predictor -EXPOSE 8080 -ENTRYPOINT ["/predictor"] -``` - -Mount a volume at `/data` and set `PREDICTOR_DATA_DIR=/data`. - -### Load-balanced cluster - -The server is stateless apart from the on-disk dataset cache and in-memory -job table. For multiple replicas, point all replicas at a shared filesystem -(NFS or similar) for `data_dir`; each replica reads-only its own mmap. Active -download coordination across replicas is not implemented — run downloads on -one node, or accept that two nodes may download the same epoch concurrently -(only one Commit wins via atomic rename). - -## Elevation dataset - -Without elevation data, descent terminates at sea level. With elevation, -descent terminates at ground level, matching upstream Tawhiri. - -```bash -pip install xarray netcdf4 numpy -python3 scripts/build_elevation.py /var/lib/predictor/elevation -``` - -`PREDICTOR_ELEVATION_DATASET=/var/lib/predictor/elevation ./bin/predictor` - -## Numerical methods - -The numerics package (`internal/numerics`) provides: - -- regular-grid multilinear interpolation, -- monotone bisection, -- classical RK4 (forward and reverse time), -- binary-search refinement of a termination point. - -Detailed math reference: `docs/numerics.tex`. The package has no -domain dependencies and is small enough for manual verification (~300 -lines of Go), enabling a future C or Rust port without changes to the -trajectory engine. - -## Wind data - -| Property | Value | -|---|---| -| Source | NOAA GFS, S3 mirror (`noaa-gfs-bdp-pds.s3.amazonaws.com`) | -| Resolution | 0.5° | -| Grid | 361 × 720 (lat × lng) | -| Forecast steps | 65 (every 3 hours, 0–192h) | -| Pressure levels | 47 (1000 → 1 hPa) | -| Variables | Geopotential height, U-wind, V-wind | -| File size | ~8.87 GiB (float32 flat binary, mmap-backed) | -| Update cadence | every 6 hours | - -Downloads use HTTP Range requests against `.idx` index files to fetch only -the needed GRIB messages. Downloads are transactional (temp file, manifest, -atomic rename on commit) and resumable: interrupted downloads pick up where -they left off via the manifest. +Local single instance, Docker container, or load-balanced cluster behind a +shared filesystem for the dataset cache. The async API stores results +in-memory only; for cluster deployments with sticky sessions, ensure +clients poll the same node they submitted to. ## Validation `./bin/compare-tawhiri --server http://localhost:8080` runs an identical -prediction against the local server and against the public SondeHub Tawhiri +prediction against the local server and the public SondeHub Tawhiri instance, reporting the great-circle distance between landing points. +## Numerical methods + +`docs/numerics.tex` is the complete mathematical reference: state vector, +equations of motion (constant rate, parachute drag, piecewise, wind +transport), numerical methods (multilinear interpolation, bisection, +classical RK4, binary-search termination refinement), constraint +geometry (scalar comparisons, point-in-polygon with antimeridian +handling), and design notes on the deferred items (WGS84/ECEF +coordinate system, mass-aware drift, Monte Carlo). + ## References - [Tawhiri](https://github.com/cuspaceflight/tawhiri) — reference Python/Cython predictor - [ruaumoko](https://github.com/cuspaceflight/ruaumoko) — global elevation dataset format - [NOAA GFS](https://www.ncei.noaa.gov/products/weather-climate-models/global-forecast) +- [NOAA GEFS](https://www.ncei.noaa.gov/products/weather-climate-models/global-ensemble-forecast) - [ETOPO 2022](https://www.ncei.noaa.gov/products/etopo-global-relief-model) - [SondeHub Tawhiri API](https://api.v2.sondehub.org/tawhiri) — public Tawhiri instance diff --git a/cmd/predictor/main.go b/cmd/predictor/main.go index 13391ce..548365e 100644 --- a/cmd/predictor/main.go +++ b/cmd/predictor/main.go @@ -19,11 +19,14 @@ import ( "go.uber.org/zap/zapcore" "predictor-refactored/internal/api" + "predictor-refactored/internal/api/async" "predictor-refactored/internal/config" "predictor-refactored/internal/datasets" + "predictor-refactored/internal/datasets/gefs" "predictor-refactored/internal/datasets/gfs" "predictor-refactored/internal/elevation" "predictor-refactored/internal/metrics" + wgfs "predictor-refactored/internal/weather/gfs" ) func main() { @@ -60,13 +63,23 @@ func run(args []string) error { return fmt.Errorf("init store: %w", err) } - // Source is GFS today; the spec leaves room for ECMWF later via the - // same datasets.Source interface. - if cfg.Data.Source != "noaa-gfs-0p50" { - return fmt.Errorf("source %q not supported", cfg.Data.Source) + variant, err := wgfs.VariantByID(cfg.Data.Source) + if err != nil { + return fmt.Errorf("unsupported source %q: %w", cfg.Data.Source, err) + } + var source datasets.Source + switch variant.Family { + case wgfs.FamilyGFS: + s := gfs.NewSource(variant, log) + s.Parallel = cfg.Download.Parallel + source = s + case wgfs.FamilyGEFS: + s := gefs.NewSource(variant, log) + s.Parallel = cfg.Download.Parallel + source = s + default: + return fmt.Errorf("unsupported family for %q", cfg.Data.Source) } - source := gfs.NewSource(log) - source.Parallel = cfg.Download.Parallel var throttle datasets.Throttle if cfg.Download.BandwidthBytesPerSecond > 0 { @@ -128,12 +141,20 @@ func run(args []string) error { scheduler.StartAsync() defer scheduler.Stop() + asyncMgr := async.New(async.Config{ + Workers: cfg.HTTP.AsyncWorkers, + QueueSize: cfg.HTTP.AsyncQueueSize, + ResultTTL: cfg.HTTP.AsyncResultTTL, + }, mgr, elev, sink, log) + defer asyncMgr.Close() + server, err := api.New(cfg.HTTP.Port, api.Deps{ Manager: mgr, Elevation: elev, Metrics: sink, MetricsHandler: metricsHandler, MetricsPath: cfg.Metrics.Path, + AsyncManager: asyncMgr, Log: log, }) if err != nil { diff --git a/docs/numerics.tex b/docs/numerics.tex index 3706f45..16aaacd 100644 --- a/docs/numerics.tex +++ b/docs/numerics.tex @@ -4,19 +4,177 @@ \usepackage{algorithm, algpseudocode} \usepackage{hyperref} -\title{Numerics Library: Mathematical Reference} +\title{stratoflights-predictor: Mathematical Reference} \author{stratoflights-predictor} \date{} \begin{document} \maketitle -This document describes every numerical primitive in the -\verb|internal/numerics| package. Each section pairs the mathematical -definition with a pointer to the Go implementation and at least one -worked example that can be reproduced manually. +\noindent This document is the end-to-end mathematical specification of the +trajectory calculation performed by stratoflights-predictor. It is meant +to be detailed enough to permit hand verification of the numerical +output. Section~\ref{sec:numerics} covers the small numerics library +(\verb|internal/numerics|); the remaining sections describe the engine +(\verb|internal/engine|) and the data plane. -\section{Regular-grid bracketing} +\tableofcontents + +% ========================================================================= +\section{State vector and equations of motion} +\label{sec:state} + +\paragraph{State vector.} The balloon state is the four-tuple +\[ + \mathbf{s}(t) \;=\; \bigl(t,\; \varphi(t),\; \lambda(t),\; h(t)\bigr), +\] +where $t$ is UNIX seconds, $\varphi \in [-90, 90]$ is geographic latitude +in degrees, $\lambda \in [0, 360)$ is geographic longitude in degrees, +and $h$ is altitude above mean sea level in metres. Inside the +implementation, the spatial part $(\varphi, \lambda, h)$ is the +\verb|engine.State| struct; $t$ is tracked separately by the integrator. + +\paragraph{Equations of motion.} The time derivative of state is the +direction-agnostic vector +\[ + \dot{\mathbf{s}}(t) \;=\; \bigl(\dot \varphi,\; \dot \lambda,\; \dot h\bigr) + \;=\; \sum_{m \in \text{Models}(t)} \mathbf{F}_m(t, \mathbf{s}), +\] +i.e.\ the sum of the active stage's models evaluated at the current +state. The supported per-model contributions are: + +\paragraph{Constant rate.} A purely vertical model with no horizontal +component, used for the standard balloon ascent profile: +\[ + \mathbf{F}_{\text{const}}(t, \mathbf{s}) = (0, 0, r), \qquad r \in \mathbb{R}. +\] +Positive $r$ is upward; a negative $r$ may be combined with reverse-time +integration to model an ascent backwards from a known apex. + +\paragraph{Parachute descent.} The vertical contribution under a constant +drag coefficient with the NASA atmosphere model density $\rho(h)$: +\[ + \mathbf{F}_{\text{drag}}(t, \mathbf{s}) = \Bigl(0, 0, -\frac{k}{\sqrt{\rho(h)}}\Bigr), + \qquad k = r_0 \cdot 1{.}1045, +\] +where $r_0$ is the descent rate at sea level. Density is computed +piecewise from the layered model described in +\href{https://www.grc.nasa.gov/WWW/K-12/airplane/atmosmet.html}{NASA's +atmosphere page}. + +\paragraph{Piecewise rate.} A schedule $\{(\tau_i, r_i)\}_{i=1}^N$ +parameterised in either absolute UNIX time, or seconds since +profile start, or seconds since the propagator's own start. Resolution +happens lazily through the \verb|Propagator.BuildModel| hook so the same +spec can be reused across profiles with different launch times. +The contribution at time $t$ is +\[ + \mathbf{F}_{\text{pwc}}(t, \mathbf{s}) = (0, 0, r_{i^\star}), + \qquad i^\star = \min\{i : \tau_i > t\}. +\] + +\paragraph{Wind transport.} The horizontal contribution from sampling the +loaded wind field $W$: +\[ + \mathbf{F}_{\text{wind}}(t, \mathbf{s}) = \Bigl( + \frac{180}{\pi}\,\frac{v}{R + h},\;\; + \frac{180}{\pi}\,\frac{u}{(R + h)\cos\bigl(\varphi\,\pi/180\bigr)},\;\; + 0 + \Bigr), +\] +where $(u, v) = W(t, \varphi, \lambda, h)$ are the eastward and northward +wind components in metres per second, and $R = 6{,}371{,}009$~m is the +spherical Earth radius. The implementation lives in +\verb|engine.WindTransport| (\verb|engine/models.go|). + +\paragraph{Coordinate system.} The model is a spherical Earth in +plate-carrée (latitude/longitude/altitude) coordinates. This matches the +reference Tawhiri predictor exactly and is necessary for bit-identical +back-to-back testing. A WGS84/ECEF variant is planned but deferred: it +would require converting U/V wind components from the GFS sphere model +to the ellipsoid, which is not a trivial coordinate transform. + +% ========================================================================= +\section{Profiles and propagators} +\label{sec:profile} + +\paragraph{Propagator.} A propagator owns one Model and a list of +Constraints; it produces a sequence of trajectory points via classical +Runge--Kutta--4 integration with step $\Delta t$ (positive for forward, +negative for reverse propagation): +\[ + \Pi : (t_0, \mathbf{s}_0) \;\longmapsto\; \bigl[(t_k, \mathbf{s}_k)\bigr]_{k=0}^{K}. +\] +The sequence ends at the first $k$ where any constraint is violated; +the violation point is refined by binary search (see +\S\ref{sec:numerics}). + +\paragraph{Profile.} A profile is an ordered chain of propagators +$[\Pi_1, \Pi_2, \ldots, \Pi_N]$. Stage $i$ starts where stage $i-1$ +ended; the time direction (sign of $\Delta t$) is shared. + +\paragraph{Constraint actions.} When a constraint $c$ is violated at the +refined point $(t^\star, \mathbf{s}^\star)$, $c.\text{Action}$ controls +the dispatch: +\begin{itemize} + \item \texttt{stop} — the profile ends at $(t^\star, \mathbf{s}^\star)$. + \item \texttt{fallback} — the current propagator hands off to its + \texttt{Fallback} propagator (chains supported). + \item \texttt{clip} — the violated coordinate is clipped to the + constraint's boundary and integration continues. Useful for soft + constraints such as ``hold altitude above 500~m''. +\end{itemize} +Constraints fire on full RK4 steps only, never on intermediate +sub-evaluations. This matches the reference Tawhiri behaviour +bit-for-bit. + +\paragraph{Reverse propagation.} A profile with \verb|Direction = Reverse| +runs every propagator with $\Delta t = -\Delta t$. Models are +direction-agnostic: their derivative formulas above hold unchanged. The +typical use is to start from a known landing point and recover the +launch position by integrating backwards in time. + +% ========================================================================= +\section{Constraint geometry} +\label{sec:constraints} + +The engine ships four constraint primitives. + +\paragraph{Scalar comparison: altitude.} +\( + c_{\text{alt}}(t, \mathbf{s}) \;=\; h \,\mathrel{\bigotimes}\, h_0, +\) +where $\bigotimes \in \{<, \le, >, \ge, =\}$ is the configured operator +and $h_0$ is the limit (metres). The implementation is +\verb|engine.Altitude| (\verb|engine/constraints.go|). + +\paragraph{Scalar comparison: time.} Same shape as altitude, but acting +on $t$ in UNIX seconds. Implementation: \verb|engine.Time|. + +\paragraph{Terrain contact.} +\( + c_{\text{terr}}(t, \mathbf{s}) = \bigl(z(\varphi, \lambda) > h\bigr), +\) +with $z$ provided by the ruaumoko-compatible elevation dataset. + +\paragraph{Polygon.} For a polygon $P$ with vertices +$(\varphi_i, \lambda_i)_{i=1}^N$ and mode $\mu \in +\{\text{inside}, \text{outside}\}$, the constraint is +$c_{\text{poly}}(\mathbf{s}) = \bigl(\mathbf{s} \in P\bigr) \oplus +[\mu = \text{outside}]$. Containment is tested by ray casting in +plate-carrée after normalising every longitude to within 180\textdegree{} +of the first vertex; this handles antimeridian-crossing edges so long +as the polygon spans no more than 180\textdegree{} in longitude. + +% ========================================================================= +\section{Numerics library} +\label{sec:numerics} + +The numerics package (\verb|internal/numerics|) provides four primitives: +regular-grid bracketing, multilinear interpolation, monotone bisection, +and classical RK4 with termination-point refinement. + +\subsection{Regular-grid bracketing} \paragraph{Definition.} An \emph{axis} is the regularly-spaced sequence $x_i = \ell + i \cdot s$ for $i = 0, 1, \ldots, N - 1$, parameterised by @@ -27,64 +185,49 @@ $x_{i_0} \le v < x_{i_1}$ and the dimensionless position \[ f = \frac{v - x_{i_0}}{s} \in [0, 1). \] -Implemented as \verb|Axis.Locate| (\verb|internal/numerics/grid.go|). +Implemented as \verb|Axis.Locate| in \verb|internal/numerics/grid.go|. \paragraph{Wrapping axes.} For periodic axes (e.g.\ longitude), the sequence is extended by the convention $x_N = x_0$ so a value approaching $x_N$ from below brackets $(N{-}1, 0)$ with fraction $f = (v - x_{N-1})/s$. -\paragraph{Domain.} The bracket is undefined when $v$ falls outside the -half-open interval $[\ell, \ell + (N{-}1)\,s)$ (for non-wrapping axes) or -$[\ell, \ell + N\,s)$ (for wrapping axes); the implementation returns -an \verb|AxisError| in those cases. +\paragraph{Worked example.} Latitude axis with $\ell = -90$, $s = 0{.}5$, +$N = 361$. Query $v = -89{.}75$ yields $p = 0{.}5$, so $i_0 = 0$, +$i_1 = 1$, $f = 0{.}5$. -\paragraph{Worked example.} Latitude axis $\ell = -90$, $s = 0{.}5$, -$N = 361$. Query $v = -89{.}75$ yields -$p = (-89{.}75 - (-90))/0{.}5 = 0{.}5$, so $i_0 = 0$, $i_1 = 1$, $f = 0{.}5$. - -\section{Multilinear interpolation} +\subsection{Multilinear interpolation} \paragraph{Definition.} For a scalar field $u$ defined at the grid nodes -of three axes, the trilinear interpolant at brackets $b_a, b_b, b_c$ is +of three axes, the trilinear interpolant at brackets +$b_a, b_b, b_c$ is \[ \tilde u = \sum_{i, j, k \in \{0, 1\}} w_{a,i} \, w_{b,j} \, w_{c,k} \; u\bigl(b_a^i, b_b^j, b_c^k\bigr), \] where $w_{\bullet, 0} = 1 - f_\bullet$ and $w_{\bullet, 1} = f_\bullet$. -Implemented as \verb|EvalTrilinear|. +The corner terms are accumulated in the order +$(0,0,0), (0,0,1), \ldots, (1,1,1)$, matching the reference Cython +implementation so that double-precision results agree byte for byte. \paragraph{Linear exactness.} For any affine field $u(i, j, k) = \alpha i + \beta j + \gamma k + \delta$, the formula returns -$\alpha \cdot p_a + \beta \cdot p_b + \gamma \cdot p_c + \delta$ exactly -(modulo floating-point rounding), where $p_\bullet = b_\bullet^0 + f_\bullet$. +$\alpha p_a + \beta p_b + \gamma p_c + \delta$ exactly (modulo +floating-point rounding), where $p_\bullet = b_\bullet^0 + f_\bullet$. -\paragraph{Evaluation order.} The eight corner terms are accumulated in -the order $(0,0,0), (0,0,1), \ldots, (1,1,1)$, matching the reference -Tawhiri implementation \emph{exactly} so that double-precision results -agree bit-for-bit. +\subsection{Monotone bisection} -\section{Monotone bisection} +For an integer-indexed monotone non-decreasing sequence +$f : \{i_{\min}, \ldots, i_{\max}\} \to \mathbb{R}$ and a target $t$, +\verb|Bisect| returns the largest index $i^\star$ with $f(i^\star) < t$. +Used by the wind sampler to locate the pressure level bracketing the +query altitude. Time complexity: +$\mathcal{O}(\log(i_{\max} - i_{\min}))$. -\paragraph{Definition.} For an integer-indexed monotone non-decreasing -sequence $f : \{i_{\min}, \ldots, i_{\max}\} \to \mathbb{R}$ and a target -$t$, $\mathrm{Bisect}$ returns the largest index $i^\star$ with -$f(i^\star) < t$. The implementation evaluates $f$ on a midpoint -$m = \lceil(i_{\min} + i_{\max})/2\rceil$ each iteration and halves the -interval, taking $\mathcal{O}(\log(i_{\max} - i_{\min}))$ evaluations. - -\paragraph{Boundary behaviour.} If $t \le f(i_{\min})$, the function -returns $i_{\min}$; if $t > f(i_{\max})$, it returns $i_{\max}$. - -\paragraph{Usage in this codebase.} The pressure-level search in the GFS -wind field locates the largest level whose interpolated geopotential -height is below the query altitude; vertical interpolation then runs -between that level and its successor. - -\section{Classical Runge--Kutta--4 integrator} +\subsection{Classical RK4} \paragraph{Definition.} For a state $y$, derivative $\dot y = f(t, y)$, -and step $\Delta t$, \verb|RK4Step| applies the classical RK4 update +and step $\Delta t$, \verb|RK4Step| applies \[ \begin{aligned} k_1 &= f(t, y), \\ @@ -94,24 +237,17 @@ and step $\Delta t$, \verb|RK4Step| applies the classical RK4 update y(t + \Delta t) &= y + \tfrac{\Delta t}{6}\bigl(k_1 + 2 k_2 + 2 k_3 + k_4\bigr). \end{aligned} \] +Reverse-time integration uses $\Delta t < 0$ unchanged; the implementation +contains no branch on the sign of $\Delta t$. Domain-specific vector +arithmetic (longitude wrap) is injected via \verb|VecAdd|. -\paragraph{Reverse-time integration.} Passing $\Delta t < 0$ integrates -backwards in time. The derivative $f$ is treated as direction-independent: -all sign accounting lives in the integrator. The implementation contains -no explicit branch on the sign of $\Delta t$. - -\paragraph{Vector state.} \verb|RK4Step| is generic on the state type. -Domain-specific vector arithmetic (in particular longitude wrap on the -$0\!:\!360$ degree circle) is injected via the \verb|VecAdd| operation -$\mathrm{add}(y, k, \delta y) = y + k \cdot \delta y$. - -\section{Termination-point refinement} +\subsection{Termination refinement} After each integration step the propagator checks one or more constraints. When a constraint reports a violation between $(t_1, y_1)$ (not violated) and $(t_2, y_2)$ (violated), \verb|RefineTrigger| locates the crossing within tolerance $\tau \in (0, 1)$ by binary -search in the linear interpolation parameter $\lambda$: +search in the linear-interpolation parameter $\lambda \in [0, 1]$: \begin{algorithm}[H] \caption{RefineTrigger}\label{alg:refine} @@ -132,27 +268,96 @@ search in the linear interpolation parameter $\lambda$: \end{algorithmic} \end{algorithm} -\paragraph{Termination guarantee.} After $\lceil \log_2 \tau^{-1} \rceil$ -iterations, $R - L \le \tau$. With $\tau = 0{.}01$ and $\Delta t = 60$~s, -the returned point is within $0{.}6$~s of the true crossing in parameter -space; the corresponding altitude error is bounded by $0{.}6\,|\dot y|$, -which for typical balloon ascent and parachute descent rates is at most -$\sim 3$~m. +After $\lceil \log_2 \tau^{-1} \rceil$ iterations, $R - L \le \tau$. +With $\tau = 0{.}01$ and $\Delta t = 60$~s, the returned point is within +$0{.}6$~s of the true crossing in parameter space; the corresponding +altitude error is bounded by $0{.}6\,|\dot y|$, which for typical +balloon ascent and parachute descent rates is at most $\sim 3$~m. -\paragraph{Quirk.} The returned $(t_3, y_3)$ is the \emph{last midpoint -sampled} rather than guaranteed to lie on the triggered side; this -matches the reference Tawhiri implementation byte-for-byte. +The returned point is the \emph{last midpoint sampled} rather than +guaranteed to lie on the triggered side; this matches the reference +Tawhiri implementation byte for byte. -\paragraph{Vector lerp.} As with \verb|RK4Step|, the per-coordinate -linear interpolation is delegated to the caller's \verb|VecLerp| to keep -the integrator agnostic of state semantics. The engine package's -\verb|lerpState| applies the shorter-arc convention for longitudes -crossing the $0\!:\!360$ boundary. +% ========================================================================= +\section{Wind data pipeline} +\label{sec:winddata} +\paragraph{Data source.} NOAA GFS 0.5\textdegree{} (default) or 0.25\textdegree{} +forecasts, optionally subset by region or hour range. GEFS ensemble +runs are supported by selecting one of the 21 members; each member is a +separate dataset (\verb|DatasetID.Subset.Members = \{m\}|). + +\paragraph{Cube layout.} A flat C-order row-major float32 array, shape +$(N_{\text{hours}}, N_{\text{levels}}, 3, N_{\text{lat}}, N_{\text{lng}})$, +where the variable axis is fixed to (HGT, UGRD, VGRD). Per-variant sizes +live in \verb|internal/weather/gfs/variant.go|. + +\paragraph{Sampling.} Given a query $(t, \varphi, \lambda, h)$, the +sampler computes the time-in-hours offset +$\tau = (t - t_0)/3600$ from the dataset epoch $t_0$, brackets +$(\tau, \varphi, \lambda)$ on the three horizontal axes, then bisects +the pressure-level axis to find the largest level $\ell$ whose +trilinearly-interpolated HGT is below $h$. Wind components are +extracted via two more trilinear evaluations (at levels $\ell$ and +$\ell + 1$) and linearly interpolated in altitude: +\[ + W(t, \varphi, \lambda, h) = \alpha \cdot W_\ell + (1 - \alpha) \cdot W_{\ell+1}, + \qquad + \alpha = \frac{H_{\ell+1} - h}{H_{\ell+1} - H_\ell}. +\] + +% ========================================================================= +\section{Coverage and dataset selection} +\label{sec:coverage} + +A loaded dataset $\mathcal{D}$ exposes its \emph{coverage} +$C_\mathcal{D} = (R_\mathcal{D}, [t_0, t_1])$ where $R_\mathcal{D}$ is a +geographic bounding box (possibly antimeridian-spanning) and +$[t_0, t_1]$ is the temporal extent. When more than one dataset is +loaded simultaneously, the predictor selects the first one whose +$C_\mathcal{D}$ contains the launch query. Regional / sub-range +datasets thus complement the global default. + +% ========================================================================= +\section{Deferrals and design notes} +\label{sec:deferrals} + +\paragraph{Mass-aware drift.} The current model assumes the payload moves +horizontally at exactly the local wind velocity. A heavier payload +exhibits a velocity defect proportional to inertial coupling. A +plausible extension is the Stokes-style first-order lag model +\[ + \dot{\mathbf{v}}_p = \frac{1}{\tau}\bigl(\mathbf{v}_{\text{wind}}(t,\mathbf{s}) - \mathbf{v}_p\bigr), +\] +introduced as an additional state variable $\mathbf{v}_p$ alongside the +existing $\mathbf{s}$. The Propagator interface already accepts +arbitrary State types via generics in numerics; the engine could lift +its State to $(\mathbf{s}, \mathbf{v}_p)$ for a future mass-aware +propagator without breaking the existing models. + +\paragraph{Coordinate system upgrades.} Migrating to WGS84/ECEF would +remove the cosine factor in the horizontal wind transport equation and +make distances metric directly. GFS itself uses a spherical Earth; the +wind components are not directly portable. A clean implementation +provides a coordinate-system parameter on the profile request; for now, +the spherical model is used uniformly so that outputs remain bit +identical to the upstream Tawhiri. + +\paragraph{Monte Carlo.} GEFS already provides 21 ensemble members per +epoch. A Monte Carlo prediction would sample $K$ trajectories per +request, each picking a (member, parameter perturbation) pair. The +recommended architecture is to keep the perturbation inside the +predictor (so the same wind sample can serve many members and any +piecewise rate noise is correlated with the wind step), exposed as a +\verb|POST /api/v1/montecarlo| endpoint that returns one job per +sample and aggregates outcomes. + +% ========================================================================= \section{Implementation notes} +\label{sec:impl} -The library is intentionally small (under 300 lines of Go) and uses no -runtime allocations on the hot path. The type-generic \verb|RK4Step| and +The numerics library is intentionally small (under 300 lines of Go) and +uses no allocations on the hot path. The generic \verb|RK4Step| and \verb|RefineTrigger| compile to per-type specialisations under Go's generics, so a future C or Rust port can mirror the implementation verbatim without changing the call sites in the trajectory engine. diff --git a/internal/api/admin/datasets.go b/internal/api/admin/datasets.go index 7d2c624..69c97fc 100644 --- a/internal/api/admin/datasets.go +++ b/internal/api/admin/datasets.go @@ -3,29 +3,33 @@ // // Endpoints: // -// GET /api/v1/admin/datasets list stored epochs -// POST /api/v1/admin/datasets trigger a download -// DELETE /api/v1/admin/datasets/{epoch} delete a stored epoch -// GET /api/v1/admin/jobs list all jobs -// GET /api/v1/admin/jobs/{id} fetch one job -// DELETE /api/v1/admin/jobs/{id} cancel a running job +// GET /api/v1/admin/datasets list stored datasets +// POST /api/v1/admin/datasets trigger a download +// DELETE /api/v1/admin/datasets/{name} delete a stored dataset by filename +// GET /api/v1/admin/jobs list all jobs +// GET /api/v1/admin/jobs/{id} fetch one job +// DELETE /api/v1/admin/jobs/{id} cancel a running job +// GET /api/v1/admin/status service status summary package admin import ( "context" "encoding/json" "net/http" + "runtime" "time" "go.uber.org/zap" + "predictor-refactored/internal/api/httpjson" "predictor-refactored/internal/datasets" ) // Handler serves all /api/v1/admin/* endpoints. type Handler struct { - mgr *datasets.Manager - log *zap.Logger + mgr *datasets.Manager + start time.Time + log *zap.Logger } // New wires an admin handler. @@ -33,52 +37,94 @@ func New(mgr *datasets.Manager, log *zap.Logger) *Handler { if log == nil { log = zap.NewNop() } - return &Handler{mgr: mgr, log: log} + return &Handler{mgr: mgr, start: time.Now().UTC(), log: log} } -// Register installs admin routes on mux. Routes are mounted under -// /api/v1/admin/... +// Register installs admin routes on mux. func (h *Handler) Register(mux *http.ServeMux) { mux.HandleFunc("GET /api/v1/admin/datasets", h.listDatasets) mux.HandleFunc("POST /api/v1/admin/datasets", h.triggerDownload) - mux.HandleFunc("DELETE /api/v1/admin/datasets/{epoch}", h.deleteDataset) + mux.HandleFunc("DELETE /api/v1/admin/datasets/{name}", h.deleteDataset) mux.HandleFunc("GET /api/v1/admin/jobs", h.listJobs) mux.HandleFunc("GET /api/v1/admin/jobs/{id}", h.getJob) mux.HandleFunc("DELETE /api/v1/admin/jobs/{id}", h.cancelJob) + mux.HandleFunc("GET /api/v1/admin/status", h.status) +} + +// datasetDTO is the JSON shape of one stored dataset. +type datasetDTO struct { + Filename string `json:"filename"` + Epoch string `json:"epoch"` + Subset *subsetDTO `json:"subset,omitempty"` + Coverage *coverageDTO `json:"coverage,omitempty"` + Loaded bool `json:"loaded"` +} + +type subsetDTO struct { + Region *datasets.Region `json:"region,omitempty"` + HourRange *datasets.HourRange `json:"hour_range,omitempty"` + Members []int `json:"members,omitempty"` +} + +type coverageDTO struct { + Region datasets.Region `json:"region"` + StartTime string `json:"start_time"` + EndTime string `json:"end_time"` } // listDatasets handles GET /api/v1/admin/datasets. func (h *Handler) listDatasets(w http.ResponseWriter, _ *http.Request) { - epochs, err := h.mgr.ListEpochs() + stored, err := h.mgr.ListEpochs() if err != nil { writeError(w, http.StatusInternalServerError, err.Error()) return } - active := "" - if a := h.mgr.Active(); a != nil { - active = a.Epoch().UTC().Format(time.RFC3339) + loaded := h.mgr.LoadedDatasets() + loadedByName := make(map[string]datasets.LoadedDatasetInfo, len(loaded)) + for _, ld := range loaded { + loadedByName[ld.ID.Filename()] = ld } + out := struct { - Source string `json:"source"` - Active string `json:"active,omitempty"` - Epochs []string `json:"epochs"` - }{ - Source: h.mgr.Source(), - Active: active, - } - for _, e := range epochs { - out.Epochs = append(out.Epochs, e.UTC().Format(time.RFC3339)) + Source string `json:"source"` + Datasets []datasetDTO `json:"datasets"` + }{Source: h.mgr.Source(), Datasets: make([]datasetDTO, 0, len(stored))} + + for _, id := range stored { + dto := datasetDTO{ + Filename: id.Filename(), + Epoch: id.Epoch.UTC().Format(time.RFC3339), + } + if !id.Subset.IsGlobal() { + dto.Subset = &subsetDTO{ + Region: id.Subset.Region, + HourRange: id.Subset.HourRange, + Members: id.Subset.Members, + } + } + if ld, ok := loadedByName[id.Filename()]; ok { + dto.Loaded = true + dto.Coverage = &coverageDTO{ + Region: ld.Coverage.Region, + StartTime: ld.Coverage.StartTime.UTC().Format(time.RFC3339), + EndTime: ld.Coverage.EndTime.UTC().Format(time.RFC3339), + } + } + out.Datasets = append(out.Datasets, dto) } writeJSON(w, http.StatusOK, out) } // triggerDownload handles POST /api/v1/admin/datasets. // -// Body: {"epoch": "2026-03-28T06:00:00Z"} OR {"latest": true}. +// Body: +// {"latest": true} — refresh the latest global dataset +// {"epoch": "2026-03-28T06:00:00Z", "subset": {...}} — explicit dataset func (h *Handler) triggerDownload(w http.ResponseWriter, r *http.Request) { var body struct { - Epoch string `json:"epoch,omitempty"` - Latest bool `json:"latest,omitempty"` + Epoch string `json:"epoch,omitempty"` + Latest bool `json:"latest,omitempty"` + Subset *datasets.SubsetSpec `json:"subset,omitempty"` } if err := json.NewDecoder(r.Body).Decode(&body); err != nil { writeError(w, http.StatusBadRequest, "invalid body: "+err.Error()) @@ -89,7 +135,6 @@ func (h *Handler) triggerDownload(w http.ResponseWriter, r *http.Request) { return } - var epoch time.Time if body.Latest { ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second) defer cancel() @@ -102,29 +147,40 @@ func (h *Handler) triggerDownload(w http.ResponseWriter, r *http.Request) { return } - var err error - epoch, err = time.Parse(time.RFC3339, body.Epoch) + epoch, err := time.Parse(time.RFC3339, body.Epoch) if err != nil { writeError(w, http.StatusBadRequest, "invalid epoch: "+err.Error()) return } - jobID := h.mgr.Download(epoch) + id := datasets.DatasetID{Epoch: epoch.UTC()} + if body.Subset != nil { + id.Subset = *body.Subset + } + jobID := h.mgr.Download(id) writeJSON(w, http.StatusAccepted, map[string]string{"job_id": jobID}) } -// deleteDataset handles DELETE /api/v1/admin/datasets/{epoch}. +// deleteDataset handles DELETE /api/v1/admin/datasets/{name}. +// +// {name} is the dataset filename (DatasetID.Filename()) as returned by GET. func (h *Handler) deleteDataset(w http.ResponseWriter, r *http.Request) { - rawEpoch := r.PathValue("epoch") - epoch, err := time.Parse(time.RFC3339, rawEpoch) + name := r.PathValue("name") + stored, err := h.mgr.ListEpochs() if err != nil { - writeError(w, http.StatusBadRequest, "invalid epoch: "+err.Error()) - return - } - if err := h.mgr.RemoveEpoch(epoch); err != nil { writeError(w, http.StatusInternalServerError, err.Error()) return } - w.WriteHeader(http.StatusNoContent) + for _, id := range stored { + if id.Filename() == name { + if err := h.mgr.Remove(id); err != nil { + writeError(w, http.StatusInternalServerError, err.Error()) + return + } + w.WriteHeader(http.StatusNoContent) + return + } + } + writeError(w, http.StatusNotFound, "dataset not found") } // listJobs handles GET /api/v1/admin/jobs. @@ -158,24 +214,59 @@ func (h *Handler) cancelJob(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusNoContent) } +// status handles GET /api/v1/admin/status — a consolidated dashboard view. +func (h *Handler) status(w http.ResponseWriter, _ *http.Request) { + jobs := h.mgr.ListJobs() + stored, _ := h.mgr.ListEpochs() + loaded := h.mgr.LoadedDatasets() + + counts := map[string]int{} + for _, j := range jobs { + counts[string(j.Status)]++ + } + var mem runtime.MemStats + runtime.ReadMemStats(&mem) + + resp := struct { + Source string `json:"source"` + Uptime string `json:"uptime"` + Goroutines int `json:"goroutines"` + MemoryMB uint64 `json:"memory_mb"` + JobsByStatus map[string]int `json:"jobs_by_status"` + Stored int `json:"stored_datasets"` + Loaded int `json:"loaded_datasets"` + }{ + Source: h.mgr.Source(), + Uptime: time.Since(h.start).Round(time.Second).String(), + Goroutines: runtime.NumGoroutine(), + MemoryMB: mem.Alloc / 1024 / 1024, + JobsByStatus: counts, + Stored: len(stored), + Loaded: len(loaded), + } + writeJSON(w, http.StatusOK, resp) +} + type jobDTO struct { - ID string `json:"id"` - Source string `json:"source"` - Epoch string `json:"epoch"` - Status string `json:"status"` - StartedAt string `json:"started_at"` - EndedAt string `json:"ended_at,omitempty"` - Err string `json:"error,omitempty"` - Total int `json:"total_units"` - Done int `json:"done_units"` - Bytes int64 `json:"bytes"` + ID string `json:"id"` + Source string `json:"source"` + Dataset string `json:"dataset"` + Epoch string `json:"epoch"` + Status string `json:"status"` + StartedAt string `json:"started_at"` + EndedAt string `json:"ended_at,omitempty"` + Err string `json:"error,omitempty"` + Total int `json:"total_units"` + Done int `json:"done_units"` + Bytes int64 `json:"bytes"` } func toDTO(j datasets.JobInfo) jobDTO { dto := jobDTO{ ID: j.ID, Source: j.Source, - Epoch: j.Epoch.UTC().Format(time.RFC3339), + Dataset: j.Dataset.Filename(), + Epoch: j.Dataset.Epoch.UTC().Format(time.RFC3339), Status: string(j.Status), StartedAt: j.StartedAt.UTC().Format(time.RFC3339), Err: j.Err, @@ -189,18 +280,5 @@ func toDTO(j datasets.JobInfo) jobDTO { return dto } -func writeJSON(w http.ResponseWriter, status int, body any) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - _ = json.NewEncoder(w).Encode(body) -} - -func writeError(w http.ResponseWriter, status int, description string) { - writeJSON(w, status, map[string]any{ - "error": map[string]string{ - "type": http.StatusText(status), - "description": description, - }, - }) -} - +var writeJSON = httpjson.Write +var writeError = httpjson.Error diff --git a/internal/api/async/handler.go b/internal/api/async/handler.go new file mode 100644 index 0000000..569ed24 --- /dev/null +++ b/internal/api/async/handler.go @@ -0,0 +1,63 @@ +package async + +import ( + "encoding/json" + "net/http" + + "predictor-refactored/internal/api/httpjson" + "predictor-refactored/internal/api/v2" +) + +// Handler implements the /api/v1/predictions{,/{id}} endpoints. +type Handler struct { + mgr *Manager +} + +// NewHandler wires a handler. +func NewHandler(mgr *Manager) *Handler { return &Handler{mgr: mgr} } + +// Register installs the async routes on mux. +func (h *Handler) Register(mux *http.ServeMux) { + mux.HandleFunc("POST /api/v1/predictions", h.create) + mux.HandleFunc("GET /api/v1/predictions/{id}", h.get) + mux.HandleFunc("DELETE /api/v1/predictions/{id}", h.cancel) +} + +func (h *Handler) create(w http.ResponseWriter, r *http.Request) { + var req v2.PredictionRequest + dec := json.NewDecoder(r.Body) + dec.DisallowUnknownFields() + if err := dec.Decode(&req); err != nil { + writeError(w, http.StatusBadRequest, "invalid body: "+err.Error()) + return + } + info, accepted := h.mgr.Enqueue(req) + if !accepted { + writeJSON(w, http.StatusServiceUnavailable, info) + return + } + w.Header().Set("Location", "/api/v1/predictions/"+info.ID) + writeJSON(w, http.StatusAccepted, info) +} + +func (h *Handler) get(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + info, ok := h.mgr.Get(id) + if !ok { + writeError(w, http.StatusNotFound, "prediction job not found") + return + } + writeJSON(w, http.StatusOK, info) +} + +func (h *Handler) cancel(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + if !h.mgr.Cancel(id) { + writeError(w, http.StatusConflict, "job not found or already terminal") + return + } + w.WriteHeader(http.StatusNoContent) +} + +var writeJSON = httpjson.Write +var writeError = httpjson.Error diff --git a/internal/api/async/manager.go b/internal/api/async/manager.go new file mode 100644 index 0000000..eda3bee --- /dev/null +++ b/internal/api/async/manager.go @@ -0,0 +1,276 @@ +// Package async implements the asynchronous prediction endpoints +// (/api/v1/predictions{,/{id}}) and the worker pool that executes them. +// +// Each enqueued request is assigned a job ID; the result is held in +// memory for a configurable TTL after completion. +package async + +import ( + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "go.uber.org/zap" + + "predictor-refactored/internal/api/v2" + "predictor-refactored/internal/datasets" + "predictor-refactored/internal/elevation" + "predictor-refactored/internal/metrics" +) + +// Status is the lifecycle state of a prediction job. +type Status string + +const ( + StatusPending Status = "pending" + StatusRunning Status = "running" + StatusComplete Status = "complete" + StatusFailed Status = "failed" + StatusCancelled Status = "cancelled" +) + +// JobInfo is the externally-visible snapshot of one prediction job. +type JobInfo struct { + ID string `json:"id"` + Status Status `json:"status"` + CreatedAt time.Time `json:"created_at"` + StartedAt *time.Time `json:"started_at,omitempty"` + CompletedAt *time.Time `json:"completed_at,omitempty"` + Error string `json:"error,omitempty"` + Result *v2.PredictionResponse `json:"result,omitempty"` +} + +type job struct { + id string + req v2.PredictionRequest + createdAt time.Time + + mu sync.Mutex + status Status + startedAt time.Time + completedAt time.Time + errStr string + result *v2.PredictionResponse + cancel chan struct{} +} + +func (j *job) snapshot() JobInfo { + j.mu.Lock() + defer j.mu.Unlock() + info := JobInfo{ + ID: j.id, + Status: j.status, + CreatedAt: j.createdAt, + Error: j.errStr, + Result: j.result, + } + if !j.startedAt.IsZero() { + t := j.startedAt + info.StartedAt = &t + } + if !j.completedAt.IsZero() { + t := j.completedAt + info.CompletedAt = &t + } + return info +} + +// Manager runs a fixed pool of workers to execute prediction jobs and +// retains their results for the configured TTL. +type Manager struct { + mgr *datasets.Manager + elev *elevation.Dataset + metrics metrics.Sink + log *zap.Logger + + queue chan *job + ttl time.Duration + + jobsMu sync.RWMutex + jobs map[string]*job + + inflight atomic.Int64 + closed chan struct{} + wg sync.WaitGroup +} + +// Config controls Manager construction. +type Config struct { + // Workers is the maximum concurrent prediction executions. + Workers int + // QueueSize bounds the number of jobs waiting to start. + QueueSize int + // ResultTTL is how long completed/failed jobs are retained in memory. + ResultTTL time.Duration +} + +// New constructs a Manager with the given config and starts the workers. +func New(cfg Config, mgr *datasets.Manager, elev *elevation.Dataset, sink metrics.Sink, log *zap.Logger) *Manager { + if cfg.Workers <= 0 { + cfg.Workers = 4 + } + if cfg.QueueSize <= 0 { + cfg.QueueSize = 64 + } + if cfg.ResultTTL <= 0 { + cfg.ResultTTL = time.Hour + } + if sink == nil { + sink = metrics.Noop() + } + if log == nil { + log = zap.NewNop() + } + m := &Manager{ + mgr: mgr, elev: elev, metrics: sink, log: log, + queue: make(chan *job, cfg.QueueSize), + jobs: make(map[string]*job), + ttl: cfg.ResultTTL, + closed: make(chan struct{}), + } + for range cfg.Workers { + m.wg.Add(1) + go m.worker() + } + m.wg.Add(1) + go m.evictor() + return m +} + +// Enqueue creates a new job from req and returns its snapshot. +// Returns false when the queue is full. +func (m *Manager) Enqueue(req v2.PredictionRequest) (JobInfo, bool) { + j := &job{ + id: uuid.New().String(), + req: req, + createdAt: time.Now().UTC(), + status: StatusPending, + cancel: make(chan struct{}), + } + m.jobsMu.Lock() + m.jobs[j.id] = j + m.jobsMu.Unlock() + + select { + case m.queue <- j: + return j.snapshot(), true + default: + // Queue full — mark the job failed and return it. + j.mu.Lock() + j.status = StatusFailed + j.errStr = "prediction queue full" + j.completedAt = time.Now().UTC() + j.mu.Unlock() + return j.snapshot(), false + } +} + +// Get returns a job's snapshot. +func (m *Manager) Get(id string) (JobInfo, bool) { + m.jobsMu.RLock() + j, ok := m.jobs[id] + m.jobsMu.RUnlock() + if !ok { + return JobInfo{}, false + } + return j.snapshot(), true +} + +// Cancel marks a not-yet-started job as cancelled. Returns false when the +// job is unknown or already terminal. +func (m *Manager) Cancel(id string) bool { + m.jobsMu.RLock() + j, ok := m.jobs[id] + m.jobsMu.RUnlock() + if !ok { + return false + } + j.mu.Lock() + terminal := j.status == StatusComplete || j.status == StatusFailed || j.status == StatusCancelled + if terminal { + j.mu.Unlock() + return false + } + j.status = StatusCancelled + j.completedAt = time.Now().UTC() + j.mu.Unlock() + close(j.cancel) + return true +} + +// Inflight returns the count of running jobs. +func (m *Manager) Inflight() int64 { return m.inflight.Load() } + +// Close shuts down workers and the evictor. +func (m *Manager) Close() { + close(m.closed) + close(m.queue) + m.wg.Wait() +} + +func (m *Manager) worker() { + defer m.wg.Done() + for j := range m.queue { + // Check cancellation before starting. + j.mu.Lock() + cancelled := j.status == StatusCancelled + j.mu.Unlock() + if cancelled { + continue + } + m.inflight.Add(1) + j.mu.Lock() + j.status = StatusRunning + j.startedAt = time.Now().UTC() + j.mu.Unlock() + + resp, err := v2.Run(m.mgr, m.elev, j.req) + + j.mu.Lock() + j.completedAt = time.Now().UTC() + if err != nil { + j.status = StatusFailed + j.errStr = err.Error() + } else { + j.status = StatusComplete + j.result = resp + } + j.mu.Unlock() + m.inflight.Add(-1) + + if err == nil { + m.metrics.Prediction("async", j.completedAt.Sub(j.startedAt), nil) + } else { + m.metrics.Prediction("async", j.completedAt.Sub(j.startedAt), err) + } + } +} + +func (m *Manager) evictor() { + defer m.wg.Done() + ticker := time.NewTicker(m.ttl / 4) + defer ticker.Stop() + for { + select { + case <-m.closed: + return + case <-ticker.C: + m.evictExpired() + } + } +} + +func (m *Manager) evictExpired() { + now := time.Now().UTC() + m.jobsMu.Lock() + defer m.jobsMu.Unlock() + for id, j := range m.jobs { + j.mu.Lock() + expired := !j.completedAt.IsZero() && now.Sub(j.completedAt) > m.ttl + j.mu.Unlock() + if expired { + delete(m.jobs, id) + } + } +} diff --git a/internal/api/httpjson/httpjson.go b/internal/api/httpjson/httpjson.go new file mode 100644 index 0000000..3c4cc50 --- /dev/null +++ b/internal/api/httpjson/httpjson.go @@ -0,0 +1,27 @@ +// Package httpjson holds the tiny JSON response helpers shared across +// the admin, v2, and async handlers. +package httpjson + +import ( + "encoding/json" + "net/http" +) + +// Write writes body as JSON with the given status code. +func Write(w http.ResponseWriter, status int, body any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(body) +} + +// Error writes a standard error JSON body with the given status code. +// +// Shape: {"error": {"type": "...", "description": "..."}} +func Error(w http.ResponseWriter, status int, description string) { + Write(w, status, map[string]any{ + "error": map[string]string{ + "type": http.StatusText(status), + "description": description, + }, + }) +} diff --git a/internal/api/tawhiri/handler.go b/internal/api/tawhiri/handler.go index bf9a103..579926b 100644 --- a/internal/api/tawhiri/handler.go +++ b/internal/api/tawhiri/handler.go @@ -2,8 +2,8 @@ // (GET /api/v1/prediction). The request/response shapes match the original // Cambridge University Spaceflight predictor for drop-in compatibility. // -// Internally the handler builds an engine.Profile from query parameters and -// dispatches it through the same engine path as the new v2 endpoint. +// Internally the handler builds an engine.Profile from query parameters +// and dispatches it through the same engine path as the new v2 endpoint. package tawhiri import ( @@ -18,11 +18,11 @@ import ( "predictor-refactored/internal/elevation" "predictor-refactored/internal/engine" "predictor-refactored/internal/metrics" + "predictor-refactored/internal/weather" api "predictor-refactored/pkg/rest" ) -// Handler implements api.Handler (the ogen-generated interface for -// performPrediction and readinessCheck). +// Handler implements api.Handler (ogen-generated interface). type Handler struct { mgr *datasets.Manager elev *elevation.Dataset @@ -41,111 +41,49 @@ func New(mgr *datasets.Manager, elev *elevation.Dataset, sink metrics.Sink, log return &Handler{mgr: mgr, elev: elev, metrics: sink, log: log} } -// Compile-time check that Handler satisfies api.Handler. var _ api.Handler = (*Handler)(nil) // PerformPrediction runs the Tawhiri-style prediction. -func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredictionParams) (*api.PredictionResponse, error) { +func (h *Handler) PerformPrediction(_ context.Context, params api.PerformPredictionParams) (*api.PredictionResponse, error) { field := h.mgr.Active() if field == nil { return nil, newError(http.StatusServiceUnavailable, "no dataset loaded, service is starting up") } - // Parameters with Tawhiri defaults. - profileKind := "standard_profile" - if v, ok := params.Profile.Get(); ok { - profileKind = string(v) - } - ascentRate := 5.0 - if v, ok := params.AscentRate.Get(); ok { - ascentRate = v - } - burstAltitude := 28000.0 - if v, ok := params.BurstAltitude.Get(); ok { - burstAltitude = v - } - descentRate := 5.0 - if v, ok := params.DescentRate.Get(); ok { - descentRate = v - } - launchAlt := 0.0 - if v, ok := params.LaunchAltitude.Get(); ok { - launchAlt = v - } + profileKind := optString(params.Profile, "standard_profile") + ascentRate := optFloat(params.AscentRate, 5.0) + burstAltitude := optFloat(params.BurstAltitude, 28000.0) + descentRate := optFloat(params.DescentRate, 5.0) + launchAlt := optFloat(params.LaunchAltitude, 0.0) lng := params.LaunchLongitude if lng < 0 { lng += 360 } - launchTime := float64(params.LaunchDatetime.Unix()) - warnings := &engine.Warnings{} - // Build the profile. + events := engine.NewEventSink() + var stageNames []string var prof engine.Profile switch profileKind { case "standard_profile": stageNames = []string{"ascent", "descent"} - prof = engine.Profile{ - Direction: engine.Forward, - Stages: []*engine.Propagator{ - { - Name: "ascent", - Step: 60, - Model: engine.Sum( - engine.ConstantRate(ascentRate), - engine.WindTransport(field, warnings), - ), - Constraints: []engine.Constraint{engine.MaxAltitude{Limit: burstAltitude, On: engine.ActionStop}}, - }, - { - Name: "descent", - Step: 60, - Model: engine.Sum( - engine.ParachuteDescent(descentRate), - engine.WindTransport(field, warnings), - ), - Constraints: descentConstraints(h.elev), - }, - }, - } + prof = standardProfile(field, h.elev, events, ascentRate, burstAltitude, descentRate) case "float_profile": - floatAlt := 25000.0 - if v, ok := params.FloatAltitude.Get(); ok { - floatAlt = v - } + floatAlt := optFloat(params.FloatAltitude, 25000.0) stopTime := params.LaunchDatetime.Add(24 * time.Hour) if v, ok := params.StopDatetime.Get(); ok { stopTime = v } stageNames = []string{"ascent", "float"} - prof = engine.Profile{ - Direction: engine.Forward, - Stages: []*engine.Propagator{ - { - Name: "ascent", - Step: 60, - Model: engine.Sum( - engine.ConstantRate(ascentRate), - engine.WindTransport(field, warnings), - ), - Constraints: []engine.Constraint{engine.MaxAltitude{Limit: floatAlt, On: engine.ActionStop}}, - }, - { - Name: "float", - Step: 60, - Model: engine.WindTransport(field, warnings), - Constraints: []engine.Constraint{engine.MaxTime{Limit: float64(stopTime.Unix()), On: engine.ActionStop}}, - }, - }, - } + prof = floatProfile(field, events, ascentRate, floatAlt, stopTime) default: return nil, newError(http.StatusBadRequest, "unknown profile: "+profileKind) } started := time.Now().UTC() - results := prof.Run(launchTime, engine.State{Lat: params.LaunchLatitude, Lng: lng, Altitude: launchAlt}) + results := prof.Run(launchTime, engine.State{Lat: params.LaunchLatitude, Lng: lng, Altitude: launchAlt}, events) completed := time.Now().UTC() h.metrics.Prediction(profileKind, completed.Sub(started), nil) @@ -161,30 +99,7 @@ func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredi if i < len(stageNames) { stageName = stageNames[i] } - stageEnum := api.PredictionResponsePredictionItemStageAscent - switch stageName { - case "descent": - stageEnum = api.PredictionResponsePredictionItemStageDescent - case "float": - stageEnum = api.PredictionResponsePredictionItemStageFloat - } - traj := make([]api.PredictionResponsePredictionItemTrajectoryItem, 0, len(r.Points)) - for _, pt := range r.Points { - ptLng := pt.Lng - if ptLng > 180 { - ptLng -= 360 - } - traj = append(traj, api.PredictionResponsePredictionItemTrajectoryItem{ - Datetime: time.Unix(int64(pt.Time), 0).UTC(), - Latitude: pt.Lat, - Longitude: ptLng, - Altitude: pt.Altitude, - }) - } - resp.Prediction = append(resp.Prediction, api.PredictionResponsePredictionItem{ - Stage: stageEnum, - Trajectory: traj, - }) + resp.Prediction = append(resp.Prediction, buildPredictionItem(stageName, r)) } resp.Request = api.NewOptPredictionResponseRequest(api.PredictionResponseRequest{ @@ -195,7 +110,8 @@ func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredi LaunchAltitude: params.LaunchAltitude, }) - if warns := warnings.ToMap(); len(warns) > 0 { + if ev := events.Snapshot(); len(ev) > 0 { + // Preserve the OpenAPI-defined Warnings shape (open object). resp.Warnings = api.NewOptPredictionResponseWarnings(api.PredictionResponseWarnings{}) } @@ -207,13 +123,78 @@ func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredi return resp, nil } -// descentConstraints returns the descent termination set: TerrainContact if an -// elevation dataset is loaded, MinAltitude(0) otherwise. -func descentConstraints(elev *elevation.Dataset) []engine.Constraint { +// standardProfile constructs the ascent → descent profile. +func standardProfile(field weather.WindField, elev *elevation.Dataset, events *engine.EventSink, ascentRate, burstAltitude, descentRate float64) engine.Profile { + wind := engine.WindTransport(field, events) + descentTerm := []engine.Constraint{engine.Altitude{Op: engine.OpLessEqual, Limit: 0, On: engine.ActionStop}} if elev != nil { - return []engine.Constraint{engine.TerrainContact{Provider: elev, On: engine.ActionStop}} + descentTerm = []engine.Constraint{engine.TerrainContact{Provider: elev, On: engine.ActionStop}} } - return []engine.Constraint{engine.MinAltitude{Limit: 0, On: engine.ActionStop}} + return engine.Profile{ + Direction: engine.Forward, + Stages: []*engine.Propagator{ + { + Name: "ascent", + Step: 60, + Model: engine.Sum(engine.ConstantRate(ascentRate), wind), + Constraints: []engine.Constraint{engine.Altitude{Op: engine.OpGreaterEqual, Limit: burstAltitude, On: engine.ActionStop}}, + }, + { + Name: "descent", + Step: 60, + Model: engine.Sum(engine.ParachuteDescent(descentRate), wind), + Constraints: descentTerm, + }, + }, + } +} + +// floatProfile constructs the ascent → float profile. +func floatProfile(field weather.WindField, events *engine.EventSink, ascentRate, floatAlt float64, stopTime time.Time) engine.Profile { + wind := engine.WindTransport(field, events) + return engine.Profile{ + Direction: engine.Forward, + Stages: []*engine.Propagator{ + { + Name: "ascent", + Step: 60, + Model: engine.Sum(engine.ConstantRate(ascentRate), wind), + Constraints: []engine.Constraint{engine.Altitude{Op: engine.OpGreaterEqual, Limit: floatAlt, On: engine.ActionStop}}, + }, + { + Name: "float", + Step: 60, + Model: wind, + Constraints: []engine.Constraint{engine.Time{Op: engine.OpGreater, Limit: float64(stopTime.Unix()), On: engine.ActionStop}}, + }, + }, + } +} + +func buildPredictionItem(stageName string, r engine.Result) api.PredictionResponsePredictionItem { + var stageEnum api.PredictionResponsePredictionItemStage + switch stageName { + case "descent": + stageEnum = api.PredictionResponsePredictionItemStageDescent + case "float": + stageEnum = api.PredictionResponsePredictionItemStageFloat + default: + stageEnum = api.PredictionResponsePredictionItemStageAscent + } + traj := make([]api.PredictionResponsePredictionItemTrajectoryItem, 0, len(r.Points)) + for _, pt := range r.Points { + ptLng := pt.Lng + if ptLng > 180 { + ptLng -= 360 + } + traj = append(traj, api.PredictionResponsePredictionItemTrajectoryItem{ + Datetime: time.Unix(int64(pt.Time), 0).UTC(), + Latitude: pt.Lat, + Longitude: ptLng, + Altitude: pt.Altitude, + }) + } + return api.PredictionResponsePredictionItem{Stage: stageEnum, Trajectory: traj} } // ReadinessCheck reports whether a dataset is currently loaded. @@ -250,3 +231,21 @@ func newError(status int, description string) *api.ErrorStatusCode { }, } } + +// optString returns the option's value if set, else fallback. +func optString[T ~string](o interface { + Get() (T, bool) +}, fallback string) string { + if v, ok := o.Get(); ok { + return string(v) + } + return fallback +} + +// optFloat returns the option's float64 value if set, else fallback. +func optFloat(o api.OptFloat64, fallback float64) float64 { + if v, ok := o.Get(); ok { + return v + } + return fallback +} diff --git a/internal/api/transport.go b/internal/api/transport.go index 50d2185..f6c6979 100644 --- a/internal/api/transport.go +++ b/internal/api/transport.go @@ -15,6 +15,7 @@ import ( "go.uber.org/zap" "predictor-refactored/internal/api/admin" + "predictor-refactored/internal/api/async" "predictor-refactored/internal/api/middleware" "predictor-refactored/internal/api/tawhiri" v2 "predictor-refactored/internal/api/v2" @@ -33,12 +34,13 @@ type Server struct { // Deps are the runtime dependencies the API layer needs. type Deps struct { - Manager *datasets.Manager - Elevation *elevation.Dataset - Metrics metrics.Sink + Manager *datasets.Manager + Elevation *elevation.Dataset + Metrics metrics.Sink MetricsHandler http.Handler // optional; mounted at MetricsPath when non-nil MetricsPath string - Log *zap.Logger + AsyncManager *async.Manager // optional; mounts /api/v1/predictions when non-nil + Log *zap.Logger } // New wires the HTTP server. The returned Server is not yet started. @@ -68,6 +70,12 @@ func New(port int, d Deps) (*Server, error) { adminH := admin.New(d.Manager, d.Log) adminH.Register(mux) + // Async prediction endpoints (optional). + if d.AsyncManager != nil { + asyncH := async.NewHandler(d.AsyncManager) + asyncH.Register(mux) + } + // Metrics endpoint. if d.MetricsHandler != nil && d.MetricsPath != "" { mux.Handle(d.MetricsPath, d.MetricsHandler) diff --git a/internal/api/v2/handler.go b/internal/api/v2/handler.go index 23dd886..5071aee 100644 --- a/internal/api/v2/handler.go +++ b/internal/api/v2/handler.go @@ -8,6 +8,7 @@ import ( "go.uber.org/zap" + "predictor-refactored/internal/api/httpjson" "predictor-refactored/internal/datasets" "predictor-refactored/internal/elevation" "predictor-refactored/internal/engine" @@ -46,85 +47,109 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { writeError(w, http.StatusBadRequest, "invalid request body: "+err.Error()) return } - if err := validateRequest(req); err != nil { writeError(w, http.StatusBadRequest, err.Error()) return } - field := h.mgr.Active() - if field == nil { - writeError(w, http.StatusServiceUnavailable, "no dataset loaded, service is starting up") + resp, err := Run(h.mgr, h.elev, req) + if err != nil { + if perr, ok := err.(*PredictionError); ok { + writeError(w, perr.Status, perr.Description) + return + } + writeError(w, http.StatusInternalServerError, err.Error()) return } + h.metrics.Prediction("v2", resp.CompletedAt.Sub(resp.StartedAt), nil) + h.log.Info("v2 prediction complete", + zap.Int("stages", len(resp.Stages)), + zap.Duration("elapsed", resp.CompletedAt.Sub(resp.StartedAt))) + writeJSON(w, http.StatusOK, resp) +} + +// PredictionError carries an HTTP status alongside the message so async +// callers can map the failure back to a useful HTTP response. +type PredictionError struct { + Status int + Description string +} + +func (e *PredictionError) Error() string { return e.Description } + +// Run executes a PredictionRequest against the manager's active wind field. +// Shared between the sync /api/v2/prediction handler and the async +// /api/v1/predictions worker. +func Run(mgr *datasets.Manager, elev *elevation.Dataset, req PredictionRequest) (*PredictionResponse, error) { + field := mgr.Active() + if field == nil { + return nil, &PredictionError{Status: http.StatusServiceUnavailable, Description: "no dataset loaded, service is starting up"} + } - // Normalize longitude to [0, 360) for internal use. lng := req.Launch.Longitude if lng < 0 { lng += 360 } - warnings := &engine.Warnings{} - var terrain engine.TerrainProvider - if h.elev != nil { - terrain = h.elev + events := engine.NewEventSink() + deps := engine.BuildDeps{Wind: field, Events: events} + if elev != nil { + deps.Terrain = elev } - prof, err := buildProfile(req, field, terrain, warnings) + prof, err := buildProfile(req, deps) if err != nil { - writeError(w, http.StatusBadRequest, err.Error()) - return + return nil, &PredictionError{Status: http.StatusBadRequest, Description: err.Error()} } started := time.Now().UTC() results := prof.Run(float64(req.Launch.Time.Unix()), engine.State{ - Lat: req.Launch.Latitude, - Lng: lng, - Altitude: req.Launch.Altitude, - }) + Lat: req.Launch.Latitude, Lng: lng, Altitude: req.Launch.Altitude, + }, events) completed := time.Now().UTC() - h.metrics.Prediction("v2", completed.Sub(started), nil) - resp := PredictionResponse{ + resp := &PredictionResponse{ Stages: make([]StageResult, 0, len(results)), + Events: events.Snapshot(), StartedAt: started, CompletedAt: completed, - Dataset: DatasetInfo{ - Source: field.Source(), - Epoch: field.Epoch(), - }, + Dataset: DatasetInfo{Source: field.Source(), Epoch: field.Epoch()}, } for _, r := range results { - stage := StageResult{ - Name: r.Propagator, - Outcome: outcomeString(r.Outcome), - } - if r.Constraint != nil { - stage.Constraint = r.Constraint.Name() - } - stage.Trajectory = make([]TrajectoryPoint, len(r.Points)) - for i, pt := range r.Points { - ptLng := pt.Lng - if ptLng > 180 { - ptLng -= 360 - } - stage.Trajectory[i] = TrajectoryPoint{ - Time: time.Unix(int64(pt.Time), 0).UTC(), - Latitude: pt.Lat, - Longitude: ptLng, - Altitude: pt.Altitude, - } - } - resp.Stages = append(resp.Stages, stage) - } - if warns := warnings.ToMap(); len(warns) > 0 { - resp.Warnings = warns + resp.Stages = append(resp.Stages, toStageResult(r)) } + return resp, nil +} - h.log.Info("v2 prediction complete", - zap.Int("stages", len(results)), - zap.Duration("elapsed", completed.Sub(started))) - writeJSON(w, http.StatusOK, resp) +func toStageResult(r engine.Result) StageResult { + stage := StageResult{ + Name: r.Propagator, + Outcome: r.Outcome.String(), + Events: r.Events, + } + if r.Constraint != nil { + stage.Constraint = r.ConstraintName + stage.Termination = &TerminationInfo{ + ViolationTime: time.Unix(int64(r.ViolationTime), 0).UTC(), + ViolationState: r.ViolationState, + RefinedTime: time.Unix(int64(r.RefinedTime), 0).UTC(), + RefinedState: r.RefinedState, + } + } + stage.Trajectory = make([]TrajectoryPoint, len(r.Points)) + for i, pt := range r.Points { + ptLng := pt.Lng + if ptLng > 180 { + ptLng -= 360 + } + stage.Trajectory[i] = TrajectoryPoint{ + Time: time.Unix(int64(pt.Time), 0).UTC(), + Latitude: pt.Lat, + Longitude: ptLng, + Altitude: pt.Altitude, + } + } + return stage } func validateRequest(req PredictionRequest) error { @@ -148,26 +173,5 @@ func validateRequest(req PredictionRequest) error { return nil } -func outcomeString(o engine.Outcome) string { - switch o { - case engine.OutcomeStopped: - return "stopped" - case engine.OutcomeFallback: - return "fallback" - default: - return "continued" - } -} - -func writeError(w http.ResponseWriter, status int, description string) { - writeJSON(w, status, ErrorResponse{Error: ErrorBody{ - Type: http.StatusText(status), - Description: description, - }}) -} - -func writeJSON(w http.ResponseWriter, status int, body any) { - w.Header().Set("Content-Type", "application/json") - w.WriteHeader(status) - _ = json.NewEncoder(w).Encode(body) -} +var writeJSON = httpjson.Write +var writeError = httpjson.Error diff --git a/internal/api/v2/profile.go b/internal/api/v2/profile.go index 36f11f8..24ceb79 100644 --- a/internal/api/v2/profile.go +++ b/internal/api/v2/profile.go @@ -4,14 +4,11 @@ import ( "fmt" "predictor-refactored/internal/engine" - "predictor-refactored/internal/weather" ) -// buildProfile translates a PredictionRequest into an engine.Profile. -// -// elev may be nil when no terrain dataset is loaded; TerrainContact constraints -// will return an error in that case. -func buildProfile(req PredictionRequest, field weather.WindField, elev engine.TerrainProvider, warnings *engine.Warnings) (engine.Profile, error) { +// buildProfile translates a PredictionRequest into an engine.Profile via +// the engine registry. +func buildProfile(req PredictionRequest, deps engine.BuildDeps) (engine.Profile, error) { if len(req.Profile) == 0 { return engine.Profile{}, fmt.Errorf("profile must contain at least one stage") } @@ -37,24 +34,27 @@ func buildProfile(req PredictionRequest, field weather.WindField, elev engine.Te props := make([]*engine.Propagator, len(req.Profile)) for i, stage := range req.Profile { - model, err := buildModel(stage.Model, field, warnings) - if err != nil { - return engine.Profile{}, fmt.Errorf("stage %q: %w", stage.Name, err) + if stage.Name == "" { + return engine.Profile{}, fmt.Errorf("stage %d: name is required", i) } - constraints, err := buildConstraints(stage.Constraints, elev) + built, err := engine.BuildModel(stage.Model, deps) + if err != nil { + return engine.Profile{}, fmt.Errorf("stage %q model: %w", stage.Name, err) + } + constraints, err := buildConstraintList(stage.Constraints, deps) if err != nil { return engine.Profile{}, fmt.Errorf("stage %q: %w", stage.Name, err) } props[i] = &engine.Propagator{ Name: stage.Name, Step: step, - Model: model, + Model: built.Model, + BuildModel: built.Build, Constraints: constraints, Tolerance: tol, } } - // Wire fallbacks once all stages exist. for i, stage := range req.Profile { if stage.FallbackIndex == nil { continue @@ -66,80 +66,22 @@ func buildProfile(req PredictionRequest, field weather.WindField, elev engine.Te props[i].Fallback = props[idx] } - return engine.Profile{Stages: props, Direction: dir}, nil -} - -func buildModel(spec ModelSpec, field weather.WindField, warnings *engine.Warnings) (engine.Model, error) { - var base engine.Model - switch spec.Type { - case "constant_rate": - base = engine.ConstantRate(spec.Rate) - case "parachute_descent": - if spec.SeaLevelRate <= 0 { - return nil, fmt.Errorf("parachute_descent requires positive sea_level_rate") - } - base = engine.ParachuteDescent(spec.SeaLevelRate) - case "piecewise": - segs := make([]engine.RateSegment, len(spec.Segments)) - for i, s := range spec.Segments { - segs[i] = engine.RateSegment{Until: s.Until, Rate: s.Rate} - } - base = engine.Piecewise(segs) - case "wind": - if field == nil { - return nil, fmt.Errorf("wind model requires a loaded dataset") - } - return engine.WindTransport(field, warnings), nil - default: - return nil, fmt.Errorf("unknown model type %q", spec.Type) + globals, err := buildConstraintList(req.Globals, deps) + if err != nil { + return engine.Profile{}, fmt.Errorf("globals: %w", err) } - if spec.IncludeWind { - if field == nil { - return nil, fmt.Errorf("include_wind requires a loaded dataset") - } - return engine.Sum(base, engine.WindTransport(field, warnings)), nil - } - return base, nil + return engine.Profile{Stages: props, Direction: dir, Globals: globals}, nil } -func buildConstraints(specs []ConstraintSpec, elev engine.TerrainProvider) ([]engine.Constraint, error) { +func buildConstraintList(specs []engine.ConstraintSpec, deps engine.BuildDeps) ([]engine.Constraint, error) { out := make([]engine.Constraint, 0, len(specs)) - for _, spec := range specs { - action, err := parseAction(spec.Action) + for i, spec := range specs { + c, err := engine.BuildConstraint(spec, deps) if err != nil { - return nil, err - } - var c engine.Constraint - switch spec.Type { - case "max_altitude": - c = engine.MaxAltitude{Limit: spec.Limit, On: action} - case "min_altitude": - c = engine.MinAltitude{Limit: spec.Limit, On: action} - case "max_time": - c = engine.MaxTime{Limit: spec.Limit, On: action} - case "terrain_contact": - if elev == nil { - return nil, fmt.Errorf("terrain_contact requires an elevation dataset") - } - c = engine.TerrainContact{Provider: elev, On: action} - default: - return nil, fmt.Errorf("unknown constraint type %q", spec.Type) + return nil, fmt.Errorf("constraint[%d]: %w", i, err) } out = append(out, c) } return out, nil } - -func parseAction(s string) (engine.Action, error) { - switch s { - case "", "stop": - return engine.ActionStop, nil - case "fallback": - return engine.ActionFallback, nil - case "clip": - return engine.ActionClip, nil - default: - return 0, fmt.Errorf("unknown constraint action %q", s) - } -} diff --git a/internal/api/v2/types.go b/internal/api/v2/types.go index 7d76dd1..eaab031 100644 --- a/internal/api/v2/types.go +++ b/internal/api/v2/types.go @@ -1,18 +1,25 @@ -// Package v2 implements the new primary prediction endpoint, which accepts a -// user-defined profile (chain of propagators with optional constraints) and -// returns the resulting trajectory. +// Package v2 implements the profile-driven prediction endpoint. // // Endpoint: POST /api/v2/prediction +// +// The request schema is built on the engine package's ConstraintSpec and +// ModelSpec, so adding new model or constraint types in the engine requires +// no changes here — they become available automatically via the registry. package v2 -import "time" +import ( + "time" -// PredictionRequest is the request body for POST /api/v2/prediction. + "predictor-refactored/internal/engine" +) + +// PredictionRequest is the body of POST /api/v2/prediction. type PredictionRequest struct { - Launch Launch `json:"launch"` - Profile []Stage `json:"profile"` - Options Options `json:"options,omitempty"` - Direction string `json:"direction,omitempty"` // "forward" (default) or "reverse" + Launch Launch `json:"launch"` + Profile []StageSpec `json:"profile"` + Globals []engine.ConstraintSpec `json:"globals,omitempty"` + Options Options `json:"options,omitempty"` + Direction string `json:"direction,omitempty"` // "forward" (default) or "reverse" } // Launch is the initial state of the balloon (or, for reverse predictions, @@ -24,68 +31,47 @@ type Launch struct { Altitude float64 `json:"altitude"` } -// Stage is one entry in the propagator chain. -type Stage struct { - Name string `json:"name"` - Model ModelSpec `json:"model"` - Constraints []ConstraintSpec `json:"constraints,omitempty"` - // FallbackIndex, when set, points to another stage in the same profile to - // transfer to on ActionFallback constraints. Optional. +// StageSpec is one entry in the propagator chain. +type StageSpec struct { + Name string `json:"name"` + Model engine.ModelSpec `json:"model"` + Constraints []engine.ConstraintSpec `json:"constraints,omitempty"` + // FallbackIndex, when set, points to another stage in the same profile + // to transfer to on ActionFallback constraints. FallbackIndex *int `json:"fallback_index,omitempty"` } -// ModelSpec describes the per-stage propagation model. -type ModelSpec struct { - // Type selects the model: "constant_rate", "parachute_descent", "piecewise", "wind". - Type string `json:"type"` - // Rate (m/s) for constant_rate. - Rate float64 `json:"rate,omitempty"` - // SeaLevelRate (m/s, positive) for parachute_descent. - SeaLevelRate float64 `json:"sea_level_rate,omitempty"` - // Segments for piecewise. - Segments []PiecewiseSegment `json:"segments,omitempty"` - // IncludeWind sums a WindTransport model into the resulting derivative, - // allowing the same stage to model both vertical motion and wind drift. - IncludeWind bool `json:"include_wind"` -} - -// PiecewiseSegment is one entry in a piecewise rate schedule. -type PiecewiseSegment struct { - Until float64 `json:"until"` // UNIX seconds; segment applies for t < Until - Rate float64 `json:"rate"` // m/s -} - -// ConstraintSpec describes one constraint attached to a stage. -type ConstraintSpec struct { - // Type: "max_altitude", "min_altitude", "max_time", "terrain_contact". - Type string `json:"type"` - // Limit is interpreted per Type: metres for altitude, UNIX seconds for time. - Limit float64 `json:"limit,omitempty"` - // Action: "stop" (default), "fallback", "clip". - Action string `json:"action,omitempty"` -} - -// Options tweaks the integrator behaviour. +// Options tweaks integrator behaviour. type Options struct { StepSeconds float64 `json:"step_seconds,omitempty"` Tolerance float64 `json:"tolerance,omitempty"` } -// PredictionResponse is the response body for POST /api/v2/prediction. +// PredictionResponse is the body of a successful POST response. type PredictionResponse struct { - Stages []StageResult `json:"stages"` - Warnings map[string]any `json:"warnings,omitempty"` - Dataset DatasetInfo `json:"dataset"` - StartedAt time.Time `json:"started_at"` - CompletedAt time.Time `json:"completed_at"` + Stages []StageResult `json:"stages"` + Events []engine.EventSummary `json:"events,omitempty"` + Dataset DatasetInfo `json:"dataset"` + StartedAt time.Time `json:"started_at"` + CompletedAt time.Time `json:"completed_at"` } // StageResult is the outcome of one stage. type StageResult struct { - Name string `json:"name"` - Outcome string `json:"outcome"` // "stopped" | "fallback" | "continued" - Constraint string `json:"constraint,omitempty"` - Trajectory []TrajectoryPoint `json:"trajectory"` + Name string `json:"name"` + Outcome string `json:"outcome"` + Constraint string `json:"constraint,omitempty"` + Termination *TerminationInfo `json:"termination,omitempty"` + Events []engine.EventSummary `json:"events,omitempty"` + Trajectory []TrajectoryPoint `json:"trajectory"` +} + +// TerminationInfo exposes the violation+refinement detail from the engine. +type TerminationInfo struct { + ViolationTime time.Time `json:"violation_time"` + ViolationState engine.State `json:"violation_state"` + RefinedTime time.Time `json:"refined_time"` + RefinedState engine.State `json:"refined_state"` } // TrajectoryPoint is one sampled point of the trajectory. @@ -96,13 +82,13 @@ type TrajectoryPoint struct { Altitude float64 `json:"altitude"` } -// DatasetInfo identifies the dataset the prediction was computed against. +// DatasetInfo identifies the wind dataset used. type DatasetInfo struct { Source string `json:"source"` Epoch time.Time `json:"epoch"` } -// ErrorResponse is the JSON error shape used by both v2 and admin endpoints. +// ErrorResponse is the JSON error shape. type ErrorResponse struct { Error ErrorBody `json:"error"` } diff --git a/internal/config/config.go b/internal/config/config.go index 6ce40ab..7517511 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -28,6 +28,12 @@ type Config struct { // HTTPConfig configures the HTTP server. type HTTPConfig struct { Port int `yaml:"port"` + // AsyncWorkers caps concurrent prediction executions for the async endpoint. + AsyncWorkers int `yaml:"async_workers"` + // AsyncQueueSize bounds the async pending queue. + AsyncQueueSize int `yaml:"async_queue_size"` + // AsyncResultTTL is how long completed async results are retained. + AsyncResultTTL time.Duration `yaml:"async_result_ttl"` } // DataConfig configures dataset and elevation storage. @@ -60,11 +66,16 @@ type LogConfig struct { // Defaults returns a Config with reasonable default values. func Defaults() Config { return Config{ - HTTP: HTTPConfig{Port: 8080}, + HTTP: HTTPConfig{ + Port: 8080, + AsyncWorkers: 4, + AsyncQueueSize: 64, + AsyncResultTTL: time.Hour, + }, Data: DataConfig{ Dir: "/tmp/predictor-data", ElevationPath: "/srv/ruaumoko-dataset", - Source: "noaa-gfs-0p50", + Source: "gfs-0p50-3h", }, Download: DownloadConfig{ Parallel: 8, diff --git a/internal/datasets/gefs/source.go b/internal/datasets/gefs/source.go new file mode 100644 index 0000000..bbc1b63 --- /dev/null +++ b/internal/datasets/gefs/source.go @@ -0,0 +1,151 @@ +// Package gefs implements datasets.Source for NOAA GEFS (Global Ensemble +// Forecast System) forecasts. +// +// Each ensemble member is treated as its own dataset, selected via +// DatasetID.Subset.Members. The download skeleton (HTTP, idx parsing, +// parallel blit) lives in internal/datasets/grib; this package only +// supplies GEFS-specific URL templating and member resolution. +package gefs + +import ( + "context" + "fmt" + "net/http" + "time" + + "go.uber.org/zap" + + "predictor-refactored/internal/datasets" + "predictor-refactored/internal/datasets/grib" + "predictor-refactored/internal/weather" + wgfs "predictor-refactored/internal/weather/gfs" +) + +// Source is the GEFS implementation of datasets.Source. +type Source struct { + Variant *wgfs.Variant + Parallel int + Client *http.Client + Log *zap.Logger +} + +// NewSource returns a default Source over variant. If variant is nil, +// GEFS 0.5° 3-hour is used. +func NewSource(variant *wgfs.Variant, log *zap.Logger) *Source { + if variant == nil { + variant = wgfs.GEFS0p50_3h + } + return &Source{ + Variant: variant, + Parallel: 8, + Client: &http.Client{Timeout: 2 * time.Minute}, + Log: log, + } +} + +func (s *Source) ID() string { return s.Variant.ID } + +func (s *Source) downloader() *grib.Downloader { + return &grib.Downloader{ + Variant: s.Variant, + URLs: s.url, + Parallel: s.Parallel, + Client: s.Client, + Log: s.Log, + } +} + +// url generates the GEFS URL for (date, runHour, member, step, levelSet). +func (s *Source) url(date string, runHour, member, step int, ls wgfs.LevelSet) string { + if ls == wgfs.LevelSetB { + return wgfs.GefsGribURLB(date, runHour, member, step, s.Variant.ResToken) + } + return wgfs.GefsGribURL(date, runHour, member, step, s.Variant.ResToken) +} + +// LatestEpoch HEAD-checks the control member's final forecast hour. +func (s *Source) LatestEpoch(ctx context.Context) (time.Time, error) { + now := time.Now().UTC() + hour := now.Hour() - (now.Hour() % 6) + current := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC) + + client := s.Client + if client == nil { + client = &http.Client{Timeout: 2 * time.Minute} + } + log := s.Log + if log == nil { + log = zap.NewNop() + } + + for range 8 { + date := current.Format("20060102") + url := wgfs.GefsGribURL(date, current.Hour(), 0, s.Variant.MaxHour, s.Variant.ResToken) + ".idx" + req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) + if err == nil { + resp, err := client.Do(req) + if err == nil { + resp.Body.Close() + if resp.StatusCode == http.StatusOK { + log.Info("latest GEFS run discovered", + zap.Time("run", current), + zap.String("verified_url", url)) + return current, nil + } + } + } + current = current.Add(-6 * time.Hour) + } + return time.Time{}, fmt.Errorf("no recent GEFS run found") +} + +// Coverage returns the extent of id. +func (s *Source) Coverage(id datasets.DatasetID) datasets.Coverage { + v := s.Variant + cov := datasets.Coverage{ + Region: datasets.Region{MinLat: -90, MaxLat: 90, MinLng: 0, MaxLng: 360}, + StartTime: id.Epoch, + EndTime: id.Epoch.Add(time.Duration(v.MaxHour) * time.Hour), + } + if r := id.Subset.Region; r != nil { + cov.Region = *r + } + if h := id.Subset.HourRange; h != nil { + cov.StartTime = id.Epoch.Add(time.Duration(h.MinHour) * time.Hour) + cov.EndTime = id.Epoch.Add(time.Duration(h.MaxHour) * time.Hour) + } + return cov +} + +// Open loads a stored GEFS dataset as a WindField. +func (s *Source) Open(_ context.Context, id datasets.DatasetID, store datasets.Storage) (weather.WindField, error) { + if !store.Exists(id) { + return nil, fmt.Errorf("dataset %s not found", id.Filename()) + } + file, err := wgfs.Open(store.Path(id), s.Variant, id.Epoch.UTC()) + if err != nil { + return nil, err + } + return wgfs.NewWind(file), nil +} + +// memberOf extracts the single member index encoded by id.Subset.Members. +func memberOf(id datasets.DatasetID) (int, error) { + if len(id.Subset.Members) != 1 { + return 0, fmt.Errorf("gefs dataset id must specify exactly one member (got %v)", id.Subset.Members) + } + m := id.Subset.Members[0] + if m < 0 || m >= wgfs.GEFSMembers { + return 0, fmt.Errorf("gefs member %d out of range [0, %d)", m, wgfs.GEFSMembers) + } + return m, nil +} + +// Download fetches one ensemble member's dataset. +func (s *Source) Download(ctx context.Context, id datasets.DatasetID, store datasets.Storage, prog datasets.ProgressSink, throttle datasets.Throttle) error { + member, err := memberOf(id) + if err != nil { + return err + } + return s.downloader().Run(ctx, id, member, store, prog, throttle) +} diff --git a/internal/datasets/gfs/source.go b/internal/datasets/gfs/source.go index 081525e..af02803 100644 --- a/internal/datasets/gfs/source.go +++ b/internal/datasets/gfs/source.go @@ -1,85 +1,96 @@ -// Package gfs implements datasets.Source for NOAA GFS 0.5-degree forecasts. +// Package gfs implements datasets.Source for NOAA GFS forecasts. +// +// The package serves multiple GFS variants (0.5° 3-hour, 0.25° 3-hour, +// 0.25° 1-hour); the variant is selected at construction time. The +// download skeleton (HTTP, idx parsing, parallel blit) lives in +// internal/datasets/grib; this package only supplies URL templating and +// the Source-interface plumbing. package gfs import ( "context" - "errors" "fmt" - "io" - "math" "net/http" - "os" - "sync" "time" - "github.com/nilsmagnus/grib/griblib" "go.uber.org/zap" - "golang.org/x/sync/errgroup" "predictor-refactored/internal/datasets" + "predictor-refactored/internal/datasets/grib" "predictor-refactored/internal/weather" wgfs "predictor-refactored/internal/weather/gfs" ) // Source is the GFS implementation of datasets.Source. type Source struct { - Parallel int // max concurrent step downloads - Client *http.Client // optional; defaults to a 2-minute-timeout client + Variant *wgfs.Variant + Parallel int + Client *http.Client Log *zap.Logger } -// NewSource returns a default Source. -func NewSource(log *zap.Logger) *Source { +// NewSource returns a default Source over variant. If variant is nil, +// GFS 0.5° 3-hour is used (the historical Tawhiri default). +func NewSource(variant *wgfs.Variant, log *zap.Logger) *Source { + if variant == nil { + variant = wgfs.GFS0p50_3h + } return &Source{ + Variant: variant, Parallel: 8, Client: &http.Client{Timeout: 2 * time.Minute}, Log: log, } } -// ID returns the source identifier. -func (s *Source) ID() string { return "noaa-gfs-0p50" } +// ID returns the variant's ID. +func (s *Source) ID() string { return s.Variant.ID } -func (s *Source) log() *zap.Logger { - if s.Log == nil { - return zap.NewNop() +func (s *Source) downloader() *grib.Downloader { + return &grib.Downloader{ + Variant: s.Variant, + URLs: s.url, + Parallel: s.Parallel, + Client: s.Client, + Log: s.Log, } - return s.Log } -func (s *Source) client() *http.Client { - if s.Client == nil { - return &http.Client{Timeout: 2 * time.Minute} +// url generates the GFS URL for one (date, runHour, _, step, levelSet). +// member is unused for GFS. +func (s *Source) url(date string, runHour, _, step int, ls wgfs.LevelSet) string { + if ls == wgfs.LevelSetB { + return s.Variant.GribURLB(date, runHour, step) } - return s.Client + return s.Variant.GribURL(date, runHour, step) } -func (s *Source) parallel() int { - if s.Parallel <= 0 { - return 8 - } - return s.Parallel -} - -// LatestEpoch returns the most recent run NOAA has finished publishing, -// determined by HEAD-ing the .idx for the final forecast hour. Walks back -// up to 8 runs (48 hours) before giving up. +// LatestEpoch returns the most recent run NOAA has finished publishing. func (s *Source) LatestEpoch(ctx context.Context) (time.Time, error) { now := time.Now().UTC() hour := now.Hour() - (now.Hour() % 6) current := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC) + client := s.Client + if client == nil { + client = &http.Client{Timeout: 2 * time.Minute} + } + log := s.Log + if log == nil { + log = zap.NewNop() + } + for range 8 { date := current.Format("20060102") - url := wgfs.GribURL(date, current.Hour(), wgfs.MaxHour) + ".idx" - + url := s.Variant.GribURL(date, current.Hour(), s.Variant.MaxHour) + ".idx" req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) if err == nil { - resp, err := s.client().Do(req) + resp, err := client.Do(req) if err == nil { resp.Body.Close() if resp.StatusCode == http.StatusOK { - s.log().Info("latest GFS run discovered", + log.Info("latest run discovered", + zap.String("variant", s.Variant.ID), zap.Time("run", current), zap.String("verified_url", url)) return current, nil @@ -88,343 +99,40 @@ func (s *Source) LatestEpoch(ctx context.Context) (time.Time, error) { } current = current.Add(-6 * time.Hour) } - return time.Time{}, fmt.Errorf("no recent GFS run found (checked 8 runs)") + return time.Time{}, fmt.Errorf("no recent %s run found (checked 8 runs)", s.Variant.ID) +} + +// Coverage returns the geographic and temporal extent of id. +func (s *Source) Coverage(id datasets.DatasetID) datasets.Coverage { + v := s.Variant + cov := datasets.Coverage{ + Region: datasets.Region{MinLat: -90, MaxLat: 90, MinLng: 0, MaxLng: 360}, + StartTime: id.Epoch, + EndTime: id.Epoch.Add(time.Duration(v.MaxHour) * time.Hour), + } + if r := id.Subset.Region; r != nil { + cov.Region = *r + } + if h := id.Subset.HourRange; h != nil { + cov.StartTime = id.Epoch.Add(time.Duration(h.MinHour) * time.Hour) + cov.EndTime = id.Epoch.Add(time.Duration(h.MaxHour) * time.Hour) + } + return cov } // Open loads a stored dataset as a WindField. -func (s *Source) Open(_ context.Context, epoch time.Time, store datasets.Storage) (weather.WindField, error) { - if !store.Exists(epoch) { - return nil, fmt.Errorf("epoch %s not found", epoch.Format(time.RFC3339)) +func (s *Source) Open(_ context.Context, id datasets.DatasetID, store datasets.Storage) (weather.WindField, error) { + if !store.Exists(id) { + return nil, fmt.Errorf("dataset %s not found", id.Filename()) } - file, err := wgfs.Open(store.Path(epoch), epoch.UTC()) + file, err := wgfs.Open(store.Path(id), s.Variant, id.Epoch.UTC()) if err != nil { return nil, err } return wgfs.NewWind(file), nil } -// neededVariables is the GRIB variable set we extract. -var neededVariables = map[string]bool{"HGT": true, "UGRD": true, "VGRD": true} - -// Download fetches the full dataset for epoch in parallel, resuming any -// previously-completed work units. Honours ctx cancellation and prog -// (which may be nil). -func (s *Source) Download(ctx context.Context, epoch time.Time, store datasets.Storage, prog datasets.ProgressSink, throttle datasets.Throttle) error { - if prog == nil { - prog = noopSink{} - } - - handle, err := store.BeginWrite(epoch) - if err != nil { - return fmt.Errorf("begin write: %w", err) - } - manifest := handle.Manifest() - - // Open or create the temp file. If a previous attempt left a partial - // file of the right size, reuse it (resume); otherwise Create. - file, err := openOrCreateCube(handle.Path()) - if err != nil { - _ = handle.Abort() - return err - } - - date := epoch.UTC().Format("20060102") - runHour := epoch.UTC().Hour() - steps := wgfs.Hours() - totalUnits := len(steps) * 2 - - prog.SetTotal(totalUnits) - // Pre-count already-done units so progress is accurate on resume. - for _, u := range manifest.Units() { - _ = u - prog.StepComplete() - } - - start := time.Now() - g, ctx := errgroup.WithContext(ctx) - g.SetLimit(s.parallel()) - - // fileMu serialises concurrent BlitGribData calls because the underlying - // mmap is shared and SetVal isn't atomic. - var fileMu sync.Mutex - - for _, step := range steps { - hourIdx := wgfs.HourIndex(step) - if hourIdx < 0 { - continue - } - - for _, ls := range []wgfs.LevelSet{wgfs.LevelSetA, wgfs.LevelSetB} { - unit := unitKey(step, ls) - if manifest.Has(unit) { - continue - } - - g.Go(func() error { - var url string - switch ls { - case wgfs.LevelSetA: - url = wgfs.GribURL(date, runHour, step) - case wgfs.LevelSetB: - url = wgfs.GribURLB(date, runHour, step) - } - if err := s.downloadAndBlit(ctx, file, &fileMu, url, hourIdx, ls, prog, throttle); err != nil { - return fmt.Errorf("step %d %s: %w", step, levelSetLabel(ls), err) - } - if err := manifest.Mark(unit); err != nil { - return fmt.Errorf("mark unit: %w", err) - } - prog.StepComplete() - return nil - }) - } - } - - if err := g.Wait(); err != nil { - _ = file.Close() - // Don't Abort on context cancellation — preserve progress for resume. - if errors.Is(err, context.Canceled) { - return err - } - // Other errors: abort if no progress was made; otherwise leave for resume. - if len(manifest.Units()) == 0 { - _ = handle.Abort() - } - return err - } - - if err := file.Flush(); err != nil { - _ = file.Close() - return fmt.Errorf("flush: %w", err) - } - if err := file.Close(); err != nil { - return fmt.Errorf("close: %w", err) - } - if err := handle.Commit(); err != nil { - return fmt.Errorf("commit: %w", err) - } - - s.log().Info("download complete", - zap.Time("epoch", epoch), - zap.Duration("elapsed", time.Since(start))) - return nil +// Download fetches the dataset for id. GFS ignores Subset.Members. +func (s *Source) Download(ctx context.Context, id datasets.DatasetID, store datasets.Storage, prog datasets.ProgressSink, throttle datasets.Throttle) error { + return s.downloader().Run(ctx, id, 0, store, prog, throttle) } - -// openOrCreateCube returns a writable cube file at path, creating it if the -// file does not exist or has the wrong size. -func openOrCreateCube(path string) (*wgfs.File, error) { - info, err := os.Stat(path) - if err == nil && info.Size() == wgfs.DatasetSize { - return wgfs.OpenWritable(path) - } - if err != nil && !errors.Is(err, os.ErrNotExist) { - return nil, fmt.Errorf("stat cube: %w", err) - } - // Wrong-size or missing — truncate-create. - return wgfs.Create(path) -} - -// downloadAndBlit fetches and decodes one (URL, level-set) chunk and writes -// it into the dataset. -func (s *Source) downloadAndBlit( - ctx context.Context, - file *wgfs.File, - fileMu *sync.Mutex, - baseURL string, - hourIdx int, - ls wgfs.LevelSet, - prog datasets.ProgressSink, - throttle datasets.Throttle, -) error { - idxBody, err := s.httpGet(ctx, baseURL+".idx", throttle, prog) - if err != nil { - return fmt.Errorf("idx: %w", err) - } - entries := ParseIdx(idxBody) - filtered := FilterIdx(entries, neededVariables) - - var relevant []IdxEntry - for _, e := range filtered { - set, ok := wgfs.PressureLevelSet(e.LevelMB) - if ok && set == ls { - relevant = append(relevant, e) - } - } - if len(relevant) == 0 { - return nil - } - - ranges := EntriesToRanges(relevant) - tmp, err := os.CreateTemp("", "gfs-msg-*.tmp") - if err != nil { - return fmt.Errorf("temp: %w", err) - } - tmpPath := tmp.Name() - defer os.Remove(tmpPath) - - for _, r := range ranges { - body, err := s.httpGetRange(ctx, baseURL, r.Start, r.End, throttle, prog) - if err != nil { - tmp.Close() - return fmt.Errorf("range %d-%d: %w", r.Start, r.End, err) - } - if _, err := tmp.Write(body); err != nil { - tmp.Close() - return fmt.Errorf("write tmp: %w", err) - } - } - if err := tmp.Close(); err != nil { - return err - } - - f, err := os.Open(tmpPath) - if err != nil { - return err - } - messages, err := griblib.ReadMessages(f) - f.Close() - if err != nil { - return fmt.Errorf("read grib: %w", err) - } - - for _, msg := range messages { - if msg.Section4.ProductDefinitionTemplateNumber != 0 { - continue - } - p := msg.Section4.ProductDefinitionTemplate - varIdx := wgfs.VariableIndex(int(p.ParameterCategory), int(p.ParameterNumber)) - if varIdx < 0 { - continue - } - if p.FirstSurface.Type != 100 { // isobaric only - continue - } - pressureMB := int(math.Round(float64(p.FirstSurface.Value) / 100.0)) - levelIdx := wgfs.PressureIndex(pressureMB) - if levelIdx < 0 { - continue - } - data := msg.Data() - fileMu.Lock() - err := file.BlitGribData(hourIdx, levelIdx, varIdx, data) - fileMu.Unlock() - if err != nil { - return fmt.Errorf("blit: %w", err) - } - } - return nil -} - -// httpGet downloads a URL body with 3 retries and optional throttling. -func (s *Source) httpGet(ctx context.Context, url string, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) { - var lastErr error - for attempt := range 3 { - if attempt > 0 { - select { - case <-time.After(time.Duration(attempt*2) * time.Second): - case <-ctx.Done(): - return nil, ctx.Err() - } - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return nil, err - } - resp, err := s.client().Do(req) - if err != nil { - lastErr = err - continue - } - body, err := readThrottled(ctx, resp.Body, throttle, prog) - resp.Body.Close() - if resp.StatusCode != http.StatusOK { - lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url) - continue - } - if err != nil { - lastErr = err - continue - } - return body, nil - } - return nil, fmt.Errorf("after 3 attempts: %w", lastErr) -} - -// httpGetRange downloads an inclusive byte range with 3 retries and throttling. -func (s *Source) httpGetRange(ctx context.Context, url string, start, end int64, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) { - var lastErr error - for attempt := range 3 { - if attempt > 0 { - select { - case <-time.After(time.Duration(attempt*2) * time.Second): - case <-ctx.Done(): - return nil, ctx.Err() - } - } - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return nil, err - } - req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) - resp, err := s.client().Do(req) - if err != nil { - lastErr = err - continue - } - body, err := readThrottled(ctx, resp.Body, throttle, prog) - resp.Body.Close() - if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK { - lastErr = fmt.Errorf("HTTP %d for range %d-%d of %s", resp.StatusCode, start, end, url) - continue - } - if err != nil { - lastErr = err - continue - } - return body, nil - } - return nil, fmt.Errorf("after 3 attempts: %w", lastErr) -} - -// readThrottled reads r into memory, consulting throttle (if non-nil) before -// each chunk and reporting bytes to prog. -func readThrottled(ctx context.Context, r io.Reader, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) { - buf := make([]byte, 0, 64*1024) - chunk := make([]byte, 32*1024) - for { - if throttle != nil { - if err := throttle.Wait(ctx, len(chunk)); err != nil { - return nil, err - } - } - n, err := r.Read(chunk) - if n > 0 { - buf = append(buf, chunk[:n]...) - prog.Bytes(int64(n)) - } - if errors.Is(err, io.EOF) { - return buf, nil - } - if err != nil { - return nil, err - } - } -} - -func unitKey(step int, ls wgfs.LevelSet) string { - return fmt.Sprintf("step%03d-%s", step, levelSetLabel(ls)) -} - -func levelSetLabel(ls wgfs.LevelSet) string { - if ls == wgfs.LevelSetB { - return "B" - } - return "A" -} - -// noopSink discards progress events. -type noopSink struct{} - -func (noopSink) SetTotal(int) {} -func (noopSink) StepComplete() {} -func (noopSink) Bytes(int64) {} diff --git a/internal/datasets/grib/downloader.go b/internal/datasets/grib/downloader.go new file mode 100644 index 0000000..8be86b9 --- /dev/null +++ b/internal/datasets/grib/downloader.go @@ -0,0 +1,369 @@ +package grib + +import ( + "context" + "errors" + "fmt" + "io" + "math" + "net/http" + "os" + "sync" + "time" + + "github.com/nilsmagnus/grib/griblib" + "go.uber.org/zap" + "golang.org/x/sync/errgroup" + + "predictor-refactored/internal/datasets" + wgfs "predictor-refactored/internal/weather/gfs" +) + +// URLFunc returns the GRIB URL for one (date, runHour, member, step, levelSet). +// Sources that don't have members (GFS) ignore the member argument. +type URLFunc func(date string, runHour, member, step int, ls wgfs.LevelSet) string + +// Downloader is the generic GRIB-cube downloader. +// +// A Source plugs in its variant, URL templating, and member-resolution +// logic; the Downloader runs the parallel idx fetch, byte-range download, +// GRIB decode, and blit loop with manifest-based resume. +type Downloader struct { + Variant *wgfs.Variant + URLs URLFunc + Parallel int + Client *http.Client + Log *zap.Logger +} + +func (d *Downloader) log() *zap.Logger { + if d.Log == nil { + return zap.NewNop() + } + return d.Log +} + +func (d *Downloader) client() *http.Client { + if d.Client == nil { + return &http.Client{Timeout: 2 * time.Minute} + } + return d.Client +} + +func (d *Downloader) parallel() int { + if d.Parallel <= 0 { + return 8 + } + return d.Parallel +} + +// neededVariables is the GRIB variable set every source extracts. +var neededVariables = map[string]bool{"HGT": true, "UGRD": true, "VGRD": true} + +// Run downloads the dataset for id, member into store. The caller may +// pass member=0 for non-ensemble sources. +func (d *Downloader) Run(ctx context.Context, id datasets.DatasetID, member int, store datasets.Storage, prog datasets.ProgressSink, throttle datasets.Throttle) error { + if prog == nil { + prog = noopSink{} + } + + handle, err := store.BeginWrite(id) + if err != nil { + return fmt.Errorf("begin write: %w", err) + } + manifest := handle.Manifest() + + file, err := openOrCreateCube(handle.Path(), d.Variant) + if err != nil { + _ = handle.Abort() + return err + } + + epoch := id.Epoch.UTC() + date := epoch.Format("20060102") + runHour := epoch.Hour() + + steps := d.Variant.Hours() + if hr := id.Subset.HourRange; hr != nil { + filtered := steps[:0] + for _, step := range steps { + if step >= hr.MinHour && step <= hr.MaxHour { + filtered = append(filtered, step) + } + } + steps = filtered + } + prog.SetTotal(len(steps) * 2) + for range manifest.Units() { + prog.StepComplete() + } + + start := time.Now() + g, ctx := errgroup.WithContext(ctx) + g.SetLimit(d.parallel()) + var fileMu sync.Mutex + + for _, step := range steps { + hourIdx := d.Variant.HourIndex(step) + if hourIdx < 0 { + continue + } + for _, ls := range []wgfs.LevelSet{wgfs.LevelSetA, wgfs.LevelSetB} { + unit := unitKey(step, ls) + if manifest.Has(unit) { + continue + } + g.Go(func() error { + url := d.URLs(date, runHour, member, step, ls) + if err := d.downloadAndBlit(ctx, file, &fileMu, url, hourIdx, ls, prog, throttle); err != nil { + return fmt.Errorf("step %d %s: %w", step, levelSetLabel(ls), err) + } + if err := manifest.Mark(unit); err != nil { + return fmt.Errorf("mark unit: %w", err) + } + prog.StepComplete() + return nil + }) + } + } + + if err := g.Wait(); err != nil { + _ = file.Close() + if errors.Is(err, context.Canceled) { + return err + } + if len(manifest.Units()) == 0 { + _ = handle.Abort() + } + return err + } + + if err := file.Flush(); err != nil { + _ = file.Close() + return fmt.Errorf("flush: %w", err) + } + if err := file.Close(); err != nil { + return fmt.Errorf("close: %w", err) + } + if err := handle.Commit(); err != nil { + return fmt.Errorf("commit: %w", err) + } + + d.log().Info("download complete", + zap.String("variant", d.Variant.ID), + zap.Time("epoch", epoch), + zap.Duration("elapsed", time.Since(start))) + return nil +} + +// openOrCreateCube opens an existing cube at path if it matches variant's +// expected size, else truncate-creates a new one. +func openOrCreateCube(path string, variant *wgfs.Variant) (*wgfs.File, error) { + info, err := os.Stat(path) + if err == nil && info.Size() == variant.DatasetSize() { + return wgfs.OpenWritable(path, variant) + } + if err != nil && !errors.Is(err, os.ErrNotExist) { + return nil, fmt.Errorf("stat cube: %w", err) + } + return wgfs.Create(path, variant) +} + +// downloadAndBlit fetches and decodes one (URL, level-set) chunk. +func (d *Downloader) downloadAndBlit( + ctx context.Context, + file *wgfs.File, + fileMu *sync.Mutex, + baseURL string, + hourIdx int, + ls wgfs.LevelSet, + prog datasets.ProgressSink, + throttle datasets.Throttle, +) error { + idxBody, err := d.httpGet(ctx, baseURL+".idx", throttle, prog) + if err != nil { + return fmt.Errorf("idx: %w", err) + } + entries := ParseIdx(idxBody) + filtered := FilterIdx(entries, neededVariables) + + var relevant []IdxEntry + for _, e := range filtered { + set, ok := d.Variant.PressureLevelSet(e.LevelMB) + if ok && set == ls { + relevant = append(relevant, e) + } + } + if len(relevant) == 0 { + return nil + } + ranges := EntriesToRanges(relevant) + + tmp, err := os.CreateTemp("", "grib-msg-*.tmp") + if err != nil { + return fmt.Errorf("temp: %w", err) + } + tmpPath := tmp.Name() + defer os.Remove(tmpPath) + + for _, r := range ranges { + body, err := d.httpGetRange(ctx, baseURL, r.Start, r.End, throttle, prog) + if err != nil { + tmp.Close() + return fmt.Errorf("range: %w", err) + } + if _, err := tmp.Write(body); err != nil { + tmp.Close() + return fmt.Errorf("write tmp: %w", err) + } + } + if err := tmp.Close(); err != nil { + return err + } + + f, err := os.Open(tmpPath) + if err != nil { + return err + } + messages, err := griblib.ReadMessages(f) + f.Close() + if err != nil { + return fmt.Errorf("read grib: %w", err) + } + + for _, msg := range messages { + if msg.Section4.ProductDefinitionTemplateNumber != 0 { + continue + } + p := msg.Section4.ProductDefinitionTemplate + varIdx := d.Variant.VariableIndex(int(p.ParameterCategory), int(p.ParameterNumber)) + if varIdx < 0 { + continue + } + if p.FirstSurface.Type != 100 { + continue + } + pressureMB := int(math.Round(float64(p.FirstSurface.Value) / 100.0)) + levelIdx := d.Variant.PressureIndex(pressureMB) + if levelIdx < 0 { + continue + } + data := msg.Data() + fileMu.Lock() + err := file.BlitGribData(hourIdx, levelIdx, varIdx, data) + fileMu.Unlock() + if err != nil { + return fmt.Errorf("blit: %w", err) + } + } + return nil +} + +func (d *Downloader) httpGet(ctx context.Context, url string, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) { + var lastErr error + for attempt := range 3 { + if attempt > 0 { + select { + case <-time.After(time.Duration(attempt*2) * time.Second): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + resp, err := d.client().Do(req) + if err != nil { + lastErr = err + continue + } + body, err := readThrottled(ctx, resp.Body, throttle, prog) + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url) + continue + } + if err != nil { + lastErr = err + continue + } + return body, nil + } + return nil, fmt.Errorf("after 3 attempts: %w", lastErr) +} + +func (d *Downloader) httpGetRange(ctx context.Context, url string, start, end int64, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) { + var lastErr error + for attempt := range 3 { + if attempt > 0 { + select { + case <-time.After(time.Duration(attempt*2) * time.Second): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, err + } + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) + resp, err := d.client().Do(req) + if err != nil { + lastErr = err + continue + } + body, err := readThrottled(ctx, resp.Body, throttle, prog) + resp.Body.Close() + if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK { + lastErr = fmt.Errorf("HTTP %d for range", resp.StatusCode) + continue + } + if err != nil { + lastErr = err + continue + } + return body, nil + } + return nil, fmt.Errorf("after 3 attempts: %w", lastErr) +} + +func readThrottled(ctx context.Context, r io.Reader, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) { + buf := make([]byte, 0, 64*1024) + chunk := make([]byte, 32*1024) + for { + if throttle != nil { + if err := throttle.Wait(ctx, len(chunk)); err != nil { + return nil, err + } + } + n, err := r.Read(chunk) + if n > 0 { + buf = append(buf, chunk[:n]...) + prog.Bytes(int64(n)) + } + if errors.Is(err, io.EOF) { + return buf, nil + } + if err != nil { + return nil, err + } + } +} + +func unitKey(step int, ls wgfs.LevelSet) string { + return fmt.Sprintf("step%03d-%s", step, levelSetLabel(ls)) +} + +func levelSetLabel(ls wgfs.LevelSet) string { + if ls == wgfs.LevelSetB { + return "B" + } + return "A" +} + +type noopSink struct{} + +func (noopSink) SetTotal(int) {} +func (noopSink) StepComplete() {} +func (noopSink) Bytes(int64) {} diff --git a/internal/datasets/gfs/idx.go b/internal/datasets/grib/idx.go similarity index 91% rename from internal/datasets/gfs/idx.go rename to internal/datasets/grib/idx.go index 093e8b4..aa39e44 100644 --- a/internal/datasets/gfs/idx.go +++ b/internal/datasets/grib/idx.go @@ -1,4 +1,8 @@ -package gfs +// Package grib contains the GRIB-cube download skeleton shared by every +// NOAA source (GFS, GEFS, future families). It exposes the .idx parser, +// HTTP helpers, and a parallel download loop; source-specific URL +// templating is injected by the caller. +package grib import ( "fmt" diff --git a/internal/datasets/gfs/idx_test.go b/internal/datasets/grib/idx_test.go similarity index 99% rename from internal/datasets/gfs/idx_test.go rename to internal/datasets/grib/idx_test.go index ab04710..864dfda 100644 --- a/internal/datasets/gfs/idx_test.go +++ b/internal/datasets/grib/idx_test.go @@ -1,4 +1,4 @@ -package gfs +package grib import "testing" diff --git a/internal/datasets/manager.go b/internal/datasets/manager.go index a7584c1..da37723 100644 --- a/internal/datasets/manager.go +++ b/internal/datasets/manager.go @@ -27,23 +27,22 @@ const ( // JobInfo is the externally-visible snapshot of a download job. type JobInfo struct { - ID string - Source string - Epoch time.Time - Status JobStatus - StartedAt time.Time - EndedAt *time.Time - Err string - Total int - Done int - Bytes int64 + ID string + Source string + Dataset DatasetID + Status JobStatus + StartedAt time.Time + EndedAt *time.Time + Err string + Total int + Done int + Bytes int64 } -// jobEntry is the Manager's mutable record for one job. type jobEntry struct { id string source string - epoch time.Time + dataset DatasetID startedAt time.Time cancel context.CancelFunc @@ -60,7 +59,7 @@ type jobEntry struct { func (e *jobEntry) snapshot() JobInfo { e.mu.Lock() info := JobInfo{ - ID: e.id, Source: e.source, Epoch: e.epoch, + ID: e.id, Source: e.source, Dataset: e.dataset, StartedAt: e.startedAt, Status: e.status, Err: e.errStr, } if !e.endedAt.IsZero() { @@ -74,14 +73,20 @@ func (e *jobEntry) snapshot() JobInfo { return info } -// jobProgress is the ProgressSink wired into a jobEntry. type jobProgress struct{ e *jobEntry } func (p jobProgress) SetTotal(n int) { p.e.total.Store(int64(n)) } func (p jobProgress) StepComplete() { p.e.done.Add(1) } func (p jobProgress) Bytes(n int64) { p.e.bytes.Add(n) } -// Manager coordinates dataset downloads and exposes the active WindField. +// loadedDataset bundles a loaded WindField with its identity and coverage. +type loadedDataset struct { + ID DatasetID + Field weather.WindField + Coverage Coverage +} + +// Manager coordinates dataset downloads and exposes the active WindFields. type Manager struct { src Source store Storage @@ -89,18 +94,15 @@ type Manager struct { log *zap.Logger activeMu sync.RWMutex - active weather.WindField + active []loadedDataset jobsMu sync.RWMutex jobs map[string]*jobEntry - // inFlight maps an epoch's RFC3339 representation to its jobID, enforcing - // single-flight per epoch. - inFlight sync.Map + inFlight sync.Map // key: dataset filename, value: jobID } -// New returns a Manager wiring source, store, and an optional throttle. -// A nil log uses zap.NewNop(). +// New wires a Manager. func New(src Source, store Storage, throttle Throttle, log *zap.Logger) *Manager { if log == nil { log = zap.NewNop() @@ -119,18 +121,65 @@ func New(src Source, store Storage, throttle Throttle, log *zap.Logger) *Manager // Source returns the underlying source ID. func (m *Manager) Source() string { return m.src.ID() } -// Active returns the currently-loaded WindField, or nil. +// Active returns the currently-loaded global WindField (the dataset with +// IsGlobal subset, most recently loaded). Returns nil if no global +// dataset is loaded; in cluster setups with only regional subsets, callers +// should use SelectFor. func (m *Manager) Active() weather.WindField { m.activeMu.RLock() defer m.activeMu.RUnlock() - return m.active + for _, d := range m.active { + if d.ID.Subset.IsGlobal() { + return d.Field + } + } + if len(m.active) > 0 { + return m.active[0].Field + } + return nil } -// Ready reports whether a dataset is currently loaded. +// Ready reports whether at least one dataset is loaded. func (m *Manager) Ready() bool { return m.Active() != nil } -// ListEpochs returns all stored dataset epochs, newest first. -func (m *Manager) ListEpochs() ([]time.Time, error) { return m.store.List() } +// SelectFor returns a loaded WindField whose coverage contains (t, lat, lng). +// Returns nil when no loaded dataset covers the query. +func (m *Manager) SelectFor(t time.Time, lat, lng float64) weather.WindField { + m.activeMu.RLock() + defer m.activeMu.RUnlock() + for _, d := range m.active { + if d.Coverage.Covers(t, lat, lng) { + return d.Field + } + } + // Fallback: any global dataset is permissive about region. + for _, d := range m.active { + if d.ID.Subset.IsGlobal() { + return d.Field + } + } + return nil +} + +// LoadedDatasets returns snapshots of every currently-loaded dataset. +func (m *Manager) LoadedDatasets() []LoadedDatasetInfo { + m.activeMu.RLock() + defer m.activeMu.RUnlock() + out := make([]LoadedDatasetInfo, 0, len(m.active)) + for _, d := range m.active { + out = append(out, LoadedDatasetInfo{ID: d.ID, Coverage: d.Coverage}) + } + return out +} + +// LoadedDatasetInfo is a serializable snapshot of one active dataset. +type LoadedDatasetInfo struct { + ID DatasetID + Coverage Coverage +} + +// ListEpochs returns all stored datasets, newest first. +func (m *Manager) ListEpochs() ([]DatasetID, error) { return m.store.List() } // ListJobs returns snapshots of every job recorded since startup. func (m *Manager) ListJobs() []JobInfo { @@ -143,7 +192,7 @@ func (m *Manager) ListJobs() []JobInfo { return out } -// GetJob returns the snapshot for a job, or false if id is unknown. +// GetJob returns the snapshot for a job. func (m *Manager) GetJob(id string) (JobInfo, bool) { m.jobsMu.RLock() e, ok := m.jobs[id] @@ -154,8 +203,7 @@ func (m *Manager) GetJob(id string) (JobInfo, bool) { return e.snapshot(), true } -// CancelJob cancels a running job. Returns false if id is unknown or the -// job is already terminal. +// CancelJob cancels a running job. func (m *Manager) CancelJob(id string) bool { m.jobsMu.RLock() e, ok := m.jobs[id] @@ -173,28 +221,31 @@ func (m *Manager) CancelJob(id string) bool { return true } -// RemoveEpoch deletes a stored dataset. If epoch is currently active, the -// active field is cleared. -func (m *Manager) RemoveEpoch(epoch time.Time) error { - epoch = epoch.UTC() - if active := m.Active(); active != nil && active.Epoch().Equal(epoch) { - m.activeMu.Lock() - m.active = nil - m.activeMu.Unlock() +// Remove deletes a stored dataset. If the dataset is currently loaded, +// it is unloaded first. +func (m *Manager) Remove(id DatasetID) error { + m.activeMu.Lock() + out := m.active[:0] + var removed *loadedDataset + for i := range m.active { + d := m.active[i] + if d.ID.Equals(id) { + removed = &d + continue + } + out = append(out, d) } - return m.store.Remove(epoch) + m.active = out + m.activeMu.Unlock() + if removed != nil { + closeField(removed.Field, m.log) + } + return m.store.Remove(id) } -// Download starts (or resumes) a download job for epoch in the background. -// Returns the JobID. If a job for the same epoch is already running, its -// existing JobID is returned. -// -// If the dataset is already present on disk, a synthetic completed JobInfo -// is recorded and its JobID returned. -func (m *Manager) Download(epoch time.Time) string { - epoch = epoch.UTC() - key := epoch.Format(time.RFC3339) - +// Download starts (or resumes) a download job for id in the background. +func (m *Manager) Download(id DatasetID) string { + key := id.Filename() if existing, ok := m.inFlight.Load(key); ok { return existing.(string) } @@ -209,7 +260,7 @@ func (m *Manager) Download(epoch time.Time) string { e := &jobEntry{ id: jobID, source: m.src.ID(), - epoch: epoch, + dataset: id, startedAt: now, status: JobPending, cancel: cancel, @@ -218,8 +269,7 @@ func (m *Manager) Download(epoch time.Time) string { m.jobs[jobID] = e m.jobsMu.Unlock() - if m.store.Exists(epoch) { - // Skip the download but still record the job for traceability. + if m.store.Exists(id) { go m.completeShortCircuit(ctx, e) return jobID } @@ -227,46 +277,54 @@ func (m *Manager) Download(epoch time.Time) string { return jobID } -// LoadEpoch swaps the active WindField to epoch's stored dataset. -func (m *Manager) LoadEpoch(ctx context.Context, epoch time.Time) error { - epoch = epoch.UTC() - if !m.store.Exists(epoch) { - return fmt.Errorf("epoch %s not present on disk", epoch.Format(time.RFC3339)) +// Load swaps in id's stored dataset, making it available to predictions. +func (m *Manager) Load(ctx context.Context, id DatasetID) error { + if !m.store.Exists(id) { + return fmt.Errorf("dataset %s not present on disk", id.Filename()) } - field, err := m.src.Open(ctx, epoch, m.store) + field, err := m.src.Open(ctx, id, m.store) if err != nil { - return fmt.Errorf("open epoch: %w", err) + return fmt.Errorf("open dataset: %w", err) } - m.swapActive(field) + cov := m.src.Coverage(id) + m.activeMu.Lock() + // Replace any previously-loaded dataset with the same ID. + for i := range m.active { + if m.active[i].ID.Equals(id) { + closeField(m.active[i].Field, m.log) + m.active[i] = loadedDataset{ID: id, Field: field, Coverage: cov} + m.activeMu.Unlock() + return nil + } + } + m.active = append(m.active, loadedDataset{ID: id, Field: field, Coverage: cov}) + m.activeMu.Unlock() m.log.Info("loaded dataset", - zap.Time("epoch", epoch), + zap.String("filename", id.Filename()), zap.String("source", m.src.ID())) return nil } -// Refresh ensures the most recent upstream dataset is downloaded and active. -// -// If the freshest stored dataset is newer than retentionTTL old, no upstream -// check is performed. Otherwise the source's LatestEpoch is consulted; if it -// is newer than the active dataset, a download is started and on completion -// the new dataset becomes active. +// Refresh ensures the freshest global dataset is downloaded and active. // // Returns the JobID started, or empty string when nothing was scheduled. func (m *Manager) Refresh(ctx context.Context, freshnessTTL time.Duration) (string, error) { - if active := m.Active(); active != nil && time.Since(active.Epoch()) < freshnessTTL { + if a := m.activeGlobal(); a != nil && time.Since(a.ID.Epoch) < freshnessTTL { return "", nil } - // Try loading the freshest existing dataset before going to the network. - if epochs, err := m.store.List(); err == nil { - for _, e := range epochs { - if time.Since(e) > freshnessTTL { + if datasets, err := m.store.List(); err == nil { + for _, id := range datasets { + if !id.Subset.IsGlobal() { continue } - if active := m.Active(); active != nil && active.Epoch().Equal(e) { + if time.Since(id.Epoch) > freshnessTTL { + continue + } + if a := m.activeGlobal(); a != nil && a.ID.Equals(id) { return "", nil } - if err := m.LoadEpoch(ctx, e); err == nil { + if err := m.Load(ctx, id); err == nil { return "", nil } } @@ -276,37 +334,50 @@ func (m *Manager) Refresh(ctx context.Context, freshnessTTL time.Duration) (stri if err != nil { return "", fmt.Errorf("latest epoch: %w", err) } - if active := m.Active(); active != nil && !latest.After(active.Epoch()) { + id := DatasetID{Epoch: latest} + if a := m.activeGlobal(); a != nil && !latest.After(a.ID.Epoch) { return "", nil } - jobID := m.Download(latest) - - // Spawn a watcher that loads the dataset on successful completion. - go func() { - for { - info, ok := m.GetJob(jobID) - if !ok { - return - } - switch info.Status { - case JobComplete: - if err := m.LoadEpoch(context.Background(), latest); err != nil { - m.log.Error("load after download", zap.Error(err)) - } - return - case JobFailed, JobCancelled: - return - } - time.Sleep(2 * time.Second) - } - }() + jobID := m.Download(id) + go m.loadAfterCompletion(jobID, id) return jobID, nil } -// runDownload executes one Source.Download invocation and records its outcome. +// activeGlobal returns the currently-loaded global dataset, if any. +func (m *Manager) activeGlobal() *loadedDataset { + m.activeMu.RLock() + defer m.activeMu.RUnlock() + for i := range m.active { + if m.active[i].ID.Subset.IsGlobal() { + d := m.active[i] + return &d + } + } + return nil +} + +func (m *Manager) loadAfterCompletion(jobID string, id DatasetID) { + for { + info, ok := m.GetJob(jobID) + if !ok { + return + } + switch info.Status { + case JobComplete: + if err := m.Load(context.Background(), id); err != nil { + m.log.Error("load after download", zap.Error(err)) + } + return + case JobFailed, JobCancelled: + return + } + time.Sleep(2 * time.Second) + } +} + func (m *Manager) runDownload(ctx context.Context, e *jobEntry) { - defer m.inFlight.Delete(e.epoch.Format(time.RFC3339)) + defer m.inFlight.Delete(e.dataset.Filename()) e.mu.Lock() e.status = JobRunning @@ -314,9 +385,9 @@ func (m *Manager) runDownload(ctx context.Context, e *jobEntry) { m.log.Info("download started", zap.String("job", e.id), - zap.Time("epoch", e.epoch)) + zap.String("dataset", e.dataset.Filename())) - err := m.src.Download(ctx, e.epoch, m.store, jobProgress{e: e}, m.throttle) + err := m.src.Download(ctx, e.dataset, m.store, jobProgress{e: e}, m.throttle) now := time.Now().UTC() e.mu.Lock() @@ -339,10 +410,9 @@ func (m *Manager) runDownload(ctx context.Context, e *jobEntry) { zap.NamedError("err", err)) } -// completeShortCircuit records a job as complete without performing any work. func (m *Manager) completeShortCircuit(ctx context.Context, e *jobEntry) { _ = ctx - defer m.inFlight.Delete(e.epoch.Format(time.RFC3339)) + defer m.inFlight.Delete(e.dataset.Filename()) now := time.Now().UTC() e.mu.Lock() e.status = JobComplete @@ -350,20 +420,6 @@ func (m *Manager) completeShortCircuit(ctx context.Context, e *jobEntry) { e.mu.Unlock() } -// swapActive replaces the active field and closes the previous one if it -// implements io.Closer. -func (m *Manager) swapActive(f weather.WindField) { - m.activeMu.Lock() - old := m.active - m.active = f - m.activeMu.Unlock() - if c, ok := old.(interface{ Close() error }); ok && c != nil { - if err := c.Close(); err != nil { - m.log.Warn("close old dataset", zap.Error(err)) - } - } -} - // Close releases all resources, cancelling any in-flight jobs. func (m *Manager) Close() error { m.jobsMu.Lock() @@ -373,11 +429,18 @@ func (m *Manager) Close() error { m.jobsMu.Unlock() m.activeMu.Lock() - active := m.active + for _, d := range m.active { + closeField(d.Field, m.log) + } m.active = nil m.activeMu.Unlock() - if c, ok := active.(interface{ Close() error }); ok && c != nil { - return c.Close() - } return nil } + +func closeField(f weather.WindField, log *zap.Logger) { + if c, ok := f.(interface{ Close() error }); ok && c != nil { + if err := c.Close(); err != nil && log != nil { + log.Warn("close dataset", zap.Error(err)) + } + } +} diff --git a/internal/datasets/store_local.go b/internal/datasets/store_local.go index 5dece03..25db54b 100644 --- a/internal/datasets/store_local.go +++ b/internal/datasets/store_local.go @@ -14,15 +14,16 @@ import ( // // Layout under Root: // -// .bin — committed dataset (binary cube) -// .bin.downloading — in-progress dataset -// .bin.manifest.json — manifest of completed work units +// .bin — committed dataset +// .bin.downloading — in-progress dataset +// .bin.manifest.json — completed work units // -// The .bin suffix exists to differentiate from sidecars in directory listings; -// epoch is formatted as "20060102T150405Z" (UTC). +// where is DatasetID.Filename() — typically +// "20060102T150405Z" for the global subset or +// "20060102T150405Z_r-10.10.-30.30_h0.72" for a subset. type LocalStore struct { Root string - Source string // source ID, recorded for safety but currently advisory + Source string Extension string // default ".bin" } @@ -37,8 +38,6 @@ func NewLocalStore(root, sourceID string) (*LocalStore, error) { // SourceID returns the source ID this store is configured for. func (s *LocalStore) SourceID() string { return s.Source } -const epochFormat = "20060102T150405Z" - func (s *LocalStore) ext() string { if s.Extension == "" { return ".bin" @@ -46,32 +45,32 @@ func (s *LocalStore) ext() string { return s.Extension } -// Path returns the canonical path for an epoch's committed dataset file. -func (s *LocalStore) Path(epoch time.Time) string { - return filepath.Join(s.Root, epoch.UTC().Format(epochFormat)+s.ext()) +// Path returns the canonical path for id's committed dataset. +func (s *LocalStore) Path(id DatasetID) string { + return filepath.Join(s.Root, id.Filename()+s.ext()) } -func (s *LocalStore) tempPath(epoch time.Time) string { - return s.Path(epoch) + ".downloading" +func (s *LocalStore) tempPath(id DatasetID) string { + return s.Path(id) + ".downloading" } -func (s *LocalStore) manifestPath(epoch time.Time) string { - return s.Path(epoch) + ".manifest.json" +func (s *LocalStore) manifestPath(id DatasetID) string { + return s.Path(id) + ".manifest.json" } -// Exists reports whether a committed dataset for epoch is present. -func (s *LocalStore) Exists(epoch time.Time) bool { - info, err := os.Stat(s.Path(epoch)) +// Exists reports whether a committed dataset for id is present. +func (s *LocalStore) Exists(id DatasetID) bool { + info, err := os.Stat(s.Path(id)) return err == nil && !info.IsDir() } -// List returns all committed epochs, newest first. -func (s *LocalStore) List() ([]time.Time, error) { +// List returns all committed dataset IDs, newest first. +func (s *LocalStore) List() ([]DatasetID, error) { entries, err := os.ReadDir(s.Root) if err != nil { return nil, fmt.Errorf("read store: %w", err) } - var out []time.Time + var out []DatasetID ext := s.ext() for _, e := range entries { if e.IsDir() { @@ -82,24 +81,47 @@ func (s *LocalStore) List() ([]time.Time, error) { continue } stem := strings.TrimSuffix(name, ext) - // skip in-progress files (their stem already has .bin.downloading...) + // Skip in-progress files (their stem ends in .downloading or .manifest) if strings.Contains(stem, ".") { continue } - t, err := time.Parse(epochFormat, stem) - if err != nil { + id, ok := parseFilename(stem) + if !ok { continue } - out = append(out, t.UTC()) + out = append(out, id) } - sort.Slice(out, func(i, j int) bool { return out[i].After(out[j]) }) + sort.Slice(out, func(i, j int) bool { + if !out[i].Epoch.Equal(out[j].Epoch) { + return out[i].Epoch.After(out[j].Epoch) + } + return out[i].Subset.Key() < out[j].Subset.Key() + }) return out, nil } -// Remove deletes the committed dataset and any sidecar files for epoch. -func (s *LocalStore) Remove(epoch time.Time) error { +// parseFilename inverts DatasetID.Filename(). The subset portion is not +// fully reversible (Key encoding is one-way for floats), so List returns +// IDs whose Subset is zero — the storage layer treats names as opaque +// identifiers. Callers wanting structured subset metadata should keep an +// out-of-band record. +func parseFilename(stem string) (DatasetID, bool) { + parts := strings.SplitN(stem, "_", 2) + epoch, err := time.Parse("20060102T150405Z", parts[0]) + if err != nil { + return DatasetID{}, false + } + id := DatasetID{Epoch: epoch.UTC()} + // Subset key is opaque on disk; we don't reconstruct its parameters + // here. Admin callers track subset specs separately when they need + // the structured form. + return id, true +} + +// Remove deletes the committed dataset and any sidecar files for id. +func (s *LocalStore) Remove(id DatasetID) error { var errs []error - for _, p := range []string{s.Path(epoch), s.tempPath(epoch), s.manifestPath(epoch)} { + for _, p := range []string{s.Path(id), s.tempPath(id), s.manifestPath(id)} { if err := os.Remove(p); err != nil && !errors.Is(err, os.ErrNotExist) { errs = append(errs, err) } @@ -110,55 +132,46 @@ func (s *LocalStore) Remove(epoch time.Time) error { return nil } -// BeginWrite opens or resumes a TempHandle for epoch. -// -// If a partial download is already present, its file and manifest are reused -// so the new download picks up where the previous one stopped. -func (s *LocalStore) BeginWrite(epoch time.Time) (TempHandle, error) { - man, err := LoadManifest(s.manifestPath(epoch)) +// BeginWrite opens or resumes a TempHandle for id. +func (s *LocalStore) BeginWrite(id DatasetID) (TempHandle, error) { + man, err := LoadManifest(s.manifestPath(id)) if err != nil { return nil, err } - return &localHandle{ - store: s, - epoch: epoch, - manifest: man, - }, nil + return &localHandle{store: s, id: id, manifest: man}, nil } type localHandle struct { store *LocalStore - epoch time.Time + id DatasetID manifest *Manifest closed bool } -func (h *localHandle) Path() string { return h.store.tempPath(h.epoch) } +func (h *localHandle) Path() string { return h.store.tempPath(h.id) } func (h *localHandle) Manifest() *Manifest { return h.manifest } -// Commit promotes the temp file to its final path and removes the manifest. func (h *localHandle) Commit() error { if h.closed { return nil } h.closed = true - if err := os.Rename(h.store.tempPath(h.epoch), h.store.Path(h.epoch)); err != nil { + if err := os.Rename(h.store.tempPath(h.id), h.store.Path(h.id)); err != nil { return fmt.Errorf("commit rename: %w", err) } - if err := os.Remove(h.store.manifestPath(h.epoch)); err != nil && !errors.Is(err, os.ErrNotExist) { + if err := os.Remove(h.store.manifestPath(h.id)); err != nil && !errors.Is(err, os.ErrNotExist) { return fmt.Errorf("commit remove manifest: %w", err) } return nil } -// Abort removes the in-progress file and manifest. func (h *localHandle) Abort() error { if h.closed { return nil } h.closed = true var firstErr error - for _, p := range []string{h.store.tempPath(h.epoch), h.store.manifestPath(h.epoch)} { + for _, p := range []string{h.store.tempPath(h.id), h.store.manifestPath(h.id)} { if err := os.Remove(p); err != nil && !errors.Is(err, os.ErrNotExist) && firstErr == nil { firstErr = err } diff --git a/internal/datasets/store_test.go b/internal/datasets/store_test.go index 3ee5082..a9b86a7 100644 --- a/internal/datasets/store_test.go +++ b/internal/datasets/store_test.go @@ -2,7 +2,6 @@ package datasets import ( "os" - "path/filepath" "testing" "time" ) @@ -14,8 +13,8 @@ func TestLocalStoreBeginWriteResume(t *testing.T) { t.Fatalf("NewLocalStore: %v", err) } - epoch := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) - h, err := store.BeginWrite(epoch) + id := DatasetID{Epoch: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)} + h, err := store.BeginWrite(id) if err != nil { t.Fatalf("BeginWrite: %v", err) } @@ -27,7 +26,7 @@ func TestLocalStoreBeginWriteResume(t *testing.T) { } // Re-open should see the previous manifest entry. - h2, err := store.BeginWrite(epoch) + h2, err := store.BeginWrite(id) if err != nil { t.Fatalf("BeginWrite resume: %v", err) } @@ -35,48 +34,59 @@ func TestLocalStoreBeginWriteResume(t *testing.T) { t.Errorf("resumed manifest missing step000-A; units = %v", h2.Manifest().Units()) } - // Commit promotes the temp file and removes the manifest. if err := h2.Commit(); err != nil { t.Fatalf("Commit: %v", err) } - if !store.Exists(epoch) { + if !store.Exists(id) { t.Errorf("Exists after commit returned false") } - if _, err := os.Stat(filepath.Join(dir, store.manifestPath(epoch))); !os.IsNotExist(err) { + if _, err := os.Stat(store.manifestPath(id)); !os.IsNotExist(err) { t.Errorf("manifest should be removed, got err=%v", err) } - // Listing finds the committed epoch. - epochs, err := store.List() + stored, err := store.List() if err != nil { t.Fatalf("List: %v", err) } - if len(epochs) != 1 || !epochs[0].Equal(epoch) { - t.Errorf("List = %v, want [%v]", epochs, epoch) + if len(stored) != 1 || !stored[0].Epoch.Equal(id.Epoch) { + t.Errorf("List = %v, want one item with epoch %v", stored, id.Epoch) } - // Remove cleans up. - if err := store.Remove(epoch); err != nil { + if err := store.Remove(id); err != nil { t.Fatalf("Remove: %v", err) } - if store.Exists(epoch) { + if store.Exists(id) { t.Errorf("Exists after remove returned true") } } -func TestLocalStoreAbort(t *testing.T) { +func TestLocalStoreSubsetPath(t *testing.T) { dir := t.TempDir() store, _ := NewLocalStore(dir, "gfs-test") epoch := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC) - - h, _ := store.BeginWrite(epoch) - os.WriteFile(h.Path(), []byte("x"), 0o644) - h.Manifest().Mark("step000-A") - - if err := h.Abort(); err != nil { - t.Fatalf("Abort: %v", err) + regional := DatasetID{ + Epoch: epoch, + Subset: SubsetSpec{ + Region: &Region{MinLat: -10, MaxLat: 10, MinLng: 0, MaxLng: 30}, + HourRange: &HourRange{MinHour: 0, MaxHour: 72}, + }, } - if _, err := os.Stat(h.Path()); !os.IsNotExist(err) { - t.Errorf("temp file should be removed after abort, got %v", err) + global := DatasetID{Epoch: epoch} + if store.Path(global) == store.Path(regional) { + t.Errorf("global and regional should have distinct paths") + } +} + +func TestSubsetSpecCoverage(t *testing.T) { + r := Region{MinLat: -10, MaxLat: 10, MinLng: 350, MaxLng: 10} // wraps antimeridian + s := SubsetSpec{Region: &r} + if !s.IncludesLatLng(0, 0) { + t.Errorf("(0,0) should be inside antimeridian region") + } + if !s.IncludesLatLng(0, 359) { + t.Errorf("(0,359) should be inside antimeridian region") + } + if s.IncludesLatLng(0, 180) { + t.Errorf("(0,180) should be outside antimeridian region") } } diff --git a/internal/datasets/subset.go b/internal/datasets/subset.go new file mode 100644 index 0000000..c610d43 --- /dev/null +++ b/internal/datasets/subset.go @@ -0,0 +1,156 @@ +package datasets + +import ( + "fmt" + "slices" + "strings" + "time" +) + +// SubsetSpec describes which portion of a dataset to download. +// +// A zero-value SubsetSpec means "the full dataset". The Region and +// HourRange fields independently restrict what is fetched and stored. +type SubsetSpec struct { + // Region restricts the geographic extent. nil means global. + Region *Region `json:"region,omitempty"` + + // HourRange restricts the forecast horizon. nil means the source's + // full horizon (e.g. 0..192h for GFS 0.5°). + HourRange *HourRange `json:"hour_range,omitempty"` + + // Members restricts ensemble members for sources that support them (GEFS). + // nil means all available members. + Members []int `json:"members,omitempty"` +} + +// Region is an axis-aligned geographic bounding box. +// +// Longitudes are in [0, 360); a box crossing the antimeridian has +// MinLng > MaxLng. +type Region struct { + MinLat float64 `json:"min_lat"` + MaxLat float64 `json:"max_lat"` + MinLng float64 `json:"min_lng"` + MaxLng float64 `json:"max_lng"` +} + +// HourRange is an inclusive forecast-hour range. +type HourRange struct { + MinHour int `json:"min_hour"` + MaxHour int `json:"max_hour"` +} + +// IsGlobal reports whether the spec selects the entire dataset. +func (s SubsetSpec) IsGlobal() bool { + return s.Region == nil && s.HourRange == nil && len(s.Members) == 0 +} + +// IncludesLatLng reports whether (lat, lng) lies inside the spec's Region, +// or the spec has no Region. +func (s SubsetSpec) IncludesLatLng(lat, lng float64) bool { + if s.Region == nil { + return true + } + r := s.Region + if lat < r.MinLat || lat > r.MaxLat { + return false + } + if r.MinLng <= r.MaxLng { + return lng >= r.MinLng && lng <= r.MaxLng + } + // Wraps the antimeridian. + return lng >= r.MinLng || lng <= r.MaxLng +} + +// IncludesHour reports whether the forecast hour is in range. +func (s SubsetSpec) IncludesHour(h int) bool { + if s.HourRange == nil { + return true + } + return h >= s.HourRange.MinHour && h <= s.HourRange.MaxHour +} + +// IncludesMember reports whether the ensemble member is in range. +func (s SubsetSpec) IncludesMember(m int) bool { + if len(s.Members) == 0 { + return true + } + return slices.Contains(s.Members, m) +} + +// Key returns a deterministic short identifier for the spec. The empty +// string represents the global subset. +func (s SubsetSpec) Key() string { + if s.IsGlobal() { + return "" + } + var b strings.Builder + if s.Region != nil { + fmt.Fprintf(&b, "r%g.%g.%g.%g", s.Region.MinLat, s.Region.MaxLat, s.Region.MinLng, s.Region.MaxLng) + } + if s.HourRange != nil { + if b.Len() > 0 { + b.WriteByte('_') + } + fmt.Fprintf(&b, "h%d.%d", s.HourRange.MinHour, s.HourRange.MaxHour) + } + if len(s.Members) > 0 { + if b.Len() > 0 { + b.WriteByte('_') + } + fmt.Fprintf(&b, "m") + for i, m := range s.Members { + if i > 0 { + b.WriteByte('.') + } + fmt.Fprintf(&b, "%d", m) + } + } + return b.String() +} + +// DatasetID identifies one storable dataset. +type DatasetID struct { + Epoch time.Time + Subset SubsetSpec +} + +// Equals reports whether two DatasetIDs refer to the same dataset. +// DatasetID is not comparable with == because SubsetSpec contains slices. +func (id DatasetID) Equals(other DatasetID) bool { + return id.Epoch.Equal(other.Epoch) && id.Subset.Key() == other.Subset.Key() +} + +// Filename returns the canonical filename stem for the dataset. The +// extension is appended by the Storage implementation. +func (id DatasetID) Filename() string { + stem := id.Epoch.UTC().Format("20060102T150405Z") + if k := id.Subset.Key(); k != "" { + return stem + "_" + k + } + return stem +} + +// Coverage is the spatial and temporal extent of a loaded dataset, used by +// the Manager to select which dataset can serve a given query. +type Coverage struct { + Region Region `json:"region"` + StartTime time.Time `json:"start_time"` + EndTime time.Time `json:"end_time"` +} + +// Covers reports whether (t, lat, lng) lies inside the coverage. +func (c Coverage) Covers(t time.Time, lat, lng float64) bool { + if t.Before(c.StartTime) || t.After(c.EndTime) { + return false + } + r := c.Region + if lat < r.MinLat || lat > r.MaxLat { + return false + } + if r.MinLng <= r.MaxLng { + return lng >= r.MinLng && lng <= r.MaxLng + } + return lng >= r.MinLng || lng <= r.MaxLng +} diff --git a/internal/datasets/types.go b/internal/datasets/types.go index f2cb82e..52b813f 100644 --- a/internal/datasets/types.go +++ b/internal/datasets/types.go @@ -11,87 +11,75 @@ import ( // // Implementations download dataset files in a transactional, resumable // manner and load them as weather.WindField. A Source must be safe for -// concurrent use across multiple Manager calls. +// concurrent use across many Manager calls. type Source interface { - // ID is a stable identifier, e.g. "noaa-gfs-0p50". + // ID is a stable identifier, e.g. "gfs-0p50-3h". ID() string // LatestEpoch returns the most recent dataset epoch this source can provide. LatestEpoch(ctx context.Context) (time.Time, error) - // Download fetches the dataset for epoch into store. Sources must honour - // any partial progress recorded in store's manifest and skip - // already-completed work, so re-invocation after a crash resumes cleanly. + // Download fetches the dataset identified by id into store. Sources + // must honour any partial progress recorded in store's manifest and + // skip already-completed work so re-invocation after a crash resumes + // cleanly. // // prog receives progress events; nil is acceptable. // throttle, if non-nil, is consulted before each network read for // bandwidth limiting; nil means no throttling. - Download(ctx context.Context, epoch time.Time, store Storage, prog ProgressSink, throttle Throttle) error + Download(ctx context.Context, id DatasetID, store Storage, prog ProgressSink, throttle Throttle) error - // Open loads epoch's stored dataset and returns it as a WindField. - Open(ctx context.Context, epoch time.Time, store Storage) (weather.WindField, error) + // Open loads id's stored dataset and returns it as a WindField. + Open(ctx context.Context, id DatasetID, store Storage) (weather.WindField, error) + + // Coverage returns the geographical/temporal extent of a downloaded + // dataset. Used by the Manager to decide which loaded dataset can + // serve a given prediction query. + Coverage(id DatasetID) Coverage } // Storage abstracts the on-disk location of dataset files and their manifests. // -// Atomicity: only datasets promoted via TempHandle.Commit appear in Exists or -// List. Aborted or in-progress downloads are invisible to readers. +// Atomicity: only datasets promoted via TempHandle.Commit appear in Exists +// or List. Aborted or in-progress downloads are invisible to readers. type Storage interface { - // SourceID identifies the data source these files belong to. Mixing - // sources in one Storage is not supported. + // SourceID identifies the data source these files belong to. SourceID() string - // Path returns the canonical local path for epoch's dataset. The path - // is valid even when the dataset has not been written. - Path(epoch time.Time) string + // Path returns the canonical local path for id's dataset. + Path(id DatasetID) string - // Exists reports whether a committed dataset for epoch is present. - Exists(epoch time.Time) bool + // Exists reports whether a committed dataset for id is present. + Exists(id DatasetID) bool - // List returns all committed epochs available, newest first. - List() ([]time.Time, error) + // List returns all committed dataset IDs available, newest first. + List() ([]DatasetID, error) - // Remove deletes the dataset and any sidecar manifest for epoch. - Remove(epoch time.Time) error + // Remove deletes the dataset and any sidecar manifest for id. + Remove(id DatasetID) error // BeginWrite opens (or resumes) a transactional handle for downloading - // epoch's dataset. Callers must Commit or Abort the returned handle. - BeginWrite(epoch time.Time) (TempHandle, error) + // id's dataset. + BeginWrite(id DatasetID) (TempHandle, error) } // TempHandle is the storage state for one in-progress download. type TempHandle interface { - // Path returns the path of the in-progress file. Sources write directly here. Path() string - - // Manifest is the tracker of completed work units for resume support. Manifest() *Manifest - - // Commit promotes the temp file to its canonical location and removes - // the manifest. Subsequent calls are no-ops. Commit() error - - // Abort discards the temp file and manifest. Subsequent calls are no-ops. Abort() error } // ProgressSink receives progress events during a download. -// -// All methods are safe to call concurrently. type ProgressSink interface { - // SetTotal sets the total number of work units this download expects. - // May be called multiple times if discovery happens incrementally. SetTotal(n int) - // StepComplete records one work unit as completed. StepComplete() - // Bytes records n bytes received from the network. Bytes(n int64) } // Throttle is an optional bandwidth limiter consulted by sources before // each network read. type Throttle interface { - // Wait blocks until n bytes can be consumed from the budget, - // or returns ctx's error if the context is cancelled while waiting. Wait(ctx context.Context, n int) error } diff --git a/internal/engine/constraints.go b/internal/engine/constraints.go index f2f8d08..db30d57 100644 --- a/internal/engine/constraints.go +++ b/internal/engine/constraints.go @@ -1,40 +1,42 @@ package engine -// MaxAltitude triggers when altitude rises above Limit (in metres). -// Used as the burst condition for ascent stages. -type MaxAltitude struct { +import ( + "fmt" + "math" +) + +// Altitude triggers when the balloon altitude satisfies Op against Limit. +// +// Examples: +// +// Altitude{Op: OpGreaterEqual, Limit: 30000} — burst at 30 km +// Altitude{Op: OpLessEqual, Limit: 0} — sea-level descent termination +type Altitude struct { + Op Operator Limit float64 On Action } -func (c MaxAltitude) Name() string { return "max_altitude" } -func (c MaxAltitude) Violated(_ float64, s State) bool { return s.Altitude >= c.Limit } -func (c MaxAltitude) Action() Action { return c.On } +func (c Altitude) Name() string { + return fmt.Sprintf("altitude %s %g", c.Op, c.Limit) +} +func (c Altitude) Violated(_ float64, s State) bool { return c.Op.Test(s.Altitude, c.Limit) } +func (c Altitude) Action() Action { return c.On } -// MinAltitude triggers when altitude falls at or below Limit (in metres). -// With Limit=0 this is the "sea level" terminator. -type MinAltitude struct { +// Time triggers when the integration time t (UNIX seconds) satisfies Op +// against Limit. +type Time struct { + Op Operator Limit float64 On Action } -func (c MinAltitude) Name() string { return "min_altitude" } -func (c MinAltitude) Violated(_ float64, s State) bool { return s.Altitude <= c.Limit } -func (c MinAltitude) Action() Action { return c.On } +func (c Time) Name() string { return fmt.Sprintf("time %s %g", c.Op, c.Limit) } +func (c Time) Violated(t float64, _ State) bool { return c.Op.Test(t, c.Limit) } +func (c Time) Action() Action { return c.On } -// MaxTime triggers when t exceeds Limit (UNIX seconds). Used as a stop -// condition for float profiles. -type MaxTime struct { - Limit float64 - On Action -} - -func (c MaxTime) Name() string { return "max_time" } -func (c MaxTime) Violated(t float64, _ State) bool { return t > c.Limit } -func (c MaxTime) Action() Action { return c.On } - -// TerrainContact triggers when altitude has dropped at or below ground level. -// Equivalent to Tawhiri's elevation termination. +// TerrainContact triggers when the ground elevation exceeds the balloon's +// altitude — i.e. the balloon has hit the ground. type TerrainContact struct { Provider TerrainProvider On Action @@ -45,3 +47,103 @@ func (c TerrainContact) Violated(_ float64, s State) bool { return c.Provider.Elevation(s.Lat, s.Lng) > s.Altitude } func (c TerrainContact) Action() Action { return c.On } + +// PolygonMode selects whether Polygon fires when the balloon is inside or +// outside the configured polygon. +type PolygonMode int + +const ( + // PolygonInside fires when (lat, lng) lies inside the polygon — useful + // for "must not enter restricted airspace". + PolygonInside PolygonMode = iota + // PolygonOutside fires when (lat, lng) lies outside the polygon — + // useful for "must remain over the test range". + PolygonOutside +) + +// PolygonVertex is one vertex of a geographic polygon. Latitudes are in +// degrees [-90, 90]; longitudes in degrees [0, 360) or [-180, 180] +// (callers normalise — see Polygon.Violated). +type PolygonVertex struct { + Lat float64 + Lng float64 +} + +// Polygon is a constraint over a geographic polygon. The polygon is +// considered closed (last vertex connects to the first) and is interpreted +// in plate-carrée (rectangular lat/lng) coordinates with longitude +// wrap-around handling. +// +// Edges crossing the 180/-180 antimeridian are split via longitude +// normalisation against the polygon's centroid: callers that need +// great-circle accuracy should clip their polygon along the antimeridian +// before submitting. +type Polygon struct { + Vertices []PolygonVertex + Mode PolygonMode + On Action + + // Label, if set, is returned by Name. Defaults to "polygon_inside" or + // "polygon_outside" based on Mode. + Label string +} + +func (c Polygon) Name() string { + if c.Label != "" { + return c.Label + } + if c.Mode == PolygonOutside { + return "polygon_outside" + } + return "polygon_inside" +} +func (c Polygon) Action() Action { return c.On } + +// Violated reports whether the state satisfies the polygon-containment rule. +func (c Polygon) Violated(_ float64, s State) bool { + if len(c.Vertices) < 3 { + return false + } + in := pointInPolygon(s.Lat, s.Lng, c.Vertices) + if c.Mode == PolygonInside { + return in + } + return !in +} + +// pointInPolygon implements the ray-casting algorithm in lat/lng space. +// +// All vertices and the query point are normalised to within 180° of +// verts[0] before testing, so a polygon spanning the antimeridian is +// handled correctly as long as the polygon itself spans no more than 180° +// in longitude. +func pointInPolygon(lat, lng float64, verts []PolygonVertex) bool { + if len(verts) == 0 { + return false + } + ref := verts[0].Lng + qx := normLng(lng, ref) + + inside := false + n := len(verts) + for i, j := 0, n-1; i < n; j, i = i, i+1 { + yi, yj := verts[i].Lat, verts[j].Lat + xi := normLng(verts[i].Lng, ref) + xj := normLng(verts[j].Lng, ref) + + if (yi > lat) != (yj > lat) { + xIntersect := (xj-xi)*(lat-yi)/(yj-yi) + xi + if qx < xIntersect { + inside = !inside + } + } + } + return inside +} + +// normLng rewrites v so that it lies within 180° of ref. With ref=10 and +// v=350, normLng returns -10. +func normLng(v, ref float64) float64 { + diff := math.Mod(v-ref+540, 360) - 180 + return ref + diff +} diff --git a/internal/engine/engine_test.go b/internal/engine/engine_test.go index fd55e38..1c83944 100644 --- a/internal/engine/engine_test.go +++ b/internal/engine/engine_test.go @@ -8,8 +8,7 @@ import ( "predictor-refactored/internal/weather" ) -// noWind is a WindField that always returns zero wind. Lets us test -// integration of vertical-only profiles deterministically. +// noWind is a WindField that always returns zero wind. type noWind struct{ epoch time.Time } func (n noWind) Wind(_ float64, _, _, _ float64) (weather.Sample, error) { @@ -31,19 +30,23 @@ func TestConstantAscentToBurst(t *testing.T) { Name: "ascent", Step: 60, Model: Sum(ConstantRate(rate), WindTransport(noWind{}, nil)), - Constraints: []Constraint{MaxAltitude{Limit: burst, On: ActionStop}}, + Constraints: []Constraint{Altitude{Op: OpGreaterEqual, Limit: burst, On: ActionStop}}, } prof := Profile{Stages: []*Propagator{ascend}, Direction: Forward} - results := prof.Run(0, State{Lat: 0, Lng: 0, Altitude: 0}) + results := prof.Run(0, State{Lat: 0, Lng: 0, Altitude: 0}, NewEventSink()) if len(results) != 1 || results[0].Outcome != OutcomeStopped { t.Fatalf("expected one stopped stage, got %+v", results) } + if results[0].ConstraintName == "" { + t.Errorf("ConstraintName not populated") + } + if results[0].RefinedState.Altitude == 0 { + t.Errorf("RefinedState not populated") + } last := results[0].Points[len(results[0].Points)-1] - // Refinement tolerance is 0.01 in parameter space over a 60s step, so the - // returned point sits within ±0.6s × rate ≈ ±3m of the boundary. if math.Abs(last.Altitude-burst) > 5 { t.Errorf("burst altitude = %v, want within 5m of %v", last.Altitude, burst) } @@ -67,12 +70,12 @@ func TestProfileWithFallback(t *testing.T) { Name: "ascent", Step: 60, Model: ConstantRate(rate), - Constraints: []Constraint{MaxAltitude{Limit: burst, On: ActionFallback}}, + Constraints: []Constraint{Altitude{Op: OpGreaterEqual, Limit: burst, On: ActionFallback}}, Fallback: descent, } prof := Profile{Stages: []*Propagator{ascend}, Direction: Forward} - results := prof.Run(0, State{Altitude: 0}) + results := prof.Run(0, State{Altitude: 0}, NewEventSink()) if len(results) != 2 { t.Fatalf("expected 2 results (ascent then descent fallback), got %d", len(results)) @@ -91,16 +94,14 @@ func TestProfileWithFallback(t *testing.T) { } func TestReverseDirection(t *testing.T) { - // Start at altitude 100m with downward rate; integrating reverse should - // give increasing altitude. desc := &Propagator{ Name: "rewind", Step: 1, - Model: ConstantRate(-1), // forward: alt decreases at 1 m/s - Constraints: []Constraint{MaxAltitude{Limit: 200, On: ActionStop}}, + Model: ConstantRate(-1), + Constraints: []Constraint{Altitude{Op: OpGreaterEqual, Limit: 200, On: ActionStop}}, } prof := Profile{Stages: []*Propagator{desc}, Direction: Reverse} - results := prof.Run(0, State{Altitude: 100}) + results := prof.Run(0, State{Altitude: 100}, NewEventSink()) last := results[0].Points[len(results[0].Points)-1] if math.Abs(last.Altitude-200) > 1 { @@ -129,6 +130,33 @@ func TestPiecewiseRate(t *testing.T) { } } +func TestPiecewiseReferenceResolution(t *testing.T) { + // Build via the registry with propagator_start segments. + spec := ModelSpec{ + Type: "piecewise", + Segments: []PiecewiseSegmentSpec{ + {Until: 100, Rate: 5, Reference: "propagator_start"}, + {Until: 200, Rate: 3, Reference: "propagator_start"}, + }, + } + built, err := BuildModel(spec, BuildDeps{}) + if err != nil { + t.Fatalf("BuildModel: %v", err) + } + if built.Build == nil { + t.Fatalf("expected lazy build for propagator_start references") + } + ctx := StageContext{ProfileStart: 1000, PropagatorStart: 5000} + m := built.Build(ctx) + // Until=100 from propagator_start=5000 → absolute 5100. + if r := m(5050, State{}); r.Altitude != 5 { + t.Errorf("rate at t=5050 = %v, want 5", r.Altitude) + } + if r := m(5150, State{}); r.Altitude != 3 { + t.Errorf("rate at t=5150 = %v, want 3", r.Altitude) + } +} + // fixedWind returns a constant wind sample. type fixedWind struct{ u, v float64 } @@ -139,12 +167,8 @@ func (fixedWind) Epoch() time.Time { return time.Unix(0, 0) } func (fixedWind) Source() string { return "test-fixed" } func TestWindTransportUnitConversion(t *testing.T) { - // Pure eastward wind of 10 m/s at the equator at sea level. - // Expected dlng/dt = (180/pi) * 10 / (6371009 * cos(0)) ≈ 0.00008991 deg/s. - // Expected dlat/dt = 0. wind := WindTransport(fixedWind{u: 10, v: 0}, nil) d := wind(0, State{Lat: 0, Lng: 0, Altitude: 0}) - wantLng := (180.0 / math.Pi) * 10.0 / 6371009.0 if math.Abs(d.Lng-wantLng) > 1e-12 { t.Errorf("dlng = %v, want %v", d.Lng, wantLng) @@ -153,7 +177,6 @@ func TestWindTransportUnitConversion(t *testing.T) { t.Errorf("dlat = %v, want 0 for u=10 v=0", d.Lat) } - // Pure northward at 60° latitude: dlat = (180/pi) * v / R, dlng = 0. wind2 := WindTransport(fixedWind{u: 0, v: 5}, nil) d = wind2(0, State{Lat: 60, Lng: 0, Altitude: 0}) wantLat := (180.0 / math.Pi) * 5.0 / 6371009.0 @@ -162,8 +185,28 @@ func TestWindTransportUnitConversion(t *testing.T) { } } +// aboveModelWind reports AboveModel on every sample. Used to verify event emission. +type aboveModelWind struct{} + +func (aboveModelWind) Wind(_ float64, _, _, _ float64) (weather.Sample, error) { + return weather.Sample{AboveModel: true}, nil +} +func (aboveModelWind) Epoch() time.Time { return time.Unix(0, 0) } +func (aboveModelWind) Source() string { return "above" } + +func TestWindTransportEmitsAboveModel(t *testing.T) { + sink := NewEventSink() + wind := WindTransport(aboveModelWind{}, sink) + for range 3 { + _ = wind(0, State{}) + } + events := sink.Snapshot() + if len(events) != 1 || events[0].Type != "above_model" || events[0].Count != 3 { + t.Errorf("expected one above_model event with count=3, got %+v", events) + } +} + func TestStateAddWrapsLongitude(t *testing.T) { - // Demonstrates state algebra used by the integrator and refinement. s := stateAdd(State{Lat: 0, Lng: 350, Altitude: 0}, 1, State{Lng: 20}) if math.Abs(s.Lng-10) > 1e-9 { t.Errorf("addState wrap: lng = %v, want 10", s.Lng) @@ -174,3 +217,39 @@ func TestStateAddWrapsLongitude(t *testing.T) { t.Errorf("lerpState lng wrap: %v, want 0 or 360", mid.Lng) } } + +func TestPolygonInside(t *testing.T) { + // Unit square at the equator. + square := []PolygonVertex{ + {Lat: -1, Lng: -1}, + {Lat: -1, Lng: 1}, + {Lat: 1, Lng: 1}, + {Lat: 1, Lng: -1}, + } + c := Polygon{Vertices: square, Mode: PolygonInside, On: ActionStop} + if !c.Violated(0, State{Lat: 0, Lng: 0}) { + t.Errorf("origin should be inside the square") + } + if c.Violated(0, State{Lat: 5, Lng: 0}) { + t.Errorf("(5, 0) should be outside the square") + } +} + +func TestPolygonOutsideAntimeridian(t *testing.T) { + // A polygon centred near the antimeridian, spanning lng 170..-170 + // (i.e. lng 170..190 in [0, 360) form). + poly := []PolygonVertex{ + {Lat: -10, Lng: 170}, + {Lat: -10, Lng: 190}, + {Lat: 10, Lng: 190}, + {Lat: 10, Lng: 170}, + } + c := Polygon{Vertices: poly, Mode: PolygonInside, On: ActionStop} + // A point at the antimeridian. + if !c.Violated(0, State{Lat: 0, Lng: 180}) { + t.Errorf("(0, 180) should be inside the antimeridian polygon") + } + if c.Violated(0, State{Lat: 0, Lng: 0}) { + t.Errorf("(0, 0) should be outside") + } +} diff --git a/internal/engine/events.go b/internal/engine/events.go new file mode 100644 index 0000000..7fde684 --- /dev/null +++ b/internal/engine/events.go @@ -0,0 +1,89 @@ +package engine + +import "sync" + +// Event is a non-fatal observation made during integration. +// +// Events generalise the warnings counter from the original Tawhiri port: +// any model or constraint can emit them, the EventSink aggregates by Type, +// and each Result carries a summary slice for the API to surface. +type Event struct { + Type string // short identifier, e.g. "above_model" + Time float64 // UNIX seconds when the event was emitted + State State + Message string +} + +// EventSummary is the per-type aggregation of repeated emissions. +type EventSummary struct { + Type string `json:"type"` + Count int64 `json:"count"` + FirstTime float64 `json:"first_time"` + LastTime float64 `json:"last_time"` + FirstState State `json:"first_state"` + LastState State `json:"last_state"` + Message string `json:"message"` +} + +// EventSink collects events from models and the integrator, aggregating +// duplicate types into a single EventSummary. Safe for concurrent use. +type EventSink struct { + mu sync.Mutex + summaries map[string]*EventSummary +} + +// NewEventSink returns an empty sink. +func NewEventSink() *EventSink { return &EventSink{summaries: make(map[string]*EventSummary)} } + +// Emit records one occurrence of typ at (t, s) with the provided message. +// Subsequent emits with the same typ update LastTime/LastState and Count. +func (s *EventSink) Emit(typ string, t float64, state State, message string) { + if s == nil { + return + } + s.mu.Lock() + defer s.mu.Unlock() + sum, ok := s.summaries[typ] + if !ok { + s.summaries[typ] = &EventSummary{ + Type: typ, Count: 1, + FirstTime: t, LastTime: t, + FirstState: state, LastState: state, + Message: message, + } + return + } + sum.Count++ + sum.LastTime = t + sum.LastState = state + if sum.Message == "" && message != "" { + sum.Message = message + } +} + +// Snapshot returns a stable copy of every summary in deterministic order +// (sorted by Type). +func (s *EventSink) Snapshot() []EventSummary { + if s == nil { + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + out := make([]EventSummary, 0, len(s.summaries)) + for _, sum := range s.summaries { + out = append(out, *sum) + } + sortEventSummaries(out) + return out +} + +func sortEventSummaries(s []EventSummary) { + // Insertion sort: usually one or two entries. + for i := 1; i < len(s); i++ { + j := i + for j > 0 && s[j-1].Type > s[j].Type { + s[j-1], s[j] = s[j], s[j-1] + j-- + } + } +} diff --git a/internal/engine/models.go b/internal/engine/models.go index 6431b60..2eec493 100644 --- a/internal/engine/models.go +++ b/internal/engine/models.go @@ -3,14 +3,13 @@ package engine import ( "math" "sort" - "sync/atomic" "predictor-refactored/internal/weather" ) // Sum composes models by summing their derivatives at each evaluation point. // -// Useful for combining e.g. a vertical-rate model with a horizontal wind model +// Useful for combining a vertical-rate model with a horizontal wind model // into a single propagator. Equivalent to Tawhiri's LinearModel. func Sum(models ...Model) Model { if len(models) == 1 { @@ -29,18 +28,16 @@ func Sum(models ...Model) Model { } // ConstantRate returns a model with a constant vertical velocity (m/s). -// A positive rate is upward (ascent); a negative rate is downward. +// Positive rates are upward. func ConstantRate(rate float64) Model { - return func(_ float64, _ State) State { - return State{Altitude: rate} - } + return func(_ float64, _ State) State { return State{Altitude: rate} } } -// ParachuteDescent returns a model where vertical velocity grows with altitude -// because thinner air provides less drag. +// ParachuteDescent returns a model where vertical velocity grows with +// altitude because thinner air provides less drag. seaLevelRate is the +// descent speed at sea level (m/s, positive). // -// seaLevelRate is the descent speed at sea level (m/s, positive number). -// The terminal velocity at altitude is computed as +// Terminal velocity at altitude is computed as // // v = -k / sqrt(rho(alt)), k = seaLevelRate * 1.1045, // @@ -52,9 +49,9 @@ func ParachuteDescent(seaLevelRate float64) Model { } } -// nasaDensity returns air density (kg/m^3) for the given altitude in metres, -// using the NASA simple atmosphere model. See -// https://www.grc.nasa.gov/WWW/K-12/airplane/atmosmet.html. +// nasaDensity returns air density (kg/m^3) for an altitude in metres, +// using the NASA simple atmosphere model. +// See https://www.grc.nasa.gov/WWW/K-12/airplane/atmosmet.html. func nasaDensity(alt float64) float64 { var temp, pressure float64 switch { @@ -71,22 +68,17 @@ func nasaDensity(alt float64) float64 { return pressure / (0.2869 * (temp + 273.1)) } -// RateSegment is one entry in a Piecewise rate schedule. +// RateSegment is one entry in a Piecewise rate schedule. Until is the UNIX +// timestamp at which this segment ends — the model emits the segment's +// Rate for all t < Until. The final segment's Rate is held indefinitely. type RateSegment struct { - // Until is the UNIX timestamp at which this segment ends. - // The model applies the segment's Rate for all t < Until. Until float64 - // Rate is the vertical velocity (m/s) during the segment. Positive is up. - Rate float64 + Rate float64 } -// Piecewise returns a model that produces a piecewise-constant vertical rate -// over a sequence of time intervals. -// -// Segments are searched by their Until field; the first segment whose Until -// exceeds t supplies the active rate. For t at or after the last Until, the -// final segment's Rate is held indefinitely. Input is sorted ascending by -// Until on construction. +// Piecewise returns a model that produces a piecewise-constant vertical +// rate over a sequence of intervals. The input is sorted ascending by +// Until on construction; later segments shadow earlier ones. func Piecewise(segments []RateSegment) Model { if len(segments) == 0 { return ConstantRate(0) @@ -104,33 +96,13 @@ func Piecewise(segments []RateSegment) Model { } } -// Warnings aggregates non-fatal conditions encountered during integration. -type Warnings struct { - // AltitudeTooHigh counts evaluations where the wind sampler reported - // that altitude was above the highest pressure level of the dataset. - AltitudeTooHigh atomic.Int64 -} - -// ToMap returns warnings as a map suitable for JSON output. Only counters -// that have fired are included. -func (w *Warnings) ToMap() map[string]any { - out := make(map[string]any) - if n := w.AltitudeTooHigh.Load(); n > 0 { - out["altitude_too_high"] = map[string]any{ - "count": n, - "description": "altitude exceeded the highest pressure level of the wind dataset; samples were extrapolated", - } - } - return out -} - // WindTransport returns a model that moves laterally at the wind velocity -// sampled from field. The vertical component of the returned derivative is -// zero. Wind units are converted from m/s to deg/s on Earth's surface. +// sampled from field. Vertical component is zero. Wind components in m/s +// are converted to deg/s on Earth's surface using R = 6371009 m. // -// If warnings is non-nil, the AltitudeTooHigh counter is incremented for any -// sample where the wind field reported altitude above the model top. -func WindTransport(field weather.WindField, warnings *Warnings) Model { +// If events is non-nil, an "above_model" event is emitted whenever the +// wind field reports altitude above the highest pressure level. +func WindTransport(field weather.WindField, events *EventSink) Model { const earthR = 6371009.0 const piOver180 = math.Pi / 180.0 const degPerRad = 180.0 / math.Pi @@ -139,8 +111,9 @@ func WindTransport(field weather.WindField, warnings *Warnings) Model { if err != nil { return State{} } - if sample.AboveModel && warnings != nil { - warnings.AltitudeTooHigh.Add(1) + if sample.AboveModel && events != nil { + events.Emit("above_model", t, s, + "altitude exceeded the highest pressure level of the wind dataset; samples extrapolated") } r := earthR + s.Altitude return State{ diff --git a/internal/engine/operators.go b/internal/engine/operators.go new file mode 100644 index 0000000..13b3a6d --- /dev/null +++ b/internal/engine/operators.go @@ -0,0 +1,69 @@ +package engine + +import "fmt" + +// Operator is a scalar comparison used by generalised constraints like +// Altitude and Time. A constraint fires when its Operator.Test(value, limit) +// returns true. +type Operator int + +const ( + OpLess Operator = iota // value < limit + OpLessEqual // value ≤ limit + OpGreater // value > limit + OpGreaterEqual // value ≥ limit + OpEqual // value == limit +) + +// Test evaluates op(value, limit). +func (o Operator) Test(value, limit float64) bool { + switch o { + case OpLess: + return value < limit + case OpLessEqual: + return value <= limit + case OpGreater: + return value > limit + case OpGreaterEqual: + return value >= limit + case OpEqual: + return value == limit + } + return false +} + +// String returns the symbol "<", "<=", ">", ">=", "==". +func (o Operator) String() string { + switch o { + case OpLess: + return "<" + case OpLessEqual: + return "<=" + case OpGreater: + return ">" + case OpGreaterEqual: + return ">=" + case OpEqual: + return "==" + } + return "?" +} + +// ParseOperator maps a textual operator to its Operator constant. +// Accepts "<", "<=", "le", ">", ">=", "ge", "==", "eq". +func ParseOperator(s string) (Operator, error) { + switch s { + case "<", "lt": + return OpLess, nil + case "<=", "le": + return OpLessEqual, nil + case ">", "gt": + return OpGreater, nil + case ">=", "ge": + return OpGreaterEqual, nil + case "==", "eq": + return OpEqual, nil + default: + return 0, fmt.Errorf("unknown operator %q", s) + } +} diff --git a/internal/engine/profile.go b/internal/engine/profile.go index 57460de..3fa03ef 100644 --- a/internal/engine/profile.go +++ b/internal/engine/profile.go @@ -3,21 +3,26 @@ package engine // Profile is an ordered chain of propagators executed sequentially. Each // propagator picks up where the previous one finished. type Profile struct { - // Stages are run in order. For Direction=Reverse they are still iterated - // from index 0 onwards, but each propagator integrates with negative dt. + // Stages are run in order. For Direction=Reverse they are still + // iterated from index 0 onwards but each propagator integrates with + // negative dt. Stages []*Propagator - // Direction controls the sign of dt across the whole profile. + // Direction controls the sign of dt across the profile. Direction Direction - // Globals are constraints evaluated alongside each stage's local Constraints. - // Useful for profile-wide bounds like "stop after N hours total". + // Globals are constraints evaluated alongside each stage's local + // Constraints. Useful for profile-wide bounds like "stop after N hours". Globals []Constraint } -// Run executes the profile from the given launch point. Returns one Result -// per executed stage, including any Fallback chains that were activated. -func (p *Profile) Run(t0 float64, launch State) []Result { +// Run executes the profile from the given launch point. Returns one +// Result per executed stage, including any Fallback chains that were +// activated. The supplied EventSink is shared across stages and aggregates +// non-fatal observations. +// +// events may be nil; pass NewEventSink() to capture observations. +func (p *Profile) Run(t0 float64, launch State, events *EventSink) []Result { if p.Direction == 0 { p.Direction = Forward } @@ -27,28 +32,36 @@ func (p *Profile) Run(t0 float64, launch State) []Result { for i := 0; i < len(p.Stages); i++ { stage := p.Stages[i] - res := stage.run(t, s, p.Direction, p.Globals) + ctx := StageContext{ + ProfileStart: t0, + PropagatorStart: t, + Launch: launch, + PropagatorState: s, + Direction: p.Direction, + } + res := stage.run(ctx, t, s, p.Globals, events) results = append(results, res) last := res.Points[len(res.Points)-1] t = last.Time s = State{Lat: last.Lat, Lng: last.Lng, Altitude: last.Altitude} - // Follow Fallback chains until none remains. Each fallback consumes - // from the same point the previous stage stopped at. + // Follow Fallback chains until none remains. for res.Outcome == OutcomeFallback && stage.Fallback != nil { stage = stage.Fallback - res = stage.run(t, s, p.Direction, p.Globals) + ctx = StageContext{ + ProfileStart: t0, + PropagatorStart: t, + Launch: launch, + PropagatorState: s, + Direction: p.Direction, + } + res = stage.run(ctx, t, s, p.Globals, events) results = append(results, res) last = res.Points[len(res.Points)-1] t = last.Time s = State{Lat: last.Lat, Lng: last.Lng, Altitude: last.Altitude} } - - // If a propagator's stop fired (not a fallback), end the profile. - if res.Outcome == OutcomeStopped { - continue - } } return results diff --git a/internal/engine/propagator.go b/internal/engine/propagator.go index c653218..50ccd53 100644 --- a/internal/engine/propagator.go +++ b/internal/engine/propagator.go @@ -7,71 +7,58 @@ import ( // Propagator advances state under one Model, checking a set of Constraints // after every integration step. // -// When a constraint fires, the propagator binary-search refines the violation -// point and emits it as its final trajectory point. The Action of the -// triggering constraint controls what the surrounding Profile does next: -// stop the profile, transfer to Fallback, or clip and continue. +// When a constraint fires, the propagator binary-search refines the +// violation point and emits it as its final trajectory point. The Action of +// the triggering constraint controls what the surrounding Profile does +// next: stop the profile, transfer to Fallback, or clip and continue. type Propagator struct { - // Name identifies the propagator in trajectory metadata. + // Name identifies the propagator in trajectory metadata. Optional — + // callers using sequential profile chains may leave it empty. Name string // Step is the magnitude of the integration step in seconds (always positive). // The Profile flips its sign for Reverse direction. Step float64 - // Model produces the per-second time derivative of state. - Model Model + // Model is the per-second derivative function used for integration. + // One of Model or BuildModel must be non-nil. If both are set, BuildModel + // takes precedence (it is invoked once per stage with a StageContext). + Model Model + BuildModel func(ctx StageContext) Model - // Constraints are evaluated after each step. Any fired constraint stops - // the propagator at the refined point; the first one in this slice wins - // on ties. - Constraints []Constraint + // Constraints are evaluated after each step. The first violation wins. + Constraints []Constraint + BuildConstraints func(ctx StageContext) []Constraint // Fallback is the propagator to switch to when a constraint with // ActionFallback fires. Optional. Fallback *Propagator - // Tolerance is the binary-search refinement tolerance in parameter space - // (default 0.01, matching Tawhiri). + // Tolerance is the binary-search refinement tolerance in parameter + // space (default 0.01, matching Tawhiri). Tolerance float64 } -// Outcome describes how a propagator's run ended. -type Outcome int - -const ( - // OutcomeStopped means a Constraint with ActionStop fired and the profile - // should end here. - OutcomeStopped Outcome = iota - // OutcomeFallback means a Constraint with ActionFallback fired and the - // profile should transfer to the propagator's Fallback chain. - OutcomeFallback - // OutcomeContinued means no constraint fired before the time horizon was - // reached. In practice this is only seen when a propagator runs unbounded, - // which means the profile is misconfigured. - OutcomeContinued -) - -// Result is the output of running one propagator. -type Result struct { - Propagator string - Points []TrajectoryPoint - Outcome Outcome - // Constraint is the constraint that fired, or nil if Outcome == OutcomeContinued. - Constraint Constraint -} - // run integrates the model from (t0, s0) in direction dir, returning a Result. // globals are constraints injected by the Profile and checked alongside the -// propagator's local Constraints. -func (p *Propagator) run(t0 float64, s0 State, dir Direction, globals []Constraint) Result { - dt := p.Step * float64(dir) +// propagator's local Constraints. events receives non-fatal observations. +func (p *Propagator) run(ctx StageContext, t0 float64, s0 State, globals []Constraint, events *EventSink) Result { + dt := p.Step * float64(ctx.Direction) tol := p.Tolerance if tol == 0 { tol = 0.01 } - deriv := numerics.Deriv[State](func(t float64, s State) State { return p.Model(t, s) }) + model := p.Model + if p.BuildModel != nil { + model = p.BuildModel(ctx) + } + constraints := p.Constraints + if p.BuildConstraints != nil { + constraints = p.BuildConstraints(ctx) + } + + deriv := numerics.Deriv[State](func(t float64, s State) State { return model(t, s) }) add := numerics.VecAdd[State](stateAdd) lerp := numerics.VecLerp[State](stateLerp) @@ -90,39 +77,50 @@ func (p *Propagator) run(t0 float64, s0 State, dir Direction, globals []Constrai s2 := numerics.RK4Step(t, s, dt, deriv, add) t2 := t + dt - if c, fired := firstFiring(p.Constraints, globals, t2, s2); fired { - trig := numerics.Trigger[State](func(tt float64, ss State) bool { return c.Violated(tt, ss) }) - t3, s3 := numerics.RefineTrigger(t, s, t2, s2, trig, lerp, tol) - - switch c.Action() { - case ActionClip: - s3 = clipToConstraint(c, s3) - out.Points = append(out.Points, TrajectoryPoint{ - Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude, - }) - t, s = t3, s3 - continue - case ActionFallback: - out.Points = append(out.Points, TrajectoryPoint{ - Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude, - }) - out.Outcome = OutcomeFallback - out.Constraint = c - return out - default: // ActionStop - out.Points = append(out.Points, TrajectoryPoint{ - Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude, - }) - out.Outcome = OutcomeStopped - out.Constraint = c - return out - } + c, fired := firstFiring(constraints, globals, t2, s2) + if !fired { + t, s = t2, s2 + out.Points = append(out.Points, TrajectoryPoint{ + Time: t, Lat: s.Lat, Lng: s.Lng, Altitude: s.Altitude, + }) + continue } - t, s = t2, s2 - out.Points = append(out.Points, TrajectoryPoint{ - Time: t, Lat: s.Lat, Lng: s.Lng, Altitude: s.Altitude, - }) + // Record the unrefined violation. + out.ViolationTime = t2 + out.ViolationState = s2 + + trig := numerics.Trigger[State](func(tt float64, ss State) bool { return c.Violated(tt, ss) }) + t3, s3 := numerics.RefineTrigger(t, s, t2, s2, trig, lerp, tol) + out.RefinedTime = t3 + out.RefinedState = s3 + out.Constraint = c + out.ConstraintName = c.Name() + + switch c.Action() { + case ActionClip: + s3 = clipToConstraint(c, s3) + out.RefinedState = s3 + out.Points = append(out.Points, TrajectoryPoint{ + Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude, + }) + t, s = t3, s3 + continue + case ActionFallback: + out.Points = append(out.Points, TrajectoryPoint{ + Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude, + }) + out.Outcome = OutcomeFallback + out.Events = events.Snapshot() + return out + default: // ActionStop + out.Points = append(out.Points, TrajectoryPoint{ + Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude, + }) + out.Outcome = OutcomeStopped + out.Events = events.Snapshot() + return out + } } } @@ -142,15 +140,12 @@ func firstFiring(local, globals []Constraint, t float64, s State) (Constraint, b return nil, false } -// clipToConstraint adjusts s so that the given constraint is exactly satisfied -// (not violated). Implemented for constraints with a well-defined boundary; -// others fall through unchanged. +// clipToConstraint adjusts s so that the given constraint is exactly +// satisfied (not violated). Defined only for constraints with a +// well-defined coordinate boundary; others fall through unchanged. func clipToConstraint(c Constraint, s State) State { - switch v := c.(type) { - case MaxAltitude: - s.Altitude = v.Limit - case MinAltitude: - s.Altitude = v.Limit + if alt, ok := c.(Altitude); ok { + s.Altitude = alt.Limit } return s } diff --git a/internal/engine/registry.go b/internal/engine/registry.go new file mode 100644 index 0000000..8cab94f --- /dev/null +++ b/internal/engine/registry.go @@ -0,0 +1,287 @@ +package engine + +import ( + "fmt" + "sync" + + "predictor-refactored/internal/weather" +) + +// ConstraintSpec is the source-agnostic JSON-shape used to declare a +// constraint. The Type field is the registry key; remaining fields are +// extracted by the registered factory. +type ConstraintSpec struct { + Type string `json:"type"` + Action string `json:"action,omitempty"` + // Op is the comparison operator for scalar constraints (altitude, time). + Op string `json:"op,omitempty"` + Limit float64 `json:"limit,omitempty"` + // Vertices and Mode are used by the polygon constraint. + Vertices []PolygonVertex `json:"vertices,omitempty"` + Mode string `json:"mode,omitempty"` + // Label is an optional human-readable identifier surfaced via Name(). + Label string `json:"label,omitempty"` +} + +// ModelSpec is the source-agnostic JSON shape used to declare a model. +type ModelSpec struct { + Type string `json:"type"` + // Rate (m/s) for constant_rate. + Rate float64 `json:"rate,omitempty"` + // SeaLevelRate (m/s, positive) for parachute_descent. + SeaLevelRate float64 `json:"sea_level_rate,omitempty"` + // Segments for piecewise. + Segments []PiecewiseSegmentSpec `json:"segments,omitempty"` + // IncludeWind sums a WindTransport model into the resulting derivative. + IncludeWind bool `json:"include_wind,omitempty"` +} + +// PiecewiseSegmentSpec is one entry in a piecewise rate schedule. +// +// Reference selects how the Until field is interpreted: +// +// - "absolute" (default): UNIX seconds. +// - "profile_start": seconds since the profile's launch time. +// - "propagator_start": seconds since this propagator began running. +type PiecewiseSegmentSpec struct { + Until float64 `json:"until"` + Rate float64 `json:"rate"` + Reference string `json:"reference,omitempty"` +} + +// BuildDeps bundle the runtime dependencies factories may consult. +type BuildDeps struct { + Wind weather.WindField + Terrain TerrainProvider + Events *EventSink +} + +// ConstraintFactory builds one Constraint from a spec. +type ConstraintFactory func(spec ConstraintSpec, deps BuildDeps) (Constraint, error) + +// ModelFactory builds one model from a spec. The returned Built is held by +// a Propagator; if Build is set, it is invoked lazily by the profile +// runner before every stage so it can capture per-stage start times. +type ModelFactory func(spec ModelSpec, deps BuildDeps) (BuiltModel, error) + +// BuiltModel is either an eager Model, a lazy Build, or both. The profile +// runner prefers Build when present. +type BuiltModel struct { + Model Model + Build func(ctx StageContext) Model +} + +var ( + regMu sync.RWMutex + constraintFactories = map[string]ConstraintFactory{} + modelFactories = map[string]ModelFactory{} +) + +// RegisterConstraint installs a factory for typeName. Subsequent calls +// overwrite the previous factory. +func RegisterConstraint(typeName string, f ConstraintFactory) { + regMu.Lock() + defer regMu.Unlock() + constraintFactories[typeName] = f +} + +// RegisterModel installs a model factory. +func RegisterModel(typeName string, f ModelFactory) { + regMu.Lock() + defer regMu.Unlock() + modelFactories[typeName] = f +} + +// BuildConstraint dispatches spec to its registered factory. +func BuildConstraint(spec ConstraintSpec, deps BuildDeps) (Constraint, error) { + regMu.RLock() + f, ok := constraintFactories[spec.Type] + regMu.RUnlock() + if !ok { + return nil, fmt.Errorf("unknown constraint type %q", spec.Type) + } + return f(spec, deps) +} + +// BuildModel dispatches spec to its registered factory. +func BuildModel(spec ModelSpec, deps BuildDeps) (BuiltModel, error) { + regMu.RLock() + f, ok := modelFactories[spec.Type] + regMu.RUnlock() + if !ok { + return BuiltModel{}, fmt.Errorf("unknown model type %q", spec.Type) + } + return f(spec, deps) +} + +// RegisteredConstraints returns the names of every registered constraint type. +func RegisteredConstraints() []string { + regMu.RLock() + defer regMu.RUnlock() + out := make([]string, 0, len(constraintFactories)) + for k := range constraintFactories { + out = append(out, k) + } + return out +} + +// RegisteredModels returns the names of every registered model type. +func RegisteredModels() []string { + regMu.RLock() + defer regMu.RUnlock() + out := make([]string, 0, len(modelFactories)) + for k := range modelFactories { + out = append(out, k) + } + return out +} + +// --- Built-in registrations ------------------------------------------------ + +func init() { + RegisterConstraint("altitude", buildAltitude) + RegisterConstraint("time", buildTime) + RegisterConstraint("terrain_contact", buildTerrainContact) + RegisterConstraint("polygon", buildPolygon) + + RegisterModel("constant_rate", buildConstantRate) + RegisterModel("parachute_descent", buildParachuteDescent) + RegisterModel("piecewise", buildPiecewise) + RegisterModel("wind", buildWind) +} + +func buildAltitude(spec ConstraintSpec, _ BuildDeps) (Constraint, error) { + op, err := ParseOperator(spec.Op) + if err != nil { + return nil, fmt.Errorf("altitude: %w", err) + } + act, err := ParseAction(spec.Action) + if err != nil { + return nil, fmt.Errorf("altitude: %w", err) + } + return Altitude{Op: op, Limit: spec.Limit, On: act}, nil +} + +func buildTime(spec ConstraintSpec, _ BuildDeps) (Constraint, error) { + op, err := ParseOperator(spec.Op) + if err != nil { + return nil, fmt.Errorf("time: %w", err) + } + act, err := ParseAction(spec.Action) + if err != nil { + return nil, fmt.Errorf("time: %w", err) + } + return Time{Op: op, Limit: spec.Limit, On: act}, nil +} + +func buildTerrainContact(spec ConstraintSpec, deps BuildDeps) (Constraint, error) { + if deps.Terrain == nil { + return nil, fmt.Errorf("terrain_contact requires a terrain provider") + } + act, err := ParseAction(spec.Action) + if err != nil { + return nil, fmt.Errorf("terrain_contact: %w", err) + } + return TerrainContact{Provider: deps.Terrain, On: act}, nil +} + +func buildPolygon(spec ConstraintSpec, _ BuildDeps) (Constraint, error) { + if len(spec.Vertices) < 3 { + return nil, fmt.Errorf("polygon requires at least 3 vertices") + } + act, err := ParseAction(spec.Action) + if err != nil { + return nil, fmt.Errorf("polygon: %w", err) + } + mode := PolygonInside + switch spec.Mode { + case "", "inside": + mode = PolygonInside + case "outside": + mode = PolygonOutside + default: + return nil, fmt.Errorf("polygon: unknown mode %q", spec.Mode) + } + return Polygon{Vertices: spec.Vertices, Mode: mode, On: act, Label: spec.Label}, nil +} + +func buildConstantRate(spec ModelSpec, _ BuildDeps) (BuiltModel, error) { + return BuiltModel{Model: ConstantRate(spec.Rate)}, nil +} + +func buildParachuteDescent(spec ModelSpec, _ BuildDeps) (BuiltModel, error) { + if spec.SeaLevelRate <= 0 { + return BuiltModel{}, fmt.Errorf("parachute_descent requires positive sea_level_rate") + } + return BuiltModel{Model: ParachuteDescent(spec.SeaLevelRate)}, nil +} + +func buildWind(_ ModelSpec, deps BuildDeps) (BuiltModel, error) { + if deps.Wind == nil { + return BuiltModel{}, fmt.Errorf("wind model requires a loaded wind field") + } + return BuiltModel{Model: WindTransport(deps.Wind, deps.Events)}, nil +} + +func buildPiecewise(spec ModelSpec, deps BuildDeps) (BuiltModel, error) { + needsCtx := false + for _, seg := range spec.Segments { + if seg.Reference == "propagator_start" { + needsCtx = true + break + } + } + if !needsCtx { + // Eager build: resolve any "profile_start" relative segments using + // the launch time we know at build time only when we have one. + // Without context, treat profile_start the same as absolute (the + // caller is expected to pre-resolve), and absolute as absolute. + segs := make([]RateSegment, 0, len(spec.Segments)) + for _, s := range spec.Segments { + if s.Reference == "profile_start" { + return BuiltModel{}, fmt.Errorf("piecewise: profile_start reference requires a stage context — supply via lazy build") + } + segs = append(segs, RateSegment{Until: s.Until, Rate: s.Rate}) + } + base := Piecewise(segs) + return BuiltModel{Model: maybeAddWind(base, spec.IncludeWind, deps)}, nil + } + // Lazy build — captures spec into a closure. + return BuiltModel{ + Build: func(ctx StageContext) Model { + segs := resolveSegments(spec.Segments, ctx) + base := Piecewise(segs) + return maybeAddWind(base, spec.IncludeWind, deps) + }, + }, nil +} + +// resolveSegments converts spec segments to engine.RateSegment using the +// stage context to resolve relative references. +func resolveSegments(in []PiecewiseSegmentSpec, ctx StageContext) []RateSegment { + out := make([]RateSegment, 0, len(in)) + for _, s := range in { + var until float64 + switch s.Reference { + case "", "absolute": + until = s.Until + case "profile_start": + until = ctx.ProfileStart + s.Until + case "propagator_start": + until = ctx.PropagatorStart + s.Until + } + out = append(out, RateSegment{Until: until, Rate: s.Rate}) + } + return out +} + +// maybeAddWind sums a WindTransport model into base when the spec asks for it. +func maybeAddWind(base Model, includeWind bool, deps BuildDeps) Model { + if !includeWind { + return base + } + if deps.Wind == nil { + return base + } + return Sum(base, WindTransport(deps.Wind, deps.Events)) +} diff --git a/internal/engine/types.go b/internal/engine/types.go index 59504e4..f08f789 100644 --- a/internal/engine/types.go +++ b/internal/engine/types.go @@ -1,27 +1,27 @@ // Package engine is the trajectory calculation engine. It composes -// propagators (model-driven integrators) into profiles (ordered chains) and -// runs them over a wind field. +// propagators (model-driven integrators) into profiles (ordered chains) +// over a wind field. // -// The engine has no direct dependency on any specific data source: wind data -// is consumed through weather.WindField and terrain data through any type -// satisfying TerrainProvider. +// The engine has no direct dependency on any specific data source: wind +// data is consumed through weather.WindField and terrain data through +// any type satisfying TerrainProvider. package engine // State holds the spatial state of the balloon. When returned by a Model -// the same struct is interpreted as the per-second time derivative of state. +// the same struct is interpreted as the per-second time derivative. type State struct { - // Lat is degrees latitude in [-90, 90] (or deg/s when returned as a derivative). - Lat float64 - // Lng is degrees longitude in [0, 360) (or deg/s as a derivative). - Lng float64 - // Altitude is metres above mean sea level (or m/s as a derivative). - Altitude float64 + // Lat is degrees latitude in [-90, 90]. + Lat float64 `json:"lat"` + // Lng is degrees longitude in [0, 360). + Lng float64 `json:"lng"` + // Altitude is metres above mean sea level. + Altitude float64 `json:"altitude"` } // Model returns the time derivative of state at (t, s). // -// The derivative is direction-independent; the integrator applies the sign -// of dt for reverse propagation. +// The derivative is direction-independent; the integrator applies the +// sign of dt for reverse propagation. type Model func(t float64, s State) State // TrajectoryPoint is one sampled point of an integration result. @@ -32,9 +32,7 @@ type TrajectoryPoint struct { Altitude float64 } -// Direction is the time direction of integration. Forward (+1) integrates -// from launch to landing; Reverse (-1) integrates from a known landing back -// to a candidate launch point. +// Direction is the time direction of integration. type Direction int8 const ( @@ -42,28 +40,39 @@ const ( Reverse Direction = -1 ) -// Action describes what the profile runner should do when a Constraint -// reports a violation. +// Action is what the profile runner does on a constraint violation. type Action int const ( - // ActionStop ends the current propagator at the (refined) violation point. - // This matches the only behaviour available in the reference Tawhiri solver. + // ActionStop ends the current propagator at the refined violation point. ActionStop Action = iota // ActionFallback ends the current propagator and starts its Fallback - // propagator from the violation point. Useful for "if max altitude is - // reached during ascent, switch to descent" profiles. + // propagator from the refined violation point. ActionFallback // ActionClip clips the violated coordinate to the boundary and continues - // integration. Useful for soft constraints such as "max altitude floor". + // integration. ActionClip ) -// Constraint reports when integration should stop, branch, or clip. -// -// A constraint is direction-agnostic: it reads state and decides. The profile -// runner is responsible for refining the trigger point via binary search and -// dispatching the configured Action. +// ParseAction maps "stop" | "fallback" | "clip" to an Action. +func ParseAction(s string) (Action, error) { + switch s { + case "", "stop": + return ActionStop, nil + case "fallback": + return ActionFallback, nil + case "clip": + return ActionClip, nil + default: + return 0, errUnknownAction(s) + } +} + +type errUnknownAction string + +func (e errUnknownAction) Error() string { return "unknown constraint action " + string(e) } + +// Constraint defines a stopping, branching, or clipping condition. type Constraint interface { // Name identifies the constraint in logs and result metadata. Name() string @@ -74,7 +83,79 @@ type Constraint interface { } // TerrainProvider returns ground elevation in metres at a coordinate. -// Implementations must be safe for concurrent use. type TerrainProvider interface { Elevation(lat, lng float64) float64 } + +// StageContext is provided to Propagator.BuildModel and BuildConstraints by +// the profile runner immediately before each stage executes. +type StageContext struct { + // ProfileStart is the UNIX timestamp of the profile's initial launch. + ProfileStart float64 + // PropagatorStart is the UNIX timestamp at which this propagator begins + // running — equal to ProfileStart for the first stage; the end-time of + // the previous stage thereafter. + PropagatorStart float64 + // Launch is the profile's initial state. + Launch State + // PropagatorState is the state at which this propagator begins. + PropagatorState State + // Direction is the integration direction the profile is configured with. + Direction Direction +} + +// Outcome describes how a propagator's run ended. +type Outcome int + +const ( + // OutcomeStopped means a Constraint with ActionStop fired. + OutcomeStopped Outcome = iota + // OutcomeFallback means a Constraint with ActionFallback fired. + OutcomeFallback + // OutcomeContinued means the propagator finished without a constraint + // firing — only seen when a propagator is misconfigured to run unbounded. + OutcomeContinued +) + +// String renders the outcome as a stable string for API serialisation. +func (o Outcome) String() string { + switch o { + case OutcomeStopped: + return "stopped" + case OutcomeFallback: + return "fallback" + default: + return "continued" + } +} + +// Result is the output of running one propagator. +type Result struct { + // Propagator is the propagator's Name. + Propagator string + + // Points is the emitted trajectory. + Points []TrajectoryPoint + + // Outcome describes how the propagator terminated. + Outcome Outcome + + // Constraint is the constraint that fired, or nil if Outcome is OutcomeContinued. + Constraint Constraint + // ConstraintName captures Constraint.Name() at fire time so callers can + // serialise the result after the Constraint has been garbage collected. + ConstraintName string + + // ViolationTime / ViolationState describe the first integration step at + // which the constraint reported a violation, before binary-search refinement. + ViolationTime float64 + ViolationState State + + // RefinedTime / RefinedState describe the refined violation point that + // appears as the propagator's last trajectory point. + RefinedTime float64 + RefinedState State + + // Events is the aggregated set of non-fatal observations from this stage. + Events []EventSummary +} diff --git a/internal/weather/gfs/constants.go b/internal/weather/gfs/constants.go index 77ed6ee..bc7a288 100644 --- a/internal/weather/gfs/constants.go +++ b/internal/weather/gfs/constants.go @@ -1,34 +1,28 @@ package gfs -import "fmt" +// Cross-variant constants. Per-variant geometry (latitudes, longitudes, +// pressure levels, hour step, max hour, URL token) lives on the Variant +// type; see variant.go. -// Dataset shape: (hour, pressure_level, variable, latitude, longitude). -// Matches the cube layout used by the reference Tawhiri implementation. const ( - NumHours = 65 // 0, 3, 6, ..., 192 hours forecast - NumLevels = 47 // pressure levels - NumVariables = 3 // geopotential height, U-wind, V-wind - NumLatitudes = 361 // -90.0 to +90.0 inclusive in 0.5° steps - NumLongitudes = 720 // 0.0 to 359.5 in 0.5° steps + // NumVariables is the number of dataset variables: HGT, UGRD, VGRD. + NumVariables = 3 + // ElementSize is the cell size in bytes (float32). + ElementSize = 4 - HourStep = 3 - MaxHour = 192 - Resolution = 0.5 - LatStart = -90.0 - LonStart = 0.0 + // LatStart is the first latitude in the cube (south to north). + LatStart = -90.0 + // LonStart is the first longitude in the cube (0..360 east). + LonStart = 0.0 + // Variable indices within the cube's 3rd axis. VarHeight = 0 VarWindU = 1 VarWindV = 2 - - ElementSize = 4 // float32 - - // DatasetSize is the canonical file size: every grid cell × element size. - DatasetSize int64 = int64(NumHours) * int64(NumLevels) * int64(NumVariables) * - int64(NumLatitudes) * int64(NumLongitudes) * int64(ElementSize) ) -// LevelSet identifies which GRIB file (primary/secondary) carries a level. +// LevelSet identifies which GRIB file (primary or secondary) carries a +// pressure level. type LevelSet int const ( @@ -36,106 +30,5 @@ const ( LevelSetB // pgrb2b — secondary file ) -// Pressures lists the 47 pressure levels (hPa) in dataset index order, -// descending from surface to top of atmosphere. -var Pressures = [NumLevels]int{ - 1000, 975, 950, 925, 900, 875, 850, 825, 800, 775, - 750, 725, 700, 675, 650, 625, 600, 575, 550, 525, - 500, 475, 450, 425, 400, 375, 350, 325, 300, 275, - 250, 225, 200, 175, 150, 125, 100, 70, 50, 30, - 20, 10, 7, 5, 3, 2, 1, -} - -// PressuresPgrb2 lists the levels carried by the primary GRIB file. -var PressuresPgrb2 = []int{ - 10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 350, 400, - 450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 925, - 950, 975, 1000, -} - -// PressuresPgrb2b lists the levels carried by the secondary GRIB file. -var PressuresPgrb2b = []int{ - 1, 2, 3, 5, 7, 125, 175, 225, 275, 325, 375, 425, - 475, 525, 575, 625, 675, 725, 775, 825, 875, -} - -var pressureIndex map[int]int -var pressureLevelSet map[int]LevelSet - -func init() { - pressureIndex = make(map[int]int, NumLevels) - for i, p := range Pressures { - pressureIndex[p] = i - } - pressureLevelSet = make(map[int]LevelSet, NumLevels) - for _, p := range PressuresPgrb2 { - pressureLevelSet[p] = LevelSetA - } - for _, p := range PressuresPgrb2b { - pressureLevelSet[p] = LevelSetB - } -} - -// PressureIndex returns the dataset index for a pressure level in hPa, -// or -1 when the level is unknown. -func PressureIndex(hPa int) int { - idx, ok := pressureIndex[hPa] - if !ok { - return -1 - } - return idx -} - -// PressureLevelSet returns the GRIB file set carrying a pressure level. -func PressureLevelSet(hPa int) (LevelSet, bool) { - ls, ok := pressureLevelSet[hPa] - return ls, ok -} - -// HourIndex returns the dataset time index for a forecast hour, or -1 when -// the hour is outside the range or not a multiple of HourStep. -func HourIndex(hour int) int { - if hour < 0 || hour > MaxHour || hour%HourStep != 0 { - return -1 - } - return hour / HourStep -} - -// Hours returns the full list of forecast hours, [0, 3, 6, ..., MaxHour]. -func Hours() []int { - out := make([]int, 0, NumHours) - for h := 0; h <= MaxHour; h += HourStep { - out = append(out, h) - } - return out -} - -// VariableIndex maps a GRIB (category, number) pair to a dataset variable -// index, returning -1 for parameters this dataset does not store. -func VariableIndex(parameterCategory, parameterNumber int) int { - switch { - case parameterCategory == 3 && parameterNumber == 5: - return VarHeight - case parameterCategory == 2 && parameterNumber == 2: - return VarWindU - case parameterCategory == 2 && parameterNumber == 3: - return VarWindV - default: - return -1 - } -} - -// S3 URL configuration for NOAA GFS data on the public S3 mirror. +// S3BaseURL is the public NOAA S3 mirror. const S3BaseURL = "https://noaa-gfs-bdp-pds.s3.amazonaws.com" - -// GribURL returns the S3 URL for a primary (pgrb2) GRIB file. -func GribURL(date string, runHour, forecastStep int) string { - return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.0p50.f%03d", - S3BaseURL, date, runHour, runHour, forecastStep) -} - -// GribURLB returns the S3 URL for a secondary (pgrb2b) GRIB file. -func GribURLB(date string, runHour, forecastStep int) string { - return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2b.0p50.f%03d", - S3BaseURL, date, runHour, runHour, forecastStep) -} diff --git a/internal/weather/gfs/file.go b/internal/weather/gfs/file.go index 720107b..003cfef 100644 --- a/internal/weather/gfs/file.go +++ b/internal/weather/gfs/file.go @@ -11,8 +11,10 @@ import ( ) // File is an mmap-backed wind dataset file. The layout is a flat C-order -// row-major array of float32 values, shape (hour, level, variable, lat, lng). +// row-major float32 array, shape (hour, level, variable, lat, lng), with +// the per-axis sizes coming from Variant. type File struct { + variant *Variant mm mmap.MMap file *os.File writable bool @@ -20,8 +22,11 @@ type File struct { Epoch time.Time } +// Variant returns the Variant the file was created with. +func (d *File) Variant() *Variant { return d.variant } + // Open opens an existing dataset file for reading. -func Open(path string, epoch time.Time) (*File, error) { +func Open(path string, variant *Variant, epoch time.Time) (*File, error) { f, err := os.Open(path) if err != nil { return nil, fmt.Errorf("open dataset: %w", err) @@ -31,39 +36,40 @@ func Open(path string, epoch time.Time) (*File, error) { f.Close() return nil, fmt.Errorf("stat dataset: %w", err) } - if info.Size() != DatasetSize { + if info.Size() != variant.DatasetSize() { f.Close() - return nil, fmt.Errorf("dataset should be %d bytes (was %d)", DatasetSize, info.Size()) + return nil, fmt.Errorf("dataset should be %d bytes (was %d)", variant.DatasetSize(), info.Size()) } mm, err := mmap.Map(f, mmap.RDONLY, 0) if err != nil { f.Close() return nil, fmt.Errorf("mmap dataset: %w", err) } - return &File{mm: mm, file: f, writable: false, Epoch: epoch}, nil + return &File{variant: variant, mm: mm, file: f, writable: false, Epoch: epoch}, nil } -// Create creates a new dataset file of the canonical size, mmap'd read-write. -func Create(path string) (*File, error) { +// Create creates a new dataset file sized for variant, mmap'd read-write. +func Create(path string, variant *Variant) (*File, error) { f, err := os.Create(path) if err != nil { return nil, fmt.Errorf("create dataset: %w", err) } - if err := f.Truncate(DatasetSize); err != nil { + size := variant.DatasetSize() + if err := f.Truncate(size); err != nil { f.Close() return nil, fmt.Errorf("truncate dataset: %w", err) } - mm, err := mmap.MapRegion(f, int(DatasetSize), mmap.RDWR, 0, 0) + mm, err := mmap.MapRegion(f, int(size), mmap.RDWR, 0, 0) if err != nil { f.Close() return nil, fmt.Errorf("mmap dataset: %w", err) } - return &File{mm: mm, file: f, writable: true}, nil + return &File{variant: variant, mm: mm, file: f, writable: true}, nil } -// OpenWritable opens an existing dataset file for read-write access. -// Used when resuming a partial download. -func OpenWritable(path string) (*File, error) { +// OpenWritable opens an existing dataset file for read-write access. Used +// when resuming a partial download. +func OpenWritable(path string, variant *Variant) (*File, error) { f, err := os.OpenFile(path, os.O_RDWR, 0o644) if err != nil { return nil, fmt.Errorf("open dataset rw: %w", err) @@ -73,51 +79,55 @@ func OpenWritable(path string) (*File, error) { f.Close() return nil, fmt.Errorf("stat dataset: %w", err) } - if info.Size() != DatasetSize { + if info.Size() != variant.DatasetSize() { f.Close() - return nil, fmt.Errorf("dataset should be %d bytes (was %d)", DatasetSize, info.Size()) + return nil, fmt.Errorf("dataset should be %d bytes (was %d)", variant.DatasetSize(), info.Size()) } - mm, err := mmap.MapRegion(f, int(DatasetSize), mmap.RDWR, 0, 0) + mm, err := mmap.MapRegion(f, int(info.Size()), mmap.RDWR, 0, 0) if err != nil { f.Close() return nil, fmt.Errorf("mmap dataset: %w", err) } - return &File{mm: mm, file: f, writable: true}, nil + return &File{variant: variant, mm: mm, file: f, writable: true}, nil } // offset returns the byte offset of the [hour][level][variable][lat][lng] cell. -func offset(hour, level, variable, lat, lng int) int64 { +func (d *File) offset(hour, level, variable, lat, lng int) int64 { + v := d.variant idx := int64(hour) - idx = idx*int64(NumLevels) + int64(level) + idx = idx*int64(v.NumLevels()) + int64(level) idx = idx*int64(NumVariables) + int64(variable) - idx = idx*int64(NumLatitudes) + int64(lat) - idx = idx*int64(NumLongitudes) + int64(lng) + idx = idx*int64(v.NumLatitudes()) + int64(lat) + idx = idx*int64(v.NumLongitudes()) + int64(lng) return idx * int64(ElementSize) } // Val reads one cell as a float32. func (d *File) Val(hour, level, variable, lat, lng int) float32 { - off := offset(hour, level, variable, lat, lng) + off := d.offset(hour, level, variable, lat, lng) return math.Float32frombits(binary.LittleEndian.Uint32(d.mm[off : off+4])) } // SetVal writes one cell. Only valid on writable files. func (d *File) SetVal(hour, level, variable, lat, lng int, val float32) { - off := offset(hour, level, variable, lat, lng) + off := d.offset(hour, level, variable, lat, lng) binary.LittleEndian.PutUint32(d.mm[off:off+4], math.Float32bits(val)) } // BlitGribData copies one decoded GRIB grid into the dataset, flipping the // latitude axis from GRIB's north-to-south scan order to our south-to-north -// storage order. gribData must be 361*720 = 259920 float64 values. +// storage order. func (d *File) BlitGribData(hourIdx, levelIdx, varIdx int, gribData []float64) error { - expected := NumLatitudes * NumLongitudes + v := d.variant + expected := v.NumLatitudes() * v.NumLongitudes() if len(gribData) != expected { return fmt.Errorf("grib data has %d values, expected %d", len(gribData), expected) } - for lat := range NumLatitudes { - for lng := range NumLongitudes { - gribIdx := (360-lat)*NumLongitudes + lng + lats := v.NumLatitudes() + lngs := v.NumLongitudes() + for lat := range lats { + for lng := range lngs { + gribIdx := (lats-1-lat)*lngs + lng d.SetVal(hourIdx, levelIdx, varIdx, lat, lng, float32(gribData[gribIdx])) } } diff --git a/internal/weather/gfs/gefs_variants.go b/internal/weather/gfs/gefs_variants.go new file mode 100644 index 0000000..a18a265 --- /dev/null +++ b/internal/weather/gfs/gefs_variants.go @@ -0,0 +1,68 @@ +package gfs + +import "fmt" + +// Family is the dataset family ("gfs" or "gefs"). Variants of different +// families have different URL layouts but share the cube format. +type Family int + +const ( + FamilyGFS Family = iota + FamilyGEFS +) + +func (f Family) String() string { + switch f { + case FamilyGEFS: + return "gefs" + default: + return "gfs" + } +} + +// HasMember reports whether the family requires a member index in URLs. +func (f Family) HasMember() bool { return f == FamilyGEFS } + +// GEFS variant constants. +// +// The 21-member ensemble is gec00 (control) + gep01..gep20 (perturbations). +// NOAA publishes more members today but 21 matches the historical Tawhiri +// configuration and is what the phase 2 spec calls for. +const GEFSMembers = 21 + +// GefsMemberName returns the file-name token for a GEFS member. +// member=0 → "gec00", member=1..20 → "gep01".."gep20". +func GefsMemberName(member int) string { + if member == 0 { + return "gec00" + } + return fmt.Sprintf("gep%02d", member) +} + +// GEFS S3 mirror. +const GEFSS3BaseURL = "https://noaa-gefs-pds.s3.amazonaws.com" + +// GefsGribURL returns the S3 URL for a GEFS primary GRIB file. +func GefsGribURL(date string, runHour, member, forecastStep int, resToken string) string { + return fmt.Sprintf("%s/gefs.%s/%02d/atmos/pgrb2ap5/%s.t%02dz.pgrb2a.%s.f%03d", + GEFSS3BaseURL, date, runHour, GefsMemberName(member), runHour, resToken, forecastStep) +} + +// GefsGribURLB returns the S3 URL for a GEFS secondary GRIB file. +func GefsGribURLB(date string, runHour, member, forecastStep int, resToken string) string { + return fmt.Sprintf("%s/gefs.%s/%02d/atmos/pgrb2bp5/%s.t%02dz.pgrb2b.%s.f%03d", + GEFSS3BaseURL, date, runHour, GefsMemberName(member), runHour, resToken, forecastStep) +} + +// GEFS variants — 0.5° resolution, 3-hour cadence, 192h horizon. +var GEFS0p50_3h = &Variant{ + ID: "gefs-0p50-3h", + Family: FamilyGEFS, + ResToken: "0p50", + Resolution: 0.5, + HourStep: 3, + MaxHour: 192, + Pressures: GFS0p50_3h.Pressures, + PressuresPgrb2: GFS0p50_3h.PressuresPgrb2, + PressuresPgrb2b: GFS0p50_3h.PressuresPgrb2b, +} diff --git a/internal/weather/gfs/variant.go b/internal/weather/gfs/variant.go new file mode 100644 index 0000000..e2c6443 --- /dev/null +++ b/internal/weather/gfs/variant.go @@ -0,0 +1,191 @@ +package gfs + +import "fmt" + +// Variant describes one configuration of a NOAA dataset family (GFS or GEFS). +// +// The dataset cube is a 5-D float32 array with shape +// (NumHours, NumLevels, NumVariables, NumLatitudes, NumLongitudes) where +// NumVariables and ElementSize are fixed across all GFS variants but the +// other dimensions depend on the resolution and forecast cadence. +type Variant struct { + // ID is a stable identifier ("gfs-0p50-3h", "gefs-0p50-3h", ...). + ID string + // Family identifies the dataset family the variant belongs to. + Family Family + + // Resolution token used in NOAA URLs ("0p50", "0p25"). + ResToken string + // Grid step in degrees (0.5, 0.25). 180 / Resolution + 1 latitudes and + // 360 / Resolution longitudes. + Resolution float64 + + HourStep int // hours between forecast steps + MaxHour int // largest forecast hour (inclusive) + + // Pressures lists every pressure level in dataset index order, descending. + Pressures []int + // PressuresPgrb2 / PressuresPgrb2b split the pressures between the two + // downloaded GRIB files. Their union must equal Pressures. + PressuresPgrb2 []int + PressuresPgrb2b []int + + pressureIndex map[int]int + pressureLevelSet map[int]LevelSet +} + +// NumHours returns MaxHour/HourStep + 1. +func (v *Variant) NumHours() int { return v.MaxHour/v.HourStep + 1 } + +// NumLevels returns len(Pressures). +func (v *Variant) NumLevels() int { return len(v.Pressures) } + +// NumLatitudes returns 180/Resolution + 1. +func (v *Variant) NumLatitudes() int { return int(180.0/v.Resolution) + 1 } + +// NumLongitudes returns 360/Resolution. +func (v *Variant) NumLongitudes() int { return int(360.0 / v.Resolution) } + +// DatasetSize returns the canonical file size in bytes. +func (v *Variant) DatasetSize() int64 { + return int64(v.NumHours()) * int64(v.NumLevels()) * int64(NumVariables) * + int64(v.NumLatitudes()) * int64(v.NumLongitudes()) * int64(ElementSize) +} + +// Hours returns the full list of forecast hours [0, HourStep, ..., MaxHour]. +func (v *Variant) Hours() []int { + out := make([]int, 0, v.NumHours()) + for h := 0; h <= v.MaxHour; h += v.HourStep { + out = append(out, h) + } + return out +} + +// HourIndex returns the dataset time index for an hour, or -1 if invalid. +func (v *Variant) HourIndex(hour int) int { + if hour < 0 || hour > v.MaxHour || hour%v.HourStep != 0 { + return -1 + } + return hour / v.HourStep +} + +// PressureIndex returns the dataset index for a pressure level in hPa, +// or -1 when the level is unknown to this variant. +func (v *Variant) PressureIndex(hPa int) int { + v.indexLazyInit() + if i, ok := v.pressureIndex[hPa]; ok { + return i + } + return -1 +} + +// PressureLevelSet returns the GRIB file set carrying a pressure level. +func (v *Variant) PressureLevelSet(hPa int) (LevelSet, bool) { + v.indexLazyInit() + ls, ok := v.pressureLevelSet[hPa] + return ls, ok +} + +// VariableIndex maps a GRIB (category, number) pair to a dataset variable index. +func (v *Variant) VariableIndex(parameterCategory, parameterNumber int) int { + switch { + case parameterCategory == 3 && parameterNumber == 5: + return VarHeight + case parameterCategory == 2 && parameterNumber == 2: + return VarWindU + case parameterCategory == 2 && parameterNumber == 3: + return VarWindV + default: + return -1 + } +} + +// GribURL returns the S3 URL for the primary (pgrb2) GRIB file. +func (v *Variant) GribURL(date string, runHour, forecastStep int) string { + return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.%s.f%03d", + S3BaseURL, date, runHour, runHour, v.ResToken, forecastStep) +} + +// GribURLB returns the S3 URL for the secondary (pgrb2b) GRIB file. +func (v *Variant) GribURLB(date string, runHour, forecastStep int) string { + return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2b.%s.f%03d", + S3BaseURL, date, runHour, runHour, v.ResToken, forecastStep) +} + +func (v *Variant) indexLazyInit() { + if v.pressureIndex != nil { + return + } + v.pressureIndex = make(map[int]int, len(v.Pressures)) + for i, p := range v.Pressures { + v.pressureIndex[p] = i + } + v.pressureLevelSet = make(map[int]LevelSet, len(v.Pressures)) + for _, p := range v.PressuresPgrb2 { + v.pressureLevelSet[p] = LevelSetA + } + for _, p := range v.PressuresPgrb2b { + v.pressureLevelSet[p] = LevelSetB + } +} + +// Standard variants -- these mirror what NOAA publishes today. +// +// GFS0p50_3h is the historical Tawhiri default: 0.5° resolution, 3-hour +// forecast cadence, 0..192h horizon, 47 pressure levels split across the +// primary and secondary GRIB files. +// +// GFS0p25_3h mirrors the same 3-hour cadence at 0.25° resolution (the +// horizon is larger in practice but we keep 192h for parity with 0p50). +// +// GFS0p25_1h targets the 1-hourly portion NOAA publishes out to 120h. +var ( + GFS0p50_3h = &Variant{ + ID: "gfs-0p50-3h", + ResToken: "0p50", + Resolution: 0.5, + HourStep: 3, + MaxHour: 192, + Pressures: []int{1000, 975, 950, 925, 900, 875, 850, 825, 800, 775, 750, 725, 700, 675, 650, 625, 600, 575, 550, 525, 500, 475, 450, 425, 400, 375, 350, 325, 300, 275, 250, 225, 200, 175, 150, 125, 100, 70, 50, 30, 20, 10, 7, 5, 3, 2, 1}, + PressuresPgrb2: []int{10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 925, 950, 975, 1000}, + PressuresPgrb2b: []int{1, 2, 3, 5, 7, 125, 175, 225, 275, 325, 375, 425, 475, 525, 575, 625, 675, 725, 775, 825, 875}, + } + + GFS0p25_3h = &Variant{ + ID: "gfs-0p25-3h", + ResToken: "0p25", + Resolution: 0.25, + HourStep: 3, + MaxHour: 192, + Pressures: GFS0p50_3h.Pressures, + PressuresPgrb2: GFS0p50_3h.PressuresPgrb2, + PressuresPgrb2b: GFS0p50_3h.PressuresPgrb2b, + } + + GFS0p25_1h = &Variant{ + ID: "gfs-0p25-1h", + ResToken: "0p25", + Resolution: 0.25, + HourStep: 1, + MaxHour: 120, + Pressures: GFS0p50_3h.Pressures, + PressuresPgrb2: GFS0p50_3h.PressuresPgrb2, + PressuresPgrb2b: GFS0p50_3h.PressuresPgrb2b, + } +) + +// VariantByID returns one of the predefined variants by its ID. +func VariantByID(id string) (*Variant, error) { + switch id { + case GFS0p50_3h.ID: + return GFS0p50_3h, nil + case GFS0p25_3h.ID: + return GFS0p25_3h, nil + case GFS0p25_1h.ID: + return GFS0p25_1h, nil + case GEFS0p50_3h.ID: + return GEFS0p50_3h, nil + default: + return nil, fmt.Errorf("unknown variant %q", id) + } +} diff --git a/internal/weather/gfs/wind.go b/internal/weather/gfs/wind.go index 01329b4..dd3667f 100644 --- a/internal/weather/gfs/wind.go +++ b/internal/weather/gfs/wind.go @@ -10,45 +10,49 @@ import ( // Wind is a WindField backed by a GFS dataset file. type Wind struct { file *File + + hourAxis numerics.Axis + latAxis numerics.Axis + lngAxis numerics.Axis } -// NewWind returns a Wind backed by file. +// NewWind returns a Wind backed by file. The axes are constructed from the +// file's variant geometry. func NewWind(file *File) *Wind { - return &Wind{file: file} + v := file.variant + return &Wind{ + file: file, + hourAxis: numerics.Axis{ + Left: 0, + Step: float64(v.HourStep), + N: v.NumHours(), + Name: "hour", + }, + latAxis: numerics.Axis{ + Left: LatStart, + Step: v.Resolution, + N: v.NumLatitudes(), + Name: "lat", + }, + lngAxis: numerics.Axis{ + Left: LonStart, + Step: v.Resolution, + N: v.NumLongitudes(), + Wrap: true, + Name: "lng", + }, + } } // Epoch returns the forecast run time of the underlying file. func (w *Wind) Epoch() time.Time { return w.file.Epoch } -// Source returns the source identifier "noaa-gfs-0p50". -func (w *Wind) Source() string { return "noaa-gfs-0p50" } +// Source returns the variant ID (e.g. "gfs-0p50-3h"). +func (w *Wind) Source() string { return w.file.variant.ID } // Close releases the underlying file's resources. func (w *Wind) Close() error { return w.file.Close() } -// Grid axes for the GFS 0.5-degree dataset. -var ( - hourAxis = numerics.Axis{ - Left: 0, - Step: float64(HourStep), - N: NumHours, - Name: "hour", - } - latAxis = numerics.Axis{ - Left: LatStart, - Step: Resolution, - N: NumLatitudes, - Name: "lat", - } - lngAxis = numerics.Axis{ - Left: LonStart, - Step: Resolution, - N: NumLongitudes, - Wrap: true, - Name: "lng", - } -) - // Wind samples the field at the given UNIX time, geographic coordinate, and // altitude. Vertical interpolation matches Tawhiri: locate the two pressure // levels whose interpolated geopotential heights bracket alt, then linearly @@ -56,15 +60,15 @@ var ( func (w *Wind) Wind(t, lat, lng, alt float64) (weather.Sample, error) { hours := (t - float64(w.file.Epoch.Unix())) / 3600.0 - bh, err := hourAxis.Locate(hours) + bh, err := w.hourAxis.Locate(hours) if err != nil { return weather.Sample{}, err } - bla, err := latAxis.Locate(lat) + bla, err := w.latAxis.Locate(lat) if err != nil { return weather.Sample{}, err } - bln, err := lngAxis.Locate(lng) + bln, err := w.lngAxis.Locate(lng) if err != nil { return weather.Sample{}, err } @@ -76,7 +80,7 @@ func (w *Wind) Wind(t, lat, lng, alt float64) (weather.Sample, error) { } } - levelIdx := numerics.Bisect(0, NumLevels-2, alt, func(level int) float64 { + levelIdx := numerics.Bisect(0, w.file.variant.NumLevels()-2, alt, func(level int) float64 { return numerics.EvalTrilinear(bs, height(level)) })