polish #8

Merged
a.petrov merged 4 commits from polish into main 2026-06-08 09:16:14 +00:00
115 changed files with 26566 additions and 3426 deletions

25
.dockerignore Normal file
View file

@ -0,0 +1,25 @@
# VCS and editor noise
.git
.gitignore
*.md
!README.md
# Build artifacts
/bin
/predictor
*.test
*.out
# Local data and datasets — never bake multi-GB cubes into the image
/data
*.bin
*.bin.downloading
*.manifest.json
/tmp
# Deployment + docs that aren't needed in the build context
/deploy
/examples
/docs
.forgejo
.github

View file

@ -0,0 +1,116 @@
name: CI/CD
# Test on every push/PR; build + push an image and deploy on develop (staging)
# and on v* tags (production). Deployment goes through the Swarmpit REST API.
on:
push:
branches: [main, develop]
tags: ["v*"]
pull_request:
branches: [main, develop]
env:
REGISTRY: git.intra.yksa.space
IMAGE_NAME: web/predictor
jobs:
test:
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: "1.25"
cache: true
- name: Check formatting
run: |
unformatted="$(gofmt -l .)"
if [ -n "$unformatted" ]; then
echo "These files need gofmt:"; echo "$unformatted"; exit 1
fi
- name: Vet
run: go vet ./...
- name: Build
run: go build ./...
- name: Test
run: go test -race ./...
build:
needs: test
runs-on: ubuntu-24.04
if: github.ref == 'refs/heads/develop' || startsWith(github.ref, 'refs/tags/v')
outputs:
tag: ${{ steps.meta.outputs.tag }}
steps:
- uses: actions/checkout@v4
- uses: docker/setup-buildx-action@v3
- uses: docker/login-action@v3
with:
registry: ${{ env.REGISTRY }}
username: ${{ secrets.REGISTRY_USERNAME }}
password: ${{ secrets.REGISTRY_PASSWORD }}
- name: Resolve image tag
id: meta
run: |
if [[ "${{ github.ref }}" == refs/tags/v* ]]; then
TAG="${GITHUB_REF#refs/tags/v}"
else
TAG="develop"
fi
echo "tag=${TAG}" >> "$GITHUB_OUTPUT"
echo "Resolved tag: ${TAG}"
- name: Build and push image
run: |
IMAGE="${REGISTRY}/${IMAGE_NAME}"
TAG="${{ steps.meta.outputs.tag }}"
TAGS="-t ${IMAGE}:${TAG}"
# Tagged releases also move :latest.
if [[ "${TAG}" != "develop" ]]; then
TAGS="${TAGS} -t ${IMAGE}:latest"
fi
docker buildx build \
--platform linux/amd64 \
--build-arg VERSION="${TAG}" \
--build-arg REVISION="${{ github.sha }}" \
--push ${TAGS} .
deploy-staging:
needs: build
runs-on: ubuntu-24.04
if: github.ref == 'refs/heads/develop'
environment: staging
steps:
- uses: actions/checkout@v4
- name: Deploy to Swarmpit (staging)
env:
SWARMPIT_URL: ${{ secrets.SWARMPIT_URL }}
SWARMPIT_TOKEN: ${{ secrets.SWARMPIT_TOKEN }}
STACK_NAME: ${{ secrets.STACK_NAME }}
CA_CERTIFICATES: ${{ secrets.CA_CERTIFICATES }}
TAG: ${{ needs.build.outputs.tag }}
run: sh deploy/swarmpit-deploy.sh
deploy-production:
needs: build
runs-on: ubuntu-24.04
if: startsWith(github.ref, 'refs/tags/v')
environment: production
steps:
- uses: actions/checkout@v4
- name: Deploy to Swarmpit (production)
env:
SWARMPIT_URL: ${{ secrets.SWARMPIT_URL }}
SWARMPIT_TOKEN: ${{ secrets.SWARMPIT_TOKEN }}
STACK_NAME: ${{ secrets.STACK_NAME }}
CA_CERTIFICATES: ${{ secrets.CA_CERTIFICATES }}
TAG: ${{ needs.build.outputs.tag }}
run: sh deploy/swarmpit-deploy.sh

156
DEPLOYMENT.md Normal file
View file

@ -0,0 +1,156 @@
# Deploying stratoflights-predictor
The predictor is a single static Go binary with no database and no required
external services. It downloads NOAA GFS/GEFS wind data to **node-local disk**
and serves the REST API (see `/docs` or `api/rest/predictor.swagger.yml`).
It is an **internal backend**: the public entrypoint is the stratoflights API
gateway, which calls the predictor over an internal overlay network. The
predictor enforces no auth of its own.
## Environments
| Environment | File | Notes |
|---|---|---|
| Local dev | `docker-compose.yml` | one instance, metrics off, named volume |
| Staging (single host) | `docker-compose.staging.yml` | all features + bundled Prometheus |
| Production (Swarm) | `docker-compose.swarm.yml` | node-pinned, replicated, metrics |
```bash
# Local
docker compose up --build
curl localhost:8080/ready
# Staging (single host, exercises the metrics pipeline)
docker compose -f docker-compose.staging.yml up --build
# Prometheus at :9090, predictor target should be UP
# Production — see below
```
## Production (Docker Swarm)
### Storage and node placement — the important part
The wind dataset is ~8.9 GiB (0.5°) and must live on **local disk, never NFS**.
To bound the number of copies, the service is pinned to nodes carrying the
`predictor.data=true` label; **label at most two nodes**. Each labelled node
keeps exactly one copy under a node-local bind mount.
On **each** labelled node, provision the local directories and a writable owner
for the non-root container (uid:gid `65532:65532`):
```bash
sudo mkdir -p /srv/predictor/data /srv/predictor/elevation
sudo chown -R 65532:65532 /srv/predictor
# (optional) seed the elevation dataset so descent terminates at ground level:
# python3 scripts/build_elevation.py /srv/predictor/elevation/ruaumoko-dataset
```
Label the two storage nodes:
```bash
docker node update --label-add predictor.data=true <node-a>
docker node update --label-add predictor.data=true <node-b>
```
Replicas are spread one-per-node by default (redundancy across both copies).
Scaling to multiple replicas **per** node is safe: they share the node-local
volume and coordinate the download with an exclusive `flock`, so only one
process per node fetches the dataset — the others wait and load the committed
file. To scale: `docker service scale predictor_predictor=4` (≤2 per node).
### Network
The gateway and Prometheus reach the predictor over a shared overlay. Create it
once and have the gateway stack join the same external network:
```bash
docker network create -d overlay --attachable stratoflights-net
```
The service is published only on that network under the alias `predictor`
(`http://predictor:8080`). No public Traefik router — the gateway is the edge.
### Deploy
Via the CI pipeline (recommended): push a `v*` tag → the image is built and the
stack is deployed through the Swarmpit API. Manually:
```bash
TAG=v1.0.0 docker stack deploy -c docker-compose.swarm.yml --with-registry-auth predictor
```
or import `docker-compose.swarm.yml` into Swarmpit and set `TAG`.
### Configuration
All settings are env vars (file/env/flag precedence; see README). Production
defaults are in `docker-compose.swarm.yml`:
| Variable | Purpose |
|---|---|
| `PREDICTOR_DATA_DIR=/data` | node-local dataset dir (bind mount) |
| `PREDICTOR_ELEVATION_DATASET=/srv/ruaumoko-dataset` | optional terrain data |
| `PREDICTOR_SOURCE=gfs-0p50-3h` | `gfs-0p50-3h`, `gfs-0p25-3h`, `gfs-0p25-1h`, `gefs-0p50-3h` |
| `PREDICTOR_DOWNLOAD_PARALLEL=16` | concurrent GRIB downloads |
| `PREDICTOR_UPDATE_INTERVAL=6h` | forecast refresh cadence |
| `PREDICTOR_METRICS_ENABLED=true` | expose `/metrics` |
No Docker secrets are needed — the predictor has no database or credentials.
### Health
- `GET /health` — liveness (always 200 while the process runs). The container
`HEALTHCHECK` calls the binary's `-healthcheck` mode (no curl in the image).
- `GET /ready` — readiness (200 only once a dataset is loaded). The gateway
should gate traffic on this; Swarm does **not** kill a container that is still
performing its first download thanks to the 120s `start_period`.
### Metrics
`/metrics` exposes Prometheus counters (`predictor_predictions_total`,
`predictor_downloads_total`, `predictor_download_bytes_total`) and the
`predictor_active_dataset_epoch_seconds` gauge. The service carries
`prometheus.scrape/port/path` deploy labels for Swarm service discovery; point
your central Prometheus at the `stratoflights-net` network.
## CI/CD (Forgejo → Swarmpit)
`.forgejo/workflows/ci-cd.yml`:
1. **test** (every push/PR): `gofmt` check, `go vet`, `go build`, `go test -race`.
2. **build** (develop branch and `v*` tags): buildx `linux/amd64` image pushed to
`git.intra.yksa.space/web/predictor` (`:develop`, or `:<version>` + `:latest`).
3. **deploy-staging** (develop) / **deploy-production** (`v*` tags): deploy
`docker-compose.swarm.yml` to the environment's Swarmpit stack via
`deploy/swarmpit-deploy.sh`.
Configure runner secrets (scope staging/production via Forgejo environments):
- `REGISTRY_USERNAME`, `REGISTRY_PASSWORD` — container registry
- `SWARMPIT_URL`, `SWARMPIT_TOKEN`, `STACK_NAME` — Swarmpit deploy target
- `CA_CERTIFICATES` — optional PEM bundle if Swarmpit uses a private CA
Cut a release:
```bash
git tag v1.0.0 && git push origin v1.0.0
```
## Operations
```bash
docker service ls --filter label=com.docker.stack.namespace=predictor
docker service logs -f predictor_predictor
docker service scale predictor_predictor=2 # ≤2 per labelled node
docker service rollback predictor_predictor
```
Trigger a dataset refresh or inspect jobs through the admin API:
```bash
curl -X POST http://predictor:8080/api/v1/admin/datasets -d '{"latest":true}'
curl http://predictor:8080/api/v1/admin/jobs
curl http://predictor:8080/api/v1/admin/status
```

37
Dockerfile Normal file
View file

@ -0,0 +1,37 @@
# syntax=docker/dockerfile:1
# --- build stage ---------------------------------------------------------
FROM golang:1.25 AS builder
WORKDIR /src
# Cache module downloads.
COPY go.mod go.sum ./
RUN go mod download
COPY . .
# Static, stripped binary — no CGO so it runs on distroless/scratch.
ARG VERSION=dev
ARG REVISION=unknown
RUN CGO_ENABLED=0 GOOS=linux go build \
-trimpath \
-ldflags="-s -w -X main.version=${VERSION} -X main.revision=${REVISION}" \
-o /predictor ./cmd/predictor
# --- runtime stage -------------------------------------------------------
# distroless/static:nonroot ships CA certificates (needed for TLS to the
# NOAA S3 mirror) and runs as uid:gid 65532:65532.
FROM gcr.io/distroless/static-debian12:nonroot AS runtime
COPY --from=builder /predictor /predictor
# Default data dir; mount a node-local volume here in production.
ENV PREDICTOR_DATA_DIR=/data
EXPOSE 8080
# Liveness probe via the binary itself — no shell/curl in the image.
HEALTHCHECK --interval=30s --timeout=5s --start-period=120s --retries=3 \
CMD ["/predictor", "-healthcheck"]
ENTRYPOINT ["/predictor"]

View file

@ -1,12 +1,20 @@
.PHONY: build run test fmt lint clean generate-ogen help
.PHONY: build server cli compare test fmt lint clean generate-ogen docs help
# Build the application
build:
go build -o predictor ./cmd/api
# Build all binaries
build: server cli compare
server:
go build -o bin/predictor ./cmd/predictor
cli:
go build -o bin/predictor-cli ./cmd/predictor-cli
compare:
go build -o bin/compare-tawhiri ./cmd/compare-tawhiri
# Run locally
run:
go run ./cmd/api
go run ./cmd/predictor
# Run tests
test:
@ -20,21 +28,29 @@ fmt:
lint:
golangci-lint run
# Generate ogen API code from swagger spec
# Build the numerics LaTeX doc (requires pdflatex)
docs:
cd docs && pdflatex numerics.tex
# Regenerate ogen API code from the OpenAPI spec. The same spec is embedded by
# the api package (api/spec.go) and served at /openapi.yaml + /docs (ReDoc).
generate-ogen:
go run github.com/ogen-go/ogen/cmd/ogen@latest --target pkg/rest --package rest --clean api/rest/predictor.swagger.yml
# Clean build artifacts
clean:
rm -f predictor
rm -rf bin/ docs/numerics.pdf docs/numerics.aux docs/numerics.log
# Show help
help:
@echo "Available commands:"
@echo " build - Build binary"
@echo " run - Run locally"
@echo " test - Run tests"
@echo " build - Build all binaries to bin/"
@echo " server - Build the HTTP server (cmd/predictor)"
@echo " cli - Build the CLI client (cmd/predictor-cli)"
@echo " compare - Build the validation tool (cmd/compare-tawhiri)"
@echo " run - Run the server with default config"
@echo " test - Run unit tests"
@echo " fmt - Format code"
@echo " lint - Lint code"
@echo " generate-ogen - Generate API code from swagger spec"
@echo " lint - Lint code (golangci-lint)"
@echo " docs - Build the numerics LaTeX doc (requires pdflatex)"
@echo " generate-ogen - Regenerate ogen code from the OpenAPI spec"
@echo " clean - Remove build artifacts"

432
README.md
View file

@ -1,261 +1,285 @@
# Balloon Trajectory Predictor
# stratoflights-predictor
High-altitude balloon trajectory prediction service. Predicts ascent, burst, and descent trajectories using GFS wind forecast data from NOAA.
High-altitude balloon trajectory prediction service. Forecasts ascent, descent,
and float trajectories from NOAA GFS and GEFS wind data, exposed as a REST API.
The prediction algorithms are an exact port of [Tawhiri](https://github.com/cuspaceflight/tawhiri) (Cambridge University Spaceflight) to Go, verified to produce identical results.
The trajectory engine is a propagator-and-constraint system: any flight
profile can be expressed as a chain of propagators (constant-rate ascent,
parachute descent, piecewise rates with absolute / profile-relative /
propagator-relative timing, wind drift) with attached constraints
(scalar comparisons over altitude or time, terrain contact, geographic
polygons). Constraints can stop the profile, hand off to a fallback
propagator, or clip the violated coordinate to the boundary. The legacy
Tawhiri request shape is kept as a compatibility endpoint so existing
clients work unchanged.
## Quick Start
## Quick start
```bash
# Build
make build
make build # produces bin/{predictor,predictor-cli,compare-tawhiri}
./bin/predictor # downloads ~9 GB of GFS data on first start
# Run (downloads ~9 GB of GFS data on first start, takes 30-60 min)
PREDICTOR_DATA_DIR=/tmp/predictor-data go run ./cmd/api
# Check readiness
curl http://localhost:8080/ready
# Run a prediction
curl 'http://localhost:8080/api/v1/prediction?launch_latitude=52.2&launch_longitude=0.1&launch_datetime=2026-03-28T12:00:00Z&launch_altitude=0&ascent_rate=5&burst_altitude=30000&descent_rate=5'
./bin/predictor-cli ready
./bin/predictor-cli predict \
launch_latitude=52.2 launch_longitude=0.1 \
launch_datetime=2026-03-28T12:00:00Z \
ascent_rate=5 burst_altitude=30000 descent_rate=5
```
## Configuration
All configuration is via environment variables.
Layered configuration: built-in defaults < YAML file < env vars < CLI flags.
| Variable | Default | Description |
|---|---|---|
| `PREDICTOR_PORT` | `8080` | HTTP server port |
| `PREDICTOR_DATA_DIR` | `/tmp/predictor-data` | Directory for wind datasets and temp files |
| `PREDICTOR_DOWNLOAD_PARALLEL` | `8` | Max concurrent GRIB download goroutines |
| `PREDICTOR_UPDATE_INTERVAL` | `6h` | How often to check for new forecasts |
| `PREDICTOR_DATASET_TTL` | `48h` | Max age before a dataset is considered stale |
| `PREDICTOR_ELEVATION_DATASET` | `/srv/ruaumoko-dataset` | Path to elevation dataset (optional) |
| Setting | Env var | CLI flag | Default |
|---|---|---|---|
| HTTP port | `PREDICTOR_PORT` | `-port` | `8080` |
| Data directory | `PREDICTOR_DATA_DIR` | `-data-dir` | `/tmp/predictor-data` |
| Elevation dataset | `PREDICTOR_ELEVATION_DATASET` | `-elevation` | `/srv/ruaumoko-dataset` |
| Source variant | `PREDICTOR_SOURCE` | — | `gfs-0p50-3h` |
| Download parallelism | `PREDICTOR_DOWNLOAD_PARALLEL` | `-download-parallel` | `8` |
| Download bandwidth (bytes/s; 0 = unlimited) | `PREDICTOR_DOWNLOAD_BANDWIDTH` | `-download-bandwidth` | `0` |
| Scheduler interval | `PREDICTOR_UPDATE_INTERVAL` | `-update-interval` | `6h` |
| Dataset freshness TTL | `PREDICTOR_DATASET_TTL` | `-freshness-ttl` | `48h` |
| Metrics enabled | `PREDICTOR_METRICS_ENABLED` | `-metrics` | `true` |
| Metrics HTTP path | `PREDICTOR_METRICS_PATH` | `-metrics-path` | `/metrics` |
| Log level | `PREDICTOR_LOG_LEVEL` | `-log-level` | `info` |
## API
YAML config mirrors the same structure; see `internal/config/config.go`.
### `GET /api/v1/prediction`
Supported source variants:
Run a balloon trajectory prediction.
| `source` | Resolution | Cadence | Notes |
|---|---|---|---|
| `gfs-0p50-3h` | 0.5° | 3h to 192h | historical Tawhiri default |
| `gfs-0p25-3h` | 0.25° | 3h to 192h | |
| `gfs-0p25-1h` | 0.25° | 1h to 120h | |
| `gefs-0p50-3h` | 0.5° | 3h to 192h | 21-member ensemble; each member is a separate dataset |
**Parameters** (query string):
## REST API
| Parameter | Required | Description |
|---|---|---|
| `launch_latitude` | yes | Launch latitude in degrees (-90 to 90) |
| `launch_longitude` | yes | Launch longitude in degrees (-180 to 180 or 0 to 360) |
| `launch_datetime` | yes | Launch time in RFC 3339 format |
| `launch_altitude` | no | Launch altitude in metres ASL (default: 0) |
| `profile` | no | `standard_profile` (default) or `float_profile` |
| `ascent_rate` | no | Ascent rate in m/s (default: 5) |
| `burst_altitude` | no | Burst altitude in metres (default: 28000) |
| `descent_rate` | no | Sea-level descent rate in m/s (default: 5) |
| `float_altitude` | no | Float altitude in metres (float_profile only) |
| `stop_datetime` | no | Float end time (float_profile only, default: +24h) |
### Tawhiri-compatible (legacy)
**Response** (Tawhiri-compatible):
`GET /api/v1/prediction` — preserves the exact request and response shape of
the upstream Cambridge University Spaceflight predictor.
`GET /ready` — returns `{"status":"ok", "dataset_time":"..."}` once a dataset
is loaded.
### Profile-driven (synchronous)
`POST /api/v2/prediction` — execute a profile synchronously and return the
trajectory. Request shape:
```json
{
"prediction": [
"launch": { "time": "2026-03-28T12:00:00Z", "latitude": 52.2, "longitude": 0.1, "altitude": 0 },
"direction": "forward",
"profile": [
{
"stage": "ascent",
"trajectory": [
{"datetime": "2026-03-28T12:00:00Z", "latitude": 52.2, "longitude": 0.1, "altitude": 0},
...
]
"name": "ascent",
"model": { "type": "constant_rate", "rate": 5, "include_wind": true },
"constraints": [{ "type": "altitude", "op": ">=", "limit": 30000 }]
},
{
"stage": "descent",
"trajectory": [...]
"name": "descent",
"model": { "type": "parachute_descent", "sea_level_rate": 5, "include_wind": true },
"constraints": [{ "type": "terrain_contact" }]
}
],
"metadata": {
"start_datetime": "...",
"complete_datetime": "..."
},
"request": {
"dataset": "2026-03-28T06:00:00Z",
"launch_latitude": 52.2,
...
"globals": [{ "type": "time", "op": ">", "limit": 1799999999 }]
}
```
Model types: `constant_rate`, `parachute_descent`, `piecewise`, `wind`.
Constraint types: `altitude`, `time`, `terrain_contact`, `polygon`.
Operators: `<`, `<=`, `>`, `>=`, `==`. Actions: `stop` (default), `fallback`, `clip`.
Direction: `forward` (default) or `reverse`.
Piecewise segments support a `reference` field (`absolute`, `profile_start`, or
`propagator_start`) so a single rate schedule can be reused across profiles
with different launch times.
The response includes per-stage trajectories, detailed termination info
(violation state + refined state + constraint name), an `events` array of
non-fatal observations (e.g. `above_model` when altitude exceeded the dataset's
highest pressure level), and dataset metadata.
### Profile-driven (asynchronous)
`POST /api/v1/predictions` — enqueue a prediction. Returns `202` with a job ID:
```json
{"id":"842107d9-…","status":"pending","created_at":"…"}
```
`GET /api/v1/predictions/{id}` — poll status. When `status == "complete"`,
the response includes a `result` field with the full v2 PredictionResponse.
`DELETE /api/v1/predictions/{id}` — cancel a queued job.
A worker pool (`http.async_workers`, default 4) services the queue; completed
results are retained for `http.async_result_ttl` (default 1h).
### Dataset admin
```
GET /api/v1/admin/datasets list stored datasets (epoch, subset, coverage, loaded?)
POST /api/v1/admin/datasets trigger a download
DELETE /api/v1/admin/datasets/{filename} delete by filename (DatasetID.Filename())
GET /api/v1/admin/jobs list every download job
GET /api/v1/admin/jobs/{id} fetch one job
DELETE /api/v1/admin/jobs/{id} cancel a running download
GET /api/v1/admin/status consolidated status (uptime, mem, goroutines, jobs, datasets)
```
Trigger-download body:
```json
{
"epoch": "2026-03-28T06:00:00Z",
"subset": {
"region": { "min_lat": -10, "max_lat": 10, "min_lng": 0, "max_lng": 30 },
"hour_range": { "min_hour": 0, "max_hour": 72 },
"members": [5]
}
}
```
### `GET /ready`
`{"latest": true}` is a shortcut that refreshes the latest global dataset
for the configured source. Each `(epoch, subset)` combination is a
separate dataset; the loader auto-selects which loaded dataset covers a
given prediction query.
Health check. Returns `{"status": "ok"}` when a dataset is loaded.
### Wind visualization
## Elevation Dataset
`GET /api/v1/wind/field` — a velocity grid in the
[wind-js-server](https://github.com/danwild/wind-js-server) / leaflet-velocity
format (a two-element `[U, V]` array of `{header, data}` records), suitable for
animated particle layers. Query params: `time`, `altitude`, `min_lat`,
`max_lat`, `min_lng`, `max_lng`, `step` (degrees, min `0.25`). Responses are
cached in memory by parameters.
Without elevation data, descent terminates at sea level (altitude <= 0). With elevation data, descent terminates at ground level, matching Tawhiri's behaviour.
`GET /api/v1/wind/meta` — active dataset source, epoch, suggested altitudes,
and bounding box.
### Building the elevation dataset
A runnable browser client is in [`examples/wind-demo`](examples/wind-demo).
The elevation dataset uses ETOPO 2022 at 30 arc-second resolution, converted to a ruaumoko-compatible binary format (21601 x 43200 grid of int16 little-endian elevation values in metres).
### Documentation & metrics
**Requirements**: Python 3, xarray, netcdf4, numpy.
`GET /docs` serves a [ReDoc](https://github.com/Redocly/redoc) rendering of the
full OpenAPI spec, which is also available raw at `GET /openapi.yaml`.
```bash
pip install xarray netcdf4 numpy
# Downloads ~1.1 GB from NOAA, produces ~1.74 GB binary file
python3 scripts/build_elevation.py /tmp/predictor-data/ruaumoko-dataset
```
To skip the download if you already have the ETOPO NetCDF file:
```bash
ETOPO_NC_PATH=/path/to/ETOPO_2022_v1_30s_N90W180_surface.nc \
python3 scripts/build_elevation.py /tmp/predictor-data/ruaumoko-dataset
```
The ETOPO 2022 NetCDF can be manually downloaded from:
https://www.ncei.noaa.gov/products/etopo-global-relief-model
### Using the elevation dataset
```bash
PREDICTOR_ELEVATION_DATASET=/tmp/predictor-data/ruaumoko-dataset go run ./cmd/api
```
If the file doesn't exist or can't be read, the service starts normally with a warning and falls back to sea-level termination.
`GET /metrics` — Prometheus text exposition. Counters:
`predictor_predictions_total{profile,status}`, `predictor_downloads_total`,
`predictor_download_bytes_total`, and a gauge
`predictor_active_dataset_epoch_seconds`.
## Architecture
The entire REST API is defined by one OpenAPI spec and served by an
[ogen](https://ogen.dev)-generated server; the `internal/api` package only
implements the generated `Handler` interface, mapping between the wire types
and the engine/dataset/wind subsystems. `/metrics`, `/docs`, and
`/openapi.yaml` are mounted on the same `http.ServeMux` alongside it.
```
cmd/api/main.go Entry point, config, scheduler, HTTP server
cmd/
predictor/ main server
predictor-cli/ HTTP client
compare-tawhiri/ end-to-end validation against the public Tawhiri instance
api/
rest/predictor.swagger.yml OpenAPI 3 spec — ogen input AND served at /openapi.yaml
spec.go embeds the spec (go:embed) for the docs handler
internal/
dataset/
dataset.go Shape constants, pressure levels, S3 URLs
file.go mmap-backed dataset file (read/write/blit)
downloader/
downloader.go S3 partial GRIB download (idx + range requests)
idx.go NOAA .idx file parser
config.go Environment-based configuration
elevation/
elevation.go Ruaumoko-compatible elevation dataset (mmap int16)
prediction/
interpolate.go 4D wind interpolation (time, lat, lon, altitude)
solver.go RK4 integrator with binary search termination
models.go Ascent, descent, wind models; flight profiles
warnings.go Prediction warning counters
service/
service.go Dataset lifecycle, concurrent-safe access
transport/
middleware/log.go Request logging middleware
rest/
handler/handler.go ogen API handler implementation
handler/deps.go Service interface
transport.go ogen HTTP server, CORS
api/rest/predictor.swagger.yml OpenAPI 3.0 spec
pkg/rest/ Generated ogen code (17 files)
scripts/
build_elevation.py ETOPO 2022 to ruaumoko converter
numerics/ performance-critical core: interpolation, bisection,
RK4 + crossing refinement, atmosphere density, vector
and polygon math (portable to C/Rust)
engine/ propagator + constraint orchestration + registry (thin over numerics)
weather/ WindField interface; gfs/ — variant-parameterized GFS cube + sampler
datasets/ Source / Storage / Manager + transactional, resumable, subsettable downloads
grib/ — shared GRIB downloader skeleton (idx parser, HTTP, parallel blit)
gfs/ — GFS Source (URL templating only)
gefs/ — GEFS Source (URL templating + member resolution)
windviz/ cube-agnostic wind-field rasterizer + cache
elevation/ ruaumoko-format ground elevation reader
config/ layered file+env+CLI config
metrics/ Sink interface + Prometheus text impl
api/ ogen Handler implementation
handler.go — composite handler + NewError
prediction.go — v1 (Tawhiri), v2, async predictions
datasets.go — dataset + job admin + status
wind.go — wind visualization endpoints
mapping.go — ogen <-> engine conversions
async/ — prediction worker pool
docs/ — ReDoc page + /openapi.yaml
middleware/ — ogen logging, CORS
pkg/rest/ ogen-generated server/client/types (regenerate via `make generate-ogen`)
examples/wind-demo/ Leaflet + leaflet-velocity sample client
docs/numerics.tex end-to-end mathematical reference
scripts/build_elevation.py ETOPO 2022 → ruaumoko converter
```
## Wind Dataset
## Subsetting and ensembles
The service downloads GFS 0.5-degree forecast data from NOAA S3:
Each stored dataset is identified by `DatasetID = (epoch, subset)`. A subset
restricts the data fetched by region, forecast-hour range, or ensemble
member. The downloader honours the subset (skipping out-of-range
forecast steps; member-selecting URLs for GEFS), the storage tracks each
subset as a separate file (filename includes a deterministic subset key),
and the Manager exposes coverage so per-query dataset selection picks the
right one.
| Property | Value |
## Deployment
The service ships as a single static binary in a distroless image and runs in
three configurations — see **[DEPLOYMENT.md](DEPLOYMENT.md)** for the full guide.
| Environment | File |
|---|---|
| Source | `noaa-gfs-bdp-pds.s3.amazonaws.com` |
| Resolution | 0.5 degrees |
| Grid | 361 lat x 720 lon |
| Time steps | 65 (every 3 hours, 0-192h) |
| Pressure levels | 47 (1000 to 1 hPa) |
| Variables | Geopotential height, U-wind, V-wind |
| Dataset size | 9,528,667,200 bytes (~8.87 GiB) |
| Update cadence | Every 6 hours (GFS runs at 00, 06, 12, 18 UTC) |
| Local dev | `docker compose up --build` (`docker-compose.yml`) |
| Staging (single host, + Prometheus) | `docker-compose.staging.yml` |
| Production (Docker Swarm) | `docker-compose.swarm.yml` |
Data is downloaded using HTTP Range requests against `.idx` index files, fetching only the needed GRIB messages (HGT, UGRD, VGRD at 47 pressure levels). Full download takes 30-60 minutes depending on bandwidth.
Production runs on Docker Swarm pinned to ≤2 nodes labelled `predictor.data=true`,
each holding one copy of the dataset on **node-local disk** (never NFS).
Replicas spread across the two nodes for redundancy; multiple replicas per node
share the node's dataset and coordinate downloads with a file lock so only one
fetches the ~9 GiB cube. The predictor is an internal backend reached by the
API gateway over an overlay network; it enforces no auth itself. CI/CD is a
Forgejo pipeline that builds, tests, and deploys to Swarmpit
(`.forgejo/workflows/ci-cd.yml`).
The dataset is stored as a memory-mapped flat binary file of float32 values in C-order with shape `(65, 47, 3, 361, 720)`.
The async prediction API stores results in memory only; behind a load balancer,
clients must poll the same instance they submitted to (or use the synchronous
`/api/v2/prediction`).
## Prediction Algorithms
### Health
All algorithms are exact ports of the reference implementations in Tawhiri. The following sections describe the key components.
- `GET /health` — liveness, always 200 while the process runs (used by the
container `HEALTHCHECK` via `predictor -healthcheck`).
- `GET /ready` — readiness, 200 only once a dataset is loaded.
### Interpolation (`internal/prediction/interpolate.go`)
## Validation
4D wind interpolation from the dataset grid to arbitrary coordinates.
`./bin/compare-tawhiri --server http://localhost:8080` runs an identical
prediction against the local server and the public SondeHub Tawhiri
instance, reporting the great-circle distance between landing points.
1. **Trilinear weights** (`pick3`): compute 8 interpolation weights for the (hour, lat, lon) cube corners.
2. **Altitude search** (`search`): binary search on interpolated geopotential height to find the two pressure levels bracketing the target altitude.
3. **Wind extraction** (`interp4`): 8-point weighted sum at each bracket level, then linear interpolation between levels.
## Numerical methods
Reference: `tawhiri/interpolate.pyx`
### Solver (`internal/prediction/solver.go`)
4th-order Runge-Kutta integrator with dt = 60 seconds.
- State vector: (latitude, longitude, altitude) in degrees and metres.
- Time: UNIX timestamp in seconds.
- Longitude is kept in [0, 360) via Python-style modulo after each `vecadd`.
- When a terminator fires, binary search refinement (tolerance 0.01) finds the precise termination point between the last good step and the first terminated step.
- Longitude interpolation (`lngLerp`) handles the 0/360 wrap-around.
Reference: `tawhiri/solver.pyx`
### Models (`internal/prediction/models.go`)
- **Constant ascent**: vertical velocity = ascent_rate m/s.
- **Drag descent**: NASA atmosphere density model with drag coefficient = sea_level_rate * 1.1045. Descent rate increases with altitude due to thinner air.
- **Wind velocity**: u, v components from interpolation converted to degrees/second: `dlat = (180/pi) * v / (R)`, `dlng = (180/pi) * u / (R * cos(lat))` where R = 6371009 + altitude.
- **Linear model**: sum of component models (e.g., wind + ascent).
- **Elevation termination**: `ground_elevation > altitude` using ruaumoko dataset.
Reference: `tawhiri/models.py`
### Profiles
- **standard_profile**: ascent (constant rate + wind) until burst altitude, then descent (drag + wind) until ground level.
- **float_profile**: ascent to float altitude, then drift at constant altitude until stop time.
## Verification
The predictor has been verified against the reference Tawhiri implementation:
| Test | Result |
|---|---|
| Dataset (step 0): 36.6M float32 values vs Python/cfgrib | 0 mismatches, max diff = 0.0 |
| Prediction burst point vs public Tawhiri API | Identical (lat, lon, alt all match) |
| Prediction landing point vs public Tawhiri API | Identical lat/lon, 5m altitude diff (different elevation datasets) |
| Descent point count | Identical (46 points) |
| Ascent point count | Identical (101 points) |
## Development
```bash
# Regenerate ogen API code after modifying the swagger spec
make generate-ogen
# Run tests
make test
# Format
make fmt
```
### Comparison tools
```bash
# Compare single dataset step against Python/cfgrib reference
go run ./cmd/compare_step0 <run_YYYYMMDDHH> <output_path>
# Run prediction and compare against public Tawhiri API
go run ./cmd/compare_prediction
```
`docs/numerics.tex` is the complete mathematical reference: state vector,
equations of motion (constant rate, parachute drag, piecewise, wind
transport), numerical methods (multilinear interpolation, bisection,
classical RK4, binary-search termination refinement), constraint
geometry (scalar comparisons, point-in-polygon with antimeridian
handling), and design notes on the deferred items (WGS84/ECEF
coordinate system, mass-aware drift, Monte Carlo).
## References
- [Tawhiri](https://github.com/cuspaceflight/tawhiri) — Reference Python/Cython predictor (Cambridge University Spaceflight)
- [tawhiri-downloader](https://github.com/cuspaceflight/tawhiri-downloader) — OCaml dataset downloader
- [ruaumoko](https://github.com/cuspaceflight/ruaumoko) — Global elevation dataset
- [NOAA GFS](https://www.ncei.noaa.gov/products/weather-climate-models/global-forecast) — Global Forecast System
- [NOAA GFS on S3](https://noaa-gfs-bdp-pds.s3.amazonaws.com/index.html) — Public S3 bucket
- [ETOPO 2022](https://www.ncei.noaa.gov/products/etopo-global-relief-model) — Global relief model for elevation data
- [SondeHub Tawhiri API](https://api.v2.sondehub.org/tawhiri) — Public Tawhiri instance for comparison
- [Tawhiri](https://github.com/cuspaceflight/tawhiri) — reference Python/Cython predictor
- [ruaumoko](https://github.com/cuspaceflight/ruaumoko) — global elevation dataset format
- [NOAA GFS](https://www.ncei.noaa.gov/products/weather-climate-models/global-forecast)
- [NOAA GEFS](https://www.ncei.noaa.gov/products/weather-climate-models/global-ensemble-forecast)
- [ETOPO 2022](https://www.ncei.noaa.gov/products/etopo-global-relief-model)
- [SondeHub Tawhiri API](https://api.v2.sondehub.org/tawhiri) — public Tawhiri instance

View file

@ -1,84 +1,37 @@
openapi: 3.0.4
openapi: 3.0.3
info:
title: Predictor API
version: 0.0.1
paths:
/api/v1/prediction:
get:
title: stratoflights-predictor API
version: "1.0.0"
description: |
Balloon trajectory prediction and wind-dataset management.
Three prediction surfaces are exposed:
* **`GET /api/v1/prediction`** — Tawhiri-compatible, drop-in for the
Cambridge University Spaceflight predictor.
* **`POST /api/v2/prediction`** — profile-driven synchronous prediction
(arbitrary chains of propagators with constraints).
* **`POST /api/v1/predictions`** — the same profile API run asynchronously
via a worker pool, polled by job id.
Dataset management (download, list, delete, job status) lives under
`/api/v1/admin/`, and wind-field visualization data (leaflet-velocity /
wind-layer format) under `/api/v1/wind/`.
servers:
- url: /
description: This server.
tags:
- Prediction
summary: Perform prediction
operationId: performPrediction
parameters:
- in: query
name: launch_latitude
required: true
schema:
type: number
- in: query
name: launch_longitude
required: true
schema:
type: number
- in: query
name: launch_datetime
required: true
schema:
type: string
format: date-time
- in: query
name: launch_altitude
schema:
type: number
- in: query
name: profile
schema:
type: string
enum: [standard_profile, float_profile]
default: standard_profile
- in: query
name: ascent_rate
schema:
type: number
- in: query
name: burst_altitude
schema:
type: number
- in: query
name: descent_rate
schema:
type: number
- in: query
name: float_altitude
schema:
type: number
- in: query
name: stop_datetime
schema:
type: string
format: date-time
- in: query
name: dataset
schema:
type: string
format: date-time
responses:
"200":
description: Prediction response
content:
application/json:
schema:
$ref: '#/components/schemas/PredictionResponse'
default:
description: Error
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
- name: Prediction
- name: Datasets
- name: Wind
- name: Health
paths:
/ready:
get:
tags:
- Health
tags: [Health]
summary: Readiness check
operationId: readinessCheck
responses:
@ -87,113 +40,652 @@ paths:
content:
application/json:
schema:
$ref: '#/components/schemas/ReadinessResponse'
$ref: "#/components/schemas/ReadinessResponse"
default:
$ref: "#/components/responses/DefaultError"
/api/v1/prediction:
get:
tags: [Prediction]
summary: Tawhiri-compatible prediction
operationId: performPrediction
parameters:
- { in: query, name: launch_latitude, required: true, schema: { type: number } }
- { in: query, name: launch_longitude, required: true, schema: { type: number } }
- { in: query, name: launch_datetime, required: true, schema: { type: string, format: date-time } }
- { in: query, name: launch_altitude, schema: { type: number } }
- { in: query, name: profile, schema: { type: string, enum: [standard_profile, float_profile], default: standard_profile } }
- { in: query, name: ascent_rate, schema: { type: number } }
- { in: query, name: burst_altitude, schema: { type: number } }
- { in: query, name: descent_rate, schema: { type: number } }
- { in: query, name: float_altitude, schema: { type: number } }
- { in: query, name: stop_datetime, schema: { type: string, format: date-time } }
- { in: query, name: dataset, schema: { type: string, format: date-time } }
responses:
"200":
description: Prediction response
content:
application/json:
schema:
$ref: "#/components/schemas/PredictionResponse"
default:
$ref: "#/components/responses/DefaultError"
/api/v2/prediction:
post:
tags: [Prediction]
summary: Profile-driven prediction (synchronous)
operationId: performPredictionV2
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/PredictionV2Request"
responses:
"200":
description: Prediction result
content:
application/json:
schema:
$ref: "#/components/schemas/PredictionV2Response"
default:
$ref: "#/components/responses/DefaultError"
/api/v1/predictions:
post:
tags: [Prediction]
summary: Enqueue an asynchronous prediction
operationId: createPredictionJob
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/PredictionV2Request"
responses:
"202":
description: Job accepted
content:
application/json:
schema:
$ref: "#/components/schemas/PredictionJob"
default:
$ref: "#/components/responses/DefaultError"
/api/v1/predictions/{id}:
get:
tags: [Prediction]
summary: Poll an asynchronous prediction job
operationId: getPredictionJob
parameters:
- { in: path, name: id, required: true, schema: { type: string } }
responses:
"200":
description: Job status (with result when complete)
content:
application/json:
schema:
$ref: "#/components/schemas/PredictionJob"
default:
$ref: "#/components/responses/DefaultError"
delete:
tags: [Prediction]
summary: Cancel a queued prediction job
operationId: cancelPredictionJob
parameters:
- { in: path, name: id, required: true, schema: { type: string } }
responses:
"204":
description: Cancelled
default:
$ref: "#/components/responses/DefaultError"
/api/v1/admin/datasets:
get:
tags: [Datasets]
summary: List stored datasets
operationId: listDatasets
responses:
"200":
description: Stored datasets
content:
application/json:
schema:
$ref: "#/components/schemas/DatasetList"
default:
$ref: "#/components/responses/DefaultError"
post:
tags: [Datasets]
summary: Trigger a dataset download
operationId: triggerDatasetDownload
requestBody:
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/DownloadRequest"
responses:
"202":
description: Download accepted
content:
application/json:
schema:
$ref: "#/components/schemas/DownloadAccepted"
default:
$ref: "#/components/responses/DefaultError"
/api/v1/admin/datasets/{name}:
delete:
tags: [Datasets]
summary: Delete a stored dataset by filename
operationId: deleteDataset
parameters:
- { in: path, name: name, required: true, schema: { type: string } }
responses:
"204":
description: Deleted
default:
$ref: "#/components/responses/DefaultError"
/api/v1/admin/jobs:
get:
tags: [Datasets]
summary: List dataset download jobs
operationId: listDatasetJobs
responses:
"200":
description: Download jobs
content:
application/json:
schema:
type: array
items:
$ref: "#/components/schemas/DownloadJob"
default:
$ref: "#/components/responses/DefaultError"
/api/v1/admin/jobs/{id}:
get:
tags: [Datasets]
summary: Get a dataset download job
operationId: getDatasetJob
parameters:
- { in: path, name: id, required: true, schema: { type: string } }
responses:
"200":
description: Download job
content:
application/json:
schema:
$ref: "#/components/schemas/DownloadJob"
default:
$ref: "#/components/responses/DefaultError"
delete:
tags: [Datasets]
summary: Cancel a running download job
operationId: cancelDatasetJob
parameters:
- { in: path, name: id, required: true, schema: { type: string } }
responses:
"204":
description: Cancelled
default:
$ref: "#/components/responses/DefaultError"
/api/v1/admin/status:
get:
tags: [Datasets]
summary: Service status summary
operationId: getServiceStatus
responses:
"200":
description: Status
content:
application/json:
schema:
$ref: "#/components/schemas/StatusResponse"
default:
$ref: "#/components/responses/DefaultError"
/api/v1/wind/meta:
get:
tags: [Wind]
summary: Wind-field visualization metadata
operationId: getWindMeta
responses:
"200":
description: Metadata describing the active dataset for visualization
content:
application/json:
schema:
$ref: "#/components/schemas/WindMeta"
default:
$ref: "#/components/responses/DefaultError"
/api/v1/wind/field:
get:
tags: [Wind]
summary: Wind-field velocity grid (leaflet-velocity / wind-layer format)
operationId: getWindField
parameters:
- { in: query, name: time, schema: { type: string, format: date-time } }
- { in: query, name: altitude, schema: { type: number } }
- { in: query, name: min_lat, schema: { type: number } }
- { in: query, name: max_lat, schema: { type: number } }
- { in: query, name: min_lng, schema: { type: number } }
- { in: query, name: max_lng, schema: { type: number } }
- { in: query, name: step, schema: { type: number } }
responses:
"200":
description: Two-component (U, V) velocity grid
content:
application/json:
schema:
type: array
items:
$ref: "#/components/schemas/WindComponent"
default:
$ref: "#/components/responses/DefaultError"
components:
responses:
DefaultError:
description: Error
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
$ref: "#/components/schemas/Error"
components:
schemas:
Error:
type: object
required:
- error
required: [error]
properties:
error:
type: object
required:
- type
- description
required: [type, description]
properties:
type:
type: string
description:
type: string
type: { type: string }
description: { type: string }
ReadinessResponse:
type: object
required: [status]
properties:
status: { type: string, enum: [ok, not_ready, error] }
dataset_time: { type: string, format: date-time }
error_message: { type: string }
# --- Tawhiri v1 ---------------------------------------------------------
PredictionResponse:
type: object
required:
- prediction
- metadata
required: [prediction, metadata]
properties:
request:
type: object
properties:
dataset:
type: string
launch_latitude:
type: number
launch_longitude:
type: number
launch_datetime:
type: string
launch_altitude:
type: number
profile:
type: string
ascent_rate:
type: number
burst_altitude:
type: number
descent_rate:
type: number
dataset: { type: string }
launch_latitude: { type: number }
launch_longitude: { type: number }
launch_datetime: { type: string }
launch_altitude: { type: number }
profile: { type: string }
ascent_rate: { type: number }
burst_altitude: { type: number }
descent_rate: { type: number }
prediction:
type: array
items:
type: object
required:
- stage
- trajectory
required: [stage, trajectory]
properties:
stage:
type: string
enum: ["ascent", "descent", "float"]
stage: { type: string, enum: [ascent, descent, float] }
trajectory:
type: array
items:
type: object
required:
- datetime
- latitude
- longitude
- altitude
properties:
datetime:
type: string
format: date-time
latitude:
type: number
longitude:
type: number
altitude:
type: number
$ref: "#/components/schemas/TawhiriPoint"
metadata:
type: object
required:
- start_datetime
- complete_datetime
required: [start_datetime, complete_datetime]
properties:
start_datetime:
type: string
format: date-time
complete_datetime:
type: string
format: date-time
start_datetime: { type: string, format: date-time }
complete_datetime: { type: string, format: date-time }
warnings:
type: object
additionalProperties: true
ReadinessResponse:
TawhiriPoint:
type: object
required:
- status
required: [datetime, latitude, longitude, altitude]
properties:
status:
type: string
enum: [ok, not_ready, error]
dataset_time:
type: string
format: date-time
error_message:
datetime: { type: string, format: date-time }
latitude: { type: number }
longitude: { type: number }
altitude: { type: number }
# --- v2 profile-driven --------------------------------------------------
PredictionV2Request:
type: object
required: [launch, profile]
description: |
A profile-driven prediction. `profile` is an ordered chain of
propagators; each integrates from where the previous ended. A stage's
`constraints` decide when it ends and what happens next: stop the
profile, hand off to `fallback_index`, or clip to the boundary.
properties:
launch: { $ref: "#/components/schemas/Launch" }
direction:
type: string
enum: [forward, reverse]
default: forward
description: forward integrates launch→landing; reverse integrates backward in time.
profile:
type: array
items: { $ref: "#/components/schemas/StageSpec" }
globals:
type: array
description: constraints evaluated on every stage in addition to its own.
items: { $ref: "#/components/schemas/ConstraintSpec" }
options: { $ref: "#/components/schemas/Options" }
example:
launch: { time: "2026-03-28T12:00:00Z", latitude: 52.2, longitude: 0.1, altitude: 0 }
profile:
- name: ascent
model: { type: constant_rate, rate: 5, include_wind: true }
constraints: [{ type: altitude, op: ">=", limit: 30000 }]
- name: descent
model: { type: parachute_descent, sea_level_rate: 5, include_wind: true }
constraints: [{ type: terrain_contact }]
Launch:
type: object
required: [time, latitude, longitude]
properties:
time: { type: string, format: date-time }
latitude: { type: number }
longitude: { type: number }
altitude: { type: number }
StageSpec:
type: object
required: [name, model]
properties:
name: { type: string }
model: { $ref: "#/components/schemas/ModelSpec" }
constraints:
type: array
items: { $ref: "#/components/schemas/ConstraintSpec" }
fallback_index: { type: integer }
ModelSpec:
type: object
required: [type]
properties:
type: { type: string, enum: [constant_rate, parachute_descent, piecewise, wind] }
rate: { type: number }
sea_level_rate: { type: number }
include_wind: { type: boolean }
segments:
type: array
items: { $ref: "#/components/schemas/PiecewiseSegment" }
PiecewiseSegment:
type: object
required: [until, rate]
properties:
until: { type: number }
rate: { type: number }
reference: { type: string, enum: [absolute, profile_start, propagator_start], default: absolute }
ConstraintSpec:
type: object
required: [type]
properties:
type: { type: string, enum: [altitude, time, terrain_contact, polygon] }
op: { type: string, enum: ["<", "<=", ">", ">=", "=="] }
limit: { type: number }
action: { type: string, enum: [stop, fallback, clip], default: stop }
mode: { type: string, enum: [inside, outside] }
label: { type: string }
vertices:
type: array
items: { $ref: "#/components/schemas/PolygonVertex" }
PolygonVertex:
type: object
required: [lat, lng]
properties:
lat: { type: number }
lng: { type: number }
Options:
type: object
properties:
step_seconds: { type: number }
tolerance: { type: number }
PredictionV2Response:
type: object
required: [stages, dataset, started_at, completed_at]
properties:
stages:
type: array
items: { $ref: "#/components/schemas/StageResult" }
events:
type: array
items: { $ref: "#/components/schemas/EventSummary" }
dataset: { $ref: "#/components/schemas/DatasetInfo" }
started_at: { type: string, format: date-time }
completed_at: { type: string, format: date-time }
StageResult:
type: object
required: [name, outcome, trajectory]
properties:
name: { type: string }
outcome: { type: string, enum: [stopped, fallback, continued] }
constraint: { type: string }
termination: { $ref: "#/components/schemas/TerminationInfo" }
events:
type: array
items: { $ref: "#/components/schemas/EventSummary" }
trajectory:
type: array
items: { $ref: "#/components/schemas/TrajectoryPoint" }
TrajectoryPoint:
type: object
required: [time, latitude, longitude, altitude]
properties:
time: { type: string, format: date-time }
latitude: { type: number }
longitude: { type: number }
altitude: { type: number }
GeoState:
type: object
required: [lat, lng, altitude]
properties:
lat: { type: number }
lng: { type: number }
altitude: { type: number }
TerminationInfo:
type: object
required: [violation_time, violation_state, refined_time, refined_state]
properties:
violation_time: { type: string, format: date-time }
violation_state: { $ref: "#/components/schemas/GeoState" }
refined_time: { type: string, format: date-time }
refined_state: { $ref: "#/components/schemas/GeoState" }
EventSummary:
type: object
required: [type, count]
properties:
type: { type: string }
count: { type: integer, format: int64 }
first_time: { type: number }
last_time: { type: number }
first_state: { $ref: "#/components/schemas/GeoState" }
last_state: { $ref: "#/components/schemas/GeoState" }
message: { type: string }
DatasetInfo:
type: object
required: [source, epoch]
properties:
source: { type: string }
epoch: { type: string, format: date-time }
# --- async jobs ---------------------------------------------------------
PredictionJob:
type: object
required: [id, status, created_at]
properties:
id: { type: string }
status: { type: string, enum: [pending, running, complete, failed, cancelled] }
created_at: { type: string, format: date-time }
started_at: { type: string, format: date-time }
completed_at: { type: string, format: date-time }
error: { type: string }
result: { $ref: "#/components/schemas/PredictionV2Response" }
# --- dataset admin ------------------------------------------------------
Region:
type: object
required: [min_lat, max_lat, min_lng, max_lng]
properties:
min_lat: { type: number }
max_lat: { type: number }
min_lng: { type: number }
max_lng: { type: number }
HourRange:
type: object
required: [min_hour, max_hour]
properties:
min_hour: { type: integer }
max_hour: { type: integer }
SubsetSpec:
type: object
properties:
region: { $ref: "#/components/schemas/Region" }
hour_range: { $ref: "#/components/schemas/HourRange" }
members:
type: array
items: { type: integer }
Coverage:
type: object
required: [region, start_time, end_time]
properties:
region: { $ref: "#/components/schemas/Region" }
start_time: { type: string, format: date-time }
end_time: { type: string, format: date-time }
DownloadRequest:
type: object
properties:
epoch: { type: string, format: date-time }
latest: { type: boolean }
subset: { $ref: "#/components/schemas/SubsetSpec" }
DownloadAccepted:
type: object
required: [job_id]
properties:
job_id: { type: string }
DatasetEntry:
type: object
required: [filename, epoch, loaded]
properties:
filename: { type: string }
epoch: { type: string, format: date-time }
subset: { $ref: "#/components/schemas/SubsetSpec" }
coverage: { $ref: "#/components/schemas/Coverage" }
loaded: { type: boolean }
DatasetList:
type: object
required: [source, datasets]
properties:
source: { type: string }
datasets:
type: array
items: { $ref: "#/components/schemas/DatasetEntry" }
DownloadJob:
type: object
required: [id, source, dataset, epoch, status, started_at, total_units, done_units, bytes]
properties:
id: { type: string }
source: { type: string }
dataset: { type: string }
epoch: { type: string, format: date-time }
status: { type: string, enum: [pending, running, complete, failed, cancelled] }
started_at: { type: string, format: date-time }
ended_at: { type: string, format: date-time }
error: { type: string }
total_units: { type: integer }
done_units: { type: integer }
bytes: { type: integer, format: int64 }
StatusResponse:
type: object
required: [source, uptime, goroutines, memory_mb, jobs_by_status, stored_datasets, loaded_datasets]
properties:
source: { type: string }
uptime: { type: string }
goroutines: { type: integer }
memory_mb: { type: integer, format: int64 }
jobs_by_status:
type: object
additionalProperties: { type: integer }
stored_datasets: { type: integer }
loaded_datasets: { type: integer }
# --- wind visualization -------------------------------------------------
WindMeta:
type: object
required: [source, epoch, default_step, min_step, suggested_altitudes, bbox]
properties:
source: { type: string }
epoch: { type: string, format: date-time }
default_step: { type: number }
min_step: { type: number }
suggested_altitudes:
type: array
items: { type: integer }
bbox: { $ref: "#/components/schemas/Region" }
WindComponent:
type: object
required: [header, data]
properties:
header: { $ref: "#/components/schemas/WindHeader" }
data:
type: array
items: { type: number }
WindHeader:
type: object
required: [parameterCategory, parameterNumber, nx, ny, lo1, la1, lo2, la2, dx, dy, refTime, forecastTime]
properties:
parameterCategory: { type: integer }
parameterNumber: { type: integer }
parameterNumberName: { type: string }
parameterUnit: { type: string }
nx: { type: integer }
ny: { type: integer }
lo1: { type: number }
la1: { type: number }
lo2: { type: number }
la2: { type: number }
dx: { type: number }
dy: { type: number }
refTime: { type: string }
forecastTime: { type: integer }

13
api/spec.go Normal file
View file

@ -0,0 +1,13 @@
// Package apispec embeds the OpenAPI specification so it can be served at
// runtime (for the ReDoc documentation page and /openapi.yaml) without
// shipping a separate file alongside the binary.
//
// The spec at rest/predictor.swagger.yml is the single source of truth: it
// is both the ogen code-generation input (see the Makefile generate-ogen
// target) and the document served by the API's docs handler.
package apispec
import _ "embed"
//go:embed rest/predictor.swagger.yml
var Spec []byte

View file

@ -1,98 +0,0 @@
package main
import (
"context"
"fmt"
"os"
"os/signal"
"syscall"
"time"
"predictor-refactored/internal/downloader"
"predictor-refactored/internal/service"
"predictor-refactored/internal/transport/rest"
"predictor-refactored/internal/transport/rest/handler"
"github.com/go-co-op/gocron"
"go.uber.org/zap"
)
func main() {
log, err := zap.NewProduction()
if err != nil {
panic(err)
}
defer log.Sync()
cfg := downloader.LoadConfig()
log.Info("configuration loaded",
zap.String("data_dir", cfg.DataDir),
zap.Int("parallel", cfg.Parallel),
zap.Duration("update_interval", cfg.UpdateInterval),
zap.Duration("dataset_ttl", cfg.DatasetTTL))
if err := os.MkdirAll(cfg.DataDir, 0o755); err != nil {
log.Fatal("failed to create data directory", zap.Error(err))
}
svc := service.New(cfg, log)
defer svc.Close()
// Load elevation dataset (optional — falls back to sea-level termination)
elevPath := "/srv/ruaumoko-dataset"
if v := os.Getenv("PREDICTOR_ELEVATION_DATASET"); v != "" {
elevPath = v
}
svc.LoadElevation(elevPath)
// Initial dataset load (async so the server starts immediately)
go func() {
log.Info("performing initial dataset update...")
if err := svc.Update(context.Background()); err != nil {
log.Error("initial dataset update failed", zap.Error(err))
} else {
log.Info("initial dataset update complete")
}
}()
// Scheduler for periodic dataset updates
scheduler := gocron.NewScheduler(time.UTC)
scheduler.Every(cfg.UpdateInterval).Do(func() {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
defer cancel()
log.Info("scheduled dataset update starting")
if err := svc.Update(ctx); err != nil {
log.Error("scheduled dataset update failed", zap.Error(err))
} else {
log.Info("scheduled dataset update complete")
}
})
scheduler.StartAsync()
defer scheduler.Stop()
// HTTP transport (ogen)
port := 8080
if p := os.Getenv("PREDICTOR_PORT"); p != "" {
fmt.Sscanf(p, "%d", &port)
}
h := handler.New(svc, log)
transport, err := rest.New(h, port, log)
if err != nil {
log.Fatal("failed to create transport", zap.Error(err))
}
go func() {
if err := transport.Run(); err != nil {
log.Fatal("HTTP server error", zap.Error(err))
}
}()
log.Info("service started")
// Graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
sig := <-sigChan
log.Info("received shutdown signal", zap.String("signal", sig.String()))
}

276
cmd/compare-tawhiri/main.go Normal file
View file

@ -0,0 +1,276 @@
// Command compare-tawhiri runs identical predictions against a local predictor
// and a hosted Tawhiri instance and reports how closely they agree.
//
// To make the comparison test the engine rather than data drift, it discovers
// the local predictor's loaded GFS run via /ready and asks Tawhiri to use the
// same run (the `dataset` parameter), so both integrate identical wind data.
// It compares the burst apex (terrain-independent) and the landing point
// (terrain-dependent) separately, since without the ruaumoko elevation dataset
// the local predictor terminates descent at sea level while Tawhiri uses
// ground elevation.
//
// Usage:
//
// compare-tawhiri --server http://localhost:8080 # built-in suite
// compare-tawhiri --lat 52.2 --lng 0.1 --burst 30000 # single site
package main
import (
"encoding/json"
"flag"
"fmt"
"io"
"math"
"net/http"
"net/url"
"os"
"text/tabwriter"
"time"
)
func main() {
var (
server = flag.String("server", "http://localhost:8080", "local predictor base URL")
tawhiri = flag.String("tawhiri", "https://api.v2.sondehub.org/tawhiri", "hosted Tawhiri base URL")
lat = flag.Float64("lat", math.NaN(), "launch latitude (single-site mode)")
lng = flag.Float64("lng", math.NaN(), "launch longitude (single-site mode)")
alt = flag.Float64("alt", 0, "launch altitude m")
ascent = flag.Float64("ascent-rate", 5, "ascent rate m/s")
burst = flag.Float64("burst", 30000, "burst altitude m")
descent = flag.Float64("descent-rate", 5, "descent rate m/s")
launch = flag.String("launch", "", "launch time RFC3339 (default: epoch + 3h)")
align = flag.Bool("align-dataset", true, "ask Tawhiri to use the local predictor's GFS run")
)
flag.Parse()
epoch, err := fetchActiveEpoch(*server)
if err != nil {
fmt.Fprintln(os.Stderr, "local /ready:", err)
os.Exit(1)
}
fmt.Printf("local dataset epoch: %s\n", epoch.Format(time.RFC3339))
launchTime := epoch.Add(3 * time.Hour)
if *launch != "" {
launchTime, err = time.Parse(time.RFC3339, *launch)
if err != nil {
fmt.Fprintln(os.Stderr, "invalid --launch:", err)
os.Exit(1)
}
}
datasetParam := ""
if *align {
datasetParam = epoch.Format(time.RFC3339)
}
sites := suite()
if !math.IsNaN(*lat) && !math.IsNaN(*lng) {
sites = []site{{name: "custom", lat: *lat, lng: *lng}}
}
tw := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(tw, "\nsite\tburst Δ\tlanding Δ\tapex alt Δ\tland alt Δ\tasc pts\tdesc pts\tnotes")
fmt.Fprintln(tw, "----\t-------\t---------\t----------\t----------\t-------\t--------\t-----")
var worst float64
compared := 0
for _, s := range sites {
p := params{lat: s.lat, lng: s.lng, alt: *alt, launch: launchTime,
ascent: *ascent, burst: *burst, descent: *descent}
ours, err := predict(*server+"/api/v1/prediction", p, "")
if err != nil {
fmt.Fprintf(tw, "%s\tlocal error: %v\n", s.name, err)
continue
}
theirs, err := predict(*tawhiri, p, datasetParam)
if err != nil {
fmt.Fprintf(tw, "%s\ttawhiri error: %v\n", s.name, err)
continue
}
compared++
burstD := haversine(ours.apexLat, ours.apexLng, theirs.apexLat, theirs.apexLng)
landD := haversine(ours.landLat, ours.landLng, theirs.landLat, theirs.landLng)
if landD > worst {
worst = landD
}
note := ""
if theirs.dataset != "" && ours.dataset != "" && theirs.dataset != ours.dataset {
note = fmt.Sprintf("dataset mismatch (theirs=%s)", theirs.dataset)
}
fmt.Fprintf(tw, "%s\t%.0f m\t%.2f km\t%.0f m\t%.0f m\t%d/%d\t%d/%d\t%s\n",
s.name, burstD, landD/1000,
math.Abs(ours.apexAlt-theirs.apexAlt), math.Abs(ours.landAlt-theirs.landAlt),
ours.ascPts, theirs.ascPts, ours.descPts, theirs.descPts, note)
}
tw.Flush()
if compared == 0 {
fmt.Println("\nVERDICT: NO COMPARISONS (every site errored — see rows above)")
os.Exit(1)
}
fmt.Printf("\ncompared %d/%d sites; worst landing distance: %.2f km\n", compared, len(sites), worst/1000)
switch {
case worst < 1000:
fmt.Println("VERDICT: MATCH (all landings < 1 km — engine agrees with Tawhiri)")
case worst < 50000:
fmt.Println("VERDICT: CLOSE (< 50 km — consistent with elevation/dataset differences)")
default:
fmt.Println("VERDICT: DIVERGENT (> 50 km — investigate)")
os.Exit(2)
}
}
type site struct {
name string
lat, lng float64
}
// suite is a small set of diverse launch points: UK (lands on land/sea
// depending on winds), mid-Atlantic and mid-Pacific (ocean landings, so the
// sea-level-vs-terrain difference vanishes), and southern hemisphere.
func suite() []site {
return []site{
{"cambridge-uk", 52.2135, 0.0964},
{"mid-atlantic", 35.0, -40.0},
{"mid-pacific", 0.0, -160.0},
{"new-zealand", -41.3, 174.8},
{"colorado-us", 39.0, -105.5},
}
}
type params struct {
lat, lng, alt float64
launch time.Time
ascent, burst, descent float64
}
type result struct {
apexLat, apexLng, apexAlt float64
landLat, landLng, landAlt float64
ascPts, descPts int
dataset string
}
func predict(endpoint string, p params, dataset string) (result, error) {
// Tawhiri requires longitude in [0, 360); normalize so both endpoints get
// the same request. Returned trajectory longitudes are [-180, 180] on both
// sides, so the comparison stays consistent.
lng := p.lng
if lng < 0 {
lng += 360
}
q := url.Values{}
q.Set("launch_latitude", fmt.Sprintf("%.4f", p.lat))
q.Set("launch_longitude", fmt.Sprintf("%.4f", lng))
q.Set("launch_altitude", fmt.Sprintf("%.0f", p.alt))
q.Set("launch_datetime", p.launch.Format(time.RFC3339))
q.Set("ascent_rate", fmt.Sprintf("%.2f", p.ascent))
q.Set("burst_altitude", fmt.Sprintf("%.0f", p.burst))
q.Set("descent_rate", fmt.Sprintf("%.2f", p.descent))
if dataset != "" {
q.Set("dataset", dataset)
}
full := endpoint + "?" + q.Encode()
var body []byte
var status int
var lastErr error
for range 3 {
resp, err := http.Get(full)
if err != nil {
lastErr = err
time.Sleep(time.Second)
continue
}
body, _ = io.ReadAll(resp.Body)
status = resp.StatusCode
resp.Body.Close()
lastErr = nil
break
}
if lastErr != nil {
return result{}, lastErr
}
if status != 200 {
return result{}, fmt.Errorf("HTTP %d: %s", status, truncate(string(body), 160))
}
var doc struct {
Prediction []struct {
Stage string `json:"stage"`
Trajectory []struct {
Latitude float64 `json:"latitude"`
Longitude float64 `json:"longitude"`
Altitude float64 `json:"altitude"`
} `json:"trajectory"`
} `json:"prediction"`
Request struct {
Dataset string `json:"dataset"`
} `json:"request"`
}
if err := json.Unmarshal(body, &doc); err != nil {
return result{}, err
}
var r result
r.dataset = doc.Request.Dataset
for _, st := range doc.Prediction {
if len(st.Trajectory) == 0 {
continue
}
last := st.Trajectory[len(st.Trajectory)-1]
switch st.Stage {
case "ascent":
r.ascPts = len(st.Trajectory)
r.apexLat, r.apexLng, r.apexAlt = last.Latitude, last.Longitude, last.Altitude
case "descent":
r.descPts = len(st.Trajectory)
r.landLat, r.landLng, r.landAlt = last.Latitude, last.Longitude, last.Altitude
}
}
return r, nil
}
type readinessResp struct {
Status string `json:"status"`
DatasetTime string `json:"dataset_time"`
}
func fetchActiveEpoch(base string) (time.Time, error) {
resp, err := http.Get(base + "/ready")
if err != nil {
return time.Time{}, err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return time.Time{}, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
var r readinessResp
if err := json.Unmarshal(body, &r); err != nil {
return time.Time{}, err
}
if r.Status != "ok" {
return time.Time{}, fmt.Errorf("server status %q (no dataset loaded yet)", r.Status)
}
return time.Parse(time.RFC3339, r.DatasetTime)
}
func haversine(lat1, lng1, lat2, lng2 float64) float64 {
const R = 6371000.0
phi1 := lat1 * math.Pi / 180
phi2 := lat2 * math.Pi / 180
dphi := (lat2 - lat1) * math.Pi / 180
dlam := (lng2 - lng1) * math.Pi / 180
a := math.Sin(dphi/2)*math.Sin(dphi/2) + math.Cos(phi1)*math.Cos(phi2)*math.Sin(dlam/2)*math.Sin(dlam/2)
return R * 2 * math.Atan2(math.Sqrt(a), math.Sqrt(1-a))
}
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "…"
}

View file

@ -1,195 +0,0 @@
package main
import (
"context"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"os"
"time"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/downloader"
"predictor-refactored/internal/prediction"
"go.uber.org/zap"
)
// Downloads a few forecast steps and runs a prediction, then compares
// against the public Tawhiri API.
func main() {
log, _ := zap.NewDevelopment()
cfg := &downloader.Config{
DataDir: os.TempDir(),
Parallel: 4,
}
dl := downloader.NewDownloader(cfg, log)
ctx := context.Background()
// Find latest run
run, err := dl.FindLatestRun(ctx)
if err != nil {
fmt.Fprintf(os.Stderr, "FindLatestRun: %v\n", err)
os.Exit(1)
}
fmt.Printf("Using run: %s\n", run.Format("2006010215"))
// Create dataset and download first 10 steps (0-27 hours, enough for a prediction)
dsPath := fmt.Sprintf("/tmp/pred_test_%s.bin", run.Format("2006010215"))
defer os.Remove(dsPath)
ds, err := dataset.Create(dsPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Create: %v\n", err)
os.Exit(1)
}
date := run.Format("20060102")
runHour := run.Hour()
stepsToDownload := []int{0, 3, 6, 9, 12, 15, 18, 21, 24, 27}
fmt.Printf("Downloading %d steps...\n", len(stepsToDownload))
for _, step := range stepsToDownload {
hourIdx := dataset.HourIndex(step)
fmt.Printf(" step %d (hour idx %d)...\n", step, hourIdx)
urlA := dataset.GribURL(date, runHour, step)
if err := dl.DownloadAndBlit(ctx, ds, urlA, hourIdx, dataset.LevelSetA); err != nil {
fmt.Fprintf(os.Stderr, " pgrb2 step %d: %v\n", step, err)
os.Exit(1)
}
urlB := dataset.GribURLB(date, runHour, step)
if err := dl.DownloadAndBlit(ctx, ds, urlB, hourIdx, dataset.LevelSetB); err != nil {
fmt.Fprintf(os.Stderr, " pgrb2b step %d: %v\n", step, err)
os.Exit(1)
}
}
ds.Flush()
fmt.Println("Download complete")
// Set dataset time
ds.DSTime = run
// Run our prediction
launchLat := 52.2135
launchLon := 0.0964 // already in [0, 360)
launchAlt := 0.0
ascentRate := 5.0
burstAlt := 30000.0
descentRate := 5.0
// Launch 3 hours into the forecast
launchTime := run.Add(3 * time.Hour)
launchTimestamp := float64(launchTime.Unix())
dsEpoch := float64(run.Unix())
warnings := &prediction.Warnings{}
stages := prediction.StandardProfile(ascentRate, burstAlt, descentRate, ds, dsEpoch, warnings, nil)
results := prediction.RunPrediction(launchTimestamp, launchLat, launchLon, launchAlt, stages)
fmt.Printf("\n=== Our prediction ===\n")
for i, sr := range results {
stage := "ascent"
if i == 1 {
stage = "descent"
}
first := sr.Points[0]
last := sr.Points[len(sr.Points)-1]
fmt.Printf(" %s: %d points, start=(%.4f, %.4f, %.0fm) end=(%.4f, %.4f, %.0fm)\n",
stage, len(sr.Points),
first.Lat, first.Lng, first.Alt,
last.Lat, last.Lng, last.Alt)
}
// Get landing point
var ourLandLat, ourLandLon float64
if len(results) >= 2 {
last := results[1].Points[len(results[1].Points)-1]
ourLandLat = last.Lat
ourLandLon = last.Lng
if ourLandLon > 180 {
ourLandLon -= 360
}
}
fmt.Printf(" Landing: lat=%.4f, lon=%.4f\n", ourLandLat, ourLandLon)
// Compare against public Tawhiri API
fmt.Printf("\n=== Tawhiri API comparison ===\n")
tawhiriLandLat, tawhiriLandLon, err := queryTawhiri(launchLat, launchLon, launchAlt, launchTime, ascentRate, burstAlt, descentRate)
if err != nil {
fmt.Printf(" Tawhiri API error: %v\n", err)
fmt.Println(" (Cannot compare — Tawhiri may use a different dataset)")
ds.Close()
return
}
fmt.Printf(" Tawhiri landing: lat=%.4f, lon=%.4f\n", tawhiriLandLat, tawhiriLandLon)
dist := haversine(ourLandLat, ourLandLon, tawhiriLandLat, tawhiriLandLon)
fmt.Printf(" Distance between landing points: %.2f km\n", dist/1000)
if dist < 1000 {
fmt.Println(" CLOSE MATCH (< 1 km)")
} else if dist < 50000 {
fmt.Printf(" MODERATE DIFFERENCE (%.1f km) — likely different datasets\n", dist/1000)
} else {
fmt.Printf(" LARGE DIFFERENCE (%.1f km) — possible bug\n", dist/1000)
}
ds.Close()
}
func queryTawhiri(lat, lon, alt float64, launchTime time.Time, ascentRate, burstAlt, descentRate float64) (landLat, landLon float64, err error) {
url := fmt.Sprintf(
"https://api.v2.sondehub.org/tawhiri?launch_latitude=%.4f&launch_longitude=%.4f&launch_altitude=%.0f&launch_datetime=%s&ascent_rate=%.1f&burst_altitude=%.0f&descent_rate=%.1f",
lat, lon, alt, launchTime.Format(time.RFC3339), ascentRate, burstAlt, descentRate)
resp, err := http.Get(url)
if err != nil {
return 0, 0, err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return 0, 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
var result struct {
Prediction []struct {
Stage string `json:"stage"`
Trajectory []struct {
Latitude float64 `json:"latitude"`
Longitude float64 `json:"longitude"`
Altitude float64 `json:"altitude"`
} `json:"trajectory"`
} `json:"prediction"`
}
if err := json.Unmarshal(body, &result); err != nil {
return 0, 0, err
}
for _, stage := range result.Prediction {
if stage.Stage == "descent" && len(stage.Trajectory) > 0 {
last := stage.Trajectory[len(stage.Trajectory)-1]
return last.Latitude, last.Longitude, nil
}
}
return 0, 0, fmt.Errorf("no descent stage found")
}
func haversine(lat1, lon1, lat2, lon2 float64) float64 {
const R = 6371000.0
phi1 := lat1 * math.Pi / 180
phi2 := lat2 * math.Pi / 180
dphi := (lat2 - lat1) * math.Pi / 180
dlam := (lon2 - lon1) * math.Pi / 180
a := math.Sin(dphi/2)*math.Sin(dphi/2) + math.Cos(phi1)*math.Cos(phi2)*math.Sin(dlam/2)*math.Sin(dlam/2)
return R * 2 * math.Atan2(math.Sqrt(a), math.Sqrt(1-a))
}

View file

@ -1,104 +0,0 @@
package main
import (
"context"
"fmt"
"os"
"time"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/downloader"
"go.uber.org/zap"
)
// Downloads step 0 of a given run and writes a minimal dataset for comparison.
// Usage: go run ./cmd/compare_step0 <run_YYYYMMDDHH> <output_path>
func main() {
if len(os.Args) < 3 {
fmt.Fprintf(os.Stderr, "Usage: %s <run_YYYYMMDDHH> <output_path>\n", os.Args[0])
os.Exit(1)
}
runStr := os.Args[1]
outPath := os.Args[2]
run, err := time.Parse("2006010215", runStr)
if err != nil {
fmt.Fprintf(os.Stderr, "Invalid run time %q: %v\n", runStr, err)
os.Exit(1)
}
log, _ := zap.NewDevelopment()
// Create a full-size dataset (we only fill step 0)
fmt.Printf("Creating dataset at %s (%d bytes)...\n", outPath, dataset.DatasetSize)
ds, err := dataset.Create(outPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Create dataset: %v\n", err)
os.Exit(1)
}
defer ds.Close()
cfg := &downloader.Config{
DataDir: os.TempDir(),
Parallel: 4,
}
dl := downloader.NewDownloader(cfg, log)
ctx := context.Background()
date := run.Format("20060102")
runHour := run.Hour()
// Download and blit step 0 from pgrb2
fmt.Println("Downloading pgrb2 step 0...")
urlA := dataset.GribURL(date, runHour, 0)
if err := dl.DownloadAndBlit(ctx, ds, urlA, 0, dataset.LevelSetA); err != nil {
fmt.Fprintf(os.Stderr, "pgrb2: %v\n", err)
os.Exit(1)
}
fmt.Println(" done")
// Download and blit step 0 from pgrb2b
fmt.Println("Downloading pgrb2b step 0...")
urlB := dataset.GribURLB(date, runHour, 0)
if err := dl.DownloadAndBlit(ctx, ds, urlB, 0, dataset.LevelSetB); err != nil {
fmt.Fprintf(os.Stderr, "pgrb2b: %v\n", err)
os.Exit(1)
}
fmt.Println(" done")
if err := ds.Flush(); err != nil {
fmt.Fprintf(os.Stderr, "Flush: %v\n", err)
os.Exit(1)
}
// Spot-check: print same values as the Python script
fmt.Println("\n=== Go dataset values (spot check) ===")
type testPoint struct {
varName string
varIdx int
levelIdx int
lat, lon int
}
points := []testPoint{
{"HGT", 0, 0, 0, 0}, // HGT @ 1000mb, lat=-90, lon=0
{"HGT", 0, 0, 180, 0}, // HGT @ 1000mb, lat=0, lon=0
{"HGT", 0, 0, 360, 0}, // HGT @ 1000mb, lat=+90, lon=0
{"HGT", 0, 20, 180, 360}, // HGT @ 500mb, lat=0, lon=180
{"UGRD", 1, 0, 180, 0}, // UGRD @ 1000mb, lat=0, lon=0
{"VGRD", 2, 0, 180, 0}, // VGRD @ 1000mb, lat=0, lon=0
{"UGRD", 1, 20, 284, 0}, // UGRD @ 500mb, lat=52N, lon=0
}
for _, p := range points {
val := ds.Val(0, p.levelIdx, p.varIdx, p.lat, p.lon)
actualLat := -90.0 + float64(p.lat)*0.5
actualLon := float64(p.lon) * 0.5
fmt.Printf(" %-4s %4dmb lat=%+7.1f lon=%6.1f: %12.4f\n",
p.varName, dataset.Pressures[p.levelIdx], actualLat, actualLon, val)
}
fmt.Printf("\nDataset written to %s\n", outPath)
}

215
cmd/predictor-cli/main.go Normal file
View file

@ -0,0 +1,215 @@
// Command predictor-cli is a small HTTP client for stratoflights-predictor.
//
// It is intended for operations and development; production callers should
// use the REST API directly.
package main
import (
"bytes"
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"net/url"
"os"
"strings"
"time"
)
const usage = `predictor-cli HTTP client for stratoflights-predictor
USAGE
predictor-cli [--server URL] <command> [args...]
COMMANDS
ready Check service health
predict <KEY=VAL>... Run a Tawhiri-compat prediction (key=value pairs)
datasets list List stored dataset epochs
datasets download [--latest|--epoch RFC3339]
Trigger a dataset download
datasets delete <epoch> Delete a stored dataset
jobs list List download jobs
jobs get <id> Show one job
jobs cancel <id> Cancel a running job
ENVIRONMENT
PREDICTOR_SERVER Default --server (overridden by the flag)
`
func main() {
fs := flag.NewFlagSet("predictor-cli", flag.ContinueOnError)
fs.Usage = func() { fmt.Fprint(os.Stderr, usage) }
server := fs.String("server", envDefault("PREDICTOR_SERVER", "http://localhost:8080"), "predictor server URL")
if err := fs.Parse(os.Args[1:]); err != nil {
os.Exit(2)
}
args := fs.Args()
if len(args) == 0 {
fs.Usage()
os.Exit(2)
}
c := &client{base: strings.TrimRight(*server, "/"), http: &http.Client{Timeout: 30 * time.Second}}
if err := dispatch(c, args); err != nil {
fmt.Fprintln(os.Stderr, "error:", err)
os.Exit(1)
}
}
func envDefault(name, fallback string) string {
if v := os.Getenv(name); v != "" {
return v
}
return fallback
}
func dispatch(c *client, args []string) error {
switch args[0] {
case "ready":
return c.ready()
case "predict":
return c.predict(args[1:])
case "datasets":
if len(args) < 2 {
return fmt.Errorf("usage: datasets {list|download|delete}")
}
switch args[1] {
case "list":
return c.datasetsList()
case "download":
return c.datasetsDownload(args[2:])
case "delete":
if len(args) < 3 {
return fmt.Errorf("usage: datasets delete <epoch>")
}
return c.datasetsDelete(args[2])
}
case "jobs":
if len(args) < 2 {
return fmt.Errorf("usage: jobs {list|get|cancel}")
}
switch args[1] {
case "list":
return c.jobsList()
case "get":
if len(args) < 3 {
return fmt.Errorf("usage: jobs get <id>")
}
return c.jobsGet(args[2])
case "cancel":
if len(args) < 3 {
return fmt.Errorf("usage: jobs cancel <id>")
}
return c.jobsCancel(args[2])
}
}
return fmt.Errorf("unknown command %q", args[0])
}
type client struct {
base string
http *http.Client
}
func (c *client) ready() error {
return c.getPrint("/ready")
}
func (c *client) predict(kv []string) error {
q := url.Values{}
for _, p := range kv {
idx := strings.IndexByte(p, '=')
if idx <= 0 {
return fmt.Errorf("expected key=value, got %q", p)
}
q.Set(p[:idx], p[idx+1:])
}
return c.getPrint("/api/v1/prediction?" + q.Encode())
}
func (c *client) datasetsList() error {
return c.getPrint("/api/v1/admin/datasets")
}
func (c *client) datasetsDownload(args []string) error {
fs := flag.NewFlagSet("datasets download", flag.ContinueOnError)
latest := fs.Bool("latest", false, "download the latest available run")
epoch := fs.String("epoch", "", "RFC3339 epoch to download")
if err := fs.Parse(args); err != nil {
return err
}
body := map[string]any{}
if *latest {
body["latest"] = true
}
if *epoch != "" {
body["epoch"] = *epoch
}
return c.postPrint("/api/v1/admin/datasets", body)
}
func (c *client) datasetsDelete(epoch string) error {
return c.deletePrint("/api/v1/admin/datasets/" + url.PathEscape(epoch))
}
func (c *client) jobsList() error { return c.getPrint("/api/v1/admin/jobs") }
func (c *client) jobsGet(id string) error {
return c.getPrint("/api/v1/admin/jobs/" + url.PathEscape(id))
}
func (c *client) jobsCancel(id string) error {
return c.deletePrint("/api/v1/admin/jobs/" + url.PathEscape(id))
}
func (c *client) getPrint(path string) error {
resp, err := c.http.Get(c.base + path)
if err != nil {
return err
}
return printResp(resp)
}
func (c *client) postPrint(path string, body any) error {
buf, err := json.Marshal(body)
if err != nil {
return err
}
resp, err := c.http.Post(c.base+path, "application/json", bytes.NewReader(buf))
if err != nil {
return err
}
return printResp(resp)
}
func (c *client) deletePrint(path string) error {
req, err := http.NewRequest(http.MethodDelete, c.base+path, nil)
if err != nil {
return err
}
resp, err := c.http.Do(req)
if err != nil {
return err
}
return printResp(resp)
}
func printResp(resp *http.Response) error {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode >= 400 {
return fmt.Errorf("HTTP %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
// Pretty-print JSON when possible; raw bytes otherwise.
if strings.Contains(resp.Header.Get("Content-Type"), "json") && len(body) > 0 {
var any any
if err := json.Unmarshal(body, &any); err == nil {
pretty, _ := json.MarshalIndent(any, "", " ")
fmt.Println(string(pretty))
return nil
}
}
if len(body) > 0 {
fmt.Println(strings.TrimSpace(string(body)))
}
return nil
}

258
cmd/predictor/main.go Normal file
View file

@ -0,0 +1,258 @@
// Command predictor is the stratoflights-predictor HTTP server.
//
// It wires the configuration, dataset manager, scheduler, and API layer
// into a single process and exits cleanly on SIGINT/SIGTERM.
package main
import (
"context"
"errors"
"fmt"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"github.com/go-co-op/gocron"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
"predictor-refactored/internal/api"
"predictor-refactored/internal/config"
"predictor-refactored/internal/datasets"
"predictor-refactored/internal/datasets/gefs"
"predictor-refactored/internal/datasets/gfs"
"predictor-refactored/internal/elevation"
"predictor-refactored/internal/metrics"
wgfs "predictor-refactored/internal/weather/gfs"
"predictor-refactored/internal/windviz"
)
// Build metadata, injected via -ldflags at build time (see Dockerfile).
var (
version = "dev"
revision = "unknown"
)
func main() {
// `predictor -healthcheck` probes the local /health endpoint and exits
// 0/1. The container HEALTHCHECK uses it so the (distroless) image needs
// no shell or curl.
for _, a := range os.Args[1:] {
if a == "-healthcheck" || a == "--healthcheck" {
os.Exit(healthcheck())
}
}
if err := run(os.Args[1:]); err != nil {
fmt.Fprintln(os.Stderr, "fatal:", err)
os.Exit(1)
}
}
// healthcheck performs a liveness probe against the local server. It resolves
// the port through the same config loader as the server, so the probe always
// matches the bind port regardless of how it was set (flag, env, or file).
func healthcheck() int {
port := 8080
if cfg, err := config.Load(withoutHealthcheckFlag(os.Args[1:])); err == nil {
port = cfg.HTTP.Port
}
client := &http.Client{Timeout: 3 * time.Second}
resp, err := client.Get(fmt.Sprintf("http://127.0.0.1:%d/health", port))
if err != nil {
fmt.Fprintln(os.Stderr, "healthcheck:", err)
return 1
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
fmt.Fprintln(os.Stderr, "healthcheck: status", resp.StatusCode)
return 1
}
return 0
}
// withoutHealthcheckFlag drops the -healthcheck flag so the remaining args
// parse cleanly through config.Load (which does not define it).
func withoutHealthcheckFlag(args []string) []string {
out := make([]string, 0, len(args))
for _, a := range args {
if a == "-healthcheck" || a == "--healthcheck" {
continue
}
out = append(out, a)
}
return out
}
func run(args []string) error {
cfg, err := config.Load(args)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
log, err := newLogger(cfg.Log.Level)
if err != nil {
return fmt.Errorf("init logger: %w", err)
}
defer log.Sync()
log.Info("starting stratoflights-predictor",
zap.String("version", version),
zap.String("revision", revision))
log.Info("configuration loaded",
zap.Int("port", cfg.HTTP.Port),
zap.String("data_dir", cfg.Data.Dir),
zap.String("source", cfg.Data.Source),
zap.Int("download_parallel", cfg.Download.Parallel),
zap.Duration("update_interval", cfg.Download.UpdateInterval),
zap.Duration("freshness_ttl", cfg.Download.FreshnessTTL),
zap.Bool("metrics_enabled", cfg.Metrics.Enabled),
)
store, err := datasets.NewLocalStore(cfg.Data.Dir, cfg.Data.Source)
if err != nil {
return fmt.Errorf("init store: %w", err)
}
variant, err := wgfs.VariantByID(cfg.Data.Source)
if err != nil {
return fmt.Errorf("unsupported source %q: %w", cfg.Data.Source, err)
}
var source datasets.Source
switch variant.Family {
case wgfs.FamilyGFS:
s := gfs.NewSource(variant, log)
s.Parallel = cfg.Download.Parallel
source = s
case wgfs.FamilyGEFS:
s := gefs.NewSource(variant, log)
s.Parallel = cfg.Download.Parallel
source = s
default:
return fmt.Errorf("unsupported family for %q", cfg.Data.Source)
}
var throttle datasets.Throttle
if cfg.Download.BandwidthBytesPerSecond > 0 {
throttle = datasets.NewTokenBucket(cfg.Download.BandwidthBytesPerSecond)
}
// Metrics (optional).
var sink metrics.Sink = metrics.Noop()
var metricsHandler http.Handler
if cfg.Metrics.Enabled {
prom := metrics.NewProm()
sink = prom
metricsHandler = prom
}
mgr := datasets.New(source, store, throttle, log)
defer mgr.Close()
// Optional elevation dataset. Missing or unreadable elevation is logged
// but non-fatal; descent terminates at sea level instead.
var elev *elevation.Dataset
if cfg.Data.ElevationPath != "" {
if d, err := elevation.Open(cfg.Data.ElevationPath); err == nil {
elev = d
log.Info("elevation dataset loaded", zap.String("path", cfg.Data.ElevationPath))
defer elev.Close()
} else {
log.Warn("elevation dataset not available, using sea-level termination",
zap.String("path", cfg.Data.ElevationPath),
zap.Error(err))
}
}
// Kick off the initial refresh in the background so the server can start
// answering /ready immediately.
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
defer cancel()
if _, err := mgr.Refresh(ctx, cfg.Download.FreshnessTTL); err != nil {
log.Error("initial dataset refresh failed", zap.Error(err))
}
if a := mgr.Active(); a != nil {
sink.ActiveEpoch(a.Epoch())
}
}()
scheduler := gocron.NewScheduler(time.UTC)
scheduler.Every(cfg.Download.UpdateInterval).Do(func() {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
defer cancel()
log.Info("scheduled dataset refresh starting")
if _, err := mgr.Refresh(ctx, cfg.Download.FreshnessTTL); err != nil {
log.Error("scheduled dataset refresh failed", zap.Error(err))
}
if a := mgr.Active(); a != nil {
sink.ActiveEpoch(a.Epoch())
}
})
scheduler.StartAsync()
defer scheduler.Stop()
var windCache *windviz.Cache
if cfg.Wind.Enabled {
windCache = windviz.NewCache(cfg.Wind.CacheSize, cfg.Wind.CacheTTL)
}
server, err := api.New(cfg.HTTP.Port, api.Deps{
Manager: mgr,
Elevation: elev,
Metrics: sink,
MetricsHandler: metricsHandler,
MetricsPath: cfg.Metrics.Path,
EnableWind: cfg.Wind.Enabled,
WindCache: windCache,
AsyncWorkers: cfg.HTTP.AsyncWorkers,
AsyncQueueSize: cfg.HTTP.AsyncQueueSize,
AsyncResultTTL: cfg.HTTP.AsyncResultTTL,
Log: log,
})
if err != nil {
return fmt.Errorf("init server: %w", err)
}
defer server.Close()
// Graceful shutdown
ctx, cancel := signalContext()
defer cancel()
log.Info("service started")
if err := server.Run(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("http server: %w", err)
}
log.Info("service stopped")
return nil
}
func signalContext() (context.Context, context.CancelFunc) {
ctx, cancel := context.WithCancel(context.Background())
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigCh
cancel()
}()
return ctx, cancel
}
func newLogger(level string) (*zap.Logger, error) {
cfg := zap.NewProductionConfig()
switch level {
case "debug":
cfg.Level = zap.NewAtomicLevelAt(zapcore.DebugLevel)
case "info":
cfg.Level = zap.NewAtomicLevelAt(zapcore.InfoLevel)
case "warn":
cfg.Level = zap.NewAtomicLevelAt(zapcore.WarnLevel)
case "error":
cfg.Level = zap.NewAtomicLevelAt(zapcore.ErrorLevel)
default:
cfg.Level = zap.NewAtomicLevelAt(zapcore.InfoLevel)
}
return cfg.Build()
}

11
deploy/prometheus.yml Normal file
View file

@ -0,0 +1,11 @@
# Minimal Prometheus config for the staging compose stack. In production a
# central Prometheus scrapes the predictor via Docker Swarm service discovery
# (see DEPLOYMENT.md); this file just proves the metrics pipeline locally.
global:
scrape_interval: 15s
scrape_configs:
- job_name: predictor
metrics_path: /metrics
static_configs:
- targets: ["predictor:8080"]

41
deploy/swarmpit-deploy.sh Executable file
View file

@ -0,0 +1,41 @@
#!/usr/bin/env sh
# Deploy (or update) the predictor stack to a Docker Swarm via the Swarmpit
# REST API, then trigger a redeploy so running services pick up the new image.
#
# Required env: SWARMPIT_URL, SWARMPIT_TOKEN, STACK_NAME, TAG
# Optional env: CA_CERTIFICATES (PEM bundle for a private Swarmpit TLS CA)
set -eu
: "${SWARMPIT_URL:?SWARMPIT_URL is required}"
: "${SWARMPIT_TOKEN:?SWARMPIT_TOKEN is required}"
: "${STACK_NAME:?STACK_NAME is required}"
TAG="${TAG:-latest}"
# Pin the image tag in the compose we send (replace the ${TAG:-latest} default
# with the concrete tag) so the exact built image is what gets deployed.
sed "s|:\${TAG:-latest}|:${TAG}|g" docker-compose.swarm.yml > /tmp/stack.yml
CA_OPT=""
if [ -n "${CA_CERTIFICATES:-}" ]; then
echo "${CA_CERTIFICATES}" > /tmp/swarmpit-ca.crt
CA_OPT="--cacert /tmp/swarmpit-ca.crt"
fi
compose_json=$(jq -Rs . < /tmp/stack.yml)
jq -n --arg name "${STACK_NAME}" --argjson compose "${compose_json}" \
'{name: $name, spec: {compose: $compose}}' > /tmp/swarmpit-payload.json
echo "Deploying stack '${STACK_NAME}' (tag ${TAG}) to ${SWARMPIT_URL}"
curl -fsS -X POST "${SWARMPIT_URL}/api/stacks/${STACK_NAME}" \
-H "authorization: Bearer ${SWARMPIT_TOKEN}" \
-H "Content-Type: application/json" \
-d @/tmp/swarmpit-payload.json \
--max-time 60 ${CA_OPT}
echo "Triggering redeploy"
curl -fsS -X POST "${SWARMPIT_URL}/api/stacks/${STACK_NAME}/redeploy" \
-H "authorization: Bearer ${SWARMPIT_TOKEN}" \
--max-time 60 ${CA_OPT} || echo "redeploy trigger failed; services may still roll forward via autoredeploy"
rm -f /tmp/stack.yml /tmp/swarmpit-payload.json /tmp/swarmpit-ca.crt
echo "Done."

View file

@ -0,0 +1,47 @@
# Staging: resembles production on a single host — all features enabled
# (metrics, wind visualization, async predictions) plus a bundled Prometheus
# so the metrics pipeline can be exercised end to end. Runs non-root like prod.
#
# docker compose -f docker-compose.staging.yml up --build
# curl localhost:8080/api/v1/admin/status
# open http://localhost:9090 (Prometheus, predictor target should be UP)
services:
init-perms:
image: busybox:1.36
command: ["sh", "-c", "mkdir -p /data && chown -R 65532:65532 /data"]
volumes:
- predictor-data:/data
predictor:
build:
context: .
args:
VERSION: staging
REVISION: staging
image: stratoflights-predictor:staging
depends_on:
init-perms:
condition: service_completed_successfully
ports:
- "8080:8080"
environment:
PREDICTOR_DATA_DIR: /data
PREDICTOR_METRICS_ENABLED: "true"
PREDICTOR_METRICS_PATH: /metrics
PREDICTOR_LOG_LEVEL: info
PREDICTOR_DOWNLOAD_PARALLEL: "16"
volumes:
- predictor-data:/data
# - ./elevation:/srv/ruaumoko-dataset:ro
prometheus:
image: prom/prometheus:v2.54.1
depends_on:
- predictor
ports:
- "9090:9090"
volumes:
- ./deploy/prometheus.yml:/etc/prometheus/prometheus.yml:ro
volumes:
predictor-data:

98
docker-compose.swarm.yml Normal file
View file

@ -0,0 +1,98 @@
version: "3.8"
# Production Docker Swarm stack for stratoflights-predictor.
#
# Deploy: TAG=v1.0.0 docker stack deploy -c docker-compose.swarm.yml --with-registry-auth predictor
# (or import via Swarmpit; the CI pipeline deploys it through the Swarmpit API)
#
# Storage & placement (see DEPLOYMENT.md):
# * The wind dataset (~8.9 GiB) lives on NODE-LOCAL disk — never NFS. To keep
# the number of copies bounded, the service is pinned to nodes labelled
# `predictor.data=true`; label at most two such nodes. Each carries one copy.
# * Replicas are spread one-per-node by default (redundancy + load balancing);
# scaling to multiple replicas per node is safe because they share the
# node-local volume and coordinate downloads via an flock (no duplicate fetch).
#
# The predictor is an internal backend: it has no public Traefik router. The
# Django API gateway and Prometheus reach it over the shared `stratoflights-net`
# overlay by the alias `predictor`.
services:
predictor:
image: git.intra.yksa.space/web/predictor:${TAG:-latest}
networks:
stratoflights-net:
aliases:
- predictor
environment:
PREDICTOR_DATA_DIR: /data
PREDICTOR_ELEVATION_DATASET: /srv/ruaumoko-dataset
PREDICTOR_SOURCE: ${PREDICTOR_SOURCE:-gfs-0p50-3h}
PREDICTOR_DOWNLOAD_PARALLEL: ${PREDICTOR_DOWNLOAD_PARALLEL:-16}
PREDICTOR_UPDATE_INTERVAL: 6h
PREDICTOR_DATASET_TTL: 48h
PREDICTOR_METRICS_ENABLED: "true"
PREDICTOR_METRICS_PATH: /metrics
PREDICTOR_LOG_LEVEL: info
volumes:
# Node-local storage. Provision these directories on each labelled node
# (chown to 65532:65532 — see DEPLOYMENT.md). NOT a shared/NFS volume.
- type: bind
source: /srv/predictor/data
target: /data
- type: bind
source: /srv/predictor/elevation
target: /srv/ruaumoko-dataset
read_only: true
healthcheck:
test: ["CMD", "/predictor", "-healthcheck"]
interval: 30s
timeout: 5s
retries: 3
start_period: 120s
logging:
driver: json-file
options:
max-size: "10m"
max-file: "3"
deploy:
mode: replicated
replicas: 2
placement:
max_replicas_per_node: 2
constraints:
- node.labels.predictor.data == true
preferences:
# Spread across the labelled nodes so the two default replicas land
# on different hosts (redundancy across both dataset copies).
- spread: node.labels.predictor.data
update_config:
parallelism: 1
delay: 15s
order: start-first
failure_action: rollback
rollback_config:
parallelism: 1
order: stop-first
restart_policy:
condition: on-failure
delay: 5s
max_attempts: 3
resources:
limits:
memory: 3072M
reservations:
memory: 512M
labels:
# Prometheus Swarm service-discovery hints (adjust to your SD relabel rules).
- "prometheus.scrape=true"
- "prometheus.port=8080"
- "prometheus.path=/metrics"
# Let Swarmpit auto-redeploy when a new :latest (or pinned TAG) is pushed.
- "swarmpit.service.deployment.autoredeploy=true"
networks:
# Shared overlay also joined by the API gateway and Prometheus.
# Create once: docker network create -d overlay --attachable stratoflights-net
stratoflights-net:
external: true

39
docker-compose.yml Normal file
View file

@ -0,0 +1,39 @@
# Local development: a single predictor instance, metrics off.
#
# docker compose up --build
# curl localhost:8080/ready
#
# First start downloads the latest GFS 0.5° run (~8.9 GiB) into the named
# volume; subsequent starts reuse it. The init service chowns the volume so
# the non-root image (uid 65532) can write to it.
services:
init-perms:
image: busybox:1.36
command: ["sh", "-c", "mkdir -p /data && chown -R 65532:65532 /data"]
volumes:
- predictor-data:/data
predictor:
build:
context: .
args:
VERSION: dev
REVISION: local
image: stratoflights-predictor:dev
depends_on:
init-perms:
condition: service_completed_successfully
ports:
- "8080:8080"
environment:
PREDICTOR_DATA_DIR: /data
PREDICTOR_METRICS_ENABLED: "false"
PREDICTOR_LOG_LEVEL: debug
# Mount and point at an elevation dataset for ground-level descent:
# PREDICTOR_ELEVATION_DATASET: /srv/ruaumoko-dataset
volumes:
- predictor-data:/data
# - ./elevation:/srv/ruaumoko-dataset:ro
volumes:
predictor-data:

365
docs/numerics.tex Normal file
View file

@ -0,0 +1,365 @@
\documentclass[a4paper,11pt]{article}
\usepackage[margin=1in]{geometry}
\usepackage{amsmath, amssymb}
\usepackage{algorithm, algpseudocode}
\usepackage{hyperref}
\title{stratoflights-predictor: Mathematical Reference}
\author{stratoflights-predictor}
\date{}
\begin{document}
\maketitle
\noindent This document is the end-to-end mathematical specification of the
trajectory calculation performed by stratoflights-predictor. It is meant
to be detailed enough to permit hand verification of the numerical
output. Section~\ref{sec:numerics} covers the small numerics library
(\verb|internal/numerics|); the remaining sections describe the engine
(\verb|internal/engine|) and the data plane.
\tableofcontents
% =========================================================================
\section{State vector and equations of motion}
\label{sec:state}
\paragraph{State vector.} The balloon state is the four-tuple
\[
\mathbf{s}(t) \;=\; \bigl(t,\; \varphi(t),\; \lambda(t),\; h(t)\bigr),
\]
where $t$ is UNIX seconds, $\varphi \in [-90, 90]$ is geographic latitude
in degrees, $\lambda \in [0, 360)$ is geographic longitude in degrees,
and $h$ is altitude above mean sea level in metres. Inside the
implementation, the spatial part $(\varphi, \lambda, h)$ is the
\verb|engine.State| struct; $t$ is tracked separately by the integrator.
\paragraph{Equations of motion.} The time derivative of state is the
direction-agnostic vector
\[
\dot{\mathbf{s}}(t) \;=\; \bigl(\dot \varphi,\; \dot \lambda,\; \dot h\bigr)
\;=\; \sum_{m \in \text{Models}(t)} \mathbf{F}_m(t, \mathbf{s}),
\]
i.e.\ the sum of the active stage's models evaluated at the current
state. The supported per-model contributions are:
\paragraph{Constant rate.} A purely vertical model with no horizontal
component, used for the standard balloon ascent profile:
\[
\mathbf{F}_{\text{const}}(t, \mathbf{s}) = (0, 0, r), \qquad r \in \mathbb{R}.
\]
Positive $r$ is upward; a negative $r$ may be combined with reverse-time
integration to model an ascent backwards from a known apex.
\paragraph{Parachute descent.} The vertical contribution under a constant
drag coefficient with the NASA atmosphere model density $\rho(h)$:
\[
\mathbf{F}_{\text{drag}}(t, \mathbf{s}) = \Bigl(0, 0, -\frac{k}{\sqrt{\rho(h)}}\Bigr),
\qquad k = r_0 \cdot 1{.}1045,
\]
where $r_0$ is the descent rate at sea level. Density is computed
piecewise from the layered model described in
\href{https://www.grc.nasa.gov/WWW/K-12/airplane/atmosmet.html}{NASA's
atmosphere page}.
\paragraph{Piecewise rate.} A schedule $\{(\tau_i, r_i)\}_{i=1}^N$
parameterised in either absolute UNIX time, or seconds since
profile start, or seconds since the propagator's own start. Resolution
happens lazily through the \verb|Propagator.BuildModel| hook so the same
spec can be reused across profiles with different launch times.
The contribution at time $t$ is
\[
\mathbf{F}_{\text{pwc}}(t, \mathbf{s}) = (0, 0, r_{i^\star}),
\qquad i^\star = \min\{i : \tau_i > t\}.
\]
\paragraph{Wind transport.} The horizontal contribution from sampling the
loaded wind field $W$:
\[
\mathbf{F}_{\text{wind}}(t, \mathbf{s}) = \Bigl(
\frac{180}{\pi}\,\frac{v}{R + h},\;\;
\frac{180}{\pi}\,\frac{u}{(R + h)\cos\bigl(\varphi\,\pi/180\bigr)},\;\;
0
\Bigr),
\]
where $(u, v) = W(t, \varphi, \lambda, h)$ are the eastward and northward
wind components in metres per second, and $R = 6{,}371{,}009$~m is the
spherical Earth radius. The implementation lives in
\verb|engine.WindTransport| (\verb|engine/models.go|).
\paragraph{Coordinate system.} The model is a spherical Earth in
plate-carrée (latitude/longitude/altitude) coordinates. This matches the
reference Tawhiri predictor exactly and is necessary for bit-identical
back-to-back testing. A WGS84/ECEF variant is planned but deferred: it
would require converting U/V wind components from the GFS sphere model
to the ellipsoid, which is not a trivial coordinate transform.
% =========================================================================
\section{Profiles and propagators}
\label{sec:profile}
\paragraph{Propagator.} A propagator owns one Model and a list of
Constraints; it produces a sequence of trajectory points via classical
Runge--Kutta--4 integration with step $\Delta t$ (positive for forward,
negative for reverse propagation):
\[
\Pi : (t_0, \mathbf{s}_0) \;\longmapsto\; \bigl[(t_k, \mathbf{s}_k)\bigr]_{k=0}^{K}.
\]
The sequence ends at the first $k$ where any constraint is violated;
the violation point is refined by binary search (see
\S\ref{sec:numerics}).
\paragraph{Profile.} A profile is an ordered chain of propagators
$[\Pi_1, \Pi_2, \ldots, \Pi_N]$. Stage $i$ starts where stage $i-1$
ended; the time direction (sign of $\Delta t$) is shared.
\paragraph{Constraint actions.} When a constraint $c$ is violated at the
refined point $(t^\star, \mathbf{s}^\star)$, $c.\text{Action}$ controls
the dispatch:
\begin{itemize}
\item \texttt{stop} — the profile ends at $(t^\star, \mathbf{s}^\star)$.
\item \texttt{fallback} — the current propagator hands off to its
\texttt{Fallback} propagator (chains supported).
\item \texttt{clip} — the violated coordinate is clipped to the
constraint's boundary and integration continues. Useful for soft
constraints such as ``hold altitude above 500~m''.
\end{itemize}
Constraints fire on full RK4 steps only, never on intermediate
sub-evaluations. This matches the reference Tawhiri behaviour
bit-for-bit.
\paragraph{Reverse propagation.} A profile with \verb|Direction = Reverse|
runs every propagator with $\Delta t = -\Delta t$. Models are
direction-agnostic: their derivative formulas above hold unchanged. The
typical use is to start from a known landing point and recover the
launch position by integrating backwards in time.
% =========================================================================
\section{Constraint geometry}
\label{sec:constraints}
The engine ships four constraint primitives.
\paragraph{Scalar comparison: altitude.}
\(
c_{\text{alt}}(t, \mathbf{s}) \;=\; h \,\mathrel{\bigotimes}\, h_0,
\)
where $\bigotimes \in \{<, \le, >, \ge, =\}$ is the configured operator
and $h_0$ is the limit (metres). The implementation is
\verb|engine.Altitude| (\verb|engine/constraints.go|).
\paragraph{Scalar comparison: time.} Same shape as altitude, but acting
on $t$ in UNIX seconds. Implementation: \verb|engine.Time|.
\paragraph{Terrain contact.}
\(
c_{\text{terr}}(t, \mathbf{s}) = \bigl(z(\varphi, \lambda) > h\bigr),
\)
with $z$ provided by the ruaumoko-compatible elevation dataset.
\paragraph{Polygon.} For a polygon $P$ with vertices
$(\varphi_i, \lambda_i)_{i=1}^N$ and mode $\mu \in
\{\text{inside}, \text{outside}\}$, the constraint is
$c_{\text{poly}}(\mathbf{s}) = \bigl(\mathbf{s} \in P\bigr) \oplus
[\mu = \text{outside}]$. Containment is tested by ray casting in
plate-carrée after normalising every longitude to within 180\textdegree{}
of the first vertex; this handles antimeridian-crossing edges so long
as the polygon spans no more than 180\textdegree{} in longitude.
% =========================================================================
\section{Numerics library}
\label{sec:numerics}
The numerics package (\verb|internal/numerics|) provides four primitives:
regular-grid bracketing, multilinear interpolation, monotone bisection,
and classical RK4 with termination-point refinement.
\subsection{Regular-grid bracketing}
\paragraph{Definition.} An \emph{axis} is the regularly-spaced sequence
$x_i = \ell + i \cdot s$ for $i = 0, 1, \ldots, N - 1$, parameterised by
the left edge $\ell$, the step $s > 0$, and the point count $N$.
Given a query $v$, the \emph{bracket} is the pair $(i_0, i_1)$ with
$x_{i_0} \le v < x_{i_1}$ and the dimensionless position
\[
f = \frac{v - x_{i_0}}{s} \in [0, 1).
\]
Implemented as \verb|Axis.Locate| in \verb|internal/numerics/grid.go|.
\paragraph{Wrapping axes.} For periodic axes (e.g.\ longitude), the
sequence is extended by the convention $x_N = x_0$ so a value approaching
$x_N$ from below brackets $(N{-}1, 0)$ with fraction
$f = (v - x_{N-1})/s$.
\paragraph{Worked example.} Latitude axis with $\ell = -90$, $s = 0{.}5$,
$N = 361$. Query $v = -89{.}75$ yields $p = 0{.}5$, so $i_0 = 0$,
$i_1 = 1$, $f = 0{.}5$.
\subsection{Multilinear interpolation}
\paragraph{Definition.} For a scalar field $u$ defined at the grid nodes
of three axes, the trilinear interpolant at brackets
$b_a, b_b, b_c$ is
\[
\tilde u = \sum_{i, j, k \in \{0, 1\}} w_{a,i} \, w_{b,j} \, w_{c,k}
\; u\bigl(b_a^i, b_b^j, b_c^k\bigr),
\]
where $w_{\bullet, 0} = 1 - f_\bullet$ and $w_{\bullet, 1} = f_\bullet$.
The corner terms are accumulated in the order
$(0,0,0), (0,0,1), \ldots, (1,1,1)$, matching the reference Cython
implementation so that double-precision results agree byte for byte.
\paragraph{Linear exactness.} For any affine field
$u(i, j, k) = \alpha i + \beta j + \gamma k + \delta$, the formula returns
$\alpha p_a + \beta p_b + \gamma p_c + \delta$ exactly (modulo
floating-point rounding), where $p_\bullet = b_\bullet^0 + f_\bullet$.
\subsection{Monotone bisection}
For an integer-indexed monotone non-decreasing sequence
$f : \{i_{\min}, \ldots, i_{\max}\} \to \mathbb{R}$ and a target $t$,
\verb|Bisect| returns the largest index $i^\star$ with $f(i^\star) < t$.
Used by the wind sampler to locate the pressure level bracketing the
query altitude. Time complexity:
$\mathcal{O}(\log(i_{\max} - i_{\min}))$.
\subsection{Classical RK4}
\paragraph{Definition.} For a state $y$, derivative $\dot y = f(t, y)$,
and step $\Delta t$, \verb|RK4Step| applies
\[
\begin{aligned}
k_1 &= f(t, y), \\
k_2 &= f\bigl(t + \tfrac{\Delta t}{2}, \; y + \tfrac{\Delta t}{2} k_1\bigr), \\
k_3 &= f\bigl(t + \tfrac{\Delta t}{2}, \; y + \tfrac{\Delta t}{2} k_2\bigr), \\
k_4 &= f\bigl(t + \Delta t, \; y + \Delta t \cdot k_3\bigr), \\
y(t + \Delta t) &= y + \tfrac{\Delta t}{6}\bigl(k_1 + 2 k_2 + 2 k_3 + k_4\bigr).
\end{aligned}
\]
Reverse-time integration uses $\Delta t < 0$ unchanged; the implementation
contains no branch on the sign of $\Delta t$. Domain-specific vector
arithmetic (longitude wrap) is injected via \verb|VecAdd|.
\subsection{Termination refinement}
After each integration step the propagator checks one or more
constraints. When a constraint reports a violation between $(t_1, y_1)$
(not violated) and $(t_2, y_2)$ (violated), \verb|RefineTrigger|
locates the crossing within tolerance $\tau \in (0, 1)$ by binary
search in the linear-interpolation parameter $\lambda \in [0, 1]$:
\begin{algorithm}[H]
\caption{RefineTrigger}\label{alg:refine}
\begin{algorithmic}[1]
\State $L \gets 0,\; R \gets 1$
\State $t_3 \gets t_2,\; y_3 \gets y_2$
\While{$R - L > \tau$}
\State $m \gets (L + R)/2$
\State $t_3 \gets (1 - m)\,t_1 + m\,t_2$
\State $y_3 \gets \mathrm{lerp}(y_1, y_2, m)$
\If{constraint violated at $(t_3, y_3)$}
\State $R \gets m$
\Else
\State $L \gets m$
\EndIf
\EndWhile
\State \Return $(t_3, y_3)$
\end{algorithmic}
\end{algorithm}
After $\lceil \log_2 \tau^{-1} \rceil$ iterations, $R - L \le \tau$.
With $\tau = 0{.}01$ and $\Delta t = 60$~s, the returned point is within
$0{.}6$~s of the true crossing in parameter space; the corresponding
altitude error is bounded by $0{.}6\,|\dot y|$, which for typical
balloon ascent and parachute descent rates is at most $\sim 3$~m.
The returned point is the \emph{last midpoint sampled} rather than
guaranteed to lie on the triggered side; this matches the reference
Tawhiri implementation byte for byte.
% =========================================================================
\section{Wind data pipeline}
\label{sec:winddata}
\paragraph{Data source.} NOAA GFS 0.5\textdegree{} (default) or 0.25\textdegree{}
forecasts, optionally subset by region or hour range. GEFS ensemble
runs are supported by selecting one of the 21 members; each member is a
separate dataset (\verb|DatasetID.Subset.Members = \{m\}|).
\paragraph{Cube layout.} A flat C-order row-major float32 array, shape
$(N_{\text{hours}}, N_{\text{levels}}, 3, N_{\text{lat}}, N_{\text{lng}})$,
where the variable axis is fixed to (HGT, UGRD, VGRD). Per-variant sizes
live in \verb|internal/weather/gfs/variant.go|.
\paragraph{Sampling.} Given a query $(t, \varphi, \lambda, h)$, the
sampler computes the time-in-hours offset
$\tau = (t - t_0)/3600$ from the dataset epoch $t_0$, brackets
$(\tau, \varphi, \lambda)$ on the three horizontal axes, then bisects
the pressure-level axis to find the largest level $\ell$ whose
trilinearly-interpolated HGT is below $h$. Wind components are
extracted via two more trilinear evaluations (at levels $\ell$ and
$\ell + 1$) and linearly interpolated in altitude:
\[
W(t, \varphi, \lambda, h) = \alpha \cdot W_\ell + (1 - \alpha) \cdot W_{\ell+1},
\qquad
\alpha = \frac{H_{\ell+1} - h}{H_{\ell+1} - H_\ell}.
\]
% =========================================================================
\section{Coverage and dataset selection}
\label{sec:coverage}
A loaded dataset $\mathcal{D}$ exposes its \emph{coverage}
$C_\mathcal{D} = (R_\mathcal{D}, [t_0, t_1])$ where $R_\mathcal{D}$ is a
geographic bounding box (possibly antimeridian-spanning) and
$[t_0, t_1]$ is the temporal extent. When more than one dataset is
loaded simultaneously, the predictor selects the first one whose
$C_\mathcal{D}$ contains the launch query. Regional / sub-range
datasets thus complement the global default.
% =========================================================================
\section{Deferrals and design notes}
\label{sec:deferrals}
\paragraph{Mass-aware drift.} The current model assumes the payload moves
horizontally at exactly the local wind velocity. A heavier payload
exhibits a velocity defect proportional to inertial coupling. A
plausible extension is the Stokes-style first-order lag model
\[
\dot{\mathbf{v}}_p = \frac{1}{\tau}\bigl(\mathbf{v}_{\text{wind}}(t,\mathbf{s}) - \mathbf{v}_p\bigr),
\]
introduced as an additional state variable $\mathbf{v}_p$ alongside the
existing $\mathbf{s}$. The Propagator interface already accepts
arbitrary State types via generics in numerics; the engine could lift
its State to $(\mathbf{s}, \mathbf{v}_p)$ for a future mass-aware
propagator without breaking the existing models.
\paragraph{Coordinate system upgrades.} Migrating to WGS84/ECEF would
remove the cosine factor in the horizontal wind transport equation and
make distances metric directly. GFS itself uses a spherical Earth; the
wind components are not directly portable. A clean implementation
provides a coordinate-system parameter on the profile request; for now,
the spherical model is used uniformly so that outputs remain bit
identical to the upstream Tawhiri.
\paragraph{Monte Carlo.} GEFS already provides 21 ensemble members per
epoch. A Monte Carlo prediction would sample $K$ trajectories per
request, each picking a (member, parameter perturbation) pair. The
recommended architecture is to keep the perturbation inside the
predictor (so the same wind sample can serve many members and any
piecewise rate noise is correlated with the wind step), exposed as a
\verb|POST /api/v1/montecarlo| endpoint that returns one job per
sample and aggregates outcomes.
% =========================================================================
\section{Implementation notes}
\label{sec:impl}
The numerics library is intentionally small (under 300 lines of Go) and
uses no allocations on the hot path. The generic \verb|RK4Step| and
\verb|RefineTrigger| compile to per-type specialisations under Go's
generics, so a future C or Rust port can mirror the implementation
verbatim without changing the call sites in the trajectory engine.
\end{document}

View file

@ -0,0 +1,51 @@
# Wind layer demo
A minimal browser client that renders the predictor's wind field as an
animated particle layer using [Leaflet](https://leafletjs.com/) and
[leaflet-velocity](https://github.com/onaci/leaflet-velocity).
The predictor's `GET /api/v1/wind/field` endpoint emits the
[wind-js-server](https://github.com/danwild/wind-js-server) "gfs.json" format
(a two-element `[U, V]` array of `{header, data}` records), which is exactly
what leaflet-velocity and [sakitam-fdd/wind-layer](https://github.com/sakitam-fdd/wind-layer)
consume — so no transformation is needed in the frontend.
## Running
Serve this directory and the predictor from the same origin (or set `API` in
`index.html` to the predictor's base URL and rely on the predictor's CORS
headers):
```bash
# Terminal 1: the predictor (must have a dataset loaded for real data)
./bin/predictor
# Terminal 2: serve the demo
cd examples/wind-demo && python3 -m http.server 8090
# open http://localhost:8090 (set API="http://localhost:8080" in index.html)
```
## API contract
`GET /api/v1/wind/field` query parameters (all optional):
| Param | Default | Meaning |
|---|---|---|
| `time` | dataset epoch | RFC3339 forecast time to sample |
| `altitude` | `0` | altitude in metres |
| `min_lat`,`max_lat`,`min_lng`,`max_lng` | global | bounding box (degrees) |
| `step` | `1.0` | grid resolution in degrees (min `0.25`) |
`GET /api/v1/wind/meta` returns the active dataset's source, epoch, suggested
altitudes, and bounding box so a client can populate its controls.
The full OpenAPI definition is served at `/openapi.yaml`, with a browsable
ReDoc rendering at `/docs`.
## Minimal fetch
```js
const res = await fetch("/api/v1/wind/field?altitude=10000&step=2");
const data = await res.json(); // [ {header, data}, {header, data} ]
L.velocityLayer({ data }).addTo(map);
```

View file

@ -0,0 +1,98 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>stratoflights-predictor — wind layer demo</title>
<!-- Leaflet -->
<link rel="stylesheet" href="https://unpkg.com/leaflet@1.9.4/dist/leaflet.css" />
<script src="https://unpkg.com/leaflet@1.9.4/dist/leaflet.js"></script>
<!-- leaflet-velocity: consumes the wind-js-server JSON this API emits -->
<link rel="stylesheet" href="https://unpkg.com/leaflet-velocity@1.7.0/dist/leaflet-velocity.css" />
<script src="https://unpkg.com/leaflet-velocity@1.7.0/dist/leaflet-velocity.js"></script>
<style>
html, body, #map { height: 100%; margin: 0; }
#controls {
position: absolute; z-index: 1000; top: 10px; right: 10px;
background: #fff; padding: 8px 10px; border-radius: 6px;
font: 13px sans-serif; box-shadow: 0 1px 4px rgba(0,0,0,.3);
}
#controls label { display: block; margin: 4px 0; }
</style>
</head>
<body>
<div id="map"></div>
<div id="controls">
<strong>Wind layer</strong>
<label>Altitude (m):
<input id="altitude" type="number" value="10000" step="1000" style="width:80px">
</label>
<label>Step (deg):
<input id="step" type="number" value="2" step="0.5" min="0.25" style="width:60px">
</label>
<button id="reload">Reload</button>
<div id="status"></div>
</div>
<script>
// Base URL of the predictor API. Same-origin by default.
const API = "";
const map = L.map("map").setView([30, 0], 2);
L.tileLayer("https://{s}.tile.openstreetmap.org/{z}/{x}/{y}.png", {
attribution: "&copy; OpenStreetMap",
}).addTo(map);
let velocityLayer = null;
// fetchWindField pulls the leaflet-velocity-compatible grid from the API.
//
// The endpoint returns a two-element array [uComponent, vComponent], each
// with a {header, data} object — exactly the gfs.json / wind-js-server
// shape leaflet-velocity and wind-layer expect.
async function fetchWindField({ altitude, step, time, bbox } = {}) {
const q = new URLSearchParams();
if (altitude != null) q.set("altitude", altitude);
if (step != null) q.set("step", step);
if (time) q.set("time", time);
if (bbox) {
q.set("min_lat", bbox.minLat); q.set("max_lat", bbox.maxLat);
q.set("min_lng", bbox.minLng); q.set("max_lng", bbox.maxLng);
}
const res = await fetch(`${API}/api/v1/wind/field?` + q.toString());
if (!res.ok) throw new Error(`HTTP ${res.status}: ${await res.text()}`);
return res.json();
}
async function reload() {
const status = document.getElementById("status");
const altitude = Number(document.getElementById("altitude").value);
const step = Number(document.getElementById("step").value);
status.textContent = "loading…";
try {
const data = await fetchWindField({ altitude, step });
if (velocityLayer) map.removeLayer(velocityLayer);
velocityLayer = L.velocityLayer({
displayValues: true,
displayOptions: {
velocityType: "Wind",
displayPosition: "bottomleft",
displayEmptyString: "No wind data",
},
data,
maxVelocity: 60,
}).addTo(map);
status.textContent = `loaded ${data[0].header.nx}×${data[0].header.ny} grid`;
} catch (err) {
status.textContent = "error: " + err.message;
console.error(err);
}
}
document.getElementById("reload").addEventListener("click", reload);
reload();
</script>
</body>
</html>

4
go.mod
View file

@ -7,6 +7,7 @@ require (
github.com/go-co-op/gocron v1.37.0
github.com/go-faster/errors v0.7.1
github.com/go-faster/jx v1.2.0
github.com/google/uuid v1.6.0
github.com/nilsmagnus/grib v1.2.8
github.com/ogen-go/ogen v1.20.2
go.opentelemetry.io/otel v1.42.0
@ -14,6 +15,7 @@ require (
go.opentelemetry.io/otel/trace v1.42.0
go.uber.org/zap v1.27.1
golang.org/x/sync v0.20.0
gopkg.in/yaml.v2 v2.4.0
)
require (
@ -24,7 +26,6 @@ require (
github.com/go-faster/yaml v0.4.6 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
@ -37,5 +38,4 @@ require (
golang.org/x/net v0.52.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.35.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
)

View file

@ -0,0 +1,279 @@
// Package async runs profile-driven predictions on a bounded worker pool and
// retains their results in memory for a configurable TTL. It is the engine
// behind the asynchronous prediction endpoints; the HTTP surface itself is
// the ogen-generated server in the parent package.
//
// The package is decoupled from the request/response wire types: a RunFunc is
// injected at construction, so this file imports only the generated API types
// it stores and returns.
package async
import (
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"predictor-refactored/internal/metrics"
apirest "predictor-refactored/pkg/rest"
)
// RunFunc executes one prediction synchronously.
type RunFunc func(req *apirest.PredictionV2Request) (*apirest.PredictionV2Response, error)
// Status is the lifecycle state of a prediction job.
type Status string
const (
StatusPending Status = "pending"
StatusRunning Status = "running"
StatusComplete Status = "complete"
StatusFailed Status = "failed"
StatusCancelled Status = "cancelled"
)
// JobInfo is a snapshot of one prediction job.
type JobInfo struct {
ID string
Status Status
CreatedAt time.Time
StartedAt *time.Time
CompletedAt *time.Time
Error string
Result *apirest.PredictionV2Response
}
type job struct {
id string
req *apirest.PredictionV2Request
createdAt time.Time
mu sync.Mutex
status Status
startedAt time.Time
completedAt time.Time
errStr string
result *apirest.PredictionV2Response
}
func (j *job) snapshot() JobInfo {
j.mu.Lock()
defer j.mu.Unlock()
info := JobInfo{
ID: j.id, Status: j.status, CreatedAt: j.createdAt,
Error: j.errStr, Result: j.result,
}
if !j.startedAt.IsZero() {
t := j.startedAt
info.StartedAt = &t
}
if !j.completedAt.IsZero() {
t := j.completedAt
info.CompletedAt = &t
}
return info
}
// Manager runs a fixed pool of workers and retains job results for a TTL.
type Manager struct {
run RunFunc
metrics metrics.Sink
log *zap.Logger
queue chan *job
ttl time.Duration
jobsMu sync.RWMutex
jobs map[string]*job
inflight atomic.Int64
closed chan struct{}
wg sync.WaitGroup
}
// Config controls Manager construction.
type Config struct {
Workers int // max concurrent executions
QueueSize int // pending-queue bound
ResultTTL time.Duration // retention of terminal jobs
}
// New constructs a Manager and starts its workers. run executes one
// prediction; sink and log may be nil.
func New(cfg Config, run RunFunc, sink metrics.Sink, log *zap.Logger) *Manager {
if cfg.Workers <= 0 {
cfg.Workers = 4
}
if cfg.QueueSize <= 0 {
cfg.QueueSize = 64
}
if cfg.ResultTTL <= 0 {
cfg.ResultTTL = time.Hour
}
if sink == nil {
sink = metrics.Noop()
}
if log == nil {
log = zap.NewNop()
}
m := &Manager{
run: run, metrics: sink, log: log,
queue: make(chan *job, cfg.QueueSize),
jobs: make(map[string]*job),
ttl: cfg.ResultTTL,
closed: make(chan struct{}),
}
for range cfg.Workers {
m.wg.Add(1)
go m.worker()
}
m.wg.Add(1)
go m.evictor()
return m
}
// Enqueue creates a job from req and returns its snapshot. The bool is false
// when the queue is full (the returned job is marked failed).
func (m *Manager) Enqueue(req *apirest.PredictionV2Request) (JobInfo, bool) {
j := &job{
id: uuid.New().String(),
req: req,
createdAt: time.Now().UTC(),
status: StatusPending,
}
m.jobsMu.Lock()
m.jobs[j.id] = j
m.jobsMu.Unlock()
select {
case m.queue <- j:
return j.snapshot(), true
default:
j.mu.Lock()
j.status = StatusFailed
j.errStr = "prediction queue full"
j.completedAt = time.Now().UTC()
j.mu.Unlock()
return j.snapshot(), false
}
}
// Get returns a job's snapshot.
func (m *Manager) Get(id string) (JobInfo, bool) {
m.jobsMu.RLock()
j, ok := m.jobs[id]
m.jobsMu.RUnlock()
if !ok {
return JobInfo{}, false
}
return j.snapshot(), true
}
// Cancel marks a still-queued job cancelled. Returns false when the job is
// unknown or already running/terminal — a running prediction cannot be
// interrupted (the worker would otherwise overwrite the cancelled status with
// its result), so callers get an honest "too late" rather than a 204 that the
// worker silently undoes.
func (m *Manager) Cancel(id string) bool {
m.jobsMu.RLock()
j, ok := m.jobs[id]
m.jobsMu.RUnlock()
if !ok {
return false
}
j.mu.Lock()
defer j.mu.Unlock()
if j.status != StatusPending {
return false
}
j.status = StatusCancelled
j.completedAt = time.Now().UTC()
return true
}
// Inflight returns the number of running jobs.
func (m *Manager) Inflight() int64 { return m.inflight.Load() }
// Close stops the workers and the evictor.
func (m *Manager) Close() {
close(m.closed)
close(m.queue)
m.wg.Wait()
}
func (m *Manager) worker() {
defer m.wg.Done()
for j := range m.queue {
j.mu.Lock()
cancelled := j.status == StatusCancelled
if !cancelled {
j.status = StatusRunning
j.startedAt = time.Now().UTC()
}
j.mu.Unlock()
if cancelled {
continue
}
m.execute(j)
}
}
// execute runs one job, recovering from a panic in the injected RunFunc so a
// single bad prediction can't leak the inflight counter or kill the worker.
func (m *Manager) execute(j *job) {
m.inflight.Add(1)
defer m.inflight.Add(-1)
resp, err := func() (resp *apirest.PredictionV2Response, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("prediction panicked: %v", r)
}
}()
return m.run(j.req)
}()
j.mu.Lock()
j.completedAt = time.Now().UTC()
if err != nil {
j.status = StatusFailed
j.errStr = err.Error()
} else {
j.status = StatusComplete
j.result = resp
}
dur := j.completedAt.Sub(j.startedAt)
j.mu.Unlock()
m.metrics.Prediction("async", dur, err)
}
func (m *Manager) evictor() {
defer m.wg.Done()
ticker := time.NewTicker(m.ttl / 4)
defer ticker.Stop()
for {
select {
case <-m.closed:
return
case <-ticker.C:
m.evictExpired()
}
}
}
func (m *Manager) evictExpired() {
now := time.Now().UTC()
m.jobsMu.Lock()
defer m.jobsMu.Unlock()
for id, j := range m.jobs {
j.mu.Lock()
expired := !j.completedAt.IsZero() && now.Sub(j.completedAt) > m.ttl
j.mu.Unlock()
if expired {
delete(m.jobs, id)
}
}
}

189
internal/api/datasets.go Normal file
View file

@ -0,0 +1,189 @@
package api
import (
"context"
"net/http"
"runtime"
"time"
"predictor-refactored/internal/datasets"
apirest "predictor-refactored/pkg/rest"
)
// ListDatasets implements GET /api/v1/admin/datasets.
func (h *Handler) ListDatasets(_ context.Context) (*apirest.DatasetList, error) {
stored, err := h.mgr.ListEpochs()
if err != nil {
return nil, apiError(http.StatusInternalServerError, err.Error())
}
loaded := make(map[string]datasets.LoadedDatasetInfo)
for _, ld := range h.mgr.LoadedDatasets() {
loaded[ld.ID.Filename()] = ld
}
out := &apirest.DatasetList{Source: h.mgr.Source(), Datasets: make([]apirest.DatasetEntry, 0, len(stored))}
for _, id := range stored {
entry := apirest.DatasetEntry{
Filename: id.Filename(),
Epoch: id.Epoch.UTC(),
}
if !id.Subset.IsGlobal() {
entry.Subset = apirest.NewOptSubsetSpec(subsetToAPI(id.Subset))
}
if ld, ok := loaded[id.Filename()]; ok {
entry.Loaded = true
entry.Coverage = apirest.NewOptCoverage(coverageToAPI(ld.Coverage))
}
out.Datasets = append(out.Datasets, entry)
}
return out, nil
}
// TriggerDatasetDownload implements POST /api/v1/admin/datasets.
func (h *Handler) TriggerDatasetDownload(ctx context.Context, req *apirest.DownloadRequest) (*apirest.DownloadAccepted, error) {
if req.Latest.Or(false) {
dctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
jobID, err := h.mgr.Refresh(dctx, 0)
if err != nil {
return nil, apiError(http.StatusInternalServerError, err.Error())
}
return &apirest.DownloadAccepted{JobID: jobID}, nil
}
epoch, ok := req.Epoch.Get()
if !ok {
return nil, apiError(http.StatusBadRequest, "specify either epoch or latest=true")
}
id := datasets.DatasetID{Epoch: epoch.UTC()}
if s, ok := req.Subset.Get(); ok {
id.Subset = subsetFromAPI(s)
}
return &apirest.DownloadAccepted{JobID: h.mgr.Download(id)}, nil
}
// DeleteDataset implements DELETE /api/v1/admin/datasets/{name}.
func (h *Handler) DeleteDataset(_ context.Context, params apirest.DeleteDatasetParams) error {
stored, err := h.mgr.ListEpochs()
if err != nil {
return apiError(http.StatusInternalServerError, err.Error())
}
for _, id := range stored {
if id.Filename() == params.Name {
if err := h.mgr.Remove(id); err != nil {
return apiError(http.StatusInternalServerError, err.Error())
}
return nil
}
}
return apiError(http.StatusNotFound, "dataset not found")
}
// ListDatasetJobs implements GET /api/v1/admin/jobs.
func (h *Handler) ListDatasetJobs(_ context.Context) ([]apirest.DownloadJob, error) {
jobs := h.mgr.ListJobs()
out := make([]apirest.DownloadJob, 0, len(jobs))
for _, j := range jobs {
out = append(out, downloadJobToAPI(j))
}
return out, nil
}
// GetDatasetJob implements GET /api/v1/admin/jobs/{id}.
func (h *Handler) GetDatasetJob(_ context.Context, params apirest.GetDatasetJobParams) (*apirest.DownloadJob, error) {
j, ok := h.mgr.GetJob(params.ID)
if !ok {
return nil, apiError(http.StatusNotFound, "job not found")
}
dto := downloadJobToAPI(j)
return &dto, nil
}
// CancelDatasetJob implements DELETE /api/v1/admin/jobs/{id}.
func (h *Handler) CancelDatasetJob(_ context.Context, params apirest.CancelDatasetJobParams) error {
if !h.mgr.CancelJob(params.ID) {
return apiError(http.StatusConflict, "job not found or already terminal")
}
return nil
}
// GetServiceStatus implements GET /api/v1/admin/status.
func (h *Handler) GetServiceStatus(_ context.Context) (*apirest.StatusResponse, error) {
jobs := h.mgr.ListJobs()
stored, _ := h.mgr.ListEpochs()
loaded := h.mgr.LoadedDatasets()
byStatus := apirest.StatusResponseJobsByStatus{}
for _, j := range jobs {
byStatus[string(j.Status)]++
}
var mem runtime.MemStats
runtime.ReadMemStats(&mem)
return &apirest.StatusResponse{
Source: h.mgr.Source(),
Uptime: time.Since(h.started).Round(time.Second).String(),
Goroutines: runtime.NumGoroutine(),
MemoryMB: int64(mem.Alloc / 1024 / 1024),
JobsByStatus: byStatus,
StoredDatasets: len(stored),
LoadedDatasets: len(loaded),
}, nil
}
// --- dataset mapping helpers ----------------------------------------------
func downloadJobToAPI(j datasets.JobInfo) apirest.DownloadJob {
dto := apirest.DownloadJob{
ID: j.ID,
Source: j.Source,
Dataset: j.Dataset.Filename(),
Epoch: j.Dataset.Epoch.UTC(),
Status: apirest.DownloadJobStatus(j.Status),
StartedAt: j.StartedAt.UTC(),
TotalUnits: j.Total,
DoneUnits: j.Done,
Bytes: j.Bytes,
}
if j.EndedAt != nil {
dto.EndedAt = apirest.NewOptDateTime(j.EndedAt.UTC())
}
if j.Err != "" {
dto.Error = apirest.NewOptString(j.Err)
}
return dto
}
func subsetToAPI(s datasets.SubsetSpec) apirest.SubsetSpec {
out := apirest.SubsetSpec{Members: s.Members}
if s.Region != nil {
out.Region = apirest.NewOptRegion(regionToAPI(*s.Region))
}
if s.HourRange != nil {
out.HourRange = apirest.NewOptHourRange(apirest.HourRange{MinHour: s.HourRange.MinHour, MaxHour: s.HourRange.MaxHour})
}
return out
}
func subsetFromAPI(s apirest.SubsetSpec) datasets.SubsetSpec {
out := datasets.SubsetSpec{Members: s.Members}
if r, ok := s.Region.Get(); ok {
out.Region = &datasets.Region{MinLat: r.MinLat, MaxLat: r.MaxLat, MinLng: r.MinLng, MaxLng: r.MaxLng}
}
if hr, ok := s.HourRange.Get(); ok {
out.HourRange = &datasets.HourRange{MinHour: hr.MinHour, MaxHour: hr.MaxHour}
}
return out
}
func regionToAPI(r datasets.Region) apirest.Region {
return apirest.Region{MinLat: r.MinLat, MaxLat: r.MaxLat, MinLng: r.MinLng, MaxLng: r.MaxLng}
}
func coverageToAPI(c datasets.Coverage) apirest.Coverage {
return apirest.Coverage{
Region: regionToAPI(c.Region),
StartTime: c.StartTime.UTC(),
EndTime: c.EndTime.UTC(),
}
}

48
internal/api/docs/docs.go Normal file
View file

@ -0,0 +1,48 @@
// Package docs serves the human-facing API documentation: the OpenAPI
// document and a ReDoc rendering of it. The spec is embedded in the binary
// (see package apispec) so the documentation needs no external files or a
// separate server.
package docs
import (
"net/http"
apispec "predictor-refactored/api"
)
// redocHTML renders the embedded spec with ReDoc loaded from a CDN.
const redocHTML = `<!DOCTYPE html>
<html>
<head>
<title>stratoflights-predictor API</title>
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1">
<style>body { margin: 0; padding: 0; }</style>
</head>
<body>
<redoc spec-url="/openapi.yaml"></redoc>
<script src="https://cdn.redoc.ly/redoc/latest/bundles/redoc.standalone.js"></script>
</body>
</html>`
// Handler serves the documentation endpoints.
type Handler struct{}
// New returns a docs Handler.
func New() *Handler { return &Handler{} }
// Register installs GET /docs and GET /openapi.yaml on mux.
func (h *Handler) Register(mux *http.ServeMux) {
mux.HandleFunc("GET /openapi.yaml", h.spec)
mux.HandleFunc("GET /docs", h.redoc)
}
func (h *Handler) spec(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/yaml")
_, _ = w.Write(apispec.Spec)
}
func (h *Handler) redoc(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
_, _ = w.Write([]byte(redocHTML))
}

70
internal/api/handler.go Normal file
View file

@ -0,0 +1,70 @@
package api
import (
"context"
"errors"
"net/http"
"time"
"go.uber.org/zap"
"predictor-refactored/internal/api/async"
"predictor-refactored/internal/datasets"
"predictor-refactored/internal/elevation"
"predictor-refactored/internal/engine"
"predictor-refactored/internal/metrics"
"predictor-refactored/internal/windviz"
apirest "predictor-refactored/pkg/rest"
)
// Handler implements the ogen-generated apirest.Handler interface for every
// operation in the OpenAPI spec. Operation methods are grouped by concern
// across prediction.go, datasets.go, and wind.go.
type Handler struct {
mgr *datasets.Manager
elev *elevation.Dataset
async *async.Manager
metrics metrics.Sink
cache *windviz.Cache
started time.Time
log *zap.Logger
}
var _ apirest.Handler = (*Handler)(nil)
// terrain returns the elevation dataset as an engine.TerrainProvider, or an
// untyped nil interface when no elevation dataset is loaded. Returning the
// concrete nil *elevation.Dataset directly would produce a non-nil interface
// wrapping a nil pointer, which then panics on first use — so the nil check
// must happen here, on the concrete type.
func (h *Handler) terrain() engine.TerrainProvider {
if h.elev == nil {
return nil
}
return h.elev
}
// NewError converts an error returned by a handler into the spec's default
// error response. Handlers return *apirest.DefaultErrorStatusCode (via the
// apiError helper) to control the status code; anything else is a 500.
func (h *Handler) NewError(_ context.Context, err error) *apirest.DefaultErrorStatusCode {
var coded *apirest.DefaultErrorStatusCode
if errors.As(err, &coded) {
return coded
}
h.log.Error("unhandled handler error", zap.Error(err))
return apiError(http.StatusInternalServerError, err.Error())
}
// apiError builds a coded error response carrying an HTTP status.
func apiError(status int, description string) *apirest.DefaultErrorStatusCode {
return &apirest.DefaultErrorStatusCode{
StatusCode: status,
Response: apirest.Error{
Error: apirest.ErrorError{
Type: http.StatusText(status),
Description: description,
},
},
}
}

217
internal/api/mapping.go Normal file
View file

@ -0,0 +1,217 @@
package api
import (
"fmt"
"time"
"predictor-refactored/internal/api/async"
"predictor-refactored/internal/engine"
apirest "predictor-refactored/pkg/rest"
)
// normalizeLng folds a longitude into [0, 360) for internal use.
func normalizeLng(lng float64) float64 {
if lng < 0 {
return lng + 360
}
return lng
}
// signedLng converts an internal [0, 360) longitude back to [-180, 180).
func signedLng(lng float64) float64 {
if lng > 180 {
return lng - 360
}
return lng
}
// buildProfile translates an API prediction request into an engine profile
// using the engine's model/constraint registry.
// maxProfileStages bounds the propagator chain length to keep a single
// request's work bounded.
const maxProfileStages = 32
func buildProfile(req *apirest.PredictionV2Request, deps engine.BuildDeps) (engine.Profile, error) {
if len(req.Profile) == 0 {
return engine.Profile{}, fmt.Errorf("profile must contain at least one stage")
}
if len(req.Profile) > maxProfileStages {
return engine.Profile{}, fmt.Errorf("profile has %d stages; maximum is %d", len(req.Profile), maxProfileStages)
}
step := 60.0
tol := 0.01
if o, ok := req.Options.Get(); ok {
step = o.StepSeconds.Or(step)
tol = o.Tolerance.Or(tol)
}
if step <= 0 || step > 3600 {
return engine.Profile{}, fmt.Errorf("options.step_seconds must be in (0, 3600], got %g", step)
}
if tol <= 0 || tol >= 1 {
return engine.Profile{}, fmt.Errorf("options.tolerance must be in (0, 1), got %g", tol)
}
dir := engine.Forward
if req.Direction.Or(apirest.PredictionV2RequestDirectionForward) == apirest.PredictionV2RequestDirectionReverse {
dir = engine.Reverse
}
props := make([]*engine.Propagator, len(req.Profile))
for i, stage := range req.Profile {
if stage.Name == "" {
return engine.Profile{}, fmt.Errorf("stage %d: name is required", i)
}
built, err := engine.BuildModel(toEngineModelSpec(stage.Model), deps)
if err != nil {
return engine.Profile{}, fmt.Errorf("stage %q model: %w", stage.Name, err)
}
constraints, err := toEngineConstraints(stage.Constraints, deps)
if err != nil {
return engine.Profile{}, fmt.Errorf("stage %q: %w", stage.Name, err)
}
props[i] = &engine.Propagator{
Name: stage.Name,
Step: step,
Model: built.Model,
BuildModel: built.Build,
Constraints: constraints,
Tolerance: tol,
}
}
for i, stage := range req.Profile {
idx, ok := stage.FallbackIndex.Get()
if !ok {
continue
}
if idx < 0 || idx >= len(props) {
return engine.Profile{}, fmt.Errorf("stage %q: fallback_index %d out of range", stage.Name, idx)
}
props[i].Fallback = props[idx]
}
globals, err := toEngineConstraints(req.Globals, deps)
if err != nil {
return engine.Profile{}, fmt.Errorf("globals: %w", err)
}
return engine.Profile{Stages: props, Direction: dir, Globals: globals}, nil
}
func toEngineModelSpec(m apirest.ModelSpec) engine.ModelSpec {
out := engine.ModelSpec{
Type: string(m.Type),
Rate: m.Rate.Or(0),
SeaLevelRate: m.SeaLevelRate.Or(0),
IncludeWind: m.IncludeWind.Or(false),
}
for _, s := range m.Segments {
out.Segments = append(out.Segments, engine.PiecewiseSegmentSpec{
Until: s.Until,
Rate: s.Rate,
Reference: string(s.Reference.Or(apirest.PiecewiseSegmentReferenceAbsolute)),
})
}
return out
}
func toEngineConstraints(specs []apirest.ConstraintSpec, deps engine.BuildDeps) ([]engine.Constraint, error) {
out := make([]engine.Constraint, 0, len(specs))
for i, s := range specs {
c, err := engine.BuildConstraint(toEngineConstraintSpec(s), deps)
if err != nil {
return nil, fmt.Errorf("constraint[%d]: %w", i, err)
}
out = append(out, c)
}
return out, nil
}
func toEngineConstraintSpec(c apirest.ConstraintSpec) engine.ConstraintSpec {
spec := engine.ConstraintSpec{
Type: string(c.Type),
Op: string(c.Op.Or("")),
Limit: c.Limit.Or(0),
Action: string(c.Action.Or(apirest.ConstraintSpecActionStop)),
Mode: string(c.Mode.Or("")),
Label: c.Label.Or(""),
}
for _, v := range c.Vertices {
spec.Vertices = append(spec.Vertices, engine.PolygonVertex{Lat: v.Lat, Lng: v.Lng})
}
return spec
}
// stageResultToAPI maps one engine stage result to the API representation.
func stageResultToAPI(r engine.Result) apirest.StageResult {
out := apirest.StageResult{
Name: r.Propagator,
Outcome: apirest.StageResultOutcome(r.Outcome.String()),
Events: eventsToAPI(r.Events),
}
if r.Constraint != nil {
out.Constraint = apirest.NewOptString(r.ConstraintName)
out.Termination = apirest.NewOptTerminationInfo(apirest.TerminationInfo{
ViolationTime: time.Unix(int64(r.ViolationTime), 0).UTC(),
ViolationState: geoStateToAPI(r.ViolationState),
RefinedTime: time.Unix(int64(r.RefinedTime), 0).UTC(),
RefinedState: geoStateToAPI(r.RefinedState),
})
}
n := r.Path.Len()
out.Trajectory = make([]apirest.TrajectoryPoint, n)
for i := range n {
t, p := r.Path.At(i)
out.Trajectory[i] = apirest.TrajectoryPoint{
Time: time.Unix(int64(t), 0).UTC(),
Latitude: p.Lat,
Longitude: signedLng(p.Lng),
Altitude: p.Altitude,
}
}
return out
}
func geoStateToAPI(s engine.State) apirest.GeoState {
return apirest.GeoState{Lat: s.Lat, Lng: signedLng(s.Lng), Altitude: s.Altitude}
}
func eventsToAPI(in []engine.EventSummary) []apirest.EventSummary {
if len(in) == 0 {
return nil
}
out := make([]apirest.EventSummary, 0, len(in))
for _, e := range in {
out = append(out, apirest.EventSummary{
Type: e.Type,
Count: e.Count,
FirstTime: apirest.NewOptFloat64(e.FirstTime),
LastTime: apirest.NewOptFloat64(e.LastTime),
FirstState: apirest.NewOptGeoState(geoStateToAPI(e.FirstState)),
LastState: apirest.NewOptGeoState(geoStateToAPI(e.LastState)),
Message: apirest.NewOptString(e.Message),
})
}
return out
}
// asyncJobToAPI maps an async job snapshot to the API PredictionJob.
func asyncJobToAPI(info async.JobInfo) *apirest.PredictionJob {
job := &apirest.PredictionJob{
ID: info.ID,
Status: apirest.PredictionJobStatus(info.Status),
CreatedAt: info.CreatedAt,
}
if info.StartedAt != nil {
job.StartedAt = apirest.NewOptDateTime(*info.StartedAt)
}
if info.CompletedAt != nil {
job.CompletedAt = apirest.NewOptDateTime(*info.CompletedAt)
}
if info.Error != "" {
job.Error = apirest.NewOptString(info.Error)
}
if info.Result != nil {
job.Result = apirest.NewOptPredictionV2Response(*info.Result)
}
return job
}

View file

@ -0,0 +1,20 @@
package middleware
import "net/http"
// CORS wraps next with permissive CORS headers and short-circuits OPTIONS preflight.
//
// This service is meant to sit behind an authenticated gateway, so we set
// "Access-Control-Allow-Origin: *". Tighten this if you deploy elsewhere.
func CORS(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}

View file

@ -0,0 +1,34 @@
package middleware
import (
"time"
"github.com/ogen-go/ogen/middleware"
"go.uber.org/zap"
)
// statusCoder is implemented by ogen's *...StatusCode error wrappers.
type statusCoder interface{ GetStatusCode() int }
// OgenLogging is an ogen middleware that logs each operation's duration and
// outcome. Handler errors carrying a 4xx/5xx-class status are logged at the
// appropriate level: client errors (and expected 503s during startup) at
// warn without a stacktrace, server errors at error.
func OgenLogging(log *zap.Logger) middleware.Middleware {
return func(req middleware.Request, next func(req middleware.Request) (middleware.Response, error)) (middleware.Response, error) {
start := time.Now()
resp, err := next(req)
lg := log.With(zap.String("operation", req.OperationID), zap.Duration("duration", time.Since(start)))
if err == nil {
lg.Info("request completed")
return resp, err
}
if sc, ok := err.(statusCoder); ok && sc.GetStatusCode() < 500 {
lg.Warn("request rejected", zap.Int("status", sc.GetStatusCode()), zap.NamedError("reason", err))
} else {
lg.Error("request failed", zap.Error(err))
}
return resp, err
}
}

239
internal/api/prediction.go Normal file
View file

@ -0,0 +1,239 @@
package api
import (
"context"
"net/http"
"time"
"predictor-refactored/internal/engine"
"predictor-refactored/internal/weather"
apirest "predictor-refactored/pkg/rest"
)
// ReadinessCheck implements GET /ready.
func (h *Handler) ReadinessCheck(_ context.Context) (*apirest.ReadinessResponse, error) {
resp := &apirest.ReadinessResponse{}
if field := h.mgr.Active(); field != nil {
resp.Status = apirest.ReadinessResponseStatusOk
resp.DatasetTime = apirest.NewOptDateTime(field.Epoch())
} else {
resp.Status = apirest.ReadinessResponseStatusNotReady
resp.ErrorMessage = apirest.NewOptString("no dataset loaded")
}
return resp, nil
}
// PerformPredictionV2 implements POST /api/v2/prediction.
func (h *Handler) PerformPredictionV2(_ context.Context, req *apirest.PredictionV2Request) (*apirest.PredictionV2Response, error) {
resp, err := h.runPredictionV2(req)
if err == nil {
h.metrics.Prediction("v2", resp.CompletedAt.Sub(resp.StartedAt), nil)
}
return resp, err
}
// CreatePredictionJob implements POST /api/v1/predictions.
func (h *Handler) CreatePredictionJob(_ context.Context, req *apirest.PredictionV2Request) (*apirest.PredictionJob, error) {
info, accepted := h.async.Enqueue(req)
if !accepted {
return nil, apiError(http.StatusServiceUnavailable, info.Error)
}
return asyncJobToAPI(info), nil
}
// GetPredictionJob implements GET /api/v1/predictions/{id}.
func (h *Handler) GetPredictionJob(_ context.Context, params apirest.GetPredictionJobParams) (*apirest.PredictionJob, error) {
info, ok := h.async.Get(params.ID)
if !ok {
return nil, apiError(http.StatusNotFound, "prediction job not found")
}
return asyncJobToAPI(info), nil
}
// CancelPredictionJob implements DELETE /api/v1/predictions/{id}.
func (h *Handler) CancelPredictionJob(_ context.Context, params apirest.CancelPredictionJobParams) error {
if !h.async.Cancel(params.ID) {
return apiError(http.StatusConflict, "job not found or already terminal")
}
return nil
}
// runPredictionV2 is the synchronous prediction core, shared by the v2
// endpoint and the async worker pool.
func (h *Handler) runPredictionV2(req *apirest.PredictionV2Request) (*apirest.PredictionV2Response, error) {
// Validate the request shape before checking dataset availability, so a
// malformed request is a 400 regardless of startup state.
lat := req.Launch.Latitude
rawLng := req.Launch.Longitude
alt := req.Launch.Altitude.Or(0)
if lat < -90 || lat > 90 {
return nil, apiError(http.StatusBadRequest, "launch.latitude must be in [-90, 90]")
}
if rawLng < -180 || rawLng >= 360 {
return nil, apiError(http.StatusBadRequest, "launch.longitude must be in [-180, 360)")
}
lng := normalizeLng(rawLng)
field := h.mgr.Active()
if field == nil {
return nil, apiError(http.StatusServiceUnavailable, "no dataset loaded, service is starting up")
}
events := engine.NewEventSink()
deps := engine.BuildDeps{Wind: field, Events: events, Terrain: h.terrain()}
prof, err := buildProfile(req, deps)
if err != nil {
return nil, apiError(http.StatusBadRequest, err.Error())
}
started := time.Now().UTC()
results := prof.Run(float64(req.Launch.Time.Unix()), engine.State{Lat: lat, Lng: lng, Altitude: alt}, events)
completed := time.Now().UTC()
resp := &apirest.PredictionV2Response{
Stages: make([]apirest.StageResult, 0, len(results)),
Events: eventsToAPI(events.Snapshot()),
Dataset: apirest.DatasetInfo{Source: field.Source(), Epoch: field.Epoch()},
StartedAt: started,
CompletedAt: completed,
}
for _, r := range results {
resp.Stages = append(resp.Stages, stageResultToAPI(r))
}
return resp, nil
}
// PerformPrediction implements GET /api/v1/prediction (Tawhiri-compatible).
func (h *Handler) PerformPrediction(_ context.Context, params apirest.PerformPredictionParams) (*apirest.PredictionResponse, error) {
field := h.mgr.Active()
if field == nil {
return nil, apiError(http.StatusServiceUnavailable, "no dataset loaded, service is starting up")
}
profileKind := "standard_profile"
if p, ok := params.Profile.Get(); ok {
profileKind = string(p)
}
ascentRate := params.AscentRate.Or(5)
descentRate := params.DescentRate.Or(5)
launchAlt := params.LaunchAltitude.Or(0)
lng := normalizeLng(params.LaunchLongitude)
launchTime := float64(params.LaunchDatetime.Unix())
events := engine.NewEventSink()
var stageNames []string
var prof engine.Profile
switch profileKind {
case "standard_profile":
stageNames = []string{"ascent", "descent"}
prof = standardProfile(field, h.terrain(), events, ascentRate, params.BurstAltitude.Or(28000), descentRate)
case "float_profile":
stopTime := params.LaunchDatetime.Add(24 * time.Hour)
if v, ok := params.StopDatetime.Get(); ok {
stopTime = v
}
stageNames = []string{"ascent", "float"}
prof = floatProfile(field, events, ascentRate, params.FloatAltitude.Or(25000), stopTime)
default:
return nil, apiError(http.StatusBadRequest, "unknown profile: "+profileKind)
}
started := time.Now().UTC()
results := prof.Run(launchTime, engine.State{Lat: params.LaunchLatitude, Lng: lng, Altitude: launchAlt}, events)
completed := time.Now().UTC()
h.metrics.Prediction(profileKind, completed.Sub(started), nil)
resp := &apirest.PredictionResponse{
Metadata: apirest.PredictionResponseMetadata{StartDatetime: started, CompleteDatetime: completed},
}
for i, r := range results {
name := "ascent"
if i < len(stageNames) {
name = stageNames[i]
}
resp.Prediction = append(resp.Prediction, tawhiriItem(name, r))
}
resp.Request = apirest.NewOptPredictionResponseRequest(apirest.PredictionResponseRequest{
Dataset: apirest.NewOptString(field.Epoch().Format("2006-01-02T15:04:05Z")),
LaunchLatitude: apirest.NewOptFloat64(params.LaunchLatitude),
LaunchLongitude: apirest.NewOptFloat64(params.LaunchLongitude),
LaunchDatetime: apirest.NewOptString(params.LaunchDatetime.Format(time.RFC3339)),
LaunchAltitude: params.LaunchAltitude,
})
if ev := events.Snapshot(); len(ev) > 0 {
resp.Warnings = apirest.NewOptPredictionResponseWarnings(apirest.PredictionResponseWarnings{})
}
return resp, nil
}
// standardProfile builds the Tawhiri ascent → descent chain.
func standardProfile(field weather.WindField, elev engine.TerrainProvider, events *engine.EventSink, ascentRate, burst, descentRate float64) engine.Profile {
wind := engine.WindTransport(field, events)
descentTerm := []engine.Constraint{engine.Altitude{Op: engine.OpLessEqual, Limit: 0, On: engine.ActionStop}}
if elev != nil {
descentTerm = []engine.Constraint{engine.TerrainContact{Provider: elev, On: engine.ActionStop}}
}
return engine.Profile{
Direction: engine.Forward,
Stages: []*engine.Propagator{
{
Name: "ascent",
Step: 60,
Model: engine.Sum(engine.ConstantRate(ascentRate), wind),
Constraints: []engine.Constraint{engine.Altitude{Op: engine.OpGreaterEqual, Limit: burst, On: engine.ActionStop}},
},
{
Name: "descent",
Step: 60,
Model: engine.Sum(engine.ParachuteDescent(descentRate), wind),
Constraints: descentTerm,
},
},
}
}
// floatProfile builds the Tawhiri ascent → float chain.
func floatProfile(field weather.WindField, events *engine.EventSink, ascentRate, floatAlt float64, stopTime time.Time) engine.Profile {
wind := engine.WindTransport(field, events)
return engine.Profile{
Direction: engine.Forward,
Stages: []*engine.Propagator{
{
Name: "ascent",
Step: 60,
Model: engine.Sum(engine.ConstantRate(ascentRate), wind),
Constraints: []engine.Constraint{engine.Altitude{Op: engine.OpGreaterEqual, Limit: floatAlt, On: engine.ActionStop}},
},
{
Name: "float",
Step: 60,
Model: wind,
Constraints: []engine.Constraint{engine.Time{Op: engine.OpGreater, Limit: float64(stopTime.Unix()), On: engine.ActionStop}},
},
},
}
}
// tawhiriItem maps one engine stage result to a v1 prediction item.
func tawhiriItem(name string, r engine.Result) apirest.PredictionResponsePredictionItem {
stage := apirest.PredictionResponsePredictionItemStageAscent
switch name {
case "descent":
stage = apirest.PredictionResponsePredictionItemStageDescent
case "float":
stage = apirest.PredictionResponsePredictionItemStageFloat
}
n := r.Path.Len()
traj := make([]apirest.TawhiriPoint, 0, n)
for i := range n {
t, p := r.Path.At(i)
traj = append(traj, apirest.TawhiriPoint{
Datetime: time.Unix(int64(t), 0).UTC(),
Latitude: p.Lat,
Longitude: signedLng(p.Lng),
Altitude: p.Altitude,
})
}
return apirest.PredictionResponsePredictionItem{Stage: stage, Trajectory: traj}
}

131
internal/api/transport.go Normal file
View file

@ -0,0 +1,131 @@
// Package api is the HTTP surface of the service. Every REST operation is
// defined in the OpenAPI spec (api/rest/predictor.swagger.yml) and served by
// the ogen-generated server in pkg/rest; this package implements the
// generated Handler interface and wires the server together with the
// non-OpenAPI endpoints (Prometheus metrics, ReDoc docs).
package api
import (
"context"
"fmt"
"net/http"
"time"
"go.uber.org/zap"
"predictor-refactored/internal/api/async"
"predictor-refactored/internal/api/docs"
"predictor-refactored/internal/api/middleware"
"predictor-refactored/internal/datasets"
"predictor-refactored/internal/elevation"
"predictor-refactored/internal/metrics"
"predictor-refactored/internal/windviz"
apirest "predictor-refactored/pkg/rest"
)
// Server is the top-level HTTP server.
type Server struct {
port int
mux *http.ServeMux
async *async.Manager
log *zap.Logger
}
// Deps are the runtime dependencies the API layer needs.
type Deps struct {
Manager *datasets.Manager
Elevation *elevation.Dataset
Metrics metrics.Sink
MetricsHandler http.Handler // optional; mounted at MetricsPath when non-nil
MetricsPath string
EnableWind bool
WindCache *windviz.Cache // optional; created if nil and EnableWind
AsyncWorkers int
AsyncQueueSize int
AsyncResultTTL time.Duration
Log *zap.Logger
}
// New wires the HTTP server. The returned Server is not yet started.
func New(port int, d Deps) (*Server, error) {
if d.Log == nil {
d.Log = zap.NewNop()
}
if d.Metrics == nil {
d.Metrics = metrics.Noop()
}
if d.EnableWind && d.WindCache == nil {
d.WindCache = windviz.NewCache(64, 10*time.Minute)
}
h := &Handler{
mgr: d.Manager,
elev: d.Elevation,
metrics: d.Metrics,
cache: d.WindCache,
started: time.Now().UTC(),
log: d.Log,
}
// The async worker pool runs the same prediction core as the synchronous
// endpoint; inject it so async stays decoupled from the wire types.
h.async = async.New(async.Config{
Workers: d.AsyncWorkers,
QueueSize: d.AsyncQueueSize,
ResultTTL: d.AsyncResultTTL,
}, h.runPredictionV2, d.Metrics, d.Log)
ogenSrv, err := apirest.NewServer(h, apirest.WithMiddleware(middleware.OgenLogging(d.Log)))
if err != nil {
return nil, fmt.Errorf("create ogen server: %w", err)
}
mux := http.NewServeMux()
// Liveness: always 200 while the process is up, independent of whether a
// dataset is loaded. Container/orchestrator health checks use this; the
// readiness of the data plane is /ready (an OpenAPI operation).
mux.HandleFunc("GET /health", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"status":"alive"}`))
})
docs.New().Register(mux)
if d.MetricsHandler != nil && d.MetricsPath != "" {
mux.Handle(d.MetricsPath, d.MetricsHandler)
}
// The ogen server owns every OpenAPI route; mount it last as the catch-all.
mux.Handle("/", ogenSrv)
return &Server{port: port, mux: mux, async: h.async, log: d.Log}, nil
}
// Run starts the HTTP server and blocks until ctx is cancelled or the server
// fails. The handler chain is CORS → mux (ogen routes + docs + metrics).
func (s *Server) Run(ctx context.Context) error {
srv := &http.Server{
Addr: fmt.Sprintf(":%d", s.port),
Handler: middleware.CORS(s.mux),
}
errCh := make(chan error, 1)
go func() {
s.log.Info("HTTP server starting", zap.Int("port", s.port))
errCh <- srv.ListenAndServe()
}()
select {
case err := <-errCh:
return err
case <-ctx.Done():
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
return srv.Shutdown(shutdownCtx)
}
}
// Close releases background resources (the async worker pool).
func (s *Server) Close() {
if s.async != nil {
s.async.Close()
}
}

92
internal/api/wind.go Normal file
View file

@ -0,0 +1,92 @@
package api
import (
"context"
"fmt"
"net/http"
"predictor-refactored/internal/windviz"
apirest "predictor-refactored/pkg/rest"
)
// GetWindMeta implements GET /api/v1/wind/meta.
func (h *Handler) GetWindMeta(_ context.Context) (*apirest.WindMeta, error) {
field := h.mgr.Active()
if field == nil {
return nil, apiError(http.StatusServiceUnavailable, "no dataset loaded")
}
return &apirest.WindMeta{
Source: field.Source(),
Epoch: field.Epoch().UTC(),
DefaultStep: 1.0,
MinStep: 0.25,
SuggestedAltitudes: []int{0, 1000, 5000, 10000, 15000, 20000, 30000},
Bbox: apirest.Region{MinLat: -90, MaxLat: 90, MinLng: 0, MaxLng: 360},
}, nil
}
// GetWindField implements GET /api/v1/wind/field.
func (h *Handler) GetWindField(_ context.Context, params apirest.GetWindFieldParams) ([]apirest.WindComponent, error) {
field := h.mgr.Active()
if field == nil {
return nil, apiError(http.StatusServiceUnavailable, "no dataset loaded")
}
when := field.Epoch()
if t, ok := params.Time.Get(); ok {
when = t
}
req := windviz.Request{
Time: float64(when.Unix()),
Altitude: params.Altitude.Or(0),
MinLat: params.MinLat.Or(0),
MaxLat: params.MaxLat.Or(0),
MinLng: params.MinLng.Or(0),
MaxLng: params.MaxLng.Or(0),
Step: params.Step.Or(0),
}
key := fmt.Sprintf("%s|%v|%.3f|%.3f|%.3f|%.3f|%.3f|%.3f",
field.Source(), req.Time, req.Altitude, req.MinLat, req.MaxLat, req.MinLng, req.MaxLng, req.Step)
if h.cache != nil {
if cached, ok := h.cache.Get(key); ok {
return windFieldToAPI(cached), nil
}
}
out, err := windviz.Rasterize(field, req)
if err != nil {
return nil, apiError(http.StatusBadRequest, err.Error())
}
if h.cache != nil {
h.cache.Put(key, out)
}
return windFieldToAPI(out), nil
}
// windFieldToAPI maps a rasterized field to the generated component slice.
func windFieldToAPI(f windviz.Field) []apirest.WindComponent {
out := make([]apirest.WindComponent, 0, len(f))
for _, c := range f {
out = append(out, apirest.WindComponent{
Header: apirest.WindHeader{
ParameterCategory: c.Header.ParameterCategory,
ParameterNumber: c.Header.ParameterNumber,
ParameterNumberName: apirest.NewOptString(c.Header.ParameterNumberName),
ParameterUnit: apirest.NewOptString(c.Header.ParameterUnit),
Nx: c.Header.Nx,
Ny: c.Header.Ny,
Lo1: c.Header.Lo1,
La1: c.Header.La1,
Lo2: c.Header.Lo2,
La2: c.Header.La2,
Dx: c.Header.Dx,
Dy: c.Header.Dy,
RefTime: c.Header.RefTime,
ForecastTime: c.Header.ForecastTime,
},
Data: c.Data,
})
}
return out
}

277
internal/config/config.go Normal file
View file

@ -0,0 +1,277 @@
// Package config holds the service's runtime configuration, loaded by
// merging (in order of increasing precedence): built-in defaults, a YAML
// config file, environment variables, and command-line flags.
//
// Validation is performed once on load; downstream consumers receive an
// immutable struct.
package config
import (
"flag"
"fmt"
"os"
"strconv"
"time"
"gopkg.in/yaml.v2"
)
// Config is the top-level configuration tree.
type Config struct {
HTTP HTTPConfig `yaml:"http"`
Data DataConfig `yaml:"data"`
Download DownloadConfig `yaml:"download"`
Metrics MetricsConfig `yaml:"metrics"`
Wind WindConfig `yaml:"wind"`
Log LogConfig `yaml:"log"`
}
// HTTPConfig configures the HTTP server.
type HTTPConfig struct {
Port int `yaml:"port"`
// AsyncWorkers caps concurrent prediction executions for the async endpoint.
AsyncWorkers int `yaml:"async_workers"`
// AsyncQueueSize bounds the async pending queue.
AsyncQueueSize int `yaml:"async_queue_size"`
// AsyncResultTTL is how long completed async results are retained.
AsyncResultTTL time.Duration `yaml:"async_result_ttl"`
}
// DataConfig configures dataset and elevation storage.
type DataConfig struct {
Dir string `yaml:"dir"`
ElevationPath string `yaml:"elevation_path"`
// Source is the dataset variant ID: gfs-0p50-3h (default), gfs-0p25-3h,
// gfs-0p25-1h, or gefs-0p50-3h. See weather/gfs.VariantByID.
Source string `yaml:"source"`
}
// DownloadConfig configures the dataset downloader.
type DownloadConfig struct {
Parallel int `yaml:"parallel"`
BandwidthBytesPerSecond int64 `yaml:"bandwidth_bytes_per_second"`
UpdateInterval time.Duration `yaml:"update_interval"`
FreshnessTTL time.Duration `yaml:"freshness_ttl"`
}
// MetricsConfig configures the metrics endpoint.
type MetricsConfig struct {
Enabled bool `yaml:"enabled"`
Path string `yaml:"path"`
}
// WindConfig configures the wind-visualization endpoints.
type WindConfig struct {
Enabled bool `yaml:"enabled"`
CacheSize int `yaml:"cache_size"`
CacheTTL time.Duration `yaml:"cache_ttl"`
}
// LogConfig configures logging.
type LogConfig struct {
Level string `yaml:"level"` // "debug", "info", "warn", "error"
}
// Defaults returns a Config with reasonable default values.
func Defaults() Config {
return Config{
HTTP: HTTPConfig{
Port: 8080,
AsyncWorkers: 4,
AsyncQueueSize: 64,
AsyncResultTTL: time.Hour,
},
Data: DataConfig{
Dir: "/tmp/predictor-data",
ElevationPath: "/srv/ruaumoko-dataset",
Source: "gfs-0p50-3h",
},
Download: DownloadConfig{
Parallel: 8,
BandwidthBytesPerSecond: 0,
UpdateInterval: 6 * time.Hour,
FreshnessTTL: 48 * time.Hour,
},
Metrics: MetricsConfig{
Enabled: true,
Path: "/metrics",
},
Wind: WindConfig{
Enabled: true,
CacheSize: 64,
CacheTTL: 10 * time.Minute,
},
Log: LogConfig{Level: "info"},
}
}
// Load resolves the configuration by merging built-in defaults with
// (in increasing precedence): a YAML file (path from PREDICTOR_CONFIG_FILE
// env var or --config flag), environment variables, and command-line flags.
//
// args is os.Args[1:] in production code; tests pass a custom slice.
func Load(args []string) (Config, error) {
cfg := Defaults()
fs := flag.NewFlagSet("predictor", flag.ContinueOnError)
// Surface a deterministic usage by suppressing the default output:
fs.SetOutput(os.Stderr)
var (
configPath = fs.String("config", os.Getenv("PREDICTOR_CONFIG_FILE"), "path to YAML config file")
// Flag-driven overrides. Empty / -1 means "not specified".
flagPort = fs.Int("port", -1, "HTTP listen port")
flagDataDir = fs.String("data-dir", "", "directory for dataset files")
flagElevation = fs.String("elevation", "", "path to ruaumoko elevation dataset")
flagParallel = fs.Int("download-parallel", -1, "max concurrent GRIB downloads")
flagBandwidth = fs.Int64("download-bandwidth", -1, "download bandwidth limit in bytes/sec (0 = unlimited)")
flagInterval = fs.Duration("update-interval", 0, "scheduler refresh interval")
flagTTL = fs.Duration("freshness-ttl", 0, "max age before a dataset is considered stale")
flagMetricsEnabled = fs.Bool("metrics", true, "enable Prometheus-compatible metrics endpoint")
flagMetricsPath = fs.String("metrics-path", "", "HTTP path for the metrics endpoint")
flagLogLevel = fs.String("log-level", "", "log level: debug|info|warn|error")
)
if err := fs.Parse(args); err != nil {
return Config{}, fmt.Errorf("parse flags: %w", err)
}
// 1. File.
if *configPath != "" {
data, err := os.ReadFile(*configPath)
if err != nil {
return Config{}, fmt.Errorf("read config %s: %w", *configPath, err)
}
if err := yaml.UnmarshalStrict(data, &cfg); err != nil {
return Config{}, fmt.Errorf("parse config %s: %w", *configPath, err)
}
}
// 2. Env vars.
applyEnv(&cfg)
// 3. Flags (only when explicitly set).
if *flagPort >= 0 {
cfg.HTTP.Port = *flagPort
}
if *flagDataDir != "" {
cfg.Data.Dir = *flagDataDir
}
if *flagElevation != "" {
cfg.Data.ElevationPath = *flagElevation
}
if *flagParallel >= 0 {
cfg.Download.Parallel = *flagParallel
}
if *flagBandwidth >= 0 {
cfg.Download.BandwidthBytesPerSecond = *flagBandwidth
}
if *flagInterval != 0 {
cfg.Download.UpdateInterval = *flagInterval
}
if *flagTTL != 0 {
cfg.Download.FreshnessTTL = *flagTTL
}
// flag.Bool defaults to true here so we only override if user explicitly disables it.
if isFlagSet(fs, "metrics") {
cfg.Metrics.Enabled = *flagMetricsEnabled
}
if *flagMetricsPath != "" {
cfg.Metrics.Path = *flagMetricsPath
}
if *flagLogLevel != "" {
cfg.Log.Level = *flagLogLevel
}
if err := cfg.Validate(); err != nil {
return Config{}, err
}
return cfg, nil
}
func isFlagSet(fs *flag.FlagSet, name string) bool {
set := false
fs.Visit(func(f *flag.Flag) {
if f.Name == name {
set = true
}
})
return set
}
// applyEnv overlays PREDICTOR_* environment variables onto cfg.
func applyEnv(cfg *Config) {
if v := os.Getenv("PREDICTOR_PORT"); v != "" {
if n, err := strconv.Atoi(v); err == nil {
cfg.HTTP.Port = n
}
}
if v := os.Getenv("PREDICTOR_DATA_DIR"); v != "" {
cfg.Data.Dir = v
}
if v := os.Getenv("PREDICTOR_ELEVATION_DATASET"); v != "" {
cfg.Data.ElevationPath = v
}
if v := os.Getenv("PREDICTOR_SOURCE"); v != "" {
cfg.Data.Source = v
}
if v := os.Getenv("PREDICTOR_DOWNLOAD_PARALLEL"); v != "" {
if n, err := strconv.Atoi(v); err == nil {
cfg.Download.Parallel = n
}
}
if v := os.Getenv("PREDICTOR_DOWNLOAD_BANDWIDTH"); v != "" {
if n, err := strconv.ParseInt(v, 10, 64); err == nil {
cfg.Download.BandwidthBytesPerSecond = n
}
}
if v := os.Getenv("PREDICTOR_UPDATE_INTERVAL"); v != "" {
if d, err := time.ParseDuration(v); err == nil {
cfg.Download.UpdateInterval = d
}
}
if v := os.Getenv("PREDICTOR_DATASET_TTL"); v != "" {
if d, err := time.ParseDuration(v); err == nil {
cfg.Download.FreshnessTTL = d
}
}
if v := os.Getenv("PREDICTOR_METRICS_ENABLED"); v != "" {
cfg.Metrics.Enabled = v == "1" || v == "true" || v == "yes"
}
if v := os.Getenv("PREDICTOR_METRICS_PATH"); v != "" {
cfg.Metrics.Path = v
}
if v := os.Getenv("PREDICTOR_LOG_LEVEL"); v != "" {
cfg.Log.Level = v
}
}
// Validate reports configuration errors.
func (c Config) Validate() error {
if c.HTTP.Port < 0 || c.HTTP.Port > 65535 {
return fmt.Errorf("http.port %d outside [0, 65535]", c.HTTP.Port)
}
if c.Data.Dir == "" {
return fmt.Errorf("data.dir is required")
}
if c.Data.Source == "" {
return fmt.Errorf("data.source is required")
}
if c.Download.Parallel <= 0 {
return fmt.Errorf("download.parallel must be > 0")
}
if c.Download.UpdateInterval <= 0 {
return fmt.Errorf("download.update_interval must be > 0")
}
if c.Download.FreshnessTTL <= 0 {
return fmt.Errorf("download.freshness_ttl must be > 0")
}
if c.Metrics.Enabled && c.Metrics.Path == "" {
return fmt.Errorf("metrics.path is required when metrics enabled")
}
switch c.Log.Level {
case "debug", "info", "warn", "error":
default:
return fmt.Errorf("log.level %q is not one of debug|info|warn|error", c.Log.Level)
}
return nil
}

View file

@ -0,0 +1,76 @@
package config
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestLoadDefaults(t *testing.T) {
t.Setenv("PREDICTOR_DATA_DIR", "")
t.Setenv("PREDICTOR_PORT", "")
t.Setenv("PREDICTOR_CONFIG_FILE", "")
cfg, err := Load(nil)
if err != nil {
t.Fatalf("Load: %v", err)
}
if cfg.HTTP.Port != 8080 {
t.Errorf("default port = %d, want 8080", cfg.HTTP.Port)
}
if cfg.Download.Parallel != 8 {
t.Errorf("default parallel = %d, want 8", cfg.Download.Parallel)
}
}
func TestLoadEnvOverridesDefaults(t *testing.T) {
t.Setenv("PREDICTOR_PORT", "9090")
t.Setenv("PREDICTOR_UPDATE_INTERVAL", "30m")
cfg, err := Load(nil)
if err != nil {
t.Fatalf("Load: %v", err)
}
if cfg.HTTP.Port != 9090 {
t.Errorf("env port = %d, want 9090", cfg.HTTP.Port)
}
if cfg.Download.UpdateInterval != 30*time.Minute {
t.Errorf("env update interval = %v, want 30m", cfg.Download.UpdateInterval)
}
}
func TestLoadFlagsOverrideEnv(t *testing.T) {
t.Setenv("PREDICTOR_PORT", "9090")
cfg, err := Load([]string{"-port", "7777"})
if err != nil {
t.Fatalf("Load: %v", err)
}
if cfg.HTTP.Port != 7777 {
t.Errorf("flag should override env: port = %d, want 7777", cfg.HTTP.Port)
}
}
func TestLoadFileOverridesDefaults(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "predictor.yml")
if err := os.WriteFile(path, []byte("http:\n port: 12345\n"), 0o644); err != nil {
t.Fatal(err)
}
cfg, err := Load([]string{"-config", path})
if err != nil {
t.Fatalf("Load: %v", err)
}
if cfg.HTTP.Port != 12345 {
t.Errorf("file port = %d, want 12345", cfg.HTTP.Port)
}
}
func TestValidate(t *testing.T) {
cfg := Defaults()
cfg.Data.Dir = ""
if err := cfg.Validate(); err == nil {
t.Error("expected validation error for empty data dir")
}
}

View file

@ -1,158 +0,0 @@
package dataset
import "fmt"
// Dataset shape constants.
// Shape: (65, 47, 3, 361, 720) = (hour, pressure_level, variable, latitude, longitude)
// This matches the reference predictor exactly.
const (
NumHours = 65 // 0, 3, 6, ..., 192
NumLevels = 47 // pressure levels
NumVariables = 3 // height, wind_u, wind_v
NumLatitudes = 361 // -90.0 to +90.0 in 0.5 degree steps
NumLongitudes = 720 // 0.0 to 359.5 in 0.5 degree steps
HourStep = 3 // hours between forecast time steps
MaxHour = 192 // maximum forecast hour
Resolution = 0.5 // grid resolution in degrees
LatStart = -90.0 // first latitude in the dataset
LonStart = 0.0 // first longitude in the dataset
// Variable indices within the dataset.
VarHeight = 0
VarWindU = 1
VarWindV = 2
ElementSize = 4 // float32 = 4 bytes
)
// DatasetSize is the total size of the dataset file in bytes.
// 65 * 47 * 3 * 361 * 720 * 4 = 9,528,667,200
const DatasetSize int64 = int64(NumHours) * int64(NumLevels) * int64(NumVariables) *
int64(NumLatitudes) * int64(NumLongitudes) * int64(ElementSize)
// LevelSet identifies which GRIB file set a pressure level belongs to.
type LevelSet int
const (
LevelSetA LevelSet = iota // pgrb2 (primary)
LevelSetB // pgrb2b (secondary)
)
// Pressures contains the 47 pressure levels in hPa, sorted descending.
// Index 0 = 1000 hPa (near surface), Index 46 = 1 hPa (high atmosphere).
var Pressures = [NumLevels]int{
1000, 975, 950, 925, 900, 875, 850, 825, 800, 775,
750, 725, 700, 675, 650, 625, 600, 575, 550, 525,
500, 475, 450, 425, 400, 375, 350, 325, 300, 275,
250, 225, 200, 175, 150, 125, 100, 70, 50, 30,
20, 10, 7, 5, 3, 2, 1,
}
// pressureIndex maps pressure in hPa to its index in the Pressures array.
var pressureIndex map[int]int
// pressureLevelSet maps pressure in hPa to its GRIB file set.
var pressureLevelSet map[int]LevelSet
func init() {
pressureIndex = make(map[int]int, NumLevels)
for i, p := range Pressures {
pressureIndex[p] = i
}
pressureLevelSet = make(map[int]LevelSet, NumLevels)
for _, p := range PressuresPgrb2 {
pressureLevelSet[p] = LevelSetA
}
for _, p := range PressuresPgrb2b {
pressureLevelSet[p] = LevelSetB
}
}
// PressuresPgrb2 contains levels found in the primary pgrb2 file (26 levels).
var PressuresPgrb2 = []int{
10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 350, 400,
450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 925,
950, 975, 1000,
}
// PressuresPgrb2b contains levels found in the secondary pgrb2b file (21 levels).
var PressuresPgrb2b = []int{
1, 2, 3, 5, 7, 125, 175, 225, 275, 325, 375, 425,
475, 525, 575, 625, 675, 725, 775, 825, 875,
}
// PressureIndex returns the dataset index for a given pressure level in hPa.
// Returns -1 if the level is not found.
func PressureIndex(hPa int) int {
idx, ok := pressureIndex[hPa]
if !ok {
return -1
}
return idx
}
// PressureLevelSet returns which GRIB file set a pressure level belongs to.
func PressureLevelSet(hPa int) (LevelSet, bool) {
ls, ok := pressureLevelSet[hPa]
return ls, ok
}
// HourIndex returns the dataset time index for a forecast hour.
// Returns -1 if the hour is invalid (not a multiple of HourStep or out of range).
func HourIndex(hour int) int {
if hour < 0 || hour > MaxHour || hour%HourStep != 0 {
return -1
}
return hour / HourStep
}
// Hours returns all forecast hours as a slice: [0, 3, 6, ..., 192].
func Hours() []int {
out := make([]int, 0, NumHours)
for h := 0; h <= MaxHour; h += HourStep {
out = append(out, h)
}
return out
}
// S3 URL configuration for NOAA GFS data.
const S3BaseURL = "https://noaa-gfs-bdp-pds.s3.amazonaws.com"
// GribURL returns the S3 URL for a primary (pgrb2) GRIB file.
func GribURL(date string, runHour, forecastStep int) string {
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.0p50.f%03d",
S3BaseURL, date, runHour, runHour, forecastStep)
}
// GribURLB returns the S3 URL for a secondary (pgrb2b) GRIB file.
func GribURLB(date string, runHour, forecastStep int) string {
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2b.0p50.f%03d",
S3BaseURL, date, runHour, runHour, forecastStep)
}
// GribFileName returns the local filename for a primary GRIB file.
func GribFileName(runHour, forecastStep int) string {
return fmt.Sprintf("gfs.t%02dz.pgrb2.0p50.f%03d", runHour, forecastStep)
}
// GribFileNameB returns the local filename for a secondary GRIB file.
func GribFileNameB(runHour, forecastStep int) string {
return fmt.Sprintf("gfs.t%02dz.pgrb2b.0p50.f%03d", runHour, forecastStep)
}
// VariableIndex returns the dataset variable index for a GRIB parameter.
// Returns -1 if the parameter is not recognized.
func VariableIndex(parameterCategory, parameterNumber int) int {
switch {
case parameterCategory == 3 && parameterNumber == 5:
return VarHeight // Geopotential Height
case parameterCategory == 2 && parameterNumber == 2:
return VarWindU // U-component of wind
case parameterCategory == 2 && parameterNumber == 3:
return VarWindV // V-component of wind
default:
return -1
}
}

View file

@ -1,152 +0,0 @@
package dataset
import (
"testing"
)
func TestDatasetShape(t *testing.T) {
if NumHours != 65 {
t.Errorf("NumHours = %d, want 65", NumHours)
}
if NumLevels != 47 {
t.Errorf("NumLevels = %d, want 47", NumLevels)
}
if NumVariables != 3 {
t.Errorf("NumVariables = %d, want 3", NumVariables)
}
if NumLatitudes != 361 {
t.Errorf("NumLatitudes = %d, want 361", NumLatitudes)
}
if NumLongitudes != 720 {
t.Errorf("NumLongitudes = %d, want 720", NumLongitudes)
}
}
func TestDatasetSize(t *testing.T) {
// 65 * 47 * 3 * 361 * 720 * 4 = 9,528,667,200
want := int64(9_528_667_200)
if DatasetSize != want {
t.Errorf("DatasetSize = %d, want %d", DatasetSize, want)
}
}
func TestPressureLevels(t *testing.T) {
if len(Pressures) != 47 {
t.Fatalf("len(Pressures) = %d, want 47", len(Pressures))
}
// First should be 1000 (highest pressure, near surface)
if Pressures[0] != 1000 {
t.Errorf("Pressures[0] = %d, want 1000", Pressures[0])
}
// Last should be 1 (lowest pressure, high atmosphere)
if Pressures[46] != 1 {
t.Errorf("Pressures[46] = %d, want 1", Pressures[46])
}
// Should be sorted descending
for i := 1; i < len(Pressures); i++ {
if Pressures[i] >= Pressures[i-1] {
t.Errorf("Pressures not descending at [%d]: %d >= %d", i, Pressures[i], Pressures[i-1])
}
}
// Total levels: 26 from pgrb2 + 21 from pgrb2b = 47
if len(PressuresPgrb2) != 26 {
t.Errorf("len(PressuresPgrb2) = %d, want 26", len(PressuresPgrb2))
}
if len(PressuresPgrb2b) != 21 {
t.Errorf("len(PressuresPgrb2b) = %d, want 21", len(PressuresPgrb2b))
}
}
func TestPressureIndex(t *testing.T) {
if PressureIndex(1000) != 0 {
t.Errorf("PressureIndex(1000) = %d, want 0", PressureIndex(1000))
}
if PressureIndex(1) != 46 {
t.Errorf("PressureIndex(1) = %d, want 46", PressureIndex(1))
}
if PressureIndex(500) != 20 {
t.Errorf("PressureIndex(500) = %d, want 20", PressureIndex(500))
}
if PressureIndex(9999) != -1 {
t.Errorf("PressureIndex(9999) = %d, want -1", PressureIndex(9999))
}
}
func TestPressureLevelSet(t *testing.T) {
// 1000 mb should be in pgrb2 (A)
ls, ok := PressureLevelSet(1000)
if !ok || ls != LevelSetA {
t.Errorf("PressureLevelSet(1000) = %v, %v; want A, true", ls, ok)
}
// 125 mb should be in pgrb2b (B)
ls, ok = PressureLevelSet(125)
if !ok || ls != LevelSetB {
t.Errorf("PressureLevelSet(125) = %v, %v; want B, true", ls, ok)
}
// 1, 2, 3, 5, 7 should be in pgrb2b (B)
for _, p := range []int{1, 2, 3, 5, 7} {
ls, ok := PressureLevelSet(p)
if !ok || ls != LevelSetB {
t.Errorf("PressureLevelSet(%d) = %v, %v; want B, true", p, ls, ok)
}
}
// Every pressure level should have a level set assignment
for _, p := range Pressures {
_, ok := PressureLevelSet(p)
if !ok {
t.Errorf("PressureLevelSet(%d) not found", p)
}
}
}
func TestHourIndex(t *testing.T) {
if HourIndex(0) != 0 {
t.Errorf("HourIndex(0) = %d, want 0", HourIndex(0))
}
if HourIndex(3) != 1 {
t.Errorf("HourIndex(3) = %d, want 1", HourIndex(3))
}
if HourIndex(192) != 64 {
t.Errorf("HourIndex(192) = %d, want 64", HourIndex(192))
}
if HourIndex(1) != -1 {
t.Errorf("HourIndex(1) = %d, want -1 (not multiple of 3)", HourIndex(1))
}
if HourIndex(195) != -1 {
t.Errorf("HourIndex(195) = %d, want -1 (out of range)", HourIndex(195))
}
}
func TestHours(t *testing.T) {
hours := Hours()
if len(hours) != NumHours {
t.Fatalf("len(Hours()) = %d, want %d", len(hours), NumHours)
}
if hours[0] != 0 {
t.Errorf("Hours()[0] = %d, want 0", hours[0])
}
if hours[len(hours)-1] != MaxHour {
t.Errorf("Hours()[last] = %d, want %d", hours[len(hours)-1], MaxHour)
}
}
func TestVariableIndex(t *testing.T) {
if VariableIndex(3, 5) != VarHeight {
t.Errorf("HGT: got %d, want %d", VariableIndex(3, 5), VarHeight)
}
if VariableIndex(2, 2) != VarWindU {
t.Errorf("UGRD: got %d, want %d", VariableIndex(2, 2), VarWindU)
}
if VariableIndex(2, 3) != VarWindV {
t.Errorf("VGRD: got %d, want %d", VariableIndex(2, 3), VarWindV)
}
if VariableIndex(0, 0) != -1 {
t.Errorf("unknown: got %d, want -1", VariableIndex(0, 0))
}
}

View file

@ -1,140 +0,0 @@
package dataset
import (
"encoding/binary"
"fmt"
"math"
"os"
"time"
mmap "github.com/edsrzf/mmap-go"
)
// File represents an mmap-backed wind dataset file.
type File struct {
mm mmap.MMap
file *os.File
writable bool
DSTime time.Time // forecast run time (UTC)
}
// Open opens an existing dataset file for reading.
func Open(path string, dsTime time.Time) (*File, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open dataset: %w", err)
}
info, err := f.Stat()
if err != nil {
f.Close()
return nil, fmt.Errorf("stat dataset: %w", err)
}
if info.Size() != DatasetSize {
f.Close()
return nil, fmt.Errorf("dataset should be %d bytes (was %d)", DatasetSize, info.Size())
}
mm, err := mmap.Map(f, mmap.RDONLY, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap dataset: %w", err)
}
return &File{mm: mm, file: f, writable: false, DSTime: dsTime}, nil
}
// Create creates a new dataset file for writing.
// The file is truncated to the correct size and mmap'd read-write.
func Create(path string) (*File, error) {
f, err := os.Create(path)
if err != nil {
return nil, fmt.Errorf("create dataset: %w", err)
}
if err := f.Truncate(DatasetSize); err != nil {
f.Close()
return nil, fmt.Errorf("truncate dataset: %w", err)
}
mm, err := mmap.MapRegion(f, int(DatasetSize), mmap.RDWR, 0, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap dataset: %w", err)
}
return &File{mm: mm, file: f, writable: true}, nil
}
// offset computes the byte offset for element [hour][level][variable][lat][lon].
// Row-major C-order indexing matching the reference implementation:
// shape = (65, 47, 3, 361, 720)
func offset(hour, level, variable, lat, lon int) int64 {
idx := int64(hour)
idx = idx*int64(NumLevels) + int64(level)
idx = idx*int64(NumVariables) + int64(variable)
idx = idx*int64(NumLatitudes) + int64(lat)
idx = idx*int64(NumLongitudes) + int64(lon)
return idx * int64(ElementSize)
}
// Val reads a float32 value from the dataset at [hour][level][variable][lat][lon].
func (d *File) Val(hour, level, variable, lat, lon int) float32 {
off := offset(hour, level, variable, lat, lon)
bits := binary.LittleEndian.Uint32(d.mm[off : off+4])
return math.Float32frombits(bits)
}
// SetVal writes a float32 value to the dataset at [hour][level][variable][lat][lon].
// Only valid on writable (created) datasets.
func (d *File) SetVal(hour, level, variable, lat, lon int, val float32) {
off := offset(hour, level, variable, lat, lon)
binary.LittleEndian.PutUint32(d.mm[off:off+4], math.Float32bits(val))
}
// BlitGribData copies decoded GRIB grid data into the dataset at the given position.
// gribData is 361*720 float64 values in GRIB scan order (north-to-south, west-to-east).
// This function flips the latitude so that dataset index 0 = -90 (south) and 360 = +90 (north).
func (d *File) BlitGribData(hourIdx, levelIdx, varIdx int, gribData []float64) error {
expected := NumLatitudes * NumLongitudes
if len(gribData) != expected {
return fmt.Errorf("grib data has %d values, expected %d", len(gribData), expected)
}
for lat := 0; lat < NumLatitudes; lat++ {
for lon := 0; lon < NumLongitudes; lon++ {
// GRIB scans north-to-south: row 0 = 90N, row 360 = 90S
// Dataset stores south-to-north: index 0 = -90 (90S), index 360 = +90 (90N)
gribIdx := (360-lat)*NumLongitudes + lon
val := float32(gribData[gribIdx])
d.SetVal(hourIdx, levelIdx, varIdx, lat, lon, val)
}
}
return nil
}
// Flush flushes the mmap to disk.
func (d *File) Flush() error {
if d.mm != nil {
return d.mm.Flush()
}
return nil
}
// Close unmaps and closes the dataset file.
func (d *File) Close() error {
if d.mm != nil {
if err := d.mm.Unmap(); err != nil {
d.file.Close()
return fmt.Errorf("unmap: %w", err)
}
d.mm = nil
}
if d.file != nil {
err := d.file.Close()
d.file = nil
return err
}
return nil
}

11
internal/datasets/doc.go Normal file
View file

@ -0,0 +1,11 @@
// Package datasets manages the lifecycle of atmospheric datasets. It exposes:
//
// - A Source interface for pluggable dataset origins (GFS now, ECMWF later).
// - A Storage interface for transactional, resumable on-disk persistence.
// - A Manager that coordinates downloads, tracks job state, and owns the
// currently-active weather.WindField.
//
// The package is the only one in the service that knows about download
// scheduling, manifests, or bandwidth throttling — engine and API layers
// only see WindField + Manager-as-admin.
package datasets

View file

@ -0,0 +1,151 @@
// Package gefs implements datasets.Source for NOAA GEFS (Global Ensemble
// Forecast System) forecasts.
//
// Each ensemble member is treated as its own dataset, selected via
// DatasetID.Subset.Members. The download skeleton (HTTP, idx parsing,
// parallel blit) lives in internal/datasets/grib; this package only
// supplies GEFS-specific URL templating and member resolution.
package gefs
import (
"context"
"fmt"
"net/http"
"time"
"go.uber.org/zap"
"predictor-refactored/internal/datasets"
"predictor-refactored/internal/datasets/grib"
"predictor-refactored/internal/weather"
wgfs "predictor-refactored/internal/weather/gfs"
)
// Source is the GEFS implementation of datasets.Source.
type Source struct {
Variant *wgfs.Variant
Parallel int
Client *http.Client
Log *zap.Logger
}
// NewSource returns a default Source over variant. If variant is nil,
// GEFS 0.5° 3-hour is used.
func NewSource(variant *wgfs.Variant, log *zap.Logger) *Source {
if variant == nil {
variant = wgfs.GEFS0p50_3h
}
return &Source{
Variant: variant,
Parallel: 8,
Client: &http.Client{Timeout: 2 * time.Minute},
Log: log,
}
}
func (s *Source) ID() string { return s.Variant.ID }
func (s *Source) downloader() *grib.Downloader {
return &grib.Downloader{
Variant: s.Variant,
URLs: s.url,
Parallel: s.Parallel,
Client: s.Client,
Log: s.Log,
}
}
// url generates the GEFS URL for (date, runHour, member, step, levelSet).
func (s *Source) url(date string, runHour, member, step int, ls wgfs.LevelSet) string {
if ls == wgfs.LevelSetB {
return wgfs.GefsGribURLB(date, runHour, member, step, s.Variant.ResToken)
}
return wgfs.GefsGribURL(date, runHour, member, step, s.Variant.ResToken)
}
// LatestEpoch HEAD-checks the control member's final forecast hour.
func (s *Source) LatestEpoch(ctx context.Context) (time.Time, error) {
now := time.Now().UTC()
hour := now.Hour() - (now.Hour() % 6)
current := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC)
client := s.Client
if client == nil {
client = &http.Client{Timeout: 2 * time.Minute}
}
log := s.Log
if log == nil {
log = zap.NewNop()
}
for range 8 {
date := current.Format("20060102")
url := wgfs.GefsGribURL(date, current.Hour(), 0, s.Variant.MaxHour, s.Variant.ResToken) + ".idx"
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
if err == nil {
resp, err := client.Do(req)
if err == nil {
resp.Body.Close()
if resp.StatusCode == http.StatusOK {
log.Info("latest GEFS run discovered",
zap.Time("run", current),
zap.String("verified_url", url))
return current, nil
}
}
}
current = current.Add(-6 * time.Hour)
}
return time.Time{}, fmt.Errorf("no recent GEFS run found")
}
// Coverage returns the extent of id.
func (s *Source) Coverage(id datasets.DatasetID) datasets.Coverage {
v := s.Variant
cov := datasets.Coverage{
Region: datasets.Region{MinLat: -90, MaxLat: 90, MinLng: 0, MaxLng: 360},
StartTime: id.Epoch,
EndTime: id.Epoch.Add(time.Duration(v.MaxHour) * time.Hour),
}
if r := id.Subset.Region; r != nil {
cov.Region = *r
}
if h := id.Subset.HourRange; h != nil {
cov.StartTime = id.Epoch.Add(time.Duration(h.MinHour) * time.Hour)
cov.EndTime = id.Epoch.Add(time.Duration(h.MaxHour) * time.Hour)
}
return cov
}
// Open loads a stored GEFS dataset as a WindField.
func (s *Source) Open(_ context.Context, id datasets.DatasetID, store datasets.Storage) (weather.WindField, error) {
if !store.Exists(id) {
return nil, fmt.Errorf("dataset %s not found", id.Filename())
}
file, err := wgfs.Open(store.Path(id), s.Variant, id.Epoch.UTC())
if err != nil {
return nil, err
}
return wgfs.NewWind(file), nil
}
// memberOf extracts the single member index encoded by id.Subset.Members.
func memberOf(id datasets.DatasetID) (int, error) {
if len(id.Subset.Members) != 1 {
return 0, fmt.Errorf("gefs dataset id must specify exactly one member (got %v)", id.Subset.Members)
}
m := id.Subset.Members[0]
if m < 0 || m >= wgfs.GEFSMembers {
return 0, fmt.Errorf("gefs member %d out of range [0, %d)", m, wgfs.GEFSMembers)
}
return m, nil
}
// Download fetches one ensemble member's dataset.
func (s *Source) Download(ctx context.Context, id datasets.DatasetID, store datasets.Storage, prog datasets.ProgressSink, throttle datasets.Throttle) error {
member, err := memberOf(id)
if err != nil {
return err
}
return s.downloader().Run(ctx, id, member, store, prog, throttle)
}

View file

@ -0,0 +1,138 @@
// Package gfs implements datasets.Source for NOAA GFS forecasts.
//
// The package serves multiple GFS variants (0.5° 3-hour, 0.25° 3-hour,
// 0.25° 1-hour); the variant is selected at construction time. The
// download skeleton (HTTP, idx parsing, parallel blit) lives in
// internal/datasets/grib; this package only supplies URL templating and
// the Source-interface plumbing.
package gfs
import (
"context"
"fmt"
"net/http"
"time"
"go.uber.org/zap"
"predictor-refactored/internal/datasets"
"predictor-refactored/internal/datasets/grib"
"predictor-refactored/internal/weather"
wgfs "predictor-refactored/internal/weather/gfs"
)
// Source is the GFS implementation of datasets.Source.
type Source struct {
Variant *wgfs.Variant
Parallel int
Client *http.Client
Log *zap.Logger
}
// NewSource returns a default Source over variant. If variant is nil,
// GFS 0.5° 3-hour is used (the historical Tawhiri default).
func NewSource(variant *wgfs.Variant, log *zap.Logger) *Source {
if variant == nil {
variant = wgfs.GFS0p50_3h
}
return &Source{
Variant: variant,
Parallel: 8,
Client: &http.Client{Timeout: 2 * time.Minute},
Log: log,
}
}
// ID returns the variant's ID.
func (s *Source) ID() string { return s.Variant.ID }
func (s *Source) downloader() *grib.Downloader {
return &grib.Downloader{
Variant: s.Variant,
URLs: s.url,
Parallel: s.Parallel,
Client: s.Client,
Log: s.Log,
}
}
// url generates the GFS URL for one (date, runHour, _, step, levelSet).
// member is unused for GFS.
func (s *Source) url(date string, runHour, _, step int, ls wgfs.LevelSet) string {
if ls == wgfs.LevelSetB {
return s.Variant.GribURLB(date, runHour, step)
}
return s.Variant.GribURL(date, runHour, step)
}
// LatestEpoch returns the most recent run NOAA has finished publishing.
func (s *Source) LatestEpoch(ctx context.Context) (time.Time, error) {
now := time.Now().UTC()
hour := now.Hour() - (now.Hour() % 6)
current := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC)
client := s.Client
if client == nil {
client = &http.Client{Timeout: 2 * time.Minute}
}
log := s.Log
if log == nil {
log = zap.NewNop()
}
for range 8 {
date := current.Format("20060102")
url := s.Variant.GribURL(date, current.Hour(), s.Variant.MaxHour) + ".idx"
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
if err == nil {
resp, err := client.Do(req)
if err == nil {
resp.Body.Close()
if resp.StatusCode == http.StatusOK {
log.Info("latest run discovered",
zap.String("variant", s.Variant.ID),
zap.Time("run", current),
zap.String("verified_url", url))
return current, nil
}
}
}
current = current.Add(-6 * time.Hour)
}
return time.Time{}, fmt.Errorf("no recent %s run found (checked 8 runs)", s.Variant.ID)
}
// Coverage returns the geographic and temporal extent of id.
func (s *Source) Coverage(id datasets.DatasetID) datasets.Coverage {
v := s.Variant
cov := datasets.Coverage{
Region: datasets.Region{MinLat: -90, MaxLat: 90, MinLng: 0, MaxLng: 360},
StartTime: id.Epoch,
EndTime: id.Epoch.Add(time.Duration(v.MaxHour) * time.Hour),
}
if r := id.Subset.Region; r != nil {
cov.Region = *r
}
if h := id.Subset.HourRange; h != nil {
cov.StartTime = id.Epoch.Add(time.Duration(h.MinHour) * time.Hour)
cov.EndTime = id.Epoch.Add(time.Duration(h.MaxHour) * time.Hour)
}
return cov
}
// Open loads a stored dataset as a WindField.
func (s *Source) Open(_ context.Context, id datasets.DatasetID, store datasets.Storage) (weather.WindField, error) {
if !store.Exists(id) {
return nil, fmt.Errorf("dataset %s not found", id.Filename())
}
file, err := wgfs.Open(store.Path(id), s.Variant, id.Epoch.UTC())
if err != nil {
return nil, err
}
return wgfs.NewWind(file), nil
}
// Download fetches the dataset for id. GFS ignores Subset.Members.
func (s *Source) Download(ctx context.Context, id datasets.DatasetID, store datasets.Storage, prog datasets.ProgressSink, throttle datasets.Throttle) error {
return s.downloader().Run(ctx, id, 0, store, prog, throttle)
}

View file

@ -0,0 +1,369 @@
package grib
import (
"context"
"errors"
"fmt"
"io"
"math"
"net/http"
"os"
"sync"
"time"
"github.com/nilsmagnus/grib/griblib"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"predictor-refactored/internal/datasets"
wgfs "predictor-refactored/internal/weather/gfs"
)
// URLFunc returns the GRIB URL for one (date, runHour, member, step, levelSet).
// Sources that don't have members (GFS) ignore the member argument.
type URLFunc func(date string, runHour, member, step int, ls wgfs.LevelSet) string
// Downloader is the generic GRIB-cube downloader.
//
// A Source plugs in its variant, URL templating, and member-resolution
// logic; the Downloader runs the parallel idx fetch, byte-range download,
// GRIB decode, and blit loop with manifest-based resume.
type Downloader struct {
Variant *wgfs.Variant
URLs URLFunc
Parallel int
Client *http.Client
Log *zap.Logger
}
func (d *Downloader) log() *zap.Logger {
if d.Log == nil {
return zap.NewNop()
}
return d.Log
}
func (d *Downloader) client() *http.Client {
if d.Client == nil {
return &http.Client{Timeout: 2 * time.Minute}
}
return d.Client
}
func (d *Downloader) parallel() int {
if d.Parallel <= 0 {
return 8
}
return d.Parallel
}
// neededVariables is the GRIB variable set every source extracts.
var neededVariables = map[string]bool{"HGT": true, "UGRD": true, "VGRD": true}
// Run downloads the dataset for id, member into store. The caller may
// pass member=0 for non-ensemble sources.
func (d *Downloader) Run(ctx context.Context, id datasets.DatasetID, member int, store datasets.Storage, prog datasets.ProgressSink, throttle datasets.Throttle) error {
if prog == nil {
prog = noopSink{}
}
handle, err := store.BeginWrite(id)
if err != nil {
return fmt.Errorf("begin write: %w", err)
}
manifest := handle.Manifest()
file, err := openOrCreateCube(handle.Path(), d.Variant)
if err != nil {
_ = handle.Abort()
return err
}
epoch := id.Epoch.UTC()
date := epoch.Format("20060102")
runHour := epoch.Hour()
steps := d.Variant.Hours()
if hr := id.Subset.HourRange; hr != nil {
filtered := steps[:0]
for _, step := range steps {
if step >= hr.MinHour && step <= hr.MaxHour {
filtered = append(filtered, step)
}
}
steps = filtered
}
prog.SetTotal(len(steps) * 2)
for range manifest.Units() {
prog.StepComplete()
}
start := time.Now()
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(d.parallel())
var fileMu sync.Mutex
for _, step := range steps {
hourIdx := d.Variant.HourIndex(step)
if hourIdx < 0 {
continue
}
for _, ls := range []wgfs.LevelSet{wgfs.LevelSetA, wgfs.LevelSetB} {
unit := unitKey(step, ls)
if manifest.Has(unit) {
continue
}
g.Go(func() error {
url := d.URLs(date, runHour, member, step, ls)
if err := d.downloadAndBlit(ctx, file, &fileMu, url, hourIdx, ls, prog, throttle); err != nil {
return fmt.Errorf("step %d %s: %w", step, levelSetLabel(ls), err)
}
if err := manifest.Mark(unit); err != nil {
return fmt.Errorf("mark unit: %w", err)
}
prog.StepComplete()
return nil
})
}
}
if err := g.Wait(); err != nil {
_ = file.Close()
if errors.Is(err, context.Canceled) {
return err
}
if len(manifest.Units()) == 0 {
_ = handle.Abort()
}
return err
}
if err := file.Flush(); err != nil {
_ = file.Close()
return fmt.Errorf("flush: %w", err)
}
if err := file.Close(); err != nil {
return fmt.Errorf("close: %w", err)
}
if err := handle.Commit(); err != nil {
return fmt.Errorf("commit: %w", err)
}
d.log().Info("download complete",
zap.String("variant", d.Variant.ID),
zap.Time("epoch", epoch),
zap.Duration("elapsed", time.Since(start)))
return nil
}
// openOrCreateCube opens an existing cube at path if it matches variant's
// expected size, else truncate-creates a new one.
func openOrCreateCube(path string, variant *wgfs.Variant) (*wgfs.File, error) {
info, err := os.Stat(path)
if err == nil && info.Size() == variant.DatasetSize() {
return wgfs.OpenWritable(path, variant)
}
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("stat cube: %w", err)
}
return wgfs.Create(path, variant)
}
// downloadAndBlit fetches and decodes one (URL, level-set) chunk.
func (d *Downloader) downloadAndBlit(
ctx context.Context,
file *wgfs.File,
fileMu *sync.Mutex,
baseURL string,
hourIdx int,
ls wgfs.LevelSet,
prog datasets.ProgressSink,
throttle datasets.Throttle,
) error {
idxBody, err := d.httpGet(ctx, baseURL+".idx", throttle, prog)
if err != nil {
return fmt.Errorf("idx: %w", err)
}
entries := ParseIdx(idxBody)
filtered := FilterIdx(entries, neededVariables)
var relevant []IdxEntry
for _, e := range filtered {
set, ok := d.Variant.PressureLevelSet(e.LevelMB)
if ok && set == ls {
relevant = append(relevant, e)
}
}
if len(relevant) == 0 {
return nil
}
ranges := EntriesToRanges(relevant)
tmp, err := os.CreateTemp("", "grib-msg-*.tmp")
if err != nil {
return fmt.Errorf("temp: %w", err)
}
tmpPath := tmp.Name()
defer os.Remove(tmpPath)
for _, r := range ranges {
body, err := d.httpGetRange(ctx, baseURL, r.Start, r.End, throttle, prog)
if err != nil {
tmp.Close()
return fmt.Errorf("range: %w", err)
}
if _, err := tmp.Write(body); err != nil {
tmp.Close()
return fmt.Errorf("write tmp: %w", err)
}
}
if err := tmp.Close(); err != nil {
return err
}
f, err := os.Open(tmpPath)
if err != nil {
return err
}
messages, err := griblib.ReadMessages(f)
f.Close()
if err != nil {
return fmt.Errorf("read grib: %w", err)
}
for _, msg := range messages {
if msg.Section4.ProductDefinitionTemplateNumber != 0 {
continue
}
p := msg.Section4.ProductDefinitionTemplate
varIdx := d.Variant.VariableIndex(int(p.ParameterCategory), int(p.ParameterNumber))
if varIdx < 0 {
continue
}
if p.FirstSurface.Type != 100 {
continue
}
pressureMB := int(math.Round(float64(p.FirstSurface.Value) / 100.0))
levelIdx := d.Variant.PressureIndex(pressureMB)
if levelIdx < 0 {
continue
}
data := msg.Data()
fileMu.Lock()
err := file.BlitGribData(hourIdx, levelIdx, varIdx, data)
fileMu.Unlock()
if err != nil {
return fmt.Errorf("blit: %w", err)
}
}
return nil
}
func (d *Downloader) httpGet(ctx context.Context, url string, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) {
var lastErr error
for attempt := range 3 {
if attempt > 0 {
select {
case <-time.After(time.Duration(attempt*2) * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := d.client().Do(req)
if err != nil {
lastErr = err
continue
}
body, err := readThrottled(ctx, resp.Body, throttle, prog)
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url)
continue
}
if err != nil {
lastErr = err
continue
}
return body, nil
}
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
}
func (d *Downloader) httpGetRange(ctx context.Context, url string, start, end int64, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) {
var lastErr error
for attempt := range 3 {
if attempt > 0 {
select {
case <-time.After(time.Duration(attempt*2) * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
resp, err := d.client().Do(req)
if err != nil {
lastErr = err
continue
}
body, err := readThrottled(ctx, resp.Body, throttle, prog)
resp.Body.Close()
if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("HTTP %d for range", resp.StatusCode)
continue
}
if err != nil {
lastErr = err
continue
}
return body, nil
}
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
}
func readThrottled(ctx context.Context, r io.Reader, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) {
buf := make([]byte, 0, 64*1024)
chunk := make([]byte, 32*1024)
for {
if throttle != nil {
if err := throttle.Wait(ctx, len(chunk)); err != nil {
return nil, err
}
}
n, err := r.Read(chunk)
if n > 0 {
buf = append(buf, chunk[:n]...)
prog.Bytes(int64(n))
}
if errors.Is(err, io.EOF) {
return buf, nil
}
if err != nil {
return nil, err
}
}
}
func unitKey(step int, ls wgfs.LevelSet) string {
return fmt.Sprintf("step%03d-%s", step, levelSetLabel(ls))
}
func levelSetLabel(ls wgfs.LevelSet) string {
if ls == wgfs.LevelSetB {
return "B"
}
return "A"
}
type noopSink struct{}
func (noopSink) SetTotal(int) {}
func (noopSink) StepComplete() {}
func (noopSink) Bytes(int64) {}

View file

@ -0,0 +1,129 @@
// Package grib contains the GRIB-cube download skeleton shared by every
// NOAA source (GFS, GEFS, future families). It exposes the .idx parser,
// HTTP helpers, and a parallel download loop; source-specific URL
// templating is injected by the caller.
package grib
import (
"fmt"
"strconv"
"strings"
)
// IdxEntry is one parsed line from a NOAA GRIB .idx file.
//
// Example line: "15:1207405:d=2024010100:HGT:1000 mb:0 hour fcst:"
type IdxEntry struct {
Index int
Offset int64
Variable string
LevelMB int // 0 when the level is not isobaric
Hour int // forecast hour; 0 for analysis ("anl"); -1 if unparseable
EndOffset int64 // computed from the next entry's Offset; -1 for the final entry
}
// Length returns the byte length of this GRIB message, or -1 if unknown
// (the final entry in an idx file).
func (e *IdxEntry) Length() int64 {
if e.EndOffset <= 0 {
return -1
}
return e.EndOffset - e.Offset
}
// ParseIdx parses a .idx file body. Unparseable lines are silently skipped.
func ParseIdx(body []byte) []IdxEntry {
lines := strings.Split(string(body), "\n")
var entries []IdxEntry
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.Split(line, ":")
if len(parts) < 7 {
continue
}
idx, err := strconv.Atoi(parts[0])
if err != nil {
continue
}
off, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
continue
}
entries = append(entries, IdxEntry{
Index: idx,
Offset: off,
Variable: parts[3],
LevelMB: parseLevelMB(parts[4]),
Hour: parseHour(parts[5]),
EndOffset: -1,
})
}
for i := 0; i < len(entries)-1; i++ {
entries[i].EndOffset = entries[i+1].Offset
}
return entries
}
// FilterIdx returns entries matching one of the wanted variables at a known
// pressure level with a computable byte length.
func FilterIdx(entries []IdxEntry, wanted map[string]bool) []IdxEntry {
var out []IdxEntry
for _, e := range entries {
if !wanted[e.Variable] || e.LevelMB <= 0 || e.Length() <= 0 {
continue
}
out = append(out, e)
}
return out
}
func parseLevelMB(s string) int {
s = strings.TrimSpace(s)
if !strings.HasSuffix(s, " mb") {
return 0
}
n, err := strconv.Atoi(strings.TrimSuffix(s, " mb"))
if err != nil {
return 0
}
return n
}
func parseHour(s string) int {
s = strings.TrimSpace(s)
if s == "anl" {
return 0
}
n, err := strconv.Atoi(strings.TrimSuffix(s, " hour fcst"))
if err != nil {
return -1
}
return n
}
// ByteRange is one HTTP range download corresponding to one GRIB message.
type ByteRange struct {
Start int64
End int64 // inclusive
Entry IdxEntry
}
// EntriesToRanges converts idx entries to inclusive HTTP byte ranges.
func EntriesToRanges(entries []IdxEntry) []ByteRange {
out := make([]ByteRange, 0, len(entries))
for _, e := range entries {
if e.Length() <= 0 {
continue
}
out = append(out, ByteRange{Start: e.Offset, End: e.EndOffset - 1, Entry: e})
}
return out
}
// FormatRange returns an HTTP Range header value for the byte range.
func (r ByteRange) FormatRange() string {
return fmt.Sprintf("bytes=%d-%d", r.Start, r.End)
}

View file

@ -0,0 +1,70 @@
package grib
import "testing"
const sampleIdx = `1:0:d=2024010100:HGT:1000 mb:0 hour fcst:
2:289012:d=2024010100:HGT:975 mb:0 hour fcst:
3:541876:d=2024010100:TMP:1000 mb:0 hour fcst:
4:789012:d=2024010100:UGRD:1000 mb:0 hour fcst:
5:1045678:d=2024010100:VGRD:1000 mb:0 hour fcst:
6:1298765:d=2024010100:UGRD:975 mb:0 hour fcst:
7:1567890:d=2024010100:UGRD:2 m above ground:0 hour fcst:
8:1812345:d=2024010100:VGRD:975 mb:0 hour fcst:
9:2098765:d=2024010100:HGT:500 mb:3 hour fcst:
`
func TestParseIdx(t *testing.T) {
entries := ParseIdx([]byte(sampleIdx))
if len(entries) != 9 {
t.Fatalf("expected 9 entries, got %d", len(entries))
}
if e := entries[0]; e.Index != 1 || e.Offset != 0 || e.Variable != "HGT" || e.LevelMB != 1000 || e.Hour != 0 || e.EndOffset != 289012 {
t.Errorf("entry 0: %+v", e)
}
if e := entries[6]; e.LevelMB != 0 {
t.Errorf("non-pressure level should have LevelMB=0, got %d", e.LevelMB)
}
if e := entries[len(entries)-1]; e.EndOffset != -1 {
t.Errorf("last entry EndOffset: got %d, want -1", e.EndOffset)
}
}
func TestFilterIdx(t *testing.T) {
entries := ParseIdx([]byte(sampleIdx))
want := map[string]bool{"HGT": true, "UGRD": true, "VGRD": true}
filtered := FilterIdx(entries, want)
// HGT@1000, HGT@975, UGRD@1000, VGRD@1000, UGRD@975, VGRD@975 = 6
// HGT@500 at 3hr is last entry (no EndOffset), so dropped.
if len(filtered) != 6 {
t.Errorf("expected 6, got %d", len(filtered))
}
}
func TestParseLevelMB(t *testing.T) {
cases := []struct {
in string
want int
}{
{"1000 mb", 1000}, {"975 mb", 975}, {"1 mb", 1},
{"2 m above ground", 0}, {"surface", 0}, {"tropopause", 0},
}
for _, c := range cases {
if got := parseLevelMB(c.in); got != c.want {
t.Errorf("parseLevelMB(%q) = %d, want %d", c.in, got, c.want)
}
}
}
func TestParseHour(t *testing.T) {
cases := []struct {
in string
want int
}{
{"0 hour fcst", 0}, {"3 hour fcst", 3}, {"192 hour fcst", 192}, {"anl", 0},
}
for _, c := range cases {
if got := parseHour(c.in); got != c.want {
t.Errorf("parseHour(%q) = %d, want %d", c.in, got, c.want)
}
}
}

View file

@ -0,0 +1,11 @@
//go:build !unix
package datasets
import "context"
// flockExclusive is a no-op on platforms without flock. The service targets
// Linux containers; this stub only keeps non-Unix builds compiling.
func flockExclusive(_ context.Context, _ string) (func(), error) {
return func() {}, nil
}

View file

@ -0,0 +1,50 @@
//go:build unix
package datasets
import (
"context"
"errors"
"fmt"
"os"
"syscall"
"time"
)
// lockPollInterval is how often a contended lock is retried. The lock is held
// for the duration of a dataset download (minutes), so sub-second acquisition
// latency is irrelevant.
const lockPollInterval = 150 * time.Millisecond
// flockExclusive acquires an exclusive flock on path, creating the lock file
// if needed, and blocks until it is held or ctx is cancelled.
//
// It uses non-blocking LOCK_NB attempts in a poll loop rather than a blocking
// flock in a goroutine: the file descriptor is only ever touched by this
// goroutine, so there is no race between a pending syscall and Close on
// cancellation.
func flockExclusive(ctx context.Context, path string) (func(), error) {
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o644)
if err != nil {
return nil, fmt.Errorf("open lock file: %w", err)
}
for {
err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
if err == nil {
return func() {
_ = syscall.Flock(int(f.Fd()), syscall.LOCK_UN)
_ = f.Close()
}, nil
}
if !errors.Is(err, syscall.EWOULDBLOCK) {
f.Close()
return nil, fmt.Errorf("flock: %w", err)
}
select {
case <-ctx.Done():
f.Close()
return nil, ctx.Err()
case <-time.After(lockPollInterval):
}
}
}

View file

@ -0,0 +1,466 @@
package datasets
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"predictor-refactored/internal/weather"
)
// JobStatus is the lifecycle state of a download job.
type JobStatus string
const (
JobPending JobStatus = "pending"
JobRunning JobStatus = "running"
JobComplete JobStatus = "complete"
JobFailed JobStatus = "failed"
JobCancelled JobStatus = "cancelled"
)
// JobInfo is the externally-visible snapshot of a download job.
type JobInfo struct {
ID string
Source string
Dataset DatasetID
Status JobStatus
StartedAt time.Time
EndedAt *time.Time
Err string
Total int
Done int
Bytes int64
}
type jobEntry struct {
id string
source string
dataset DatasetID
startedAt time.Time
cancel context.CancelFunc
mu sync.Mutex
status JobStatus
endedAt time.Time
errStr string
total atomic.Int64
done atomic.Int64
bytes atomic.Int64
}
func (e *jobEntry) snapshot() JobInfo {
e.mu.Lock()
info := JobInfo{
ID: e.id, Source: e.source, Dataset: e.dataset,
StartedAt: e.startedAt, Status: e.status, Err: e.errStr,
}
if !e.endedAt.IsZero() {
ts := e.endedAt
info.EndedAt = &ts
}
e.mu.Unlock()
info.Total = int(e.total.Load())
info.Done = int(e.done.Load())
info.Bytes = e.bytes.Load()
return info
}
type jobProgress struct{ e *jobEntry }
func (p jobProgress) SetTotal(n int) { p.e.total.Store(int64(n)) }
func (p jobProgress) StepComplete() { p.e.done.Add(1) }
func (p jobProgress) Bytes(n int64) { p.e.bytes.Add(n) }
// loadedDataset bundles a loaded WindField with its identity and coverage.
type loadedDataset struct {
ID DatasetID
Field weather.WindField
Coverage Coverage
}
// Manager coordinates dataset downloads and exposes the active WindFields.
type Manager struct {
src Source
store Storage
throttle Throttle
log *zap.Logger
activeMu sync.RWMutex
active []loadedDataset
jobsMu sync.RWMutex
jobs map[string]*jobEntry
inFlight sync.Map // key: dataset filename, value: jobID
}
// New wires a Manager.
func New(src Source, store Storage, throttle Throttle, log *zap.Logger) *Manager {
if log == nil {
log = zap.NewNop()
}
if src.ID() != store.SourceID() {
log.Warn("source/store ID mismatch",
zap.String("src", src.ID()),
zap.String("store", store.SourceID()))
}
return &Manager{
src: src, store: store, throttle: throttle, log: log,
jobs: make(map[string]*jobEntry),
}
}
// Source returns the underlying source ID.
func (m *Manager) Source() string { return m.src.ID() }
// Active returns the currently-loaded global WindField (the dataset with
// IsGlobal subset, most recently loaded). Returns nil if no global
// dataset is loaded; in cluster setups with only regional subsets, callers
// should use SelectFor.
func (m *Manager) Active() weather.WindField {
m.activeMu.RLock()
defer m.activeMu.RUnlock()
for _, d := range m.active {
if d.ID.Subset.IsGlobal() {
return d.Field
}
}
if len(m.active) > 0 {
return m.active[0].Field
}
return nil
}
// Ready reports whether at least one dataset is loaded.
func (m *Manager) Ready() bool { return m.Active() != nil }
// SelectFor returns a loaded WindField whose coverage contains (t, lat, lng).
// Returns nil when no loaded dataset covers the query.
func (m *Manager) SelectFor(t time.Time, lat, lng float64) weather.WindField {
m.activeMu.RLock()
defer m.activeMu.RUnlock()
for _, d := range m.active {
if d.Coverage.Covers(t, lat, lng) {
return d.Field
}
}
// Fallback: any global dataset is permissive about region.
for _, d := range m.active {
if d.ID.Subset.IsGlobal() {
return d.Field
}
}
return nil
}
// LoadedDatasets returns snapshots of every currently-loaded dataset.
func (m *Manager) LoadedDatasets() []LoadedDatasetInfo {
m.activeMu.RLock()
defer m.activeMu.RUnlock()
out := make([]LoadedDatasetInfo, 0, len(m.active))
for _, d := range m.active {
out = append(out, LoadedDatasetInfo{ID: d.ID, Coverage: d.Coverage})
}
return out
}
// LoadedDatasetInfo is a serializable snapshot of one active dataset.
type LoadedDatasetInfo struct {
ID DatasetID
Coverage Coverage
}
// ListEpochs returns all stored datasets, newest first.
func (m *Manager) ListEpochs() ([]DatasetID, error) { return m.store.List() }
// ListJobs returns snapshots of every job recorded since startup.
func (m *Manager) ListJobs() []JobInfo {
m.jobsMu.RLock()
defer m.jobsMu.RUnlock()
out := make([]JobInfo, 0, len(m.jobs))
for _, e := range m.jobs {
out = append(out, e.snapshot())
}
return out
}
// GetJob returns the snapshot for a job.
func (m *Manager) GetJob(id string) (JobInfo, bool) {
m.jobsMu.RLock()
e, ok := m.jobs[id]
m.jobsMu.RUnlock()
if !ok {
return JobInfo{}, false
}
return e.snapshot(), true
}
// CancelJob cancels a running job.
func (m *Manager) CancelJob(id string) bool {
m.jobsMu.RLock()
e, ok := m.jobs[id]
m.jobsMu.RUnlock()
if !ok {
return false
}
e.mu.Lock()
terminal := e.status == JobComplete || e.status == JobFailed || e.status == JobCancelled
e.mu.Unlock()
if terminal {
return false
}
e.cancel()
return true
}
// Remove deletes a stored dataset. If the dataset is currently loaded,
// it is unloaded first.
func (m *Manager) Remove(id DatasetID) error {
m.activeMu.Lock()
out := m.active[:0]
var removed *loadedDataset
for i := range m.active {
d := m.active[i]
if d.ID.Equals(id) {
removed = &d
continue
}
out = append(out, d)
}
m.active = out
m.activeMu.Unlock()
if removed != nil {
closeField(removed.Field, m.log)
}
return m.store.Remove(id)
}
// Download starts (or resumes) a download job for id in the background.
func (m *Manager) Download(id DatasetID) string {
key := id.Filename()
if existing, ok := m.inFlight.Load(key); ok {
return existing.(string)
}
jobID := uuid.New().String()
if other, loaded := m.inFlight.LoadOrStore(key, jobID); loaded {
return other.(string)
}
ctx, cancel := context.WithCancel(context.Background())
now := time.Now().UTC()
e := &jobEntry{
id: jobID,
source: m.src.ID(),
dataset: id,
startedAt: now,
status: JobPending,
cancel: cancel,
}
m.jobsMu.Lock()
m.jobs[jobID] = e
m.jobsMu.Unlock()
if m.store.Exists(id) {
go m.completeShortCircuit(ctx, e)
return jobID
}
go m.runDownload(ctx, e)
return jobID
}
// Load swaps in id's stored dataset, making it available to predictions.
func (m *Manager) Load(ctx context.Context, id DatasetID) error {
if !m.store.Exists(id) {
return fmt.Errorf("dataset %s not present on disk", id.Filename())
}
field, err := m.src.Open(ctx, id, m.store)
if err != nil {
return fmt.Errorf("open dataset: %w", err)
}
cov := m.src.Coverage(id)
m.activeMu.Lock()
// Replace any previously-loaded dataset with the same ID.
for i := range m.active {
if m.active[i].ID.Equals(id) {
closeField(m.active[i].Field, m.log)
m.active[i] = loadedDataset{ID: id, Field: field, Coverage: cov}
m.activeMu.Unlock()
return nil
}
}
m.active = append(m.active, loadedDataset{ID: id, Field: field, Coverage: cov})
m.activeMu.Unlock()
m.log.Info("loaded dataset",
zap.String("filename", id.Filename()),
zap.String("source", m.src.ID()))
return nil
}
// Refresh ensures the freshest global dataset is downloaded and active.
//
// Returns the JobID started, or empty string when nothing was scheduled.
func (m *Manager) Refresh(ctx context.Context, freshnessTTL time.Duration) (string, error) {
if a := m.activeGlobal(); a != nil && time.Since(a.ID.Epoch) < freshnessTTL {
return "", nil
}
if datasets, err := m.store.List(); err == nil {
for _, id := range datasets {
if !id.Subset.IsGlobal() {
continue
}
if time.Since(id.Epoch) > freshnessTTL {
continue
}
if a := m.activeGlobal(); a != nil && a.ID.Equals(id) {
return "", nil
}
if err := m.Load(ctx, id); err == nil {
return "", nil
}
}
}
latest, err := m.src.LatestEpoch(ctx)
if err != nil {
return "", fmt.Errorf("latest epoch: %w", err)
}
id := DatasetID{Epoch: latest}
if a := m.activeGlobal(); a != nil && !latest.After(a.ID.Epoch) {
return "", nil
}
jobID := m.Download(id)
go m.loadAfterCompletion(jobID, id)
return jobID, nil
}
// activeGlobal returns the currently-loaded global dataset, if any.
func (m *Manager) activeGlobal() *loadedDataset {
m.activeMu.RLock()
defer m.activeMu.RUnlock()
for i := range m.active {
if m.active[i].ID.Subset.IsGlobal() {
d := m.active[i]
return &d
}
}
return nil
}
func (m *Manager) loadAfterCompletion(jobID string, id DatasetID) {
for {
info, ok := m.GetJob(jobID)
if !ok {
return
}
switch info.Status {
case JobComplete:
if err := m.Load(context.Background(), id); err != nil {
m.log.Error("load after download", zap.Error(err))
}
return
case JobFailed, JobCancelled:
return
}
time.Sleep(2 * time.Second)
}
}
func (m *Manager) runDownload(ctx context.Context, e *jobEntry) {
defer m.inFlight.Delete(e.dataset.Filename())
e.mu.Lock()
e.status = JobRunning
e.mu.Unlock()
m.log.Info("download started",
zap.String("job", e.id),
zap.String("dataset", e.dataset.Filename()))
err := m.downloadLocked(ctx, e)
now := time.Now().UTC()
e.mu.Lock()
e.endedAt = now
switch {
case errors.Is(err, context.Canceled):
e.status = JobCancelled
case err != nil:
e.status = JobFailed
e.errStr = err.Error()
default:
e.status = JobComplete
}
finalStatus := e.status
e.mu.Unlock()
m.log.Info("download finished",
zap.String("job", e.id),
zap.String("status", string(finalStatus)),
zap.NamedError("err", err))
}
// downloadLocked runs the source download while holding the storage's
// cross-process lock, so multiple replicas sharing a node-local dataset
// volume coordinate instead of each fetching ~9 GB. After acquiring the lock
// it re-checks existence: if another replica committed the dataset while this
// one waited, it skips the download and lets the caller load the committed file.
func (m *Manager) downloadLocked(ctx context.Context, e *jobEntry) error {
release, err := m.store.Lock(ctx)
if err != nil {
return fmt.Errorf("acquire download lock: %w", err)
}
defer release()
if m.store.Exists(e.dataset) {
m.log.Info("dataset committed by another instance while waiting; skipping download",
zap.String("dataset", e.dataset.Filename()))
return nil
}
return m.src.Download(ctx, e.dataset, m.store, jobProgress{e: e}, m.throttle)
}
func (m *Manager) completeShortCircuit(ctx context.Context, e *jobEntry) {
_ = ctx
defer m.inFlight.Delete(e.dataset.Filename())
now := time.Now().UTC()
e.mu.Lock()
e.status = JobComplete
e.endedAt = now
e.mu.Unlock()
}
// Close releases all resources, cancelling any in-flight jobs.
func (m *Manager) Close() error {
m.jobsMu.Lock()
for _, e := range m.jobs {
e.cancel()
}
m.jobsMu.Unlock()
m.activeMu.Lock()
for _, d := range m.active {
closeField(d.Field, m.log)
}
m.active = nil
m.activeMu.Unlock()
return nil
}
func closeField(f weather.WindField, log *zap.Logger) {
if c, ok := f.(interface{ Close() error }); ok && c != nil {
if err := c.Close(); err != nil && log != nil {
log.Warn("close dataset", zap.Error(err))
}
}
}

View file

@ -0,0 +1,118 @@
package datasets
import (
"encoding/json"
"errors"
"fmt"
"os"
"sort"
"sync"
)
// Manifest tracks completed work units for a partial dataset download.
// Units are arbitrary opaque strings; sources choose the format
// (e.g. "step12-A" for "forecast step 12, level set A").
//
// A Manifest is persisted as a JSON object: {"units": ["step0-A", "step0-B", ...]}.
type Manifest struct {
path string
mu sync.Mutex
units map[string]struct{}
}
// LoadManifest opens or creates the manifest at path. Missing or unreadable
// files are treated as empty; a corrupt file returns an error.
func LoadManifest(path string) (*Manifest, error) {
m := &Manifest{path: path, units: make(map[string]struct{})}
data, err := os.ReadFile(path)
if errors.Is(err, os.ErrNotExist) {
return m, nil
}
if err != nil {
return nil, fmt.Errorf("read manifest %s: %w", path, err)
}
if len(data) == 0 {
return m, nil
}
var doc struct {
Units []string `json:"units"`
}
if err := json.Unmarshal(data, &doc); err != nil {
return nil, fmt.Errorf("parse manifest %s: %w", path, err)
}
for _, u := range doc.Units {
m.units[u] = struct{}{}
}
return m, nil
}
// Has reports whether unit has been recorded as completed.
func (m *Manifest) Has(unit string) bool {
m.mu.Lock()
defer m.mu.Unlock()
_, ok := m.units[unit]
return ok
}
// Mark records unit as completed and persists the manifest to disk.
func (m *Manifest) Mark(unit string) error {
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.units[unit]; ok {
return nil
}
m.units[unit] = struct{}{}
return m.persistLocked()
}
// Units returns the completed units in sorted order.
func (m *Manifest) Units() []string {
m.mu.Lock()
defer m.mu.Unlock()
out := make([]string, 0, len(m.units))
for u := range m.units {
out = append(out, u)
}
sort.Strings(out)
return out
}
// Reset clears all recorded units and removes the manifest file.
func (m *Manifest) Reset() error {
m.mu.Lock()
defer m.mu.Unlock()
m.units = make(map[string]struct{})
if err := os.Remove(m.path); err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("remove manifest %s: %w", m.path, err)
}
return nil
}
// persistLocked writes the manifest to disk via temp+rename.
// The caller must hold m.mu.
func (m *Manifest) persistLocked() error {
units := make([]string, 0, len(m.units))
for u := range m.units {
units = append(units, u)
}
sort.Strings(units)
data, err := json.Marshal(struct {
Units []string `json:"units"`
}{Units: units})
if err != nil {
return err
}
tmp := m.path + ".new"
if err := os.WriteFile(tmp, data, 0o644); err != nil {
return fmt.Errorf("write manifest temp: %w", err)
}
if err := os.Rename(tmp, m.path); err != nil {
os.Remove(tmp)
return fmt.Errorf("rename manifest: %w", err)
}
return nil
}

View file

@ -0,0 +1,188 @@
package datasets
import (
"context"
"errors"
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"time"
)
// LocalStore stores dataset files on the local filesystem.
//
// Layout under Root:
//
// <filename>.bin — committed dataset
// <filename>.bin.downloading — in-progress dataset
// <filename>.bin.manifest.json — completed work units
//
// where <filename> is DatasetID.Filename() — typically
// "20060102T150405Z" for the global subset or
// "20060102T150405Z_r-10.10.-30.30_h0.72" for a subset.
type LocalStore struct {
Root string
Source string
Extension string // default ".bin"
}
// NewLocalStore returns a LocalStore at root. The directory is created if missing.
func NewLocalStore(root, sourceID string) (*LocalStore, error) {
if err := os.MkdirAll(root, 0o755); err != nil {
return nil, fmt.Errorf("create store root %s: %w", root, err)
}
return &LocalStore{Root: root, Source: sourceID, Extension: ".bin"}, nil
}
// SourceID returns the source ID this store is configured for.
func (s *LocalStore) SourceID() string { return s.Source }
func (s *LocalStore) ext() string {
if s.Extension == "" {
return ".bin"
}
return s.Extension
}
// Path returns the canonical path for id's committed dataset.
func (s *LocalStore) Path(id DatasetID) string {
return filepath.Join(s.Root, id.Filename()+s.ext())
}
func (s *LocalStore) tempPath(id DatasetID) string {
return s.Path(id) + ".downloading"
}
func (s *LocalStore) manifestPath(id DatasetID) string {
return s.Path(id) + ".manifest.json"
}
// Exists reports whether a committed dataset for id is present.
func (s *LocalStore) Exists(id DatasetID) bool {
info, err := os.Stat(s.Path(id))
return err == nil && !info.IsDir()
}
// List returns all committed dataset IDs, newest first.
func (s *LocalStore) List() ([]DatasetID, error) {
entries, err := os.ReadDir(s.Root)
if err != nil {
return nil, fmt.Errorf("read store: %w", err)
}
var out []DatasetID
ext := s.ext()
for _, e := range entries {
if e.IsDir() {
continue
}
name := e.Name()
if !strings.HasSuffix(name, ext) {
continue
}
stem := strings.TrimSuffix(name, ext)
// Skip in-progress files (their stem ends in .downloading or .manifest)
if strings.Contains(stem, ".") {
continue
}
id, ok := parseFilename(stem)
if !ok {
continue
}
out = append(out, id)
}
sort.Slice(out, func(i, j int) bool {
if !out[i].Epoch.Equal(out[j].Epoch) {
return out[i].Epoch.After(out[j].Epoch)
}
return out[i].Subset.Key() < out[j].Subset.Key()
})
return out, nil
}
// parseFilename inverts DatasetID.Filename(). The subset portion is not
// fully reversible (Key encoding is one-way for floats), so List returns
// IDs whose Subset is zero — the storage layer treats names as opaque
// identifiers. Callers wanting structured subset metadata should keep an
// out-of-band record.
func parseFilename(stem string) (DatasetID, bool) {
parts := strings.SplitN(stem, "_", 2)
epoch, err := time.Parse("20060102T150405Z", parts[0])
if err != nil {
return DatasetID{}, false
}
id := DatasetID{Epoch: epoch.UTC()}
// Subset key is opaque on disk; we don't reconstruct its parameters
// here. Admin callers track subset specs separately when they need
// the structured form.
return id, true
}
// Remove deletes the committed dataset and any sidecar files for id.
func (s *LocalStore) Remove(id DatasetID) error {
var errs []error
for _, p := range []string{s.Path(id), s.tempPath(id), s.manifestPath(id)} {
if err := os.Remove(p); err != nil && !errors.Is(err, os.ErrNotExist) {
errs = append(errs, err)
}
}
if len(errs) > 0 {
return fmt.Errorf("remove dataset: %v", errs)
}
return nil
}
// Lock acquires the storage-wide download lock (an exclusive flock on a
// sentinel file in the root), serialising downloads across processes that
// share this directory.
func (s *LocalStore) Lock(ctx context.Context) (func(), error) {
return flockExclusive(ctx, filepath.Join(s.Root, ".download.lock"))
}
// BeginWrite opens or resumes a TempHandle for id.
func (s *LocalStore) BeginWrite(id DatasetID) (TempHandle, error) {
man, err := LoadManifest(s.manifestPath(id))
if err != nil {
return nil, err
}
return &localHandle{store: s, id: id, manifest: man}, nil
}
type localHandle struct {
store *LocalStore
id DatasetID
manifest *Manifest
closed bool
}
func (h *localHandle) Path() string { return h.store.tempPath(h.id) }
func (h *localHandle) Manifest() *Manifest { return h.manifest }
func (h *localHandle) Commit() error {
if h.closed {
return nil
}
h.closed = true
if err := os.Rename(h.store.tempPath(h.id), h.store.Path(h.id)); err != nil {
return fmt.Errorf("commit rename: %w", err)
}
if err := os.Remove(h.store.manifestPath(h.id)); err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("commit remove manifest: %w", err)
}
return nil
}
func (h *localHandle) Abort() error {
if h.closed {
return nil
}
h.closed = true
var firstErr error
for _, p := range []string{h.store.tempPath(h.id), h.store.manifestPath(h.id)} {
if err := os.Remove(p); err != nil && !errors.Is(err, os.ErrNotExist) && firstErr == nil {
firstErr = err
}
}
return firstErr
}

View file

@ -0,0 +1,145 @@
package datasets
import (
"context"
"os"
"testing"
"time"
)
func TestLocalStoreLockSerializes(t *testing.T) {
dir := t.TempDir()
store, _ := NewLocalStore(dir, "gfs-test")
ctx := context.Background()
release, err := store.Lock(ctx)
if err != nil {
t.Fatalf("first Lock: %v", err)
}
// A second acquisition must block until the first releases.
got := make(chan struct{})
go func() {
r2, err := store.Lock(ctx)
if err == nil {
r2()
}
close(got)
}()
select {
case <-got:
t.Fatal("second Lock acquired while first was held")
case <-time.After(100 * time.Millisecond):
// expected: still blocked
}
release()
select {
case <-got:
// expected: acquired after release
case <-time.After(2 * time.Second):
t.Fatal("second Lock did not acquire after release")
}
}
func TestLocalStoreLockContextCancel(t *testing.T) {
dir := t.TempDir()
store, _ := NewLocalStore(dir, "gfs-test")
release, err := store.Lock(context.Background())
if err != nil {
t.Fatalf("Lock: %v", err)
}
defer release()
ctx, cancel := context.WithCancel(context.Background())
cancel()
if _, err := store.Lock(ctx); err == nil {
t.Error("expected Lock to fail on cancelled context while held elsewhere")
}
}
func TestLocalStoreBeginWriteResume(t *testing.T) {
dir := t.TempDir()
store, err := NewLocalStore(dir, "gfs-test")
if err != nil {
t.Fatalf("NewLocalStore: %v", err)
}
id := DatasetID{Epoch: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)}
h, err := store.BeginWrite(id)
if err != nil {
t.Fatalf("BeginWrite: %v", err)
}
if err := os.WriteFile(h.Path(), []byte("partial"), 0o644); err != nil {
t.Fatalf("write partial: %v", err)
}
if err := h.Manifest().Mark("step000-A"); err != nil {
t.Fatalf("mark: %v", err)
}
// Re-open should see the previous manifest entry.
h2, err := store.BeginWrite(id)
if err != nil {
t.Fatalf("BeginWrite resume: %v", err)
}
if !h2.Manifest().Has("step000-A") {
t.Errorf("resumed manifest missing step000-A; units = %v", h2.Manifest().Units())
}
if err := h2.Commit(); err != nil {
t.Fatalf("Commit: %v", err)
}
if !store.Exists(id) {
t.Errorf("Exists after commit returned false")
}
if _, err := os.Stat(store.manifestPath(id)); !os.IsNotExist(err) {
t.Errorf("manifest should be removed, got err=%v", err)
}
stored, err := store.List()
if err != nil {
t.Fatalf("List: %v", err)
}
if len(stored) != 1 || !stored[0].Epoch.Equal(id.Epoch) {
t.Errorf("List = %v, want one item with epoch %v", stored, id.Epoch)
}
if err := store.Remove(id); err != nil {
t.Fatalf("Remove: %v", err)
}
if store.Exists(id) {
t.Errorf("Exists after remove returned true")
}
}
func TestLocalStoreSubsetPath(t *testing.T) {
dir := t.TempDir()
store, _ := NewLocalStore(dir, "gfs-test")
epoch := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
regional := DatasetID{
Epoch: epoch,
Subset: SubsetSpec{
Region: &Region{MinLat: -10, MaxLat: 10, MinLng: 0, MaxLng: 30},
HourRange: &HourRange{MinHour: 0, MaxHour: 72},
},
}
global := DatasetID{Epoch: epoch}
if store.Path(global) == store.Path(regional) {
t.Errorf("global and regional should have distinct paths")
}
}
func TestSubsetSpecCoverage(t *testing.T) {
r := Region{MinLat: -10, MaxLat: 10, MinLng: 350, MaxLng: 10} // wraps antimeridian
s := SubsetSpec{Region: &r}
if !s.IncludesLatLng(0, 0) {
t.Errorf("(0,0) should be inside antimeridian region")
}
if !s.IncludesLatLng(0, 359) {
t.Errorf("(0,359) should be inside antimeridian region")
}
if s.IncludesLatLng(0, 180) {
t.Errorf("(0,180) should be outside antimeridian region")
}
}

156
internal/datasets/subset.go Normal file
View file

@ -0,0 +1,156 @@
package datasets
import (
"fmt"
"slices"
"strings"
"time"
)
// SubsetSpec describes which portion of a dataset to download.
//
// A zero-value SubsetSpec means "the full dataset". The Region and
// HourRange fields independently restrict what is fetched and stored.
type SubsetSpec struct {
// Region restricts the geographic extent. nil means global.
Region *Region `json:"region,omitempty"`
// HourRange restricts the forecast horizon. nil means the source's
// full horizon (e.g. 0..192h for GFS 0.5°).
HourRange *HourRange `json:"hour_range,omitempty"`
// Members restricts ensemble members for sources that support them (GEFS).
// nil means all available members.
Members []int `json:"members,omitempty"`
}
// Region is an axis-aligned geographic bounding box.
//
// Longitudes are in [0, 360); a box crossing the antimeridian has
// MinLng > MaxLng.
type Region struct {
MinLat float64 `json:"min_lat"`
MaxLat float64 `json:"max_lat"`
MinLng float64 `json:"min_lng"`
MaxLng float64 `json:"max_lng"`
}
// HourRange is an inclusive forecast-hour range.
type HourRange struct {
MinHour int `json:"min_hour"`
MaxHour int `json:"max_hour"`
}
// IsGlobal reports whether the spec selects the entire dataset.
func (s SubsetSpec) IsGlobal() bool {
return s.Region == nil && s.HourRange == nil && len(s.Members) == 0
}
// IncludesLatLng reports whether (lat, lng) lies inside the spec's Region,
// or the spec has no Region.
func (s SubsetSpec) IncludesLatLng(lat, lng float64) bool {
if s.Region == nil {
return true
}
r := s.Region
if lat < r.MinLat || lat > r.MaxLat {
return false
}
if r.MinLng <= r.MaxLng {
return lng >= r.MinLng && lng <= r.MaxLng
}
// Wraps the antimeridian.
return lng >= r.MinLng || lng <= r.MaxLng
}
// IncludesHour reports whether the forecast hour is in range.
func (s SubsetSpec) IncludesHour(h int) bool {
if s.HourRange == nil {
return true
}
return h >= s.HourRange.MinHour && h <= s.HourRange.MaxHour
}
// IncludesMember reports whether the ensemble member is in range.
func (s SubsetSpec) IncludesMember(m int) bool {
if len(s.Members) == 0 {
return true
}
return slices.Contains(s.Members, m)
}
// Key returns a deterministic short identifier for the spec. The empty
// string represents the global subset.
func (s SubsetSpec) Key() string {
if s.IsGlobal() {
return ""
}
var b strings.Builder
if s.Region != nil {
fmt.Fprintf(&b, "r%g.%g.%g.%g", s.Region.MinLat, s.Region.MaxLat, s.Region.MinLng, s.Region.MaxLng)
}
if s.HourRange != nil {
if b.Len() > 0 {
b.WriteByte('_')
}
fmt.Fprintf(&b, "h%d.%d", s.HourRange.MinHour, s.HourRange.MaxHour)
}
if len(s.Members) > 0 {
if b.Len() > 0 {
b.WriteByte('_')
}
fmt.Fprintf(&b, "m")
for i, m := range s.Members {
if i > 0 {
b.WriteByte('.')
}
fmt.Fprintf(&b, "%d", m)
}
}
return b.String()
}
// DatasetID identifies one storable dataset.
type DatasetID struct {
Epoch time.Time
Subset SubsetSpec
}
// Equals reports whether two DatasetIDs refer to the same dataset.
// DatasetID is not comparable with == because SubsetSpec contains slices.
func (id DatasetID) Equals(other DatasetID) bool {
return id.Epoch.Equal(other.Epoch) && id.Subset.Key() == other.Subset.Key()
}
// Filename returns the canonical filename stem for the dataset. The
// extension is appended by the Storage implementation.
func (id DatasetID) Filename() string {
stem := id.Epoch.UTC().Format("20060102T150405Z")
if k := id.Subset.Key(); k != "" {
return stem + "_" + k
}
return stem
}
// Coverage is the spatial and temporal extent of a loaded dataset, used by
// the Manager to select which dataset can serve a given query.
type Coverage struct {
Region Region `json:"region"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
}
// Covers reports whether (t, lat, lng) lies inside the coverage.
func (c Coverage) Covers(t time.Time, lat, lng float64) bool {
if t.Before(c.StartTime) || t.After(c.EndTime) {
return false
}
r := c.Region
if lat < r.MinLat || lat > r.MaxLat {
return false
}
if r.MinLng <= r.MaxLng {
return lng >= r.MinLng && lng <= r.MaxLng
}
return lng >= r.MinLng || lng <= r.MaxLng
}

View file

@ -0,0 +1,63 @@
package datasets
import (
"context"
"sync"
"time"
)
// TokenBucket is a simple bytes-per-second rate limiter.
//
// The bucket is initialised full (capacity = rate × 1 second). Calls to Wait
// block until enough tokens have accumulated.
type TokenBucket struct {
mu sync.Mutex
rate float64 // tokens per second
tokens float64
cap float64
last time.Time
}
// NewTokenBucket returns a TokenBucket emitting at most bytesPerSecond.
// A non-positive rate disables throttling (Wait becomes a no-op).
func NewTokenBucket(bytesPerSecond int64) *TokenBucket {
if bytesPerSecond <= 0 {
return &TokenBucket{rate: 0}
}
r := float64(bytesPerSecond)
return &TokenBucket{rate: r, tokens: r, cap: r, last: time.Now()}
}
// Wait blocks until n tokens are available or ctx is cancelled.
func (t *TokenBucket) Wait(ctx context.Context, n int) error {
if t.rate <= 0 {
return nil
}
want := float64(n)
for {
t.mu.Lock()
now := time.Now()
elapsed := now.Sub(t.last).Seconds()
t.last = now
t.tokens += elapsed * t.rate
if t.tokens > t.cap {
t.tokens = t.cap
}
if t.tokens >= want {
t.tokens -= want
t.mu.Unlock()
return nil
}
// Sleep until we expect enough tokens.
need := want - t.tokens
sleep := time.Duration(need / t.rate * float64(time.Second))
t.mu.Unlock()
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(sleep):
}
}
}

View file

@ -0,0 +1,91 @@
package datasets
import (
"context"
"time"
"predictor-refactored/internal/weather"
)
// Source is a pluggable origin for atmospheric datasets.
//
// Implementations download dataset files in a transactional, resumable
// manner and load them as weather.WindField. A Source must be safe for
// concurrent use across many Manager calls.
type Source interface {
// ID is a stable identifier, e.g. "gfs-0p50-3h".
ID() string
// LatestEpoch returns the most recent dataset epoch this source can provide.
LatestEpoch(ctx context.Context) (time.Time, error)
// Download fetches the dataset identified by id into store. Sources
// must honour any partial progress recorded in store's manifest and
// skip already-completed work so re-invocation after a crash resumes
// cleanly.
//
// prog receives progress events; nil is acceptable.
// throttle, if non-nil, is consulted before each network read for
// bandwidth limiting; nil means no throttling.
Download(ctx context.Context, id DatasetID, store Storage, prog ProgressSink, throttle Throttle) error
// Open loads id's stored dataset and returns it as a WindField.
Open(ctx context.Context, id DatasetID, store Storage) (weather.WindField, error)
// Coverage returns the geographical/temporal extent of a downloaded
// dataset. Used by the Manager to decide which loaded dataset can
// serve a given prediction query.
Coverage(id DatasetID) Coverage
}
// Storage abstracts the on-disk location of dataset files and their manifests.
//
// Atomicity: only datasets promoted via TempHandle.Commit appear in Exists
// or List. Aborted or in-progress downloads are invisible to readers.
type Storage interface {
// SourceID identifies the data source these files belong to.
SourceID() string
// Path returns the canonical local path for id's dataset.
Path(id DatasetID) string
// Exists reports whether a committed dataset for id is present.
Exists(id DatasetID) bool
// List returns all committed dataset IDs available, newest first.
List() ([]DatasetID, error)
// Remove deletes the dataset and any sidecar manifest for id.
Remove(id DatasetID) error
// BeginWrite opens (or resumes) a transactional handle for downloading
// id's dataset.
BeginWrite(id DatasetID) (TempHandle, error)
// Lock acquires an exclusive, storage-wide lock that serialises downloads
// across every process sharing this storage (e.g. multiple replicas on a
// node that share a dataset volume). It blocks until the lock is held or
// ctx is cancelled. The returned function releases the lock.
Lock(ctx context.Context) (release func(), err error)
}
// TempHandle is the storage state for one in-progress download.
type TempHandle interface {
Path() string
Manifest() *Manifest
Commit() error
Abort() error
}
// ProgressSink receives progress events during a download.
type ProgressSink interface {
SetTotal(n int)
StepComplete()
Bytes(n int64)
}
// Throttle is an optional bandwidth limiter consulted by sources before
// each network read.
type Throttle interface {
Wait(ctx context.Context, n int) error
}

View file

@ -1,58 +0,0 @@
package downloader
import (
"os"
"strconv"
"time"
)
// Config holds downloader configuration, loaded from environment variables.
type Config struct {
// DataDir is the directory for storing dataset files and temporary GRIB data.
DataDir string
// Parallel is the maximum number of concurrent GRIB downloads.
Parallel int
// UpdateInterval is how often the scheduler checks for new forecast data.
UpdateInterval time.Duration
// DatasetTTL is how long a dataset is considered fresh before a new one is needed.
DatasetTTL time.Duration
}
// DefaultConfig returns the default configuration.
func DefaultConfig() *Config {
return &Config{
DataDir: "/tmp/predictor-data",
Parallel: 8,
UpdateInterval: 6 * time.Hour,
DatasetTTL: 48 * time.Hour,
}
}
// LoadConfig loads configuration from environment variables, falling back to defaults.
func LoadConfig() *Config {
cfg := DefaultConfig()
if v := os.Getenv("PREDICTOR_DATA_DIR"); v != "" {
cfg.DataDir = v
}
if v := os.Getenv("PREDICTOR_DOWNLOAD_PARALLEL"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
cfg.Parallel = n
}
}
if v := os.Getenv("PREDICTOR_UPDATE_INTERVAL"); v != "" {
if d, err := time.ParseDuration(v); err == nil {
cfg.UpdateInterval = d
}
}
if v := os.Getenv("PREDICTOR_DATASET_TTL"); v != "" {
if d, err := time.ParseDuration(v); err == nil {
cfg.DatasetTTL = d
}
}
return cfg
}

View file

@ -1,441 +0,0 @@
package downloader
import (
"context"
"fmt"
"io"
"math"
"net/http"
"os"
"path/filepath"
"sync/atomic"
"time"
"predictor-refactored/internal/dataset"
"github.com/nilsmagnus/grib/griblib"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
)
// Downloader handles fetching GFS forecast data from S3 and assembling dataset files.
type Downloader struct {
cfg *Config
client *http.Client
log *zap.Logger
}
// NewDownloader creates a new Downloader.
func NewDownloader(cfg *Config, log *zap.Logger) *Downloader {
return &Downloader{
cfg: cfg,
client: &http.Client{
Timeout: 2 * time.Minute,
},
log: log,
}
}
// neededVariables is the set of GRIB variable names we need.
var neededVariables = map[string]bool{
"HGT": true,
"UGRD": true,
"VGRD": true,
}
// FindLatestRun finds the most recent available GFS model run on S3.
// It checks the last forecast step of each run to confirm availability.
func (d *Downloader) FindLatestRun(ctx context.Context) (time.Time, error) {
now := time.Now().UTC()
hour := now.Hour() - (now.Hour() % 6)
current := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC)
for i := 0; i < 8; i++ {
date := current.Format("20060102")
url := dataset.GribURL(date, current.Hour(), dataset.MaxHour) + ".idx"
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
if err != nil {
current = current.Add(-6 * time.Hour)
continue
}
resp, err := d.client.Do(req)
if err == nil {
resp.Body.Close()
if resp.StatusCode == http.StatusOK {
d.log.Info("found latest model run",
zap.Time("run", current),
zap.String("verified_url", url))
return current, nil
}
}
current = current.Add(-6 * time.Hour)
}
return time.Time{}, fmt.Errorf("no recent GFS forecast found (checked 8 runs)")
}
// progress tracks download progress across concurrent goroutines.
type progress struct {
bytesDownloaded atomic.Int64
stepsCompleted atomic.Int64
totalSteps int64
startTime time.Time
log *zap.Logger
}
func newProgress(totalSteps int, log *zap.Logger) *progress {
return &progress{
totalSteps: int64(totalSteps),
startTime: time.Now(),
log: log,
}
}
func (p *progress) addBytes(n int64) {
p.bytesDownloaded.Add(n)
}
func (p *progress) completeStep() {
done := p.stepsCompleted.Add(1)
total := p.totalSteps
bytes := p.bytesDownloaded.Load()
elapsed := time.Since(p.startTime).Seconds()
pct := float64(done) / float64(total) * 100
mbDownloaded := float64(bytes) / (1024 * 1024)
mbPerSec := 0.0
if elapsed > 0 {
mbPerSec = mbDownloaded / elapsed
}
// Estimate remaining
eta := ""
if done > 0 && done < total {
secsPerStep := elapsed / float64(done)
remaining := secsPerStep * float64(total-done)
if remaining > 60 {
eta = fmt.Sprintf("%.0fm%02.0fs", math.Floor(remaining/60), math.Mod(remaining, 60))
} else {
eta = fmt.Sprintf("%.0fs", remaining)
}
}
p.log.Info("download progress",
zap.String("progress", fmt.Sprintf("%d/%d", done, total)),
zap.String("percent", fmt.Sprintf("%.1f%%", pct)),
zap.String("downloaded", fmt.Sprintf("%.1f MB", mbDownloaded)),
zap.String("speed", fmt.Sprintf("%.1f MB/s", mbPerSec)),
zap.String("eta", eta))
}
// Download downloads a complete forecast and assembles a dataset file.
// Returns the path to the completed dataset file.
func (d *Downloader) Download(ctx context.Context, run time.Time) (string, error) {
date := run.Format("20060102")
runHour := run.Hour()
finalPath := filepath.Join(d.cfg.DataDir, run.Format("2006010215"))
tempPath := finalPath + ".downloading"
// Check if final dataset already exists
if info, err := os.Stat(finalPath); err == nil && info.Size() == dataset.DatasetSize {
d.log.Info("dataset already exists", zap.String("path", finalPath))
return finalPath, nil
}
steps := dataset.Hours()
totalSteps := len(steps) * 2 // pgrb2 + pgrb2b per step
prog := newProgress(totalSteps, d.log)
d.log.Info("starting dataset download",
zap.Time("run", run),
zap.Int("total_steps", totalSteps),
zap.String("temp_path", tempPath))
// Create the dataset file
ds, err := dataset.Create(tempPath)
if err != nil {
return "", fmt.Errorf("create dataset: %w", err)
}
defer ds.Close()
// Process each forecast step with bounded concurrency
g, ctx := errgroup.WithContext(ctx)
sem := make(chan struct{}, d.cfg.Parallel)
for _, step := range steps {
step := step
hourIdx := dataset.HourIndex(step)
if hourIdx < 0 {
continue
}
// Download pgrb2 (level set A)
sem <- struct{}{}
g.Go(func() error {
defer func() { <-sem }()
url := dataset.GribURL(date, runHour, step)
err := d.downloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetA, prog)
if err != nil {
return fmt.Errorf("step %d pgrb2: %w", step, err)
}
prog.completeStep()
return nil
})
// Download pgrb2b (level set B)
sem <- struct{}{}
g.Go(func() error {
defer func() { <-sem }()
url := dataset.GribURLB(date, runHour, step)
err := d.downloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetB, prog)
if err != nil {
return fmt.Errorf("step %d pgrb2b: %w", step, err)
}
prog.completeStep()
return nil
})
}
if err := g.Wait(); err != nil {
os.Remove(tempPath)
return "", err
}
elapsed := time.Since(prog.startTime)
totalMB := float64(prog.bytesDownloaded.Load()) / (1024 * 1024)
d.log.Info("download complete, flushing to disk",
zap.String("downloaded", fmt.Sprintf("%.1f MB", totalMB)),
zap.Duration("elapsed", elapsed),
zap.String("avg_speed", fmt.Sprintf("%.1f MB/s", totalMB/elapsed.Seconds())))
// Flush to disk
if err := ds.Flush(); err != nil {
os.Remove(tempPath)
return "", fmt.Errorf("flush dataset: %w", err)
}
// Close before rename
ds.Close()
// Atomic rename
if err := os.Rename(tempPath, finalPath); err != nil {
os.Remove(tempPath)
return "", fmt.Errorf("rename dataset: %w", err)
}
d.log.Info("dataset ready", zap.String("path", finalPath))
return finalPath, nil
}
// DownloadAndBlit downloads needed GRIB fields from a URL and writes them into the dataset.
func (d *Downloader) DownloadAndBlit(ctx context.Context, ds *dataset.File, baseURL string, hourIdx int, levelSet dataset.LevelSet) error {
return d.downloadAndBlit(ctx, ds, baseURL, hourIdx, levelSet, nil)
}
// downloadAndBlit is the internal implementation with optional progress tracking.
func (d *Downloader) downloadAndBlit(ctx context.Context, ds *dataset.File, baseURL string, hourIdx int, levelSet dataset.LevelSet, prog *progress) error {
// 1. Download .idx
idxURL := baseURL + ".idx"
idxBody, err := d.httpGet(ctx, idxURL)
if err != nil {
return fmt.Errorf("download idx: %w", err)
}
// 2. Parse and filter
entries := ParseIdx(idxBody)
filtered := FilterIdx(entries, neededVariables)
// Further filter to only levels in this level set
var relevant []IdxEntry
for _, e := range filtered {
ls, ok := dataset.PressureLevelSet(e.LevelMB)
if ok && ls == levelSet {
relevant = append(relevant, e)
}
}
if len(relevant) == 0 {
d.log.Warn("no relevant entries found in idx",
zap.String("url", idxURL),
zap.Int("total_entries", len(entries)),
zap.Int("filtered", len(filtered)))
return nil
}
// 3. Download byte ranges and write to temp file
ranges := EntriesToRanges(relevant)
tmpFile, err := d.downloadRangesToTempFile(ctx, baseURL, ranges, prog)
if err != nil {
return fmt.Errorf("download ranges: %w", err)
}
defer os.Remove(tmpFile)
// 4. Read GRIB messages from temp file
f, err := os.Open(tmpFile)
if err != nil {
return fmt.Errorf("open temp grib: %w", err)
}
messages, err := griblib.ReadMessages(f)
f.Close()
if err != nil {
return fmt.Errorf("read grib messages: %w", err)
}
// 5. Decode and blit each message into the dataset
for _, msg := range messages {
if msg.Section4.ProductDefinitionTemplateNumber != 0 {
continue
}
product := msg.Section4.ProductDefinitionTemplate
varIdx := dataset.VariableIndex(int(product.ParameterCategory), int(product.ParameterNumber))
if varIdx < 0 {
continue
}
if product.FirstSurface.Type != 100 { // isobaric surface
continue
}
pressurePa := float64(product.FirstSurface.Value)
pressureMB := int(math.Round(pressurePa / 100.0))
levelIdx := dataset.PressureIndex(pressureMB)
if levelIdx < 0 {
continue
}
data := msg.Data()
if err := ds.BlitGribData(hourIdx, levelIdx, varIdx, data); err != nil {
d.log.Warn("blit failed",
zap.Int("var", varIdx),
zap.Int("level_mb", pressureMB),
zap.Error(err))
continue
}
}
return nil
}
// downloadRangesToTempFile downloads multiple byte ranges from a URL,
// concatenating them into a single temp file (valid concatenated GRIB messages).
func (d *Downloader) downloadRangesToTempFile(ctx context.Context, baseURL string, ranges []ByteRange, prog *progress) (string, error) {
tmpFile, err := os.CreateTemp(d.cfg.DataDir, "grib-*.tmp")
if err != nil {
return "", fmt.Errorf("create temp file: %w", err)
}
tmpPath := tmpFile.Name()
for _, r := range ranges {
data, err := d.httpGetRange(ctx, baseURL, r.Start, r.End)
if err != nil {
tmpFile.Close()
os.Remove(tmpPath)
return "", fmt.Errorf("download range %d-%d: %w", r.Start, r.End, err)
}
if _, err := tmpFile.Write(data); err != nil {
tmpFile.Close()
os.Remove(tmpPath)
return "", fmt.Errorf("write temp: %w", err)
}
if prog != nil {
prog.addBytes(int64(len(data)))
}
}
if err := tmpFile.Close(); err != nil {
os.Remove(tmpPath)
return "", err
}
return tmpPath, nil
}
// httpGet downloads a URL and returns the body bytes.
func (d *Downloader) httpGet(ctx context.Context, url string) ([]byte, error) {
var lastErr error
for attempt := 0; attempt < 3; attempt++ {
if attempt > 0 {
select {
case <-time.After(time.Duration(attempt*2) * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := d.client.Do(req)
if err != nil {
lastErr = err
continue
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url)
continue
}
if err != nil {
lastErr = err
continue
}
return body, nil
}
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
}
// httpGetRange downloads a byte range from a URL.
func (d *Downloader) httpGetRange(ctx context.Context, url string, start, end int64) ([]byte, error) {
var lastErr error
for attempt := 0; attempt < 3; attempt++ {
if attempt > 0 {
select {
case <-time.After(time.Duration(attempt*2) * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
resp, err := d.client.Do(req)
if err != nil {
lastErr = err
continue
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("HTTP %d for range %d-%d of %s", resp.StatusCode, start, end, url)
continue
}
if err != nil {
lastErr = err
continue
}
return body, nil
}
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
}

View file

@ -1,157 +0,0 @@
package downloader
import (
"fmt"
"strconv"
"strings"
)
// IdxEntry represents a single parsed line from a GRIB .idx file.
// Example line: "15:1207405:d=2024010100:HGT:1000 mb:0 hour fcst:"
type IdxEntry struct {
Index int
Offset int64
Variable string // "HGT", "UGRD", "VGRD", etc.
LevelMB int // pressure level in mb (0 if not a pressure level)
Hour int // forecast hour
EndOffset int64 // byte after this message (from next entry's offset, or -1 if last)
}
// Length returns the byte length of this GRIB message, or -1 if unknown.
func (e *IdxEntry) Length() int64 {
if e.EndOffset <= 0 {
return -1
}
return e.EndOffset - e.Offset
}
// ParseIdx parses a .idx file body and returns all entries.
// Lines that can't be parsed are silently skipped.
func ParseIdx(body []byte) []IdxEntry {
lines := strings.Split(string(body), "\n")
var entries []IdxEntry
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.Split(line, ":")
if len(parts) < 7 {
continue
}
idx, err := strconv.Atoi(parts[0])
if err != nil {
continue
}
offset, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
continue
}
variable := parts[3]
levelStr := parts[4]
hourStr := parts[5]
levelMB := parseLevelMB(levelStr)
hour := parseHour(hourStr)
entries = append(entries, IdxEntry{
Index: idx,
Offset: offset,
Variable: variable,
LevelMB: levelMB,
Hour: hour,
EndOffset: -1, // filled in below
})
}
// Fill in EndOffset from the next entry's Offset.
for i := 0; i < len(entries)-1; i++ {
entries[i].EndOffset = entries[i+1].Offset
}
return entries
}
// FilterIdx returns entries matching the given variables at pressure levels.
// Only entries with a recognized pressure level (levelMB > 0) are returned.
func FilterIdx(entries []IdxEntry, variables map[string]bool) []IdxEntry {
var filtered []IdxEntry
for _, e := range entries {
if !variables[e.Variable] {
continue
}
if e.LevelMB <= 0 {
continue
}
// Must have a known length (not the last entry) or be handled specially
if e.Length() <= 0 {
continue
}
filtered = append(filtered, e)
}
return filtered
}
// parseLevelMB parses a level string like "1000 mb" and returns the pressure in mb.
// Returns 0 if not a pressure level.
func parseLevelMB(s string) int {
s = strings.TrimSpace(s)
if !strings.HasSuffix(s, " mb") {
return 0
}
numStr := strings.TrimSuffix(s, " mb")
n, err := strconv.Atoi(numStr)
if err != nil {
return 0
}
return n
}
// parseHour parses a forecast hour string like "0 hour fcst" or "anl".
// Returns -1 if it can't be parsed.
func parseHour(s string) int {
s = strings.TrimSpace(s)
if s == "anl" {
return 0
}
s = strings.TrimSuffix(s, " hour fcst")
n, err := strconv.Atoi(s)
if err != nil {
return -1
}
return n
}
// GroupByRange groups idx entries into byte ranges suitable for HTTP Range downloads.
// Each range covers one contiguous GRIB message.
type ByteRange struct {
Start int64
End int64 // inclusive
Entry IdxEntry
}
// EntriesToRanges converts filtered idx entries to byte ranges.
func EntriesToRanges(entries []IdxEntry) []ByteRange {
ranges := make([]ByteRange, 0, len(entries))
for _, e := range entries {
if e.Length() <= 0 {
continue
}
ranges = append(ranges, ByteRange{
Start: e.Offset,
End: e.EndOffset - 1, // inclusive
Entry: e,
})
}
return ranges
}
// FormatRange returns an HTTP Range header value for a byte range.
func (r ByteRange) FormatRange() string {
return fmt.Sprintf("bytes=%d-%d", r.Start, r.End)
}

View file

@ -1,110 +0,0 @@
package downloader
import (
"testing"
)
const sampleIdx = `1:0:d=2024010100:HGT:1000 mb:0 hour fcst:
2:289012:d=2024010100:HGT:975 mb:0 hour fcst:
3:541876:d=2024010100:TMP:1000 mb:0 hour fcst:
4:789012:d=2024010100:UGRD:1000 mb:0 hour fcst:
5:1045678:d=2024010100:VGRD:1000 mb:0 hour fcst:
6:1298765:d=2024010100:UGRD:975 mb:0 hour fcst:
7:1567890:d=2024010100:UGRD:2 m above ground:0 hour fcst:
8:1812345:d=2024010100:VGRD:975 mb:0 hour fcst:
9:2098765:d=2024010100:HGT:500 mb:3 hour fcst:
`
func TestParseIdx(t *testing.T) {
entries := ParseIdx([]byte(sampleIdx))
if len(entries) != 9 {
t.Fatalf("expected 9 entries, got %d", len(entries))
}
// Check first entry
e := entries[0]
if e.Index != 1 || e.Offset != 0 || e.Variable != "HGT" || e.LevelMB != 1000 || e.Hour != 0 {
t.Errorf("entry 0: got %+v", e)
}
if e.EndOffset != 289012 {
t.Errorf("entry 0 EndOffset: got %d, want 289012", e.EndOffset)
}
// Check "2 m above ground" is not a pressure level
e = entries[6] // UGRD at "2 m above ground"
if e.LevelMB != 0 {
t.Errorf("non-pressure level should have LevelMB=0, got %d", e.LevelMB)
}
// Last entry should have EndOffset = -1
last := entries[len(entries)-1]
if last.EndOffset != -1 {
t.Errorf("last entry EndOffset: got %d, want -1", last.EndOffset)
}
}
func TestFilterIdx(t *testing.T) {
entries := ParseIdx([]byte(sampleIdx))
filtered := FilterIdx(entries, neededVariables)
// Should include HGT/UGRD/VGRD at pressure levels, exclude TMP and "above ground"
// And exclude last entry (no EndOffset)
for _, e := range filtered {
if !neededVariables[e.Variable] {
t.Errorf("unexpected variable %s", e.Variable)
}
if e.LevelMB <= 0 {
t.Errorf("non-pressure level included: %+v", e)
}
if e.Length() <= 0 {
t.Errorf("entry with unknown length included: %+v", e)
}
}
// Count expected: HGT@1000, HGT@975, UGRD@1000, VGRD@1000, UGRD@975, VGRD@975 = 6
// But HGT@500 at 3hr fcst is the last entry (no EndOffset), so excluded
if len(filtered) != 6 {
t.Errorf("expected 6 filtered entries, got %d", len(filtered))
for _, e := range filtered {
t.Logf(" %s %d mb (offset %d, len %d)", e.Variable, e.LevelMB, e.Offset, e.Length())
}
}
}
func TestParseLevelMB(t *testing.T) {
tests := []struct {
input string
want int
}{
{"1000 mb", 1000},
{"975 mb", 975},
{"1 mb", 1},
{"2 m above ground", 0},
{"surface", 0},
{"tropopause", 0},
}
for _, tt := range tests {
got := parseLevelMB(tt.input)
if got != tt.want {
t.Errorf("parseLevelMB(%q) = %d, want %d", tt.input, got, tt.want)
}
}
}
func TestParseHour(t *testing.T) {
tests := []struct {
input string
want int
}{
{"0 hour fcst", 0},
{"3 hour fcst", 3},
{"192 hour fcst", 192},
{"anl", 0},
}
for _, tt := range tests {
got := parseHour(tt.input)
if got != tt.want {
t.Errorf("parseHour(%q) = %d, want %d", tt.input, got, tt.want)
}
}
}

View file

@ -71,9 +71,10 @@ func (d *Dataset) getCell(latIdx, lngIdx int) int16 {
return int16(binary.LittleEndian.Uint16(d.mm[off : off+2]))
}
// Get returns the interpolated elevation in metres at the given coordinates.
// lat: -90 to +90, lng: 0 to 360 (or -180 to 180, will be normalised).
func (d *Dataset) Get(lat, lng float64) float64 {
// Elevation returns the bilinearly-interpolated ground elevation in metres at
// the given coordinates. lat is in [-90, +90]; lng accepts either [0, 360) or
// [-180, 180) and is normalised internally.
func (d *Dataset) Elevation(lat, lng float64) float64 {
// Normalise longitude to [0, 360)
if lng < 0 {
lng += 360

View file

@ -0,0 +1,117 @@
package engine
import (
"fmt"
"predictor-refactored/internal/numerics"
)
// Altitude triggers when the balloon altitude satisfies Op against Limit.
//
// Examples:
//
// Altitude{Op: OpGreaterEqual, Limit: 30000} — burst at 30 km
// Altitude{Op: OpLessEqual, Limit: 0} — sea-level descent termination
type Altitude struct {
Op Operator
Limit float64
On Action
}
func (c Altitude) Name() string {
return fmt.Sprintf("altitude %s %g", c.Op, c.Limit)
}
func (c Altitude) Violated(_ float64, s State) bool { return c.Op.Test(s.Altitude, c.Limit) }
func (c Altitude) Action() Action { return c.On }
// Time triggers when the integration time t (UNIX seconds) satisfies Op
// against Limit.
type Time struct {
Op Operator
Limit float64
On Action
}
func (c Time) Name() string { return fmt.Sprintf("time %s %g", c.Op, c.Limit) }
func (c Time) Violated(t float64, _ State) bool { return c.Op.Test(t, c.Limit) }
func (c Time) Action() Action { return c.On }
// TerrainContact triggers when the ground elevation exceeds the balloon's
// altitude — i.e. the balloon has hit the ground.
type TerrainContact struct {
Provider TerrainProvider
On Action
}
func (c TerrainContact) Name() string { return "terrain_contact" }
func (c TerrainContact) Violated(_ float64, s State) bool {
return c.Provider.Elevation(s.Lat, s.Lng) > s.Altitude
}
func (c TerrainContact) Action() Action { return c.On }
// PolygonMode selects whether Polygon fires when the balloon is inside or
// outside the configured polygon.
type PolygonMode int
const (
// PolygonInside fires when (lat, lng) lies inside the polygon — useful
// for "must not enter restricted airspace".
PolygonInside PolygonMode = iota
// PolygonOutside fires when (lat, lng) lies outside the polygon —
// useful for "must remain over the test range".
PolygonOutside
)
// PolygonVertex is one vertex of a geographic polygon. Latitudes are in
// degrees [-90, 90]; longitudes in degrees [0, 360) or [-180, 180]
// (callers normalise — see Polygon.Violated).
type PolygonVertex struct {
Lat float64
Lng float64
}
// Polygon is a constraint over a closed geographic polygon, evaluated in
// plate-carrée coordinates with antimeridian handling (see
// numerics.PointInPolygon). Build one with NewPolygon so the flattened
// vertex slices used by the hot path are precomputed.
type Polygon struct {
Vertices []PolygonVertex
Mode PolygonMode
On Action
// Label, if set, is returned by Name. Defaults to "polygon_inside" or
// "polygon_outside" based on Mode.
Label string
// Precomputed parallel vertex slices for numerics.PointInPolygon.
polyLat, polyLng []float64
}
// NewPolygon builds a Polygon, precomputing the flattened vertex slices.
func NewPolygon(verts []PolygonVertex, mode PolygonMode, on Action, label string) Polygon {
lat := make([]float64, len(verts))
lng := make([]float64, len(verts))
for i, v := range verts {
lat[i], lng[i] = v.Lat, v.Lng
}
return Polygon{Vertices: verts, Mode: mode, On: on, Label: label, polyLat: lat, polyLng: lng}
}
func (c Polygon) Name() string {
if c.Label != "" {
return c.Label
}
if c.Mode == PolygonOutside {
return "polygon_outside"
}
return "polygon_inside"
}
func (c Polygon) Action() Action { return c.On }
// Violated reports whether the state satisfies the polygon-containment rule.
func (c Polygon) Violated(_ float64, s State) bool {
in := numerics.PointInPolygon(s.Lat, s.Lng, c.polyLat, c.polyLng)
if c.Mode == PolygonInside {
return in
}
return !in
}

View file

@ -0,0 +1,265 @@
package engine
import (
"math"
"testing"
"time"
"predictor-refactored/internal/weather"
)
// noWind is a WindField that always returns zero wind.
type noWind struct{ epoch time.Time }
func (n noWind) Wind(_ float64, _, _, _ float64) (weather.Sample, error) {
return weather.Sample{}, nil
}
func (n noWind) Epoch() time.Time { return n.epoch }
func (n noWind) Source() string { return "test" }
// flatGround returns 0 metres everywhere.
type flatGround struct{}
func (flatGround) Elevation(_, _ float64) float64 { return 0 }
func TestConstantAscentToBurst(t *testing.T) {
burst := 30000.0
rate := 5.0
ascend := &Propagator{
Name: "ascent",
Step: 60,
Model: Sum(ConstantRate(rate), WindTransport(noWind{}, nil)),
Constraints: []Constraint{Altitude{Op: OpGreaterEqual, Limit: burst, On: ActionStop}},
}
prof := Profile{Stages: []*Propagator{ascend}, Direction: Forward}
results := prof.Run(0, State{Lat: 0, Lng: 0, Altitude: 0}, NewEventSink())
if len(results) != 1 || results[0].Outcome != OutcomeStopped {
t.Fatalf("expected one stopped stage, got %+v", results)
}
if results[0].ConstraintName == "" {
t.Errorf("ConstraintName not populated")
}
if results[0].RefinedState.Altitude == 0 {
t.Errorf("RefinedState not populated")
}
lastT, last := results[0].Path.Last()
if math.Abs(last.Altitude-burst) > 5 {
t.Errorf("burst altitude = %v, want within 5m of %v", last.Altitude, burst)
}
wantTime := burst / rate
if math.Abs(lastT-wantTime) > 1 {
t.Errorf("burst time = %v, want within 1s of %v", lastT, wantTime)
}
}
func TestProfileWithFallback(t *testing.T) {
burst := 1000.0
rate := 5.0
descent := &Propagator{
Name: "descent",
Step: 60,
Model: ParachuteDescent(rate),
Constraints: []Constraint{TerrainContact{Provider: flatGround{}, On: ActionStop}},
}
ascend := &Propagator{
Name: "ascent",
Step: 60,
Model: ConstantRate(rate),
Constraints: []Constraint{Altitude{Op: OpGreaterEqual, Limit: burst, On: ActionFallback}},
Fallback: descent,
}
prof := Profile{Stages: []*Propagator{ascend}, Direction: Forward}
results := prof.Run(0, State{Altitude: 0}, NewEventSink())
if len(results) != 2 {
t.Fatalf("expected 2 results (ascent then descent fallback), got %d", len(results))
}
if results[0].Outcome != OutcomeFallback {
t.Errorf("first outcome = %v, want OutcomeFallback", results[0].Outcome)
}
if results[1].Outcome != OutcomeStopped {
t.Errorf("second outcome = %v, want OutcomeStopped", results[1].Outcome)
}
_, last := results[1].Path.Last()
if math.Abs(last.Altitude) > 5 {
t.Errorf("final altitude = %v, want within 5m of 0", last.Altitude)
}
}
func TestReverseDirection(t *testing.T) {
desc := &Propagator{
Name: "rewind",
Step: 1,
Model: ConstantRate(-1),
Constraints: []Constraint{Altitude{Op: OpGreaterEqual, Limit: 200, On: ActionStop}},
}
prof := Profile{Stages: []*Propagator{desc}, Direction: Reverse}
results := prof.Run(0, State{Altitude: 100}, NewEventSink())
lastT, last := results[0].Path.Last()
if math.Abs(last.Altitude-200) > 1 {
t.Errorf("reverse final altitude = %v, want ~200", last.Altitude)
}
if lastT >= 0 {
t.Errorf("reverse final time = %v, want < 0", lastT)
}
}
func TestPiecewiseRate(t *testing.T) {
m := Piecewise([]RateSegment{
{Until: 100, Rate: 5},
{Until: 200, Rate: 3},
{Until: math.Inf(1), Rate: 0},
})
if r := m(50, State{}); r.Altitude != 5 {
t.Errorf("rate at t=50 = %v, want 5", r.Altitude)
}
if r := m(150, State{}); r.Altitude != 3 {
t.Errorf("rate at t=150 = %v, want 3", r.Altitude)
}
if r := m(300, State{}); r.Altitude != 0 {
t.Errorf("rate at t=300 = %v, want 0", r.Altitude)
}
}
func TestPiecewiseReferenceResolution(t *testing.T) {
// Build via the registry with propagator_start segments.
spec := ModelSpec{
Type: "piecewise",
Segments: []PiecewiseSegmentSpec{
{Until: 100, Rate: 5, Reference: "propagator_start"},
{Until: 200, Rate: 3, Reference: "propagator_start"},
},
}
built, err := BuildModel(spec, BuildDeps{})
if err != nil {
t.Fatalf("BuildModel: %v", err)
}
if built.Build == nil {
t.Fatalf("expected lazy build for propagator_start references")
}
ctx := StageContext{ProfileStart: 1000, PropagatorStart: 5000}
m := built.Build(ctx)
// Until=100 from propagator_start=5000 → absolute 5100.
if r := m(5050, State{}); r.Altitude != 5 {
t.Errorf("rate at t=5050 = %v, want 5", r.Altitude)
}
if r := m(5150, State{}); r.Altitude != 3 {
t.Errorf("rate at t=5150 = %v, want 3", r.Altitude)
}
}
// fixedWind returns a constant wind sample.
type fixedWind struct{ u, v float64 }
func (w fixedWind) Wind(_ float64, _, _, _ float64) (weather.Sample, error) {
return weather.Sample{U: w.u, V: w.v}, nil
}
func (fixedWind) Epoch() time.Time { return time.Unix(0, 0) }
func (fixedWind) Source() string { return "test-fixed" }
func TestWindTransportUnitConversion(t *testing.T) {
wind := WindTransport(fixedWind{u: 10, v: 0}, nil)
d := wind(0, State{Lat: 0, Lng: 0, Altitude: 0})
wantLng := (180.0 / math.Pi) * 10.0 / 6371009.0
if math.Abs(d.Lng-wantLng) > 1e-12 {
t.Errorf("dlng = %v, want %v", d.Lng, wantLng)
}
if math.Abs(d.Lat) > 1e-12 {
t.Errorf("dlat = %v, want 0 for u=10 v=0", d.Lat)
}
wind2 := WindTransport(fixedWind{u: 0, v: 5}, nil)
d = wind2(0, State{Lat: 60, Lng: 0, Altitude: 0})
wantLat := (180.0 / math.Pi) * 5.0 / 6371009.0
if math.Abs(d.Lat-wantLat) > 1e-12 {
t.Errorf("dlat at lat=60 = %v, want %v", d.Lat, wantLat)
}
}
// aboveModelWind reports AboveModel on every sample. Used to verify event emission.
type aboveModelWind struct{}
func (aboveModelWind) Wind(_ float64, _, _, _ float64) (weather.Sample, error) {
return weather.Sample{AboveModel: true}, nil
}
func (aboveModelWind) Epoch() time.Time { return time.Unix(0, 0) }
func (aboveModelWind) Source() string { return "above" }
func TestWindTransportEmitsAboveModel(t *testing.T) {
sink := NewEventSink()
wind := WindTransport(aboveModelWind{}, sink)
for range 3 {
_ = wind(0, State{})
}
events := sink.Snapshot()
if len(events) != 1 || events[0].Type != "above_model" || events[0].Count != 3 {
t.Errorf("expected one above_model event with count=3, got %+v", events)
}
}
func TestNoTerminatorStopsAtStepCap(t *testing.T) {
// A stage that ascends forever with no constraint must not loop endlessly;
// the integrator's step backstop stops it and records a max_steps event.
sink := NewEventSink()
prof := Profile{
Stages: []*Propagator{{Name: "runaway", Step: 60, Model: ConstantRate(5)}},
Direction: Forward,
}
results := prof.Run(0, State{}, sink)
if results[0].Outcome != OutcomeContinued {
t.Errorf("outcome = %v, want OutcomeContinued (step cap)", results[0].Outcome)
}
if results[0].Path.Len() != DefaultMaxSteps+1 {
t.Errorf("path len = %d, want %d", results[0].Path.Len(), DefaultMaxSteps+1)
}
ev := sink.Snapshot()
if len(ev) != 1 || ev[0].Type != "max_steps" {
t.Errorf("expected a max_steps event, got %+v", ev)
}
}
func TestPolygonInside(t *testing.T) {
// Unit square at the equator.
square := []PolygonVertex{
{Lat: -1, Lng: -1},
{Lat: -1, Lng: 1},
{Lat: 1, Lng: 1},
{Lat: 1, Lng: -1},
}
c := NewPolygon(square, PolygonInside, ActionStop, "")
if !c.Violated(0, State{Lat: 0, Lng: 0}) {
t.Errorf("origin should be inside the square")
}
if c.Violated(0, State{Lat: 5, Lng: 0}) {
t.Errorf("(5, 0) should be outside the square")
}
}
func TestPolygonOutsideAntimeridian(t *testing.T) {
// A polygon centred near the antimeridian, spanning lng 170..-170
// (i.e. lng 170..190 in [0, 360) form).
poly := []PolygonVertex{
{Lat: -10, Lng: 170},
{Lat: -10, Lng: 190},
{Lat: 10, Lng: 190},
{Lat: 10, Lng: 170},
}
c := NewPolygon(poly, PolygonInside, ActionStop, "")
// A point at the antimeridian.
if !c.Violated(0, State{Lat: 0, Lng: 180}) {
t.Errorf("(0, 180) should be inside the antimeridian polygon")
}
if c.Violated(0, State{Lat: 0, Lng: 0}) {
t.Errorf("(0, 0) should be outside")
}
}

89
internal/engine/events.go Normal file
View file

@ -0,0 +1,89 @@
package engine
import "sync"
// Event is a non-fatal observation made during integration.
//
// Events generalise the warnings counter from the original Tawhiri port:
// any model or constraint can emit them, the EventSink aggregates by Type,
// and each Result carries a summary slice for the API to surface.
type Event struct {
Type string // short identifier, e.g. "above_model"
Time float64 // UNIX seconds when the event was emitted
State State
Message string
}
// EventSummary is the per-type aggregation of repeated emissions.
type EventSummary struct {
Type string `json:"type"`
Count int64 `json:"count"`
FirstTime float64 `json:"first_time"`
LastTime float64 `json:"last_time"`
FirstState State `json:"first_state"`
LastState State `json:"last_state"`
Message string `json:"message"`
}
// EventSink collects events from models and the integrator, aggregating
// duplicate types into a single EventSummary. Safe for concurrent use.
type EventSink struct {
mu sync.Mutex
summaries map[string]*EventSummary
}
// NewEventSink returns an empty sink.
func NewEventSink() *EventSink { return &EventSink{summaries: make(map[string]*EventSummary)} }
// Emit records one occurrence of typ at (t, s) with the provided message.
// Subsequent emits with the same typ update LastTime/LastState and Count.
func (s *EventSink) Emit(typ string, t float64, state State, message string) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
sum, ok := s.summaries[typ]
if !ok {
s.summaries[typ] = &EventSummary{
Type: typ, Count: 1,
FirstTime: t, LastTime: t,
FirstState: state, LastState: state,
Message: message,
}
return
}
sum.Count++
sum.LastTime = t
sum.LastState = state
if sum.Message == "" && message != "" {
sum.Message = message
}
}
// Snapshot returns a stable copy of every summary in deterministic order
// (sorted by Type).
func (s *EventSink) Snapshot() []EventSummary {
if s == nil {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
out := make([]EventSummary, 0, len(s.summaries))
for _, sum := range s.summaries {
out = append(out, *sum)
}
sortEventSummaries(out)
return out
}
func sortEventSummaries(s []EventSummary) {
// Insertion sort: usually one or two entries.
for i := 1; i < len(s); i++ {
j := i
for j > 0 && s[j-1].Type > s[j].Type {
s[j-1], s[j] = s[j], s[j-1]
j--
}
}
}

96
internal/engine/models.go Normal file
View file

@ -0,0 +1,96 @@
package engine
import (
"sort"
"predictor-refactored/internal/numerics"
"predictor-refactored/internal/weather"
)
// Sum composes models by summing their derivatives at each evaluation point.
//
// Useful for combining a vertical-rate model with a horizontal wind model
// into a single propagator. Equivalent to Tawhiri's LinearModel.
func Sum(models ...Model) Model {

Move all computations into numerics

Move all computations into numerics
if len(models) == 1 {
return models[0]
}
return func(t float64, s State) State {
var sum State
for _, m := range models {
sum = numerics.AddGeo(sum, m(t, s))
}
return sum
}
}
// ConstantRate returns a model with a constant vertical velocity (m/s).
// Positive rates are upward.
func ConstantRate(rate float64) Model {
return func(_ float64, _ State) State { return State{Altitude: rate} }
}
// ParachuteDescent returns a model where vertical velocity grows with
// altitude because thinner air provides less drag. seaLevelRate is the
// descent speed at sea level (m/s, positive).
//
// Terminal velocity at altitude is computed as
//
// v = -k / sqrt(rho(alt)), k = seaLevelRate * 1.1045,
//
// using the NASA atmosphere model for rho. Equivalent to Tawhiri's drag_descent.
func ParachuteDescent(seaLevelRate float64) Model {
return func(_ float64, s State) State {
return State{Altitude: numerics.DragTerminalVelocity(seaLevelRate, s.Altitude)}
}

Move into numerics

Move into numerics
}
// RateSegment is one entry in a Piecewise rate schedule. Until is the UNIX
// timestamp at which this segment ends — the model emits the segment's
// Rate for all t < Until. The final segment's Rate is held indefinitely.
type RateSegment struct {
Until float64
Rate float64
}
// Piecewise returns a model that produces a piecewise-constant vertical
// rate over a sequence of intervals. The input is sorted ascending by
// Until on construction; later segments shadow earlier ones.
func Piecewise(segments []RateSegment) Model {
if len(segments) == 0 {
return ConstantRate(0)
}
sorted := append([]RateSegment(nil), segments...)
sort.Slice(sorted, func(i, j int) bool { return sorted[i].Until < sorted[j].Until })
finalRate := sorted[len(sorted)-1].Rate
return func(t float64, _ State) State {
idx := sort.Search(len(sorted), func(i int) bool { return sorted[i].Until > t })
if idx == len(sorted) {
return State{Altitude: finalRate}
}
return State{Altitude: sorted[idx].Rate}
}
}
// WindTransport returns a model that moves laterally at the wind velocity
// sampled from field. The vertical component is zero. Sampling and the
// non-fatal "above_model" event live here (orchestration); the m/s → deg/s
// conversion is numerics.WindToGeoRate.
//
// If events is non-nil, an "above_model" event is emitted whenever the
// wind field reports altitude above the highest pressure level.
func WindTransport(field weather.WindField, events *EventSink) Model {
return func(t float64, s State) State {
sample, err := field.Wind(t, s.Lat, s.Lng, s.Altitude)
if err != nil {
return State{}
}
if sample.AboveModel && events != nil {
events.Emit("above_model", t, s,
"altitude exceeded the highest pressure level of the wind dataset; samples extrapolated")
}
dLat, dLng := numerics.WindToGeoRate(sample.U, sample.V, s.Lat, s.Altitude)
return State{Lat: dLat, Lng: dLng}
}
}

View file

@ -0,0 +1,69 @@
package engine
import "fmt"
// Operator is a scalar comparison used by generalised constraints like
// Altitude and Time. A constraint fires when its Operator.Test(value, limit)
// returns true.
type Operator int
const (
OpLess Operator = iota // value < limit
OpLessEqual // value ≤ limit
OpGreater // value > limit
OpGreaterEqual // value ≥ limit
OpEqual // value == limit
)
// Test evaluates op(value, limit).
func (o Operator) Test(value, limit float64) bool {
switch o {
case OpLess:
return value < limit
case OpLessEqual:
return value <= limit
case OpGreater:
return value > limit
case OpGreaterEqual:
return value >= limit
case OpEqual:
return value == limit
}
return false
}
// String returns the symbol "<", "<=", ">", ">=", "==".
func (o Operator) String() string {
switch o {
case OpLess:
return "<"
case OpLessEqual:
return "<="
case OpGreater:
return ">"
case OpGreaterEqual:
return ">="
case OpEqual:
return "=="
}
return "?"
}
// ParseOperator maps a textual operator to its Operator constant.
// Accepts "<", "<=", "le", ">", ">=", "ge", "==", "eq".
func ParseOperator(s string) (Operator, error) {
switch s {
case "<", "lt":
return OpLess, nil
case "<=", "le":
return OpLessEqual, nil
case ">", "gt":
return OpGreater, nil
case ">=", "ge":
return OpGreaterEqual, nil
case "==", "eq":
return OpEqual, nil
default:
return 0, fmt.Errorf("unknown operator %q", s)
}
}

View file

@ -0,0 +1,59 @@
package engine
// Profile is an ordered chain of propagators executed sequentially. Each
// propagator picks up where the previous one finished.
type Profile struct {
// Stages are run in order. For Direction=Reverse they are still
// iterated from index 0 onwards but each propagator integrates with
// negative dt.
Stages []*Propagator
// Direction controls the sign of dt across the profile.
Direction Direction
// Globals are constraints evaluated alongside each stage's local
// Constraints. Useful for profile-wide bounds like "stop after N hours".
Globals []Constraint
}
// Run executes the profile from the given launch point. Returns one
// Result per executed stage, including any Fallback chains that were
// activated. The supplied EventSink is shared across stages and aggregates
// non-fatal observations.
//
// events may be nil; pass NewEventSink() to capture observations.
func (p *Profile) Run(t0 float64, launch State, events *EventSink) []Result {
if p.Direction == 0 {
p.Direction = Forward
}
results := make([]Result, 0, len(p.Stages))
t, s := t0, launch
for _, stage := range p.Stages {
res := stage.run(p.context(t0, t, launch, s), t, s, p.Globals, events)
results = append(results, res)
t, s = res.Path.Last()
// Follow Fallback chains until none remains.
for res.Outcome == OutcomeFallback && stage.Fallback != nil {
stage = stage.Fallback
res = stage.run(p.context(t0, t, launch, s), t, s, p.Globals, events)
results = append(results, res)
t, s = res.Path.Last()
}
}
return results
}
// context builds the StageContext for a stage starting at (tStart, sStart).
func (p *Profile) context(t0, tStart float64, launch, sStart State) StageContext {
return StageContext{
ProfileStart: t0,
PropagatorStart: tStart,
Launch: launch,
PropagatorState: sStart,
Direction: p.Direction,
}
}

View file

@ -0,0 +1,147 @@
package engine
import "predictor-refactored/internal/numerics"
// Propagator advances state under one Model, checking a set of Constraints
// after every integration step.
//
// When a constraint fires, the propagator binary-search refines the
// violation point and emits it as its final trajectory point. The Action of
// the triggering constraint controls what the surrounding Profile does
// next: stop the profile, transfer to Fallback, or clip and continue.
//
// The per-step numerics (RK4 stepping, crossing refinement) are delegated to
// the numerics package; this type owns only the orchestration: constraint
// evaluation, action dispatch, and trajectory assembly.
type Propagator struct {
// Name identifies the propagator in trajectory metadata. Optional.
Name string
// Step is the magnitude of the integration step in seconds (always positive).
// The Profile flips its sign for Reverse direction.
Step float64
// Model is the per-second derivative function used for integration.
// One of Model or BuildModel must be non-nil. If both are set, BuildModel
// takes precedence (it is invoked once per stage with a StageContext).
Model Model
BuildModel func(ctx StageContext) Model
// Constraints are evaluated after each step. The first violation wins.
Constraints []Constraint
BuildConstraints func(ctx StageContext) []Constraint
// Fallback is the propagator to switch to when a constraint with
// ActionFallback fires. Optional.
Fallback *Propagator
// Tolerance is the binary-search refinement tolerance in parameter
// space (default 0.01, matching Tawhiri).
Tolerance float64
}
// estimatedSteps is the initial Path capacity; a typical balloon stage is a
// few hundred 60-second steps.
const estimatedSteps = 256
// DefaultMaxSteps bounds the number of integration steps a single propagator
// may take. It is a safety backstop, not a physical limit: a profile whose
// constraints never fire (e.g. a stage with no effective terminator) would
// otherwise integrate forever and exhaust memory. At the default 60-second
// step this allows ~8 simulated years, far beyond any real flight, so it only
// ever trips on a misconfigured profile.
const DefaultMaxSteps = 1_000_000
// run integrates the model from (t0, s0) in direction dir, returning a Result.
// globals are constraints injected by the Profile and checked alongside the
// propagator's local Constraints. events receives non-fatal observations.
func (p *Propagator) run(ctx StageContext, t0 float64, s0 State, globals []Constraint, events *EventSink) Result {
dt := p.Step * float64(ctx.Direction)
tol := p.Tolerance
if tol == 0 {
tol = 0.01
}
model := p.Model
if p.BuildModel != nil {
model = p.BuildModel(ctx)
}
constraints := p.Constraints
if p.BuildConstraints != nil {
constraints = p.BuildConstraints(ctx)
}
field := numerics.Field(model)
out := Result{Propagator: p.Name, Outcome: OutcomeContinued, Path: numerics.NewPath(estimatedSteps)}
out.Path.Append(t0, s0)
t, s := t0, s0
for range DefaultMaxSteps {
s2 := numerics.RK4Step(t, s, dt, field)
t2 := t + dt
c, fired := firstFiring(constraints, globals, t2, s2)
if !fired {
t, s = t2, s2
out.Path.Append(t, s)
continue
}
out.ViolationTime, out.ViolationState = t2, s2
t3, s3 := numerics.RefineCrossing(t, s, t2, s2, c.Violated, tol)
out.Constraint, out.ConstraintName = c, c.Name()
if c.Action() == ActionClip {
s3 = clipToConstraint(c, s3)
out.RefinedTime, out.RefinedState = t3, s3
out.Path.Append(t3, s3)
t, s = t3, s3
continue
}
out.RefinedTime, out.RefinedState = t3, s3
out.Path.Append(t3, s3)
if c.Action() == ActionFallback {
out.Outcome = OutcomeFallback
} else {
out.Outcome = OutcomeStopped
}
out.Events = events.Snapshot()
return out
}
// Step cap reached without any constraint firing — the profile has no
// effective terminator for this stage. Stop safely rather than loop forever.
events.Emit("max_steps", t, s,
"integration step limit reached without a constraint firing; check the stage's terminator")
out.Outcome = OutcomeContinued
out.Events = events.Snapshot()
return out
}
// firstFiring scans local then global constraints for the first one whose
// Violated returns true at (t, s).
func firstFiring(local, globals []Constraint, t float64, s State) (Constraint, bool) {
for _, c := range local {
if c.Violated(t, s) {
return c, true
}
}
for _, c := range globals {
if c.Violated(t, s) {
return c, true
}
}
return nil, false
}
// clipToConstraint adjusts s so the given constraint is exactly satisfied.
// Defined only for constraints with a well-defined coordinate boundary;
// others fall through unchanged.
func clipToConstraint(c Constraint, s State) State {
if alt, ok := c.(Altitude); ok {
s.Altitude = alt.Limit
}
return s
}

278
internal/engine/registry.go Normal file
View file

@ -0,0 +1,278 @@
package engine
import (
"fmt"
"sync"
"predictor-refactored/internal/weather"
)
// ConstraintSpec is the source-agnostic JSON-shape used to declare a
// constraint. The Type field is the registry key; remaining fields are
// extracted by the registered factory.
type ConstraintSpec struct {
Type string `json:"type"`
Action string `json:"action,omitempty"`
// Op is the comparison operator for scalar constraints (altitude, time).
Op string `json:"op,omitempty"`
Limit float64 `json:"limit,omitempty"`
// Vertices and Mode are used by the polygon constraint.
Vertices []PolygonVertex `json:"vertices,omitempty"`
Mode string `json:"mode,omitempty"`
// Label is an optional human-readable identifier surfaced via Name().
Label string `json:"label,omitempty"`
}
// ModelSpec is the source-agnostic JSON shape used to declare a model.
type ModelSpec struct {
Type string `json:"type"`
// Rate (m/s) for constant_rate.
Rate float64 `json:"rate,omitempty"`
// SeaLevelRate (m/s, positive) for parachute_descent.
SeaLevelRate float64 `json:"sea_level_rate,omitempty"`
// Segments for piecewise.
Segments []PiecewiseSegmentSpec `json:"segments,omitempty"`
// IncludeWind sums a WindTransport model into the resulting derivative.
IncludeWind bool `json:"include_wind,omitempty"`
}
// PiecewiseSegmentSpec is one entry in a piecewise rate schedule.
//
// Reference selects how the Until field is interpreted:
//
// - "absolute" (default): UNIX seconds.
// - "profile_start": seconds since the profile's launch time.
// - "propagator_start": seconds since this propagator began running.
type PiecewiseSegmentSpec struct {
Until float64 `json:"until"`
Rate float64 `json:"rate"`
Reference string `json:"reference,omitempty"`
}
// BuildDeps bundle the runtime dependencies factories may consult.
type BuildDeps struct {
Wind weather.WindField
Terrain TerrainProvider
Events *EventSink
}
// ConstraintFactory builds one Constraint from a spec.
type ConstraintFactory func(spec ConstraintSpec, deps BuildDeps) (Constraint, error)
// ModelFactory builds one model from a spec. The returned Built is held by
// a Propagator; if Build is set, it is invoked lazily by the profile
// runner before every stage so it can capture per-stage start times.
type ModelFactory func(spec ModelSpec, deps BuildDeps) (BuiltModel, error)
// BuiltModel is either an eager Model, a lazy Build, or both. The profile
// runner prefers Build when present.
type BuiltModel struct {
Model Model
Build func(ctx StageContext) Model
}
var (
regMu sync.RWMutex
constraintFactories = map[string]ConstraintFactory{}
modelFactories = map[string]ModelFactory{}
)
// RegisterConstraint installs a factory for typeName. Subsequent calls
// overwrite the previous factory.
func RegisterConstraint(typeName string, f ConstraintFactory) {
regMu.Lock()
defer regMu.Unlock()
constraintFactories[typeName] = f
}
// RegisterModel installs a model factory.
func RegisterModel(typeName string, f ModelFactory) {
regMu.Lock()
defer regMu.Unlock()
modelFactories[typeName] = f
}
// BuildConstraint dispatches spec to its registered factory.
func BuildConstraint(spec ConstraintSpec, deps BuildDeps) (Constraint, error) {
regMu.RLock()
f, ok := constraintFactories[spec.Type]
regMu.RUnlock()
if !ok {
return nil, fmt.Errorf("unknown constraint type %q", spec.Type)
}
return f(spec, deps)
}
// BuildModel dispatches spec to its registered factory.
func BuildModel(spec ModelSpec, deps BuildDeps) (BuiltModel, error) {
regMu.RLock()
f, ok := modelFactories[spec.Type]
regMu.RUnlock()
if !ok {
return BuiltModel{}, fmt.Errorf("unknown model type %q", spec.Type)
}
return f(spec, deps)
}
// RegisteredConstraints returns the names of every registered constraint type.
func RegisteredConstraints() []string {
regMu.RLock()
defer regMu.RUnlock()
out := make([]string, 0, len(constraintFactories))
for k := range constraintFactories {
out = append(out, k)
}
return out
}
// RegisteredModels returns the names of every registered model type.
func RegisteredModels() []string {
regMu.RLock()
defer regMu.RUnlock()
out := make([]string, 0, len(modelFactories))
for k := range modelFactories {
out = append(out, k)
}
return out
}
// --- Built-in registrations ------------------------------------------------
func init() {
RegisterConstraint("altitude", buildAltitude)
RegisterConstraint("time", buildTime)
RegisterConstraint("terrain_contact", buildTerrainContact)
RegisterConstraint("polygon", buildPolygon)
RegisterModel("constant_rate", buildConstantRate)
RegisterModel("parachute_descent", buildParachuteDescent)
RegisterModel("piecewise", buildPiecewise)
RegisterModel("wind", buildWind)
}
func buildAltitude(spec ConstraintSpec, _ BuildDeps) (Constraint, error) {
op, err := ParseOperator(spec.Op)
if err != nil {
return nil, fmt.Errorf("altitude: %w", err)
}
act, err := ParseAction(spec.Action)
if err != nil {
return nil, fmt.Errorf("altitude: %w", err)
}
return Altitude{Op: op, Limit: spec.Limit, On: act}, nil
}
func buildTime(spec ConstraintSpec, _ BuildDeps) (Constraint, error) {
op, err := ParseOperator(spec.Op)
if err != nil {
return nil, fmt.Errorf("time: %w", err)
}
act, err := ParseAction(spec.Action)
if err != nil {
return nil, fmt.Errorf("time: %w", err)
}
return Time{Op: op, Limit: spec.Limit, On: act}, nil
}
func buildTerrainContact(spec ConstraintSpec, deps BuildDeps) (Constraint, error) {
if deps.Terrain == nil {
return nil, fmt.Errorf("terrain_contact requires a terrain provider")
}
act, err := ParseAction(spec.Action)
if err != nil {
return nil, fmt.Errorf("terrain_contact: %w", err)
}
return TerrainContact{Provider: deps.Terrain, On: act}, nil
}
func buildPolygon(spec ConstraintSpec, _ BuildDeps) (Constraint, error) {
if len(spec.Vertices) < 3 {
return nil, fmt.Errorf("polygon requires at least 3 vertices")
}
act, err := ParseAction(spec.Action)
if err != nil {
return nil, fmt.Errorf("polygon: %w", err)
}
mode := PolygonInside
switch spec.Mode {
case "", "inside":
mode = PolygonInside
case "outside":
mode = PolygonOutside
default:
return nil, fmt.Errorf("polygon: unknown mode %q", spec.Mode)
}
return NewPolygon(spec.Vertices, mode, act, spec.Label), nil
}
func buildConstantRate(spec ModelSpec, _ BuildDeps) (BuiltModel, error) {
return BuiltModel{Model: ConstantRate(spec.Rate)}, nil
}
func buildParachuteDescent(spec ModelSpec, _ BuildDeps) (BuiltModel, error) {
if spec.SeaLevelRate <= 0 {
return BuiltModel{}, fmt.Errorf("parachute_descent requires positive sea_level_rate")
}
return BuiltModel{Model: ParachuteDescent(spec.SeaLevelRate)}, nil
}
func buildWind(_ ModelSpec, deps BuildDeps) (BuiltModel, error) {
if deps.Wind == nil {
return BuiltModel{}, fmt.Errorf("wind model requires a loaded wind field")
}
return BuiltModel{Model: WindTransport(deps.Wind, deps.Events)}, nil
}
func buildPiecewise(spec ModelSpec, deps BuildDeps) (BuiltModel, error) {
for _, s := range spec.Segments {
switch s.Reference {
case "", "absolute", "profile_start", "propagator_start":
default:
return BuiltModel{}, fmt.Errorf("piecewise: unknown segment reference %q", s.Reference)
}
}
// Always build lazily: the profile runner supplies a StageContext before
// each stage, which is what resolves absolute / profile-relative /
// propagator-relative segment times uniformly.
return BuiltModel{
Build: func(ctx StageContext) Model {
return maybeAddWind(Piecewise(resolveSegments(spec.Segments, ctx)), spec.IncludeWind, deps)
},
}, nil
}
// resolveSegments converts spec segments to engine.RateSegment, turning each
// segment's reference-relative Until into an absolute UNIX time. References
// are validated by buildPiecewise, so an unrecognised one here is treated as
// absolute rather than re-erroring.
func resolveSegments(in []PiecewiseSegmentSpec, ctx StageContext) []RateSegment {
out := make([]RateSegment, 0, len(in))
for _, s := range in {
out = append(out, RateSegment{Until: segmentBase(s.Reference, ctx) + s.Until, Rate: s.Rate})
}
return out
}
// segmentBase returns the absolute time a piecewise segment's Until is
// measured from, per its reference.
func segmentBase(reference string, ctx StageContext) float64 {
switch reference {
case "profile_start":
return ctx.ProfileStart
case "propagator_start":
return ctx.PropagatorStart
default: // "", "absolute"
return 0
}
}
// maybeAddWind sums a WindTransport model into base when the spec asks for it.
func maybeAddWind(base Model, includeWind bool, deps BuildDeps) Model {
if !includeWind {
return base
}
if deps.Wind == nil {
return base
}
return Sum(base, WindTransport(deps.Wind, deps.Events))
}

155
internal/engine/types.go Normal file
View file

@ -0,0 +1,155 @@
// Package engine is the trajectory calculation engine. It composes
// propagators (model-driven integrators) into profiles (ordered chains)
// over a wind field.
//
// The engine orchestrates the calculation; the numerically heavy work
// (RK4 stepping, crossing refinement, interpolation, atmosphere density,
// vector and polygon math) lives in the numerics package so it can be
// reimplemented in a faster language without touching this layer.
//
// The engine has no direct dependency on any specific data source: wind
// data is consumed through weather.WindField and terrain data through
// any type satisfying TerrainProvider.
package engine
import "predictor-refactored/internal/numerics"
// State is the spatial state of the balloon: latitude/longitude in degrees,
// altitude in metres. When returned by a Model the same struct is the
// per-second derivative. It is an alias of numerics.GeoVec so the engine and
// the numeric core share one hot-path value type without conversions.
type State = numerics.GeoVec
// Model returns the time derivative of state at (t, s).
//
// The derivative is direction-independent; the integrator applies the
// sign of dt for reverse propagation.
type Model func(t float64, s State) State
// Direction is the time direction of integration.
type Direction int8
const (
Forward Direction = +1
Reverse Direction = -1
)
// Action is what the profile runner does on a constraint violation.
type Action int
const (
// ActionStop ends the current propagator at the refined violation point.
ActionStop Action = iota
// ActionFallback ends the current propagator and starts its Fallback
// propagator from the refined violation point.
ActionFallback
// ActionClip clips the violated coordinate to the boundary and continues
// integration.
ActionClip
)
// ParseAction maps "stop" | "fallback" | "clip" to an Action.
func ParseAction(s string) (Action, error) {
switch s {
case "", "stop":
return ActionStop, nil
case "fallback":
return ActionFallback, nil
case "clip":
return ActionClip, nil
default:
return 0, errUnknownAction(s)
}
}
type errUnknownAction string
func (e errUnknownAction) Error() string { return "unknown constraint action " + string(e) }
// Constraint defines a stopping, branching, or clipping condition.
type Constraint interface {
// Name identifies the constraint in logs and result metadata.
Name() string
// Violated reports whether the constraint is breached at (t, s).
Violated(t float64, s State) bool
// Action is the behaviour to take on violation.
Action() Action
}
// TerrainProvider returns ground elevation in metres at a coordinate.
type TerrainProvider interface {
Elevation(lat, lng float64) float64
}
// StageContext is provided to Propagator.BuildModel and BuildConstraints by
// the profile runner immediately before each stage executes.
type StageContext struct {
// ProfileStart is the UNIX timestamp of the profile's initial launch.
ProfileStart float64
// PropagatorStart is the UNIX timestamp at which this propagator begins
// running — equal to ProfileStart for the first stage; the end-time of
// the previous stage thereafter.
PropagatorStart float64
// Launch is the profile's initial state.
Launch State
// PropagatorState is the state at which this propagator begins.
PropagatorState State
// Direction is the integration direction the profile is configured with.
Direction Direction
}
// Outcome describes how a propagator's run ended.
type Outcome int
const (
// OutcomeStopped means a Constraint with ActionStop fired.
OutcomeStopped Outcome = iota
// OutcomeFallback means a Constraint with ActionFallback fired.
OutcomeFallback
// OutcomeContinued means the propagator finished without a constraint
// firing — only seen when a propagator is misconfigured to run unbounded.
OutcomeContinued
)
// String renders the outcome as a stable string for API serialisation.
func (o Outcome) String() string {
switch o {
case OutcomeStopped:
return "stopped"
case OutcomeFallback:
return "fallback"
default:
return "continued"
}
}
// Result is the output of running one propagator.
type Result struct {
// Propagator is the propagator's Name.
Propagator string
// Path is the emitted trajectory in struct-of-arrays form.
Path numerics.Path
// Outcome describes how the propagator terminated.
Outcome Outcome
// Constraint is the constraint that fired, or nil if Outcome is OutcomeContinued.
Constraint Constraint
// ConstraintName captures Constraint.Name() at fire time so callers can
// serialise the result after the Constraint has been garbage collected.
ConstraintName string
// ViolationTime / ViolationState describe the first integration step at
// which the constraint reported a violation, before binary-search refinement.
ViolationTime float64
ViolationState State
// RefinedTime / RefinedState describe the refined violation point that
// appears as the propagator's last trajectory point.
RefinedTime float64
RefinedState State
// Events is the aggregated set of non-fatal observations from this stage.
Events []EventSummary
}

146
internal/metrics/prom.go Normal file
View file

@ -0,0 +1,146 @@
package metrics
import (
"fmt"
"io"
"net/http"
"sort"
"strings"
"sync"
"time"
)
// Prom is a minimal Sink that exposes counters and gauges in Prometheus's
// text exposition format. No external dependencies.
//
// The Prom sink supports labelled counters, sums (for durations and byte
// counts), and labelled gauges. Histograms are intentionally omitted; if
// they are needed later, swap Prom for an OTel-based sink.
type Prom struct {
mu sync.Mutex
counters map[string]map[string]float64 // name → label-key → value
gauges map[string]map[string]float64 // name → label-key → value
}
// NewProm returns an empty Prom sink.
func NewProm() *Prom {
return &Prom{
counters: make(map[string]map[string]float64),
gauges: make(map[string]map[string]float64),
}
}
// Prediction implements Sink.
func (p *Prom) Prediction(profile string, d time.Duration, err error) {
status := "ok"
if err != nil {
status = "error"
}
labels := map[string]string{"profile": profile, "status": status}
p.incCounter("predictor_predictions_total", labels, 1)
p.incCounter("predictor_prediction_duration_seconds_sum", labels, d.Seconds())
p.incCounter("predictor_prediction_duration_seconds_count", labels, 1)
}
// Download implements Sink.
func (p *Prom) Download(source string, d time.Duration, status string, bytes int64) {
labels := map[string]string{"source": source, "status": status}
p.incCounter("predictor_downloads_total", labels, 1)
p.incCounter("predictor_download_duration_seconds_sum", labels, d.Seconds())
p.incCounter("predictor_download_bytes_total", map[string]string{"source": source}, float64(bytes))
}
// ActiveEpoch implements Sink.
func (p *Prom) ActiveEpoch(t time.Time) {
var v float64
if !t.IsZero() {
v = float64(t.Unix())
}
p.setGauge("predictor_active_dataset_epoch_seconds", map[string]string{}, v)
}
// ServeHTTP writes the metrics in Prometheus text exposition format.
func (p *Prom) ServeHTTP(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain; version=0.0.4")
p.Write(w)
}
// Write writes the metrics in Prometheus exposition format to w.
func (p *Prom) Write(w io.Writer) {
p.mu.Lock()
defer p.mu.Unlock()
names := make([]string, 0, len(p.counters)+len(p.gauges))
for n := range p.counters {
names = append(names, n)
}
for n := range p.gauges {
names = append(names, n)
}
sort.Strings(names)
for _, name := range names {
if labels, ok := p.counters[name]; ok {
fmt.Fprintf(w, "# TYPE %s counter\n", name)
writeMetricFamily(w, name, labels)
}
if labels, ok := p.gauges[name]; ok {
fmt.Fprintf(w, "# TYPE %s gauge\n", name)
writeMetricFamily(w, name, labels)
}
}
}
func writeMetricFamily(w io.Writer, name string, labels map[string]float64) {
keys := make([]string, 0, len(labels))
for k := range labels {
keys = append(keys, k)
}
sort.Strings(keys)
for _, key := range keys {
fmt.Fprintf(w, "%s%s %g\n", name, key, labels[key])
}
}
func (p *Prom) incCounter(name string, labels map[string]string, n float64) {
key := labelKey(labels)
p.mu.Lock()
defer p.mu.Unlock()
if p.counters[name] == nil {
p.counters[name] = make(map[string]float64)
}
p.counters[name][key] += n
}
func (p *Prom) setGauge(name string, labels map[string]string, v float64) {
key := labelKey(labels)
p.mu.Lock()
defer p.mu.Unlock()
if p.gauges[name] == nil {
p.gauges[name] = make(map[string]float64)
}
p.gauges[name][key] = v
}
// labelKey renders the labels into a Prometheus-format "{k1="v1",k2="v2"}"
// suffix, empty if no labels.
func labelKey(labels map[string]string) string {
if len(labels) == 0 {
return ""
}
keys := make([]string, 0, len(labels))
for k := range labels {
keys = append(keys, k)
}
sort.Strings(keys)
var sb strings.Builder
sb.WriteByte('{')
for i, k := range keys {
if i > 0 {
sb.WriteByte(',')
}
fmt.Fprintf(&sb, "%s=%q", k, labels[k])
}
sb.WriteByte('}')
return sb.String()
}

View file

@ -0,0 +1,49 @@
package metrics
import (
"bytes"
"strings"
"testing"
"time"
)
func TestPromCounters(t *testing.T) {
p := NewProm()
p.Prediction("standard_profile", 100*time.Millisecond, nil)
p.Prediction("standard_profile", 200*time.Millisecond, nil)
p.Prediction("float_profile", 50*time.Millisecond, nil)
var buf bytes.Buffer
p.Write(&buf)
out := buf.String()
if !strings.Contains(out, `predictor_predictions_total{profile="standard_profile",status="ok"} 2`) {
t.Errorf("expected count=2 for standard_profile, got: %s", out)
}
if !strings.Contains(out, `predictor_predictions_total{profile="float_profile",status="ok"} 1`) {
t.Errorf("expected count=1 for float_profile, got: %s", out)
}
// Sum of durations: 0.1 + 0.2 = 0.3 seconds.
if !strings.Contains(out, "predictor_prediction_duration_seconds_sum") {
t.Errorf("expected sum present, got: %s", out)
}
}
func TestPromGauge(t *testing.T) {
p := NewProm()
p.ActiveEpoch(time.Unix(1700000000, 0))
var buf bytes.Buffer
p.Write(&buf)
out := buf.String()
if !strings.Contains(out, "predictor_active_dataset_epoch_seconds 1.7e+09") {
t.Errorf("expected gauge with epoch 1700000000, got: %s", out)
}
}
func TestNoop(t *testing.T) {
sink := Noop()
sink.Prediction("any", time.Second, nil)
sink.Download("any", time.Second, "complete", 0)
sink.ActiveEpoch(time.Now())
}

36
internal/metrics/types.go Normal file
View file

@ -0,0 +1,36 @@
// Package metrics defines the Sink interface used to record service metrics
// and ships two implementations: a Noop sink (default, zero-cost) and a Prom
// sink that exposes counters in the Prometheus text exposition format.
//
// The metrics layer is optional: if no Sink is wired (or Noop is wired), the
// service runs unchanged.
package metrics
import "time"
// Sink collects observations from the rest of the service.
//
// Implementations must be safe for concurrent use across many goroutines.
// All methods are advisory; implementations may ignore any observation.
type Sink interface {
// Prediction records the duration and outcome of one prediction.
// err is nil on success; otherwise the error's class is used as a label.
Prediction(profile string, duration time.Duration, err error)
// Download records the outcome of one dataset download job.
// status is "complete", "failed", or "cancelled".
Download(source string, duration time.Duration, status string, bytes int64)
// ActiveEpoch reports the forecast time of the currently-loaded dataset.
// Pass time.Time{} when no dataset is loaded.
ActiveEpoch(t time.Time)
}
// Noop returns a Sink that discards every observation.
func Noop() Sink { return noop{} }
type noop struct{}
func (noop) Prediction(string, time.Duration, error) {}
func (noop) Download(string, time.Duration, string, int64) {}
func (noop) ActiveEpoch(time.Time) {}

View file

@ -0,0 +1,38 @@
package numerics
import "math"
// NasaDensity returns air density in kg/m^3 at the given altitude in metres,
// using the NASA piecewise standard-atmosphere model.
// See https://www.grc.nasa.gov/WWW/K-12/airplane/atmosmet.html.
//
// The model is split into three altitude bands (troposphere, lower
// stratosphere, upper stratosphere); density is pressure / (0.2869 * T_K).
func NasaDensity(alt float64) float64 {
var temp, pressure float64
switch {
case alt > 25000:
temp = -131.21 + 0.00299*alt
pressure = 2.488 * math.Pow((temp+273.1)/216.6, -11.388)
case alt > 11000:
temp = -56.46
pressure = 22.65 * math.Exp(1.73-0.000157*alt)
default:
temp = 15.04 - 0.00649*alt
pressure = 101.29 * math.Pow((temp+273.1)/288.08, 5.256)
}
return pressure / (0.2869 * (temp + 273.1))
}
// DragTerminalVelocity returns the vertical velocity (m/s, negative = downward)
// of a parachute descent at the given altitude. seaLevelRate is the descent
// speed at sea level (positive m/s); the rate grows with altitude as the
// thinner air provides less drag:
//
// v = -k / sqrt(rho(alt)), k = seaLevelRate * 1.1045
//
// Matches Tawhiri's drag_descent.
func DragTerminalVelocity(seaLevelRate, alt float64) float64 {
k := seaLevelRate * 1.1045
return -k / math.Sqrt(NasaDensity(alt))
}

11
internal/numerics/doc.go Normal file
View file

@ -0,0 +1,11 @@
// Package numerics provides the numerical primitives used by the trajectory
// engine: regular-grid multilinear interpolation, monotone bisection, and
// a generic explicit Runge-Kutta-4 integrator with binary-search refinement
// of a termination point.
//
// The package has no dependencies on any domain type. State and derivative
// types are generic, and all coordinate-wrap or unit-conversion semantics
// live in the caller.
//
// All algorithms are documented in docs/numerics.tex.
package numerics

View file

@ -0,0 +1,41 @@
package numerics
import "math"
// PointInPolygon reports whether (lat, lng) lies inside the closed polygon
// whose vertices are given as parallel latitude/longitude slices (degrees).
//
// The test is ray casting in plate-carrée space. Every longitude is
// normalised to within 180° of the first vertex before testing, so a polygon
// spanning the antimeridian is handled correctly as long as it spans no more
// than 180° in longitude. polyLat and polyLng must have equal length >= 3.
func PointInPolygon(lat, lng float64, polyLat, polyLng []float64) bool {
n := len(polyLat)
if n < 3 || len(polyLng) != n {
return false
}
ref := polyLng[0]
qx := NormalizeLng(lng, ref)
inside := false
for i, j := 0, n-1; i < n; j, i = i, i+1 {
yi, yj := polyLat[i], polyLat[j]
xi := NormalizeLng(polyLng[i], ref)
xj := NormalizeLng(polyLng[j], ref)
if (yi > lat) != (yj > lat) {
xIntersect := (xj-xi)*(lat-yi)/(yj-yi) + xi
if qx < xIntersect {
inside = !inside
}
}
}
return inside
}
// NormalizeLng rewrites v so that it lies within 180° of ref. For example,
// NormalizeLng(350, 10) returns -10. Used to make longitude comparisons
// continuous across the antimeridian.
func NormalizeLng(v, ref float64) float64 {
return ref + math.Mod(v-ref+540, 360) - 180
}

129
internal/numerics/grid.go Normal file
View file

@ -0,0 +1,129 @@
package numerics
import "fmt"
// Axis describes a regularly-spaced grid axis with N grid points,
// values left, left+step, left+2*step, ..., left+(N-1)*step.
//
// If Wrap is true, the axis is periodic with period N*step (e.g. longitude).
// A query value at left+N*step wraps to the value at left+0*step. Locate
// returns Hi = 0 in that case.
type Axis struct {
Left float64
Step float64
N int
Wrap bool
Name string
}
// AxisError is returned by Axis.Locate when value lies outside a non-wrapping axis.
type AxisError struct {
Axis string
Value float64
}
func (e *AxisError) Error() string {
return fmt.Sprintf("%s=%v out of range", e.Axis, e.Value)
}
// Bracket holds the two surrounding grid indices and the fractional position
// of a value within an axis. The weight at Lo is (1 - Frac); the weight at Hi
// is Frac. Frac lies in [0, 1).
type Bracket struct {
Lo, Hi int
Frac float64
}
// Locate returns the bracket containing value within the axis.
// For a non-wrapping axis, value must lie in [Left, Left + (N-1)*Step);
// for a wrapping axis, value must lie in [Left, Left + N*Step).
func (a Axis) Locate(value float64) (Bracket, error) {
pos := (value - a.Left) / a.Step
lo := int(pos) // truncates toward zero; pos is non-negative for valid inputs
maxLo := a.N - 2
if a.Wrap {
maxLo = a.N - 1
}
if lo < 0 || lo > maxLo {
return Bracket{}, &AxisError{Axis: a.Name, Value: value}
}
hi := lo + 1
if a.Wrap && hi == a.N {
hi = 0
}
return Bracket{Lo: lo, Hi: hi, Frac: pos - float64(lo)}, nil
}
// TrilinearWeights returns the eight corner weights for a (axis0, axis1,
// axis2) bracket triple, in the canonical visiting order
//
// (0,0,0) (0,0,1) (0,1,0) (0,1,1) (1,0,0) (1,0,1) (1,1,0) (1,1,1)
//
// where the bit triple selects Lo (0) or Hi (1) on each axis. The weights sum
// to 1. Pair this with Dot8 over corner values fetched in the same order.
func TrilinearWeights(b3 [3]Bracket) [8]float64 {
wa0, wa1 := 1-b3[0].Frac, b3[0].Frac
wb0, wb1 := 1-b3[1].Frac, b3[1].Frac
wc0, wc1 := 1-b3[2].Frac, b3[2].Frac
wa0wb0 := wa0 * wb0
wa0wb1 := wa0 * wb1
wa1wb0 := wa1 * wb0
wa1wb1 := wa1 * wb1
return [8]float64{
wa0wb0 * wc0,
wa0wb0 * wc1,
wa0wb1 * wc0,
wa0wb1 * wc1,
wa1wb0 * wc0,
wa1wb0 * wc1,
wa1wb1 * wc0,
wa1wb1 * wc1,
}
}
// Dot8 returns the multiply-accumulate sum w[0]*v[0] + ... + w[7]*v[7].
//
// The fixed length and straight-line accumulation are written so the Go
// compiler can keep the values in registers and a future hand-vectorised
// port can replace the body with a single SIMD MAC. The accumulation order
// is fixed (ascending index) so results are reproducible.
func Dot8(w, v *[8]float64) float64 {
acc := w[0] * v[0]
acc = w[1]*v[1] + acc
acc = w[2]*v[2] + acc
acc = w[3]*v[3] + acc
acc = w[4]*v[4] + acc
acc = w[5]*v[5] + acc
acc = w[6]*v[6] + acc
acc = w[7]*v[7] + acc
return acc
}
// EvalTrilinear samples a 3D field via f at the eight corners defined by b3
// and returns the trilinearly interpolated value.
//
// Corners are visited in the canonical order documented on TrilinearWeights.
// With f(i,j,k) = a*i + b*j + c*k + d this returns a*pos0 + b*pos1 + c*pos2
// + d, modulo floating-point rounding. For the hot path prefer precomputing
// weights once via TrilinearWeights and reducing with Dot8.
func EvalTrilinear(b3 [3]Bracket, f func(i, j, k int) float64) float64 {
w := TrilinearWeights(b3)
a0, a1 := b3[0].Lo, b3[0].Hi
b0, b1 := b3[1].Lo, b3[1].Hi
c0, c1 := b3[2].Lo, b3[2].Hi
v := [8]float64{
f(a0, b0, c0),
f(a0, b0, c1),
f(a0, b1, c0),
f(a0, b1, c1),
f(a1, b0, c0),
f(a1, b0, c1),
f(a1, b1, c0),
f(a1, b1, c1),
}
return Dot8(&w, &v)
}

View file

@ -0,0 +1,94 @@
package numerics
import (
"math"
"testing"
)
func TestAxisLocate(t *testing.T) {
a := Axis{Left: -90, Step: 0.5, N: 361, Name: "lat"}
b, err := a.Locate(-90)
if err != nil || b.Lo != 0 || b.Hi != 1 || b.Frac != 0 {
t.Errorf("Locate(-90) = %+v, %v; want {0 1 0}, nil", b, err)
}
b, err = a.Locate(0)
if err != nil || b.Lo != 180 || b.Hi != 181 || b.Frac != 0 {
t.Errorf("Locate(0) = %+v, %v; want {180 181 0}, nil", b, err)
}
b, err = a.Locate(-89.75)
if err != nil || b.Lo != 0 || b.Hi != 1 || math.Abs(b.Frac-0.5) > 1e-12 {
t.Errorf("Locate(-89.75) = %+v, %v; want frac=0.5", b, err)
}
// 90 is exactly on the upper boundary — there's no Hi above it
if _, err := a.Locate(90); err == nil {
t.Errorf("Locate(90) should error, got nil")
}
if _, err := a.Locate(-91); err == nil {
t.Errorf("Locate(-91) should error, got nil")
}
}
func TestAxisLocateWrap(t *testing.T) {
a := Axis{Left: 0, Step: 0.5, N: 720, Wrap: true, Name: "lng"}
b, err := a.Locate(0)
if err != nil || b.Lo != 0 || b.Hi != 1 || b.Frac != 0 {
t.Errorf("Locate(0) = %+v, %v", b, err)
}
// Right up against the wrap boundary
b, err = a.Locate(359.75)
if err != nil || b.Lo != 719 || b.Hi != 0 || math.Abs(b.Frac-0.5) > 1e-12 {
t.Errorf("Locate(359.75) = %+v, %v; want {719 0 0.5}", b, err)
}
// 360 is outside the half-open interval
if _, err := a.Locate(360); err == nil {
t.Errorf("Locate(360) should error, got nil")
}
}
func TestEvalTrilinear(t *testing.T) {
// Field f(i,j,k) = 100*i + 10*j + k.
f := func(i, j, k int) float64 { return 100*float64(i) + 10*float64(j) + float64(k) }
// At all fractions = 0.5, expected value is the mean of the 8 corners.
bs := [3]Bracket{{Lo: 0, Hi: 1, Frac: 0.5}, {Lo: 0, Hi: 1, Frac: 0.5}, {Lo: 0, Hi: 1, Frac: 0.5}}
got := EvalTrilinear(bs, f)
want := (0 + 1 + 10 + 11 + 100 + 101 + 110 + 111) / 8.0
if math.Abs(got-want) > 1e-12 {
t.Errorf("EvalTrilinear at center = %v, want %v", got, want)
}
// At all fractions = 0, expected value is f(lo, lo, lo) = 0.
bs = [3]Bracket{{Lo: 0, Hi: 1, Frac: 0}, {Lo: 0, Hi: 1, Frac: 0}, {Lo: 0, Hi: 1, Frac: 0}}
got = EvalTrilinear(bs, f)
if got != 0 {
t.Errorf("EvalTrilinear at (lo,lo,lo) = %v, want 0", got)
}
// Asymmetric: linear field f(i,j,k) = i should give frac of axis 0 exactly.
f2 := func(i, _, _ int) float64 { return float64(i) }
bs = [3]Bracket{{Lo: 0, Hi: 1, Frac: 0.3}, {Lo: 0, Hi: 1, Frac: 0.7}, {Lo: 0, Hi: 1, Frac: 0.9}}
got = EvalTrilinear(bs, f2)
if math.Abs(got-0.3) > 1e-12 {
t.Errorf("EvalTrilinear of i-field = %v, want 0.3", got)
}
}
func TestLerp(t *testing.T) {
if Lerp(10, 20, 0) != 10 {
t.Errorf("Lerp(10, 20, 0) != 10")
}
if Lerp(10, 20, 1) != 20 {
t.Errorf("Lerp(10, 20, 1) != 20")
}
if math.Abs(Lerp(10, 20, 0.25)-12.5) > 1e-12 {
t.Errorf("Lerp(10, 20, 0.25) != 12.5")
}
}

View file

@ -0,0 +1,58 @@
package numerics
import (
"math"
"testing"
)
func TestAddGeo(t *testing.T) {
// Rates sum component-wise with no longitude wrapping.
got := AddGeo(GeoVec{Lat: 1, Lng: 350, Altitude: 2}, GeoVec{Lat: 3, Lng: 20, Altitude: 4})
want := GeoVec{Lat: 4, Lng: 370, Altitude: 6}
if got != want {
t.Errorf("AddGeo = %+v, want %+v (no wrap on rates)", got, want)
}
}
func TestWindToGeoRate(t *testing.T) {
// Pure eastward 10 m/s at the equator, sea level.
dLat, dLng := WindToGeoRate(10, 0, 0, 0)
wantLng := (180.0 / math.Pi) * 10.0 / EarthRadius
if math.Abs(dLat) > 1e-15 {
t.Errorf("dLat = %v, want 0", dLat)
}
if math.Abs(dLng-wantLng) > 1e-15 {
t.Errorf("dLng = %v, want %v", dLng, wantLng)
}
// Northward 5 m/s at 60°N: dLat independent of longitude scaling.
dLat, _ = WindToGeoRate(0, 5, 60, 0)
wantLat := (180.0 / math.Pi) * 5.0 / EarthRadius
if math.Abs(dLat-wantLat) > 1e-15 {
t.Errorf("dLat at 60N = %v, want %v", dLat, wantLat)
}
// cos(lat) factor makes eastward motion span more degrees nearer the poles.
_, dLngEq := WindToGeoRate(10, 0, 0, 0)
_, dLng60 := WindToGeoRate(10, 0, 60, 0)
if dLng60 <= dLngEq {
t.Errorf("eastward deg/s should grow with latitude: eq=%v 60N=%v", dLngEq, dLng60)
}
}
func TestDragTerminalVelocity(t *testing.T) {
// Descent is downward (negative) and faster (more negative) at altitude
// where the air is thinner.
sea := DragTerminalVelocity(5, 0)
high := DragTerminalVelocity(5, 20000)
if sea >= 0 {
t.Errorf("sea-level rate = %v, want negative (downward)", sea)
}
if high >= sea {
t.Errorf("expected faster descent at altitude: sea=%v high=%v", sea, high)
}
// Sanity: at sea level rho≈1.225, so v ≈ -5*1.1045/sqrt(1.225) ≈ -4.99 m/s.
if math.Abs(sea-(-5*1.1045/math.Sqrt(NasaDensity(0)))) > 1e-12 {
t.Errorf("sea-level formula mismatch: %v", sea)
}
}

94
internal/numerics/ode.go Normal file
View file

@ -0,0 +1,94 @@
package numerics
// Field returns the time derivative of a geographic state at (t, y).
// The derivative is direction-independent; the integrator applies the sign
// of dt for reverse-time integration.
type Field func(t float64, y GeoVec) GeoVec
// Crossed reports whether a termination condition holds at (t, y).
type Crossed func(t float64, y GeoVec) bool
// RK4Step performs one classical Runge-Kutta-4 step from (t, y) with step dt.
// dt may be negative to integrate backwards in time. Longitude wrapping is
// applied at every intermediate add via GeoAdd, matching the reference
// integrator. The function performs no heap allocation.
func RK4Step(t float64, y GeoVec, dt float64, f Field) GeoVec {
half := dt / 2
k1 := f(t, y)
k2 := f(t+half, GeoAdd(y, half, k1))
k3 := f(t+half, GeoAdd(y, half, k2))
k4 := f(t+dt, GeoAdd(y, dt, k3))
y2 := GeoAdd(y, dt/6, k1)
y2 = GeoAdd(y2, dt/3, k2)
y2 = GeoAdd(y2, dt/3, k3)
y2 = GeoAdd(y2, dt/6, k4)
return y2
}
// RefineCrossing locates a crossing between (t1, y1) (not crossed) and
// (t2, y2) (crossed) by binary search in the linear-interpolation parameter
// space, stopping when the parameter interval is narrower than tol.
//
// It returns the final midpoint sampled, matching Tawhiri's solver.pyx: the
// returned point is not guaranteed to satisfy the predicate, but for tol << 1
// it is within one tolerance-width of the true crossing.
func RefineCrossing(t1 float64, y1 GeoVec, t2 float64, y2 GeoVec, crossed Crossed, tol float64) (float64, GeoVec) {
left, right := 0.0, 1.0
t3, y3 := t2, y2
for right-left > tol {
mid := (left + right) / 2
t3 = Lerp(t1, t2, mid)
y3 = GeoLerp(y1, y2, mid)
if crossed(t3, y3) {
right = mid
} else {
left = mid
}
}
return t3, y3
}
// Path is a struct-of-arrays trajectory: parallel slices of time and the
// three state components. SoA layout keeps each component contiguous, which
// is friendlier to cache and to vectorised post-processing than a slice of
// point structs, and lets the integrator append with a single bounds check
// per component.
type Path struct {
T []float64
Lat []float64
Lng []float64
Altitude []float64
}
// NewPath returns a Path with capacity reserved for n points.
func NewPath(n int) Path {
return Path{
T: make([]float64, 0, n),
Lat: make([]float64, 0, n),
Lng: make([]float64, 0, n),
Altitude: make([]float64, 0, n),
}
}
// Len returns the number of points in the path.
func (p *Path) Len() int { return len(p.T) }
// Append adds one point to the path.
func (p *Path) Append(t float64, y GeoVec) {
p.T = append(p.T, t)
p.Lat = append(p.Lat, y.Lat)
p.Lng = append(p.Lng, y.Lng)
p.Altitude = append(p.Altitude, y.Altitude)
}
// Last returns the final (t, state) of the path. It panics on an empty path.
func (p *Path) Last() (float64, GeoVec) {
i := len(p.T) - 1
return p.T[i], GeoVec{Lat: p.Lat[i], Lng: p.Lng[i], Altitude: p.Altitude[i]}
}
// At returns the point at index i.
func (p *Path) At(i int) (float64, GeoVec) {
return p.T[i], GeoVec{Lat: p.Lat[i], Lng: p.Lng[i], Altitude: p.Altitude[i]}
}

View file

@ -0,0 +1,78 @@
package numerics
import (
"math"
"testing"
)
func TestRK4ExponentialDecay(t *testing.T) {
// dAlt/dt = -Alt → exact: Alt(t) = Alt0 * exp(-t).
f := func(_ float64, y GeoVec) GeoVec { return GeoVec{Altitude: -y.Altitude} }
y := GeoVec{Altitude: 1}
tnow, dt := 0.0, 0.01
for range 100 {
y = RK4Step(tnow, y, dt, f)
tnow += dt
}
want := math.Exp(-1.0)
if math.Abs(y.Altitude-want) > 1e-8 {
t.Errorf("RK4 exp decay at t=1: got %v, want %v", y.Altitude, want)
}
}
func TestRK4ReverseTime(t *testing.T) {
// dAlt/dt = Alt → exact: Alt(t) = Alt0 * exp(t).
f := func(_ float64, y GeoVec) GeoVec { return GeoVec{Altitude: y.Altitude} }
y := GeoVec{Altitude: math.E}
tnow, dt := 1.0, -0.01
for range 100 {
y = RK4Step(tnow, y, dt, f)
tnow += dt
}
if math.Abs(y.Altitude-1.0) > 1e-8 {
t.Errorf("RK4 reverse: got %v, want 1.0", y.Altitude)
}
}
func TestRefineCrossing(t *testing.T) {
y1 := GeoVec{Altitude: 1}
y2 := GeoVec{Altitude: -1.5}
crossed := func(_ float64, y GeoVec) bool { return y.Altitude <= 0 }
tr, yr := RefineCrossing(0, y1, 1, y2, crossed, 0.001)
if math.Abs(tr-0.4) > 0.01 {
t.Errorf("refined t = %v, want ~0.4", tr)
}
if math.Abs(yr.Altitude) > 0.01 {
t.Errorf("refined alt = %v, want ~0", yr.Altitude)
}
}
func TestGeoAddWrapsLongitude(t *testing.T) {
y := GeoAdd(GeoVec{Lng: 350}, 1, GeoVec{Lng: 20})
if math.Abs(y.Lng-10) > 1e-9 {
t.Errorf("GeoAdd wrap: lng = %v, want 10", y.Lng)
}
}
func TestGeoLerpWrap(t *testing.T) {
mid := GeoLerp(GeoVec{Lng: 350}, GeoVec{Lng: 10}, 0.5)
if math.Abs(mid.Lng) > 1e-9 && math.Abs(mid.Lng-360) > 1e-9 {
t.Errorf("GeoLerp lng wrap: %v, want 0 or 360", mid.Lng)
}
}
func TestPathSoA(t *testing.T) {
p := NewPath(4)
p.Append(0, GeoVec{Lat: 1, Lng: 2, Altitude: 3})
p.Append(60, GeoVec{Lat: 4, Lng: 5, Altitude: 6})
if p.Len() != 2 {
t.Fatalf("len = %d, want 2", p.Len())
}
tt, last := p.Last()
if tt != 60 || last.Lat != 4 {
t.Errorf("last = %v, %+v", tt, last)
}
}

View file

@ -0,0 +1,19 @@
package numerics
// Bisect returns the largest index i in [imin, imax] such that f(i) < target,
// assuming f is monotonically nondecreasing on that range.
//
// If target <= f(imin), returns imin. If target > f(imax), returns imax.
// Performs O(log(imax-imin)) evaluations of f.
func Bisect(imin, imax int, target float64, f func(i int) float64) int {
lo, hi := imin, imax
for lo < hi {
mid := (lo + hi + 1) / 2
if target <= f(mid) {
hi = mid - 1
} else {
lo = mid
}
}
return lo
}

View file

@ -0,0 +1,28 @@
package numerics
import "testing"
func TestBisect(t *testing.T) {
// f(i) = 10*i, monotone increasing.
f := func(i int) float64 { return 10 * float64(i) }
// target = 25 → largest i with 10i < 25 is i=2
if got := Bisect(0, 10, 25, f); got != 2 {
t.Errorf("Bisect target=25 = %d, want 2", got)
}
// target on boundary: target = 30, condition is target <= f(mid) so f(3)=30 → not less; want 2
if got := Bisect(0, 10, 30, f); got != 2 {
t.Errorf("Bisect target=30 = %d, want 2", got)
}
// target below all values
if got := Bisect(0, 10, -5, f); got != 0 {
t.Errorf("Bisect target=-5 = %d, want 0", got)
}
// target above all values
if got := Bisect(0, 10, 1000, f); got != 10 {
t.Errorf("Bisect target=1000 = %d, want 10", got)
}
}

89
internal/numerics/vec.go Normal file
View file

@ -0,0 +1,89 @@
package numerics
import "math"
// GeoVec is a geographic state vector: latitude and longitude in degrees and
// altitude in metres. The same struct represents a per-second derivative,
// in which case the fields are deg/s and m/s.
//
// GeoVec is the hot-path state type for the integrator. It is a small value
// type (three float64) and is passed by value to stay allocation-free; a
// future SIMD/SoA batch integrator can lift these fields into parallel
// slices (see Path).
type GeoVec struct {
Lat float64 `json:"lat"`
Lng float64 `json:"lng"`
Altitude float64 `json:"altitude"`
}
// PyMod returns a mod b with Python semantics: the result carries the sign of
// b, so for b > 0 it always lies in [0, b).
func PyMod(a, b float64) float64 {
r := math.Mod(a, b)
if r < 0 {
r += b
}
return r
}
// GeoAdd returns y + k*dy with longitude wrapped to [0, 360). Latitude and
// altitude accumulate linearly. This is the integrator's state-update step.
func GeoAdd(y GeoVec, k float64, dy GeoVec) GeoVec {
return GeoVec{
Lat: y.Lat + k*dy.Lat,
Lng: PyMod(y.Lng+k*dy.Lng, 360),
Altitude: y.Altitude + k*dy.Altitude,
}
}
// GeoLerp linearly interpolates two geographic states by parameter l in
// [0, 1]. Longitude takes the shorter great-circle arc.
func GeoLerp(a, b GeoVec, l float64) GeoVec {
return GeoVec{
Lat: (1-l)*a.Lat + l*b.Lat,
Lng: LngLerp(a.Lng, b.Lng, l),
Altitude: (1-l)*a.Altitude + l*b.Altitude,
}
}
// LngLerp interpolates between two longitudes in [0, 360), choosing the
// shorter arc and wrapping the result back into range.
func LngLerp(a, b, l float64) float64 {
l2 := 1 - l
if a > b {
a, b = b, a
l, l2 = l2, l
}
if b-a < 180 {
return l2*a + l*b
}
return PyMod(l2*(a+360)+l*b, 360)
}
// Lerp returns (1-l)*a + l*b.
func Lerp(a, b, l float64) float64 {
return (1-l)*a + l*b
}
// AddGeo returns the component-wise sum a+b without longitude wrapping. Use it
// to combine derivative (rate) vectors — rates accumulate linearly, unlike
// positions, which wrap via GeoAdd.
func AddGeo(a, b GeoVec) GeoVec {
return GeoVec{Lat: a.Lat + b.Lat, Lng: a.Lng + b.Lng, Altitude: a.Altitude + b.Altitude}
}
// EarthRadius is the spherical Earth radius (metres) used for horizontal
// motion, matching the reference Tawhiri implementation.
const EarthRadius = 6371009.0
// WindToGeoRate converts eastward (u) and northward (v) wind in m/s at the
// given latitude (deg) and altitude (m) into the geographic rate in deg/s on a
// spherical Earth. The returned dLng diverges near the poles as cos(lat) → 0.
func WindToGeoRate(u, v, lat, alt float64) (dLat, dLng float64) {
const degPerRad = 180.0 / math.Pi
const piOver180 = math.Pi / 180.0
r := EarthRadius + alt
dLat = degPerRad * v / r
dLng = degPerRad * u / (r * math.Cos(lat*piOver180))
return dLat, dLng
}

View file

@ -1,153 +0,0 @@
package prediction
import (
"fmt"
"predictor-refactored/internal/dataset"
)
// Exact port of the reference interpolation logic (interpolate.pyx).
// 4D interpolation: time, latitude, longitude, altitude (via geopotential height).
// lerp1 holds an index and interpolation weight for one axis.
type lerp1 struct {
index int
lerp float64
}
// lerp3 holds indices and a combined weight for the (hour, lat, lon) axes.
type lerp3 struct {
hour, lat, lng int
lerp float64
}
// RangeError indicates a coordinate is outside the dataset bounds.
type RangeError struct {
Variable string
Value float64
}
func (e *RangeError) Error() string {
return fmt.Sprintf("%s=%f out of range", e.Variable, e.Value)
}
// pick computes interpolation indices and weights for a single axis.
// left: axis start, step: axis spacing, n: number of points, value: query value.
// Returns two lerp1 values (lower and upper bracket).
func pick(left, step float64, n int, value float64, variableName string) ([2]lerp1, error) {
a := (value - left) / step
b := int(a) // truncation toward zero, same as Cython <long> cast
if b < 0 || b >= n-1 {
return [2]lerp1{}, &RangeError{Variable: variableName, Value: value}
}
l := a - float64(b)
return [2]lerp1{
{index: b, lerp: 1 - l},
{index: b + 1, lerp: l},
}, nil
}
// pick3 computes 8 trilinear interpolation weights for (hour, lat, lng).
func pick3(hour, lat, lng float64) ([8]lerp3, error) {
lhour, err := pick(0, 3, 65, hour, "hour")
if err != nil {
return [8]lerp3{}, err
}
llat, err := pick(-90, 0.5, 361, lat, "lat")
if err != nil {
return [8]lerp3{}, err
}
// Longitude wraps: tell pick the axis is one larger, then wrap index 720 → 0
llng, err := pick(0, 0.5, 720+1, lng, "lng")
if err != nil {
return [8]lerp3{}, err
}
if llng[1].index == 720 {
llng[1].index = 0
}
var out [8]lerp3
i := 0
for _, a := range lhour {
for _, b := range llat {
for _, c := range llng {
out[i] = lerp3{
hour: a.index,
lat: b.index,
lng: c.index,
lerp: a.lerp * b.lerp * c.lerp,
}
i++
}
}
}
return out, nil
}
// interp3 performs 8-point weighted interpolation at a given variable and pressure level.
func interp3(ds *dataset.File, lerps [8]lerp3, variable, level int) float64 {
var r float64
for i := 0; i < 8; i++ {
v := ds.Val(lerps[i].hour, level, variable, lerps[i].lat, lerps[i].lng)
r += float64(v) * lerps[i].lerp
}
return r
}
// search finds the largest pressure level index where interpolated geopotential
// height is less than the target altitude. Searches levels 0..45 (excludes topmost).
func search(ds *dataset.File, lerps [8]lerp3, target float64) int {
lower, upper := 0, 45
for lower < upper {
mid := (lower + upper + 1) / 2
test := interp3(ds, lerps, dataset.VarHeight, mid)
if target <= test {
upper = mid - 1
} else {
lower = mid
}
}
return lower
}
// interp4 performs altitude-interpolated wind lookup using two bracketing levels.
func interp4(ds *dataset.File, lerps [8]lerp3, altLerp lerp1, variable int) float64 {
lower := interp3(ds, lerps, variable, altLerp.index)
upper := interp3(ds, lerps, variable, altLerp.index+1)
return lower*altLerp.lerp + upper*(1-altLerp.lerp)
}
// GetWind returns interpolated (u, v) wind components for the given position.
// hour: fractional hours since dataset start.
// lat: latitude in degrees (-90 to +90).
// lng: longitude in degrees (0 to 360).
// alt: altitude in metres above sea level.
func GetWind(ds *dataset.File, warnings *Warnings, hour, lat, lng, alt float64) (u, v float64, err error) {
lerps, err := pick3(hour, lat, lng)
if err != nil {
return 0, 0, err
}
altidx := search(ds, lerps, alt)
lower := interp3(ds, lerps, dataset.VarHeight, altidx)
upper := interp3(ds, lerps, dataset.VarHeight, altidx+1)
var altLerp float64
if lower != upper {
altLerp = (upper - alt) / (upper - lower)
} else {
altLerp = 0.5
}
if altLerp < 0 {
warnings.AltitudeTooHigh.Add(1)
}
alt1 := lerp1{index: altidx, lerp: altLerp}
u = interp4(ds, lerps, alt1, dataset.VarWindU)
v = interp4(ds, lerps, alt1, dataset.VarWindV)
return u, v, nil
}

View file

@ -1,188 +0,0 @@
package prediction
import (
"math"
"time"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/elevation"
)
// Exact port of the reference flight models (models.py).
const (
pi180 = math.Pi / 180.0
_180pi = 180.0 / math.Pi
)
// --- Up/Down Models ---
// ConstantAscent returns a model with constant vertical velocity (m/s).
func ConstantAscent(ascentRate float64) Model {
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
return 0, 0, ascentRate
}
}
// DragDescent returns a descent-under-parachute model.
// seaLevelDescentRate is the descent rate at sea level (m/s, positive value).
// Uses the NASA atmosphere model for density at altitude.
func DragDescent(seaLevelDescentRate float64) Model {
dragCoefficient := seaLevelDescentRate * 1.1045
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
return 0, 0, -dragCoefficient / math.Sqrt(nasaDensity(alt))
}
}
// nasaDensity computes air density using the NASA atmosphere model.
// Reference: http://www.grc.nasa.gov/WWW/K-12/airplane/atmosmet.html
func nasaDensity(alt float64) float64 {
var temp, pressure float64
switch {
case alt > 25000:
temp = -131.21 + 0.00299*alt
pressure = 2.488 * math.Pow((temp+273.1)/216.6, -11.388)
case alt > 11000:
temp = -56.46
pressure = 22.65 * math.Exp(1.73-0.000157*alt)
default:
temp = 15.04 - 0.00649*alt
pressure = 101.29 * math.Pow((temp+273.1)/288.08, 5.256)
}
return pressure / (0.2869 * (temp + 273.1))
}
// --- Sideways Models ---
// WindVelocity returns a model that gives lateral movement at the wind velocity.
// ds is the wind dataset, dsEpoch is the dataset start time as UNIX timestamp.
func WindVelocity(ds *dataset.File, dsEpoch float64, warnings *Warnings) Model {
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
tHours := (t - dsEpoch) / 3600.0
u, v, err := GetWind(ds, warnings, tHours, lat, lng, alt)
if err != nil {
return 0, 0, 0
}
R := 6371009.0 + alt
dlat = _180pi * v / R
dlng = _180pi * u / (R * math.Cos(lat*pi180))
return dlat, dlng, 0
}
}
// --- Model Combinations ---
// LinearModel returns a model that sums all component models.
func LinearModel(models ...Model) Model {
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
for _, m := range models {
d1, d2, d3 := m(t, lat, lng, alt)
dlat += d1
dlng += d2
dalt += d3
}
return
}
}
// --- Termination Criteria ---
// BurstTermination returns a terminator that fires when altitude >= burstAltitude.
func BurstTermination(burstAltitude float64) Terminator {
return func(t, lat, lng, alt float64) bool {
return alt >= burstAltitude
}
}
// SeaLevelTermination fires when altitude <= 0.
func SeaLevelTermination(t, lat, lng, alt float64) bool {
return alt <= 0
}
// TimeTermination returns a terminator that fires when t > maxTime.
func TimeTermination(maxTime float64) Terminator {
return func(t, lat, lng, alt float64) bool {
return t > maxTime
}
}
// ElevationTermination returns a terminator that fires when alt < ground level.
// Uses ruaumoko-compatible elevation data. Longitude is normalised internally.
func ElevationTermination(elev *elevation.Dataset) Terminator {
return func(t, lat, lng, alt float64) bool {
return elev.Get(lat, lng) > alt
}
}
// --- Pre-Defined Profiles ---
// Stage pairs a model with its termination criterion.
type Stage struct {
Model Model
Terminator Terminator
}
// StandardProfile creates the chain for a standard high-altitude balloon flight:
// ascent at constant rate → burst → descent under parachute.
// If elev is non-nil, descent terminates at ground level; otherwise at sea level.
func StandardProfile(ascentRate, burstAltitude, descentRate float64,
ds *dataset.File, dsEpoch float64, warnings *Warnings,
elev *elevation.Dataset) []Stage {
wind := WindVelocity(ds, dsEpoch, warnings)
modelUp := LinearModel(ConstantAscent(ascentRate), wind)
termUp := BurstTermination(burstAltitude)
modelDown := LinearModel(DragDescent(descentRate), wind)
var termDown Terminator
if elev != nil {
termDown = ElevationTermination(elev)
} else {
termDown = Terminator(SeaLevelTermination)
}
return []Stage{
{Model: modelUp, Terminator: termUp},
{Model: modelDown, Terminator: termDown},
}
}
// FloatProfile creates the chain for a floating balloon flight:
// ascent to float altitude → float until stop time.
func FloatProfile(ascentRate, floatAltitude float64, stopTime time.Time,
ds *dataset.File, dsEpoch float64, warnings *Warnings) []Stage {
wind := WindVelocity(ds, dsEpoch, warnings)
modelUp := LinearModel(ConstantAscent(ascentRate), wind)
termUp := BurstTermination(floatAltitude)
modelFloat := wind
termFloat := TimeTermination(float64(stopTime.Unix()))
return []Stage{
{Model: modelUp, Terminator: termUp},
{Model: modelFloat, Terminator: termFloat},
}
}
// RunPrediction runs a prediction with the given profile stages.
// launchTime is a UNIX timestamp.
func RunPrediction(launchTime float64, lat, lng, alt float64, stages []Stage) []StageResult {
chain := make([]struct {
Model Model
Terminator Terminator
}, len(stages))
for i, s := range stages {
chain[i].Model = s.Model
chain[i].Terminator = s.Terminator
}
return Solve(launchTime, lat, lng, alt, chain)
}

View file

@ -1,180 +0,0 @@
package prediction
import "math"
// Exact port of the reference RK4 solver (solver.pyx).
// Integrates balloon state using RK4 with dt=60 seconds.
// Termination uses binary search refinement (tolerance 0.01).
// Vec holds the balloon state: latitude, longitude, altitude.
type Vec struct {
Lat float64
Lng float64
Alt float64
}
// Model is a function that returns (dlat/dt, dlng/dt, dalt/dt) given state.
// t is UNIX timestamp, lat/lng in degrees, alt in metres.
type Model func(t float64, lat, lng, alt float64) (dlat, dlng, dalt float64)
// Terminator returns true when integration should stop.
type Terminator func(t float64, lat, lng, alt float64) bool
// StageResult holds the trajectory points for one flight stage.
type StageResult struct {
Points []TrajectoryPoint
}
// TrajectoryPoint is a single point in a trajectory (used by solver).
type TrajectoryPoint struct {
T float64 // UNIX timestamp
Lat float64
Lng float64
Alt float64
}
// pymod returns a % b with Python semantics (always non-negative when b > 0).
func pymod(a, b float64) float64 {
r := math.Mod(a, b)
if r < 0 {
r += b
}
return r
}
// vecadd returns a + k*b, with lng wrapped to [0, 360).
func vecadd(a Vec, k float64, b Vec) Vec {
return Vec{
Lat: a.Lat + k*b.Lat,
Lng: pymod(a.Lng+k*b.Lng, 360.0),
Alt: a.Alt + k*b.Alt,
}
}
// scalarLerp returns (1-l)*a + l*b.
func scalarLerp(a, b, l float64) float64 {
return (1-l)*a + l*b
}
// lngLerp interpolates longitude handling the 0/360 wrap-around.
func lngLerp(a, b, l float64) float64 {
l2 := 1 - l
if a > b {
a, b = b, a
l, l2 = l2, l
}
// distance round one way: b - a
// distance around other: (a + 360) - b
if b-a < 180.0 {
return l2*a + l*b
}
return pymod(l2*(a+360)+l*b, 360.0)
}
// vecLerp returns (1-l)*a + l*b with proper longitude wrapping.
func vecLerp(a, b Vec, l float64) Vec {
return Vec{
Lat: scalarLerp(a.Lat, b.Lat, l),
Lng: lngLerp(a.Lng, b.Lng, l),
Alt: scalarLerp(a.Alt, b.Alt, l),
}
}
// rk4 integrates from initial conditions using RK4.
// dt=60.0 seconds, terminationTolerance=0.01.
func rk4(t float64, lat, lng, alt float64, model Model, terminator Terminator) []TrajectoryPoint {
const dt = 60.0
const terminationTolerance = 0.01
y := Vec{Lat: lat, Lng: lng, Alt: alt}
result := []TrajectoryPoint{{T: t, Lat: y.Lat, Lng: y.Lng, Alt: y.Alt}}
for {
// Evaluate model at 4 points (standard RK4)
k1lat, k1lng, k1alt := model(t, y.Lat, y.Lng, y.Alt)
k1 := Vec{Lat: k1lat, Lng: k1lng, Alt: k1alt}
mid1 := vecadd(y, dt/2, k1)
k2lat, k2lng, k2alt := model(t+dt/2, mid1.Lat, mid1.Lng, mid1.Alt)
k2 := Vec{Lat: k2lat, Lng: k2lng, Alt: k2alt}
mid2 := vecadd(y, dt/2, k2)
k3lat, k3lng, k3alt := model(t+dt/2, mid2.Lat, mid2.Lng, mid2.Alt)
k3 := Vec{Lat: k3lat, Lng: k3lng, Alt: k3alt}
end := vecadd(y, dt, k3)
k4lat, k4lng, k4alt := model(t+dt, end.Lat, end.Lng, end.Alt)
k4 := Vec{Lat: k4lat, Lng: k4lng, Alt: k4alt}
// y2 = y + dt/6*k1 + dt/3*k2 + dt/3*k3 + dt/6*k4
y2 := y
y2 = vecadd(y2, dt/6, k1)
y2 = vecadd(y2, dt/3, k2)
y2 = vecadd(y2, dt/3, k3)
y2 = vecadd(y2, dt/6, k4)
t2 := t + dt
if terminator(t2, y2.Lat, y2.Lng, y2.Alt) {
// Binary search to refine the termination point.
// Find l in [0, 1] such that (t3, y3) = lerp((t, y), (t2, y2), l)
// is near where the terminator fires.
left := 0.0
right := 1.0
var t3 float64
var y3 Vec
t3 = t2
y3 = y2
for right-left > terminationTolerance {
mid := (left + right) / 2
t3 = scalarLerp(t, t2, mid)
y3 = vecLerp(y, y2, mid)
if terminator(t3, y3.Lat, y3.Lng, y3.Alt) {
right = mid
} else {
left = mid
}
}
result = append(result, TrajectoryPoint{T: t3, Lat: y3.Lat, Lng: y3.Lng, Alt: y3.Alt})
break
}
// Update current state
t = t2
y = y2
result = append(result, TrajectoryPoint{T: t, Lat: y.Lat, Lng: y.Lng, Alt: y.Alt})
}
return result
}
// Solve runs through a chain of (model, terminator) stages.
// Returns one StageResult per stage.
func Solve(t, lat, lng, alt float64, chain []struct {
Model Model
Terminator Terminator
}) []StageResult {
var results []StageResult
for _, stage := range chain {
points := rk4(t, lat, lng, alt, stage.Model, stage.Terminator)
results = append(results, StageResult{Points: points})
// Next stage starts where this one ended
if len(points) > 0 {
last := points[len(points)-1]
t = last.T
lat = last.Lat
lng = last.Lng
alt = last.Alt
}
}
return results
}

View file

@ -1,21 +0,0 @@
package prediction
import "sync/atomic"
// Warnings tracks warning conditions during a prediction run.
type Warnings struct {
AltitudeTooHigh atomic.Int64
}
// ToMap returns warnings as a map suitable for JSON serialization.
// Only includes warnings that have fired.
func (w *Warnings) ToMap() map[string]any {
result := make(map[string]any)
if n := w.AltitudeTooHigh.Load(); n > 0 {
result["altitude_too_high"] = map[string]any{
"count": n,
"description": "The altitude went too high, above the max forecast wind. Wind data will be unreliable",
}
}
return result
}

View file

@ -1,245 +0,0 @@
package service
import (
"context"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/downloader"
"predictor-refactored/internal/elevation"
"go.uber.org/zap"
)
// Service orchestrates the dataset lifecycle and provides prediction capabilities.
type Service struct {
mu sync.RWMutex
ds *dataset.File
elev *elevation.Dataset
cfg *downloader.Config
dl *downloader.Downloader
log *zap.Logger
updating sync.Mutex // prevents concurrent downloads
}
// New creates a new Service.
func New(cfg *downloader.Config, log *zap.Logger) *Service {
return &Service{
cfg: cfg,
dl: downloader.NewDownloader(cfg, log),
log: log,
}
}
// LoadElevation loads the ruaumoko-compatible elevation dataset from path.
// If the file doesn't exist, elevation termination is disabled (falls back to sea level).
func (s *Service) LoadElevation(path string) {
ds, err := elevation.Open(path)
if err != nil {
s.log.Warn("elevation dataset not available, using sea-level termination",
zap.String("path", path), zap.Error(err))
return
}
s.elev = ds
s.log.Info("elevation dataset loaded", zap.String("path", path))
}
// Elevation returns the elevation dataset (may be nil).
func (s *Service) Elevation() *elevation.Dataset {
return s.elev
}
// Ready returns true if the service has a loaded dataset.
func (s *Service) Ready() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.ds != nil
}
// DatasetTime returns the forecast time of the currently loaded dataset.
func (s *Service) DatasetTime() (time.Time, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
if s.ds == nil {
return time.Time{}, false
}
return s.ds.DSTime, true
}
// Dataset returns the current dataset for reading.
func (s *Service) Dataset() *dataset.File {
s.mu.RLock()
defer s.mu.RUnlock()
return s.ds
}
// Update checks for and downloads new forecast data if needed.
func (s *Service) Update(ctx context.Context) error {
if !s.updating.TryLock() {
s.log.Info("update already in progress, skipping")
return nil
}
defer s.updating.Unlock()
// Check if current dataset is still fresh
if dsTime, ok := s.DatasetTime(); ok {
if time.Since(dsTime) < s.cfg.DatasetTTL {
s.log.Info("dataset still fresh, skipping update",
zap.Time("dataset_time", dsTime),
zap.Duration("age", time.Since(dsTime)))
return nil
}
}
// Try loading an existing dataset from disk first
if err := s.loadExistingDataset(); err == nil {
return nil
}
// Find latest available model run
run, err := s.dl.FindLatestRun(ctx)
if err != nil {
return err
}
// Download and assemble
path, err := s.dl.Download(ctx, run)
if err != nil {
return err
}
// Open the new dataset
ds, err := dataset.Open(path, run)
if err != nil {
return err
}
// Swap in the new dataset
s.setDataset(ds)
s.log.Info("dataset loaded", zap.Time("run", run), zap.String("path", path))
// Clean old datasets
s.cleanOldDatasets(path)
return nil
}
// loadExistingDataset tries to find and load an existing dataset from the data directory.
func (s *Service) loadExistingDataset() error {
entries, err := os.ReadDir(s.cfg.DataDir)
if err != nil {
return err
}
// Collect valid dataset files (name is YYYYMMDDHH, no extension, correct size)
type candidate struct {
name string
path string
run time.Time
}
var candidates []candidate
for _, e := range entries {
if e.IsDir() || strings.Contains(e.Name(), ".") {
continue
}
if len(e.Name()) != 10 {
continue
}
run, err := time.Parse("2006010215", e.Name())
if err != nil {
continue
}
path := filepath.Join(s.cfg.DataDir, e.Name())
info, err := os.Stat(path)
if err != nil || info.Size() != dataset.DatasetSize {
continue
}
if time.Since(run) > s.cfg.DatasetTTL {
continue
}
candidates = append(candidates, candidate{name: e.Name(), path: path, run: run})
}
if len(candidates) == 0 {
return os.ErrNotExist
}
// Pick the newest
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].run.After(candidates[j].run)
})
best := candidates[0]
ds, err := dataset.Open(best.path, best.run)
if err != nil {
return err
}
s.setDataset(ds)
s.log.Info("loaded existing dataset",
zap.Time("run", best.run),
zap.String("path", best.path))
return nil
}
// setDataset swaps the current dataset with a new one, closing the old one.
func (s *Service) setDataset(ds *dataset.File) {
s.mu.Lock()
old := s.ds
s.ds = ds
s.mu.Unlock()
if old != nil {
if err := old.Close(); err != nil {
s.log.Error("failed to close old dataset", zap.Error(err))
}
}
}
// cleanOldDatasets removes dataset files other than the one at keepPath.
func (s *Service) cleanOldDatasets(keepPath string) {
entries, err := os.ReadDir(s.cfg.DataDir)
if err != nil {
return
}
for _, e := range entries {
if e.IsDir() {
continue
}
path := filepath.Join(s.cfg.DataDir, e.Name())
if path == keepPath {
continue
}
// Remove old datasets and temp files
if len(e.Name()) == 10 || strings.HasSuffix(e.Name(), ".downloading") {
if err := os.Remove(path); err != nil {
s.log.Warn("failed to remove old file", zap.String("path", path), zap.Error(err))
} else {
s.log.Info("removed old dataset", zap.String("path", path))
}
}
}
}
// Close releases all resources.
func (s *Service) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.ds != nil {
err := s.ds.Close()
s.ds = nil
return err
}
return nil
}

View file

@ -1,30 +0,0 @@
package middleware
import (
"time"
"github.com/ogen-go/ogen/middleware"
"go.uber.org/zap"
)
// Logging returns an ogen middleware that logs request duration.
func Logging(log *zap.Logger) middleware.Middleware {
return func(req middleware.Request, next func(req middleware.Request) (middleware.Response, error)) (middleware.Response, error) {
lg := log.With(zap.String("operation", req.OperationID))
start := time.Now()
resp, err := next(req)
dur := time.Since(start)
if err != nil {
lg.Error("request failed",
zap.Duration("duration", dur),
zap.Error(err))
} else {
lg.Info("request completed",
zap.Duration("duration", dur))
}
return resp, err
}
}

View file

@ -1,16 +0,0 @@
package handler
import (
"time"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/elevation"
)
// Service defines the interface the handler needs from the service layer.
type Service interface {
Ready() bool
DatasetTime() (time.Time, bool)
Dataset() *dataset.File
Elevation() *elevation.Dataset
}

View file

@ -1,216 +0,0 @@
package handler
import (
"context"
"net/http"
"time"
"predictor-refactored/internal/prediction"
api "predictor-refactored/pkg/rest"
"go.uber.org/zap"
)
var _ api.Handler = (*Handler)(nil)
// Handler implements the ogen-generated api.Handler interface.
type Handler struct {
svc Service
log *zap.Logger
}
// New creates a new Handler.
func New(svc Service, log *zap.Logger) *Handler {
return &Handler{svc: svc, log: log}
}
// PerformPrediction implements the prediction endpoint.
func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredictionParams) (*api.PredictionResponse, error) {
if !h.svc.Ready() {
return nil, newError(http.StatusServiceUnavailable, "no dataset loaded, service is starting up")
}
ds := h.svc.Dataset()
if ds == nil {
return nil, newError(http.StatusServiceUnavailable, "dataset unavailable")
}
dsEpoch := float64(ds.DSTime.Unix())
// Parse parameters with defaults
profile := "standard_profile"
if p, ok := params.Profile.Get(); ok {
profile = string(p)
}
ascentRate := 5.0
if v, ok := params.AscentRate.Get(); ok {
ascentRate = v
}
burstAltitude := 28000.0
if v, ok := params.BurstAltitude.Get(); ok {
burstAltitude = v
}
descentRate := 5.0
if v, ok := params.DescentRate.Get(); ok {
descentRate = v
}
launchAlt := 0.0
if v, ok := params.LaunchAltitude.Get(); ok {
launchAlt = v
}
// Normalize longitude to [0, 360)
lng := params.LaunchLongitude
if lng < 0 {
lng += 360.0
}
launchTime := float64(params.LaunchDatetime.Unix())
warnings := &prediction.Warnings{}
// Build profile chain
elev := h.svc.Elevation()
var stages []prediction.Stage
switch profile {
case "standard_profile":
stages = prediction.StandardProfile(
ascentRate, burstAltitude, descentRate,
ds, dsEpoch, warnings, elev)
case "float_profile":
floatAlt := 25000.0
if v, ok := params.FloatAltitude.Get(); ok {
floatAlt = v
}
stopTime := params.LaunchDatetime.Add(24 * time.Hour)
if v, ok := params.StopDatetime.Get(); ok {
stopTime = v
}
stages = prediction.FloatProfile(
ascentRate, floatAlt, stopTime,
ds, dsEpoch, warnings)
default:
return nil, newError(http.StatusBadRequest, "unknown profile: "+profile)
}
// Run prediction
startTime := time.Now().UTC()
results := prediction.RunPrediction(launchTime, params.LaunchLatitude, lng, launchAlt, stages)
completeTime := time.Now().UTC()
// Build response
stageNames := []string{"ascent", "descent"}
if profile == "float_profile" {
stageNames = []string{"ascent", "float"}
}
var predItems []api.PredictionResponsePredictionItem
for i, sr := range results {
stageName := "ascent"
if i < len(stageNames) {
stageName = stageNames[i]
}
var stageEnum api.PredictionResponsePredictionItemStage
switch stageName {
case "ascent":
stageEnum = api.PredictionResponsePredictionItemStageAscent
case "descent":
stageEnum = api.PredictionResponsePredictionItemStageDescent
case "float":
stageEnum = api.PredictionResponsePredictionItemStageFloat
}
var traj []api.PredictionResponsePredictionItemTrajectoryItem
for _, pt := range sr.Points {
ptLng := pt.Lng
if ptLng > 180 {
ptLng -= 360
}
traj = append(traj, api.PredictionResponsePredictionItemTrajectoryItem{
Datetime: time.Unix(int64(pt.T), 0).UTC(),
Latitude: pt.Lat,
Longitude: ptLng,
Altitude: pt.Alt,
})
}
predItems = append(predItems, api.PredictionResponsePredictionItem{
Stage: stageEnum,
Trajectory: traj,
})
}
resp := &api.PredictionResponse{
Prediction: predItems,
Metadata: api.PredictionResponseMetadata{
StartDatetime: startTime,
CompleteDatetime: completeTime,
},
}
// Echo request
resp.Request = api.NewOptPredictionResponseRequest(api.PredictionResponseRequest{
Dataset: api.NewOptString(ds.DSTime.Format("2006-01-02T15:04:05Z")),
LaunchLatitude: api.NewOptFloat64(params.LaunchLatitude),
LaunchLongitude: api.NewOptFloat64(params.LaunchLongitude),
LaunchDatetime: api.NewOptString(params.LaunchDatetime.Format(time.RFC3339)),
LaunchAltitude: params.LaunchAltitude,
})
// Warnings
warnMap := warnings.ToMap()
if len(warnMap) > 0 {
resp.Warnings = api.NewOptPredictionResponseWarnings(api.PredictionResponseWarnings{})
}
h.log.Info("prediction complete",
zap.String("profile", profile),
zap.Int("stages", len(results)),
zap.Duration("elapsed", completeTime.Sub(startTime)))
return resp, nil
}
// ReadinessCheck implements the health check endpoint.
func (h *Handler) ReadinessCheck(ctx context.Context) (*api.ReadinessResponse, error) {
resp := &api.ReadinessResponse{}
if h.svc.Ready() {
resp.Status = api.ReadinessResponseStatusOk
if dsTime, ok := h.svc.DatasetTime(); ok {
resp.DatasetTime = api.NewOptDateTime(dsTime)
}
} else {
resp.Status = api.ReadinessResponseStatusNotReady
resp.ErrorMessage = api.NewOptString("no dataset loaded")
}
return resp, nil
}
// NewError creates an ErrorStatusCode from an error returned by a handler.
func (h *Handler) NewError(ctx context.Context, err error) *api.ErrorStatusCode {
if statusErr, ok := err.(*api.ErrorStatusCode); ok {
return statusErr
}
h.log.Error("unhandled error", zap.Error(err))
return newError(http.StatusInternalServerError, err.Error())
}
func newError(status int, description string) *api.ErrorStatusCode {
return &api.ErrorStatusCode{
StatusCode: status,
Response: api.Error{
Error: api.ErrorError{
Type: http.StatusText(status),
Description: description,
},
},
}
}

View file

@ -1,75 +0,0 @@
package rest
import (
"context"
"fmt"
"net/http"
"predictor-refactored/internal/transport/middleware"
"predictor-refactored/internal/transport/rest/handler"
api "predictor-refactored/pkg/rest"
"go.uber.org/zap"
)
// Transport wraps the ogen HTTP server.
type Transport struct {
srv *api.Server
handler *handler.Handler
port int
log *zap.Logger
}
// New creates a new REST transport.
func New(h *handler.Handler, port int, log *zap.Logger) (*Transport, error) {
srv, err := api.NewServer(
h,
api.WithMiddleware(middleware.Logging(log)),
)
if err != nil {
return nil, fmt.Errorf("create ogen server: %w", err)
}
return &Transport{
srv: srv,
handler: h,
port: port,
log: log,
}, nil
}
// Run starts the HTTP server. Blocks until the server stops.
func (t *Transport) Run() error {
mux := http.NewServeMux()
mux.Handle("/", t.srv)
httpSrv := &http.Server{
Addr: fmt.Sprintf(":%d", t.port),
Handler: corsMiddleware(mux),
}
t.log.Info("starting HTTP server", zap.Int("port", t.port))
return httpSrv.ListenAndServe()
}
// Shutdown gracefully stops the HTTP server.
func (t *Transport) Shutdown(ctx context.Context) error {
// The ogen server doesn't have a shutdown method;
// shutdown is handled by the http.Server in main.go
return nil
}
func corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}

View file

@ -0,0 +1,34 @@
package gfs
// Cross-variant constants. Per-variant geometry (latitudes, longitudes,
// pressure levels, hour step, max hour, URL token) lives on the Variant
// type; see variant.go.
const (
// NumVariables is the number of dataset variables: HGT, UGRD, VGRD.
NumVariables = 3
// ElementSize is the cell size in bytes (float32).
ElementSize = 4
// LatStart is the first latitude in the cube (south to north).
LatStart = -90.0
// LonStart is the first longitude in the cube (0..360 east).
LonStart = 0.0
// Variable indices within the cube's 3rd axis.
VarHeight = 0
VarWindU = 1
VarWindV = 2
)
// LevelSet identifies which GRIB file (primary or secondary) carries a
// pressure level.
type LevelSet int
const (
LevelSetA LevelSet = iota // pgrb2 — primary file
LevelSetB // pgrb2b — secondary file
)
// S3BaseURL is the public NOAA S3 mirror.
const S3BaseURL = "https://noaa-gfs-bdp-pds.s3.amazonaws.com"

View file

@ -0,0 +1,168 @@
package gfs
import (
"encoding/binary"
"fmt"
"math"
"os"
"time"
mmap "github.com/edsrzf/mmap-go"
)
// File is an mmap-backed wind dataset file. The layout is a flat C-order
// row-major float32 array, shape (hour, level, variable, lat, lng), with
// the per-axis sizes coming from Variant.
type File struct {
variant *Variant
mm mmap.MMap
file *os.File
writable bool
// Epoch is the forecast run time (UTC) the file represents.
Epoch time.Time
}
// Variant returns the Variant the file was created with.
func (d *File) Variant() *Variant { return d.variant }
// Open opens an existing dataset file for reading.
func Open(path string, variant *Variant, epoch time.Time) (*File, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open dataset: %w", err)
}
info, err := f.Stat()
if err != nil {
f.Close()
return nil, fmt.Errorf("stat dataset: %w", err)
}
if info.Size() != variant.DatasetSize() {
f.Close()
return nil, fmt.Errorf("dataset should be %d bytes (was %d)", variant.DatasetSize(), info.Size())
}
mm, err := mmap.Map(f, mmap.RDONLY, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap dataset: %w", err)
}
return &File{variant: variant, mm: mm, file: f, writable: false, Epoch: epoch}, nil
}
// Create creates a new dataset file sized for variant, mmap'd read-write.
func Create(path string, variant *Variant) (*File, error) {
f, err := os.Create(path)
if err != nil {
return nil, fmt.Errorf("create dataset: %w", err)
}
size := variant.DatasetSize()
if err := f.Truncate(size); err != nil {
f.Close()
return nil, fmt.Errorf("truncate dataset: %w", err)
}
mm, err := mmap.MapRegion(f, int(size), mmap.RDWR, 0, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap dataset: %w", err)
}
return &File{variant: variant, mm: mm, file: f, writable: true}, nil
}
// OpenWritable opens an existing dataset file for read-write access. Used
// when resuming a partial download.
func OpenWritable(path string, variant *Variant) (*File, error) {
f, err := os.OpenFile(path, os.O_RDWR, 0o644)
if err != nil {
return nil, fmt.Errorf("open dataset rw: %w", err)
}
info, err := f.Stat()
if err != nil {
f.Close()
return nil, fmt.Errorf("stat dataset: %w", err)
}
if info.Size() != variant.DatasetSize() {
f.Close()
return nil, fmt.Errorf("dataset should be %d bytes (was %d)", variant.DatasetSize(), info.Size())
}
mm, err := mmap.MapRegion(f, int(info.Size()), mmap.RDWR, 0, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap dataset: %w", err)
}
return &File{variant: variant, mm: mm, file: f, writable: true}, nil
}
// offset returns the byte offset of the [hour][level][variable][lat][lng] cell.
func (d *File) offset(hour, level, variable, lat, lng int) int64 {
v := d.variant
idx := int64(hour)
idx = idx*int64(v.NumLevels()) + int64(level)
idx = idx*int64(NumVariables) + int64(variable)
idx = idx*int64(v.NumLatitudes()) + int64(lat)
idx = idx*int64(v.NumLongitudes()) + int64(lng)
return idx * int64(ElementSize)
}
// Val reads one cell as a float32.
func (d *File) Val(hour, level, variable, lat, lng int) float32 {
off := d.offset(hour, level, variable, lat, lng)
return math.Float32frombits(binary.LittleEndian.Uint32(d.mm[off : off+4]))
}
// ValByElem reads the float32 at a precomputed flat element index (not a byte
// offset). The wind sampler uses this to read the eight interpolation corners
// after computing their flat indices once via cube strides.
func (d *File) ValByElem(elem int64) float32 {
off := elem * ElementSize
return math.Float32frombits(binary.LittleEndian.Uint32(d.mm[off : off+4]))
}
// SetVal writes one cell. Only valid on writable files.
func (d *File) SetVal(hour, level, variable, lat, lng int, val float32) {
off := d.offset(hour, level, variable, lat, lng)
binary.LittleEndian.PutUint32(d.mm[off:off+4], math.Float32bits(val))
}
// BlitGribData copies one decoded GRIB grid into the dataset, flipping the
// latitude axis from GRIB's north-to-south scan order to our south-to-north
// storage order.
func (d *File) BlitGribData(hourIdx, levelIdx, varIdx int, gribData []float64) error {
v := d.variant
expected := v.NumLatitudes() * v.NumLongitudes()
if len(gribData) != expected {
return fmt.Errorf("grib data has %d values, expected %d", len(gribData), expected)
}
lats := v.NumLatitudes()
lngs := v.NumLongitudes()
for lat := range lats {
for lng := range lngs {
gribIdx := (lats-1-lat)*lngs + lng
d.SetVal(hourIdx, levelIdx, varIdx, lat, lng, float32(gribData[gribIdx]))
}
}
return nil
}
// Flush flushes the mmap to disk.
func (d *File) Flush() error {
if d.mm != nil {
return d.mm.Flush()
}
return nil
}
// Close unmaps and closes the file.
func (d *File) Close() error {
if d.mm != nil {
if err := d.mm.Unmap(); err != nil {
d.file.Close()
return fmt.Errorf("unmap: %w", err)
}
d.mm = nil
}
if d.file != nil {
err := d.file.Close()
d.file = nil
return err
}
return nil
}

View file

@ -0,0 +1,68 @@
package gfs
import "fmt"
// Family is the dataset family ("gfs" or "gefs"). Variants of different
// families have different URL layouts but share the cube format.
type Family int
const (
FamilyGFS Family = iota
FamilyGEFS
)
func (f Family) String() string {
switch f {
case FamilyGEFS:
return "gefs"
default:
return "gfs"
}
}
// HasMember reports whether the family requires a member index in URLs.
func (f Family) HasMember() bool { return f == FamilyGEFS }
// GEFS variant constants.
//
// The 21-member ensemble is gec00 (control) + gep01..gep20 (perturbations).
// NOAA publishes more members today but 21 matches the historical Tawhiri
// configuration and is what the phase 2 spec calls for.
const GEFSMembers = 21
// GefsMemberName returns the file-name token for a GEFS member.
// member=0 → "gec00", member=1..20 → "gep01".."gep20".
func GefsMemberName(member int) string {
if member == 0 {
return "gec00"
}
return fmt.Sprintf("gep%02d", member)
}
// GEFS S3 mirror.
const GEFSS3BaseURL = "https://noaa-gefs-pds.s3.amazonaws.com"
// GefsGribURL returns the S3 URL for a GEFS primary GRIB file.
func GefsGribURL(date string, runHour, member, forecastStep int, resToken string) string {
return fmt.Sprintf("%s/gefs.%s/%02d/atmos/pgrb2ap5/%s.t%02dz.pgrb2a.%s.f%03d",
GEFSS3BaseURL, date, runHour, GefsMemberName(member), runHour, resToken, forecastStep)
}
// GefsGribURLB returns the S3 URL for a GEFS secondary GRIB file.
func GefsGribURLB(date string, runHour, member, forecastStep int, resToken string) string {
return fmt.Sprintf("%s/gefs.%s/%02d/atmos/pgrb2bp5/%s.t%02dz.pgrb2b.%s.f%03d",
GEFSS3BaseURL, date, runHour, GefsMemberName(member), runHour, resToken, forecastStep)
}
// GEFS variants — 0.5° resolution, 3-hour cadence, 192h horizon.
var GEFS0p50_3h = &Variant{
ID: "gefs-0p50-3h",
Family: FamilyGEFS,
ResToken: "0p50",
Resolution: 0.5,
HourStep: 3,
MaxHour: 192,
Pressures: GFS0p50_3h.Pressures,
PressuresPgrb2: GFS0p50_3h.PressuresPgrb2,
PressuresPgrb2b: GFS0p50_3h.PressuresPgrb2b,
}

View file

@ -0,0 +1,191 @@
package gfs
import "fmt"
// Variant describes one configuration of a NOAA dataset family (GFS or GEFS).
//
// The dataset cube is a 5-D float32 array with shape
// (NumHours, NumLevels, NumVariables, NumLatitudes, NumLongitudes) where
// NumVariables and ElementSize are fixed across all GFS variants but the
// other dimensions depend on the resolution and forecast cadence.
type Variant struct {
// ID is a stable identifier ("gfs-0p50-3h", "gefs-0p50-3h", ...).
ID string
// Family identifies the dataset family the variant belongs to.
Family Family
// Resolution token used in NOAA URLs ("0p50", "0p25").
ResToken string
// Grid step in degrees (0.5, 0.25). 180 / Resolution + 1 latitudes and
// 360 / Resolution longitudes.
Resolution float64
HourStep int // hours between forecast steps
MaxHour int // largest forecast hour (inclusive)
// Pressures lists every pressure level in dataset index order, descending.
Pressures []int
// PressuresPgrb2 / PressuresPgrb2b split the pressures between the two
// downloaded GRIB files. Their union must equal Pressures.
PressuresPgrb2 []int
PressuresPgrb2b []int
pressureIndex map[int]int
pressureLevelSet map[int]LevelSet
}
// NumHours returns MaxHour/HourStep + 1.
func (v *Variant) NumHours() int { return v.MaxHour/v.HourStep + 1 }
// NumLevels returns len(Pressures).
func (v *Variant) NumLevels() int { return len(v.Pressures) }
// NumLatitudes returns 180/Resolution + 1.
func (v *Variant) NumLatitudes() int { return int(180.0/v.Resolution) + 1 }
// NumLongitudes returns 360/Resolution.
func (v *Variant) NumLongitudes() int { return int(360.0 / v.Resolution) }
// DatasetSize returns the canonical file size in bytes.
func (v *Variant) DatasetSize() int64 {
return int64(v.NumHours()) * int64(v.NumLevels()) * int64(NumVariables) *
int64(v.NumLatitudes()) * int64(v.NumLongitudes()) * int64(ElementSize)
}
// Hours returns the full list of forecast hours [0, HourStep, ..., MaxHour].
func (v *Variant) Hours() []int {
out := make([]int, 0, v.NumHours())
for h := 0; h <= v.MaxHour; h += v.HourStep {
out = append(out, h)
}
return out
}
// HourIndex returns the dataset time index for an hour, or -1 if invalid.
func (v *Variant) HourIndex(hour int) int {
if hour < 0 || hour > v.MaxHour || hour%v.HourStep != 0 {
return -1
}
return hour / v.HourStep
}
// PressureIndex returns the dataset index for a pressure level in hPa,
// or -1 when the level is unknown to this variant.
func (v *Variant) PressureIndex(hPa int) int {
v.indexLazyInit()
if i, ok := v.pressureIndex[hPa]; ok {
return i
}
return -1
}
// PressureLevelSet returns the GRIB file set carrying a pressure level.
func (v *Variant) PressureLevelSet(hPa int) (LevelSet, bool) {
v.indexLazyInit()
ls, ok := v.pressureLevelSet[hPa]
return ls, ok
}
// VariableIndex maps a GRIB (category, number) pair to a dataset variable index.
func (v *Variant) VariableIndex(parameterCategory, parameterNumber int) int {
switch {
case parameterCategory == 3 && parameterNumber == 5:
return VarHeight
case parameterCategory == 2 && parameterNumber == 2:
return VarWindU
case parameterCategory == 2 && parameterNumber == 3:
return VarWindV
default:
return -1
}
}
// GribURL returns the S3 URL for the primary (pgrb2) GRIB file.
func (v *Variant) GribURL(date string, runHour, forecastStep int) string {
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.%s.f%03d",
S3BaseURL, date, runHour, runHour, v.ResToken, forecastStep)
}
// GribURLB returns the S3 URL for the secondary (pgrb2b) GRIB file.
func (v *Variant) GribURLB(date string, runHour, forecastStep int) string {
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2b.%s.f%03d",
S3BaseURL, date, runHour, runHour, v.ResToken, forecastStep)
}
func (v *Variant) indexLazyInit() {
if v.pressureIndex != nil {
return
}
v.pressureIndex = make(map[int]int, len(v.Pressures))
for i, p := range v.Pressures {
v.pressureIndex[p] = i
}
v.pressureLevelSet = make(map[int]LevelSet, len(v.Pressures))
for _, p := range v.PressuresPgrb2 {
v.pressureLevelSet[p] = LevelSetA
}
for _, p := range v.PressuresPgrb2b {
v.pressureLevelSet[p] = LevelSetB
}
}
// Standard variants -- these mirror what NOAA publishes today.
//
// GFS0p50_3h is the historical Tawhiri default: 0.5° resolution, 3-hour
// forecast cadence, 0..192h horizon, 47 pressure levels split across the
// primary and secondary GRIB files.
//
// GFS0p25_3h mirrors the same 3-hour cadence at 0.25° resolution (the
// horizon is larger in practice but we keep 192h for parity with 0p50).
//
// GFS0p25_1h targets the 1-hourly portion NOAA publishes out to 120h.
var (
GFS0p50_3h = &Variant{
ID: "gfs-0p50-3h",
ResToken: "0p50",
Resolution: 0.5,
HourStep: 3,
MaxHour: 192,
Pressures: []int{1000, 975, 950, 925, 900, 875, 850, 825, 800, 775, 750, 725, 700, 675, 650, 625, 600, 575, 550, 525, 500, 475, 450, 425, 400, 375, 350, 325, 300, 275, 250, 225, 200, 175, 150, 125, 100, 70, 50, 30, 20, 10, 7, 5, 3, 2, 1},
PressuresPgrb2: []int{10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 925, 950, 975, 1000},
PressuresPgrb2b: []int{1, 2, 3, 5, 7, 125, 175, 225, 275, 325, 375, 425, 475, 525, 575, 625, 675, 725, 775, 825, 875},
}
GFS0p25_3h = &Variant{
ID: "gfs-0p25-3h",
ResToken: "0p25",
Resolution: 0.25,
HourStep: 3,
MaxHour: 192,
Pressures: GFS0p50_3h.Pressures,
PressuresPgrb2: GFS0p50_3h.PressuresPgrb2,
PressuresPgrb2b: GFS0p50_3h.PressuresPgrb2b,
}
GFS0p25_1h = &Variant{
ID: "gfs-0p25-1h",
ResToken: "0p25",
Resolution: 0.25,
HourStep: 1,
MaxHour: 120,
Pressures: GFS0p50_3h.Pressures,
PressuresPgrb2: GFS0p50_3h.PressuresPgrb2,
PressuresPgrb2b: GFS0p50_3h.PressuresPgrb2b,
}
)
// VariantByID returns one of the predefined variants by its ID.
func VariantByID(id string) (*Variant, error) {
switch id {
case GFS0p50_3h.ID:
return GFS0p50_3h, nil
case GFS0p25_3h.ID:
return GFS0p25_3h, nil
case GFS0p25_1h.ID:
return GFS0p25_1h, nil
case GEFS0p50_3h.ID:
return GEFS0p50_3h, nil
default:
return nil, fmt.Errorf("unknown variant %q", id)
}
}

View file

@ -0,0 +1,129 @@
package gfs
import (
"time"
"predictor-refactored/internal/numerics"
"predictor-refactored/internal/weather"
)
// Wind is a WindField backed by a GFS dataset file.
//
// The cube is addressed in flat element units with fixed strides so the
// sampler can compute the eight horizontal interpolation corners once and
// reach any (level, variable) by adding constant strides — avoiding the
// five-multiply offset computation per corner per evaluation.
type Wind struct {
file *File
hourAxis numerics.Axis
latAxis numerics.Axis
lngAxis numerics.Axis
hourStride int64 // elements between successive hours
levelStride int64 // elements between successive pressure levels
varStride int64 // elements between successive variables
latStride int64 // elements between successive latitudes
}
// NewWind returns a Wind backed by file. Axes and strides are derived from
// the file's variant geometry.
func NewWind(file *File) *Wind {
v := file.variant
nLat := v.NumLatitudes()
nLng := v.NumLongitudes()
nLev := v.NumLevels()
return &Wind{
file: file,
hourAxis: numerics.Axis{Left: 0, Step: float64(v.HourStep), N: v.NumHours(), Name: "hour"},
latAxis: numerics.Axis{Left: LatStart, Step: v.Resolution, N: nLat, Name: "lat"},
lngAxis: numerics.Axis{Left: LonStart, Step: v.Resolution, N: nLng, Wrap: true, Name: "lng"},
hourStride: int64(nLev) * NumVariables * int64(nLat) * int64(nLng),
levelStride: NumVariables * int64(nLat) * int64(nLng),
varStride: int64(nLat) * int64(nLng),
latStride: int64(nLng),
}
}
// Epoch returns the forecast run time of the underlying file.
func (w *Wind) Epoch() time.Time { return w.file.Epoch }
// Source returns the variant ID (e.g. "gfs-0p50-3h").
func (w *Wind) Source() string { return w.file.variant.ID }
// Close releases the underlying file's resources.
func (w *Wind) Close() error { return w.file.Close() }
// Wind samples the field at the given UNIX time, geographic coordinate, and
// altitude. Vertical interpolation matches Tawhiri: locate the two pressure
// levels whose interpolated geopotential heights bracket alt, then linearly
// interpolate U and V between them.
func (w *Wind) Wind(t, lat, lng, alt float64) (weather.Sample, error) {
hours := (t - float64(w.file.Epoch.Unix())) / 3600.0
bh, err := w.hourAxis.Locate(hours)
if err != nil {
return weather.Sample{}, err
}
bla, err := w.latAxis.Locate(lat)
if err != nil {
return weather.Sample{}, err
}
bln, err := w.lngAxis.Locate(lng)
if err != nil {
return weather.Sample{}, err
}
weights := numerics.TrilinearWeights([3]numerics.Bracket{bh, bla, bln})
// Flat element index of each of the eight horizontal corners, at level 0
// variable 0, in the canonical TrilinearWeights order (hour outer, lng
// inner). Reaching a given (level, variable) corner only adds constant
// strides.
var base [8]int64
hours2 := [2]int64{int64(bh.Lo) * w.hourStride, int64(bh.Hi) * w.hourStride}
lats2 := [2]int64{int64(bla.Lo) * w.latStride, int64(bla.Hi) * w.latStride}
lngs2 := [2]int64{int64(bln.Lo), int64(bln.Hi)}
i := 0
for _, h := range hours2 {
for _, la := range lats2 {
for _, ln := range lngs2 {
base[i] = h + la + ln
i++
}
}
}
sample := func(level int, varIdx int64) float64 {
off := int64(level)*w.levelStride + varIdx*w.varStride
var vals [8]float64
for k := range 8 {
vals[k] = float64(w.file.ValByElem(base[k] + off))
}
return numerics.Dot8(&weights, &vals)
}
// Largest pressure level whose interpolated geopotential height is below alt.
levelIdx := numerics.Bisect(0, w.file.variant.NumLevels()-2, alt, func(level int) float64 {
return sample(level, VarHeight)
})
lowerHGT := sample(levelIdx, VarHeight)
upperHGT := sample(levelIdx+1, VarHeight)
altFrac := 0.5
if lowerHGT != upperHGT {
altFrac = (upperHGT - alt) / (upperHGT - lowerHGT)
}
lowerU := sample(levelIdx, VarWindU)
upperU := sample(levelIdx+1, VarWindU)
lowerV := sample(levelIdx, VarWindV)
upperV := sample(levelIdx+1, VarWindV)
return weather.Sample{
U: lowerU*altFrac + upperU*(1-altFrac),
V: lowerV*altFrac + upperV*(1-altFrac),
AboveModel: altFrac < 0,
}, nil
}

View file

@ -0,0 +1,69 @@
package gfs
import (
"math"
"path/filepath"
"testing"
"time"
)
// testVariant is a tiny cube (2 hours × 3 levels × 3 lat × 4 lng) used to
// exercise the sampler without allocating a multi-gigabyte real dataset.
func testVariant() *Variant {
return &Variant{
ID: "gfs-test",
ResToken: "test",
Resolution: 90, // 180/90+1 = 3 lats, 360/90 = 4 lngs
HourStep: 3,
MaxHour: 3, // 2 hours
Pressures: []int{1000, 500, 100},
PressuresPgrb2: []int{1000, 500, 100},
PressuresPgrb2b: []int{},
}
}
func TestWindSampler(t *testing.T) {
v := testVariant()
path := filepath.Join(t.TempDir(), "cube.bin")
f, err := Create(path, v)
if err != nil {
t.Fatalf("Create: %v", err)
}
// HGT increases with level so the altitude bisection has a gradient;
// U and V are constant so interpolation must return them exactly.
for h := range v.NumHours() {
for lvl := range v.NumLevels() {
for la := range v.NumLatitudes() {
for ln := range v.NumLongitudes() {
f.SetVal(h, lvl, VarHeight, la, ln, float32(lvl*1000))
f.SetVal(h, lvl, VarWindU, la, ln, 7)
f.SetVal(h, lvl, VarWindV, la, ln, 3)
}
}
}
}
f.Flush()
f.Close()
epoch := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
rf, err := Open(path, v, epoch)
if err != nil {
t.Fatalf("Open: %v", err)
}
defer rf.Close()
w := NewWind(rf)
// Query at the dataset epoch, equator, lng 45, altitude 500m (between
// level 0 @ 0m and level 1 @ 1000m).
s, err := w.Wind(float64(epoch.Unix()), 0, 45, 500)
if err != nil {
t.Fatalf("Wind: %v", err)
}
if math.Abs(s.U-7) > 1e-5 || math.Abs(s.V-3) > 1e-5 {
t.Errorf("constant wind not recovered: got U=%v V=%v, want 7,3", s.U, s.V)
}
if s.AboveModel {
t.Errorf("AboveModel should be false at altitude within model range")
}
}

37
internal/weather/types.go Normal file
View file

@ -0,0 +1,37 @@
// Package weather defines the abstract interface trajectory engines use
// to sample atmospheric data, and contains source-specific implementations
// in its subpackages.
package weather
import "time"
// Sample is the result of sampling a wind field at one point.
type Sample struct {
// U is the eastward wind component in m/s.
U float64
// V is the northward wind component in m/s.
V float64
// AboveModel is set when the query altitude was above the highest
// pressure level represented in the underlying dataset. The returned
// U/V values are linear extrapolations and should be treated as unreliable.
AboveModel bool
}
// WindField provides 3D wind data interpolated at arbitrary points.
//
// Implementations must be safe for concurrent use.
type WindField interface {
// Wind samples the field at (t, lat, lng, alt).
//
// t is UNIX seconds. lat is in degrees, -90 to +90. lng is in degrees,
// 0 to 360 (callers must normalize). alt is metres above mean sea level.
//
// Returns an error if any coordinate is outside the field's domain.
Wind(t, lat, lng, alt float64) (Sample, error)
// Epoch returns the time the field is anchored to (forecast run time).
Epoch() time.Time
// Source identifies the underlying dataset for logs and metrics.
Source() string
}

63
internal/windviz/cache.go Normal file
View file

@ -0,0 +1,63 @@
package windviz
import (
"sync"
"time"
)
// Cache is a small bounded cache of rasterized fields keyed by request
// parameters and dataset epoch. It is safe for concurrent use.
//
// Visualization requests repeat heavily (a frontend re-fetches the same
// layer as users pan within a tile), so even a tiny cache removes most
// recomputation. Eviction is simplest-possible: when full, the whole map is
// cleared. Entries also expire after TTL.
type Cache struct {
mu sync.Mutex
entries map[string]cacheEntry
max int
ttl time.Duration
now func() time.Time
}
type cacheEntry struct {
field Field
expires time.Time
}
// NewCache returns a cache holding up to max entries for ttl each.
func NewCache(max int, ttl time.Duration) *Cache {
if max <= 0 {
max = 64
}
if ttl <= 0 {
ttl = 10 * time.Minute
}
return &Cache{
entries: make(map[string]cacheEntry, max),
max: max,
ttl: ttl,
now: time.Now,
}
}
// Get returns the cached field for key, if present and unexpired.
func (c *Cache) Get(key string) (Field, bool) {
c.mu.Lock()
defer c.mu.Unlock()
e, ok := c.entries[key]
if !ok || c.now().After(e.expires) {
return nil, false
}
return e.field, true
}
// Put stores field under key.
func (c *Cache) Put(key string, field Field) {
c.mu.Lock()
defer c.mu.Unlock()
if len(c.entries) >= c.max {
c.entries = make(map[string]cacheEntry, c.max)
}
c.entries[key] = cacheEntry{field: field, expires: c.now().Add(c.ttl)}
}

179
internal/windviz/windviz.go Normal file
View file

@ -0,0 +1,179 @@
// Package windviz rasterizes a weather.WindField into the JSON grid format
// consumed by browser velocity layers such as leaflet-velocity and
// wind-layer (the "gfs.json" / wind-js-server format).
//
// The module is decoupled from any specific dataset: it samples any
// weather.WindField on a regular latitude/longitude grid at a chosen time
// and altitude, downsampling by a configurable step to bound payload size.
package windviz
import (
"fmt"
"time"
"predictor-refactored/internal/weather"
)
// Request describes a wind-field rasterization.
type Request struct {
// Time is the forecast time to sample (UNIX seconds). Sampling outside
// the field's temporal coverage returns an error.
Time float64
// Altitude is the altitude in metres to sample at.
Altitude float64
// Bounding box in degrees. Latitudes in [-90, 90]; longitudes in
// [0, 360). For a global field use 0..360 (the rasterizer drops the
// duplicate 360° column).
MinLat, MaxLat float64
MinLng, MaxLng float64
// Step is the grid resolution in degrees (e.g. 1.0). Smaller is denser.
Step float64
}
// Component is one wind-js-server record: a header plus a flat data grid.
type Component struct {
Header Header `json:"header"`
Data []float64 `json:"data"`
}
// Header is the wind-js-server grid header. Field names and semantics match
// what leaflet-velocity / wind-layer expect.
type Header struct {
ParameterCategory int `json:"parameterCategory"`
ParameterNumber int `json:"parameterNumber"`
ParameterNumberName string `json:"parameterNumberName"`
ParameterUnit string `json:"parameterUnit"`
Nx int `json:"nx"`
Ny int `json:"ny"`
Lo1 float64 `json:"lo1"`
La1 float64 `json:"la1"`
Lo2 float64 `json:"lo2"`
La2 float64 `json:"la2"`
Dx float64 `json:"dx"`
Dy float64 `json:"dy"`
RefTime string `json:"refTime"`
ForecastTime int `json:"forecastTime"`
}
// Field is the two-component (U then V) payload. JSON-encoding a Field
// produces the array the velocity layers consume directly.
type Field []Component
const (
defaultStep = 1.0
minStep = 0.25 // clamp to bound output size
maxCells = 1 << 21
)
// Rasterize samples field over req and returns the U/V grid payload.
//
// Data is laid out in wind-js scan order: row 0 is the northernmost
// latitude (la1), each row runs west→east, longitudes increasing. Per-cell
// sampling errors (e.g. altitude outside the model) are written as 0 rather
// than failing the whole request; a time outside coverage is a hard error.
func Rasterize(field weather.WindField, req Request) (Field, error) {
step := req.Step
if step <= 0 {
step = defaultStep
}
if step < minStep {
step = minStep
}
minLat, maxLat := req.MinLat, req.MaxLat
minLng, maxLng := req.MinLng, req.MaxLng
if minLat == 0 && maxLat == 0 {
minLat, maxLat = -90, 90
}
if minLng == 0 && maxLng == 0 {
minLng, maxLng = 0, 360
}
if maxLat <= minLat {
return nil, fmt.Errorf("invalid bounding box latitude")
}
// Longitudes may arrive in either the [0, 360) or the [-180, 180]
// convention (the latter is what the rest of the API emits). Detect a
// full-globe span first, then fold a regional box's western edge into
// [0, 360); per-cell sampling re-folds via normLng so an eastern edge
// past 360° still reads the correct column.
lngSpan := maxLng - minLng
if lngSpan <= 0 {
return nil, fmt.Errorf("invalid bounding box longitude")
}
global := lngSpan >= 360-1e-9
var nx int
if global {
// Drop the duplicate wrap column so the layer tiles cleanly.
minLng = 0
nx = int(360/step + 0.5)
maxLng = float64(nx-1) * step
} else {
minLng = normLng(minLng)
maxLng = minLng + lngSpan
nx = int(lngSpan/step+0.5) + 1
}
ny := int((maxLat-minLat)/step+0.5) + 1
if nx < 1 || ny < 1 {
return nil, fmt.Errorf("empty grid")
}
if nx*ny > maxCells {
return nil, fmt.Errorf("grid too large (%d cells); increase step or shrink bbox", nx*ny)
}
u := make([]float64, nx*ny)
v := make([]float64, nx*ny)
// Row 0 = north (la1); rows descend in latitude.
for j := range ny {
lat := maxLat - float64(j)*step
for i := range nx {
lng := minLng + float64(i)*step
s, err := field.Wind(req.Time, lat, normLng(lng), req.Altitude)
idx := j*nx + i
if err != nil {
continue // leave as 0
}
u[idx] = s.U
v[idx] = s.V
}
}
refTime := time.Unix(int64(req.Time), 0).UTC().Format("2006-01-02T15:04:05.000Z")
mk := func(num int, name string, data []float64) Component {
return Component{
Header: Header{
ParameterCategory: 2,
ParameterNumber: num,
ParameterNumberName: name,
ParameterUnit: "m.s-1",
Nx: nx,
Ny: ny,
Lo1: minLng,
La1: maxLat,
Lo2: maxLng,
La2: minLat,
Dx: step,
Dy: step,
RefTime: refTime,
ForecastTime: 0,
},
Data: data,
}
}
return Field{
mk(2, "eastward_wind", u),
mk(3, "northward_wind", v),
}, nil
}
// normLng folds a longitude into [0, 360) for sampling.
func normLng(lng float64) float64 {
for lng < 0 {
lng += 360
}
for lng >= 360 {
lng -= 360
}
return lng
}

View file

@ -0,0 +1,96 @@
package windviz
import (
"testing"
"time"
"predictor-refactored/internal/weather"
)
// constWind is a WindField returning a fixed sample everywhere.
type constWind struct {
u, v float64
epoch time.Time
}
func (c constWind) Wind(_ float64, _, _, _ float64) (weather.Sample, error) {
return weather.Sample{U: c.u, V: c.v}, nil
}
func (c constWind) Epoch() time.Time { return c.epoch }
func (c constWind) Source() string { return "test" }
func TestRasterizeGlobalDropsDuplicateColumn(t *testing.T) {
f := constWind{u: 5, v: -3, epoch: time.Unix(0, 0)}
out, err := Rasterize(f, Request{MinLng: 0, MaxLng: 360, Step: 90})
if err != nil {
t.Fatalf("Rasterize: %v", err)
}
if len(out) != 2 {
t.Fatalf("expected 2 components, got %d", len(out))
}
u := out[0]
// 360/90 = 4 columns (no duplicate 360°); lat -90..90 step 90 = 3 rows.
if u.Header.Nx != 4 || u.Header.Ny != 3 {
t.Errorf("grid = %dx%d, want 4x3", u.Header.Nx, u.Header.Ny)
}
if len(u.Data) != 12 {
t.Errorf("data len = %d, want 12", len(u.Data))
}
if u.Header.La1 != 90 || u.Header.La2 != -90 {
t.Errorf("lat range = %v..%v, want 90..-90 (north first)", u.Header.La1, u.Header.La2)
}
if u.Header.Lo1 != 0 || u.Header.Lo2 != 270 {
t.Errorf("lng range = %v..%v, want 0..270", u.Header.Lo1, u.Header.Lo2)
}
for _, d := range u.Data {
if d != 5 {
t.Errorf("U data = %v, want 5", d)
break
}
}
if out[0].Header.ParameterNumber != 2 || out[1].Header.ParameterNumber != 3 {
t.Errorf("component order should be U(2) then V(3)")
}
}
func TestRasterizeSignedLongitudeConvention(t *testing.T) {
f := constWind{u: 1, v: 2, epoch: time.Unix(0, 0)}
// A [-180, 180] global request must be detected as global and tiled
// without a duplicate seam column, identical to a 0..360 request.
signed, err := Rasterize(f, Request{MinLng: -180, MaxLng: 180, Step: 90})
if err != nil {
t.Fatalf("signed-global Rasterize: %v", err)
}
if signed[0].Header.Nx != 4 {
t.Errorf("signed-global nx = %d, want 4 (no duplicate column)", signed[0].Header.Nx)
}
// A western-hemisphere box must not 400; its western edge folds into [0,360).
west, err := Rasterize(f, Request{MinLat: 10, MaxLat: 20, MinLng: -100, MaxLng: -50, Step: 10})
if err != nil {
t.Fatalf("western-box Rasterize: %v", err)
}
if west[0].Header.Lo1 != 260 {
t.Errorf("western-box lo1 = %v, want 260 (=-100 folded)", west[0].Header.Lo1)
}
}
func TestRasterizeStepClamp(t *testing.T) {
f := constWind{epoch: time.Unix(0, 0)}
// step below min gets clamped, not rejected.
if _, err := Rasterize(f, Request{MinLat: -1, MaxLat: 1, MinLng: 0, MaxLng: 2, Step: 0.01}); err != nil {
t.Fatalf("Rasterize with tiny step: %v", err)
}
}
func TestCacheRoundTrip(t *testing.T) {
c := NewCache(2, time.Minute)
if _, ok := c.Get("a"); ok {
t.Errorf("empty cache should miss")
}
c.Put("a", Field{})
if _, ok := c.Get("a"); !ok {
t.Errorf("cache should hit after put")
}
}

Some files were not shown because too many files have changed in this diff Show more