engine refactor

This commit is contained in:
Anatoly Antonov 2026-05-23 00:55:35 +09:00
parent 9e663db9dc
commit 81b8e763bd
37 changed files with 3532 additions and 1639 deletions

View file

@ -3,29 +3,33 @@
//
// Endpoints:
//
// GET /api/v1/admin/datasets list stored epochs
// POST /api/v1/admin/datasets trigger a download
// DELETE /api/v1/admin/datasets/{epoch} delete a stored epoch
// GET /api/v1/admin/jobs list all jobs
// GET /api/v1/admin/jobs/{id} fetch one job
// DELETE /api/v1/admin/jobs/{id} cancel a running job
// GET /api/v1/admin/datasets list stored datasets
// POST /api/v1/admin/datasets trigger a download
// DELETE /api/v1/admin/datasets/{name} delete a stored dataset by filename
// GET /api/v1/admin/jobs list all jobs
// GET /api/v1/admin/jobs/{id} fetch one job
// DELETE /api/v1/admin/jobs/{id} cancel a running job
// GET /api/v1/admin/status service status summary
package admin
import (
"context"
"encoding/json"
"net/http"
"runtime"
"time"
"go.uber.org/zap"
"predictor-refactored/internal/api/httpjson"
"predictor-refactored/internal/datasets"
)
// Handler serves all /api/v1/admin/* endpoints.
type Handler struct {
mgr *datasets.Manager
log *zap.Logger
mgr *datasets.Manager
start time.Time
log *zap.Logger
}
// New wires an admin handler.
@ -33,52 +37,94 @@ func New(mgr *datasets.Manager, log *zap.Logger) *Handler {
if log == nil {
log = zap.NewNop()
}
return &Handler{mgr: mgr, log: log}
return &Handler{mgr: mgr, start: time.Now().UTC(), log: log}
}
// Register installs admin routes on mux. Routes are mounted under
// /api/v1/admin/...
// Register installs admin routes on mux.
func (h *Handler) Register(mux *http.ServeMux) {
mux.HandleFunc("GET /api/v1/admin/datasets", h.listDatasets)
mux.HandleFunc("POST /api/v1/admin/datasets", h.triggerDownload)
mux.HandleFunc("DELETE /api/v1/admin/datasets/{epoch}", h.deleteDataset)
mux.HandleFunc("DELETE /api/v1/admin/datasets/{name}", h.deleteDataset)
mux.HandleFunc("GET /api/v1/admin/jobs", h.listJobs)
mux.HandleFunc("GET /api/v1/admin/jobs/{id}", h.getJob)
mux.HandleFunc("DELETE /api/v1/admin/jobs/{id}", h.cancelJob)
mux.HandleFunc("GET /api/v1/admin/status", h.status)
}
// datasetDTO is the JSON shape of one stored dataset.
type datasetDTO struct {
Filename string `json:"filename"`
Epoch string `json:"epoch"`
Subset *subsetDTO `json:"subset,omitempty"`
Coverage *coverageDTO `json:"coverage,omitempty"`
Loaded bool `json:"loaded"`
}
type subsetDTO struct {
Region *datasets.Region `json:"region,omitempty"`
HourRange *datasets.HourRange `json:"hour_range,omitempty"`
Members []int `json:"members,omitempty"`
}
type coverageDTO struct {
Region datasets.Region `json:"region"`
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
}
// listDatasets handles GET /api/v1/admin/datasets.
func (h *Handler) listDatasets(w http.ResponseWriter, _ *http.Request) {
epochs, err := h.mgr.ListEpochs()
stored, err := h.mgr.ListEpochs()
if err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
active := ""
if a := h.mgr.Active(); a != nil {
active = a.Epoch().UTC().Format(time.RFC3339)
loaded := h.mgr.LoadedDatasets()
loadedByName := make(map[string]datasets.LoadedDatasetInfo, len(loaded))
for _, ld := range loaded {
loadedByName[ld.ID.Filename()] = ld
}
out := struct {
Source string `json:"source"`
Active string `json:"active,omitempty"`
Epochs []string `json:"epochs"`
}{
Source: h.mgr.Source(),
Active: active,
}
for _, e := range epochs {
out.Epochs = append(out.Epochs, e.UTC().Format(time.RFC3339))
Source string `json:"source"`
Datasets []datasetDTO `json:"datasets"`
}{Source: h.mgr.Source(), Datasets: make([]datasetDTO, 0, len(stored))}
for _, id := range stored {
dto := datasetDTO{
Filename: id.Filename(),
Epoch: id.Epoch.UTC().Format(time.RFC3339),
}
if !id.Subset.IsGlobal() {
dto.Subset = &subsetDTO{
Region: id.Subset.Region,
HourRange: id.Subset.HourRange,
Members: id.Subset.Members,
}
}
if ld, ok := loadedByName[id.Filename()]; ok {
dto.Loaded = true
dto.Coverage = &coverageDTO{
Region: ld.Coverage.Region,
StartTime: ld.Coverage.StartTime.UTC().Format(time.RFC3339),
EndTime: ld.Coverage.EndTime.UTC().Format(time.RFC3339),
}
}
out.Datasets = append(out.Datasets, dto)
}
writeJSON(w, http.StatusOK, out)
}
// triggerDownload handles POST /api/v1/admin/datasets.
//
// Body: {"epoch": "2026-03-28T06:00:00Z"} OR {"latest": true}.
// Body:
// {"latest": true} — refresh the latest global dataset
// {"epoch": "2026-03-28T06:00:00Z", "subset": {...}} — explicit dataset
func (h *Handler) triggerDownload(w http.ResponseWriter, r *http.Request) {
var body struct {
Epoch string `json:"epoch,omitempty"`
Latest bool `json:"latest,omitempty"`
Epoch string `json:"epoch,omitempty"`
Latest bool `json:"latest,omitempty"`
Subset *datasets.SubsetSpec `json:"subset,omitempty"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeError(w, http.StatusBadRequest, "invalid body: "+err.Error())
@ -89,7 +135,6 @@ func (h *Handler) triggerDownload(w http.ResponseWriter, r *http.Request) {
return
}
var epoch time.Time
if body.Latest {
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
defer cancel()
@ -102,29 +147,40 @@ func (h *Handler) triggerDownload(w http.ResponseWriter, r *http.Request) {
return
}
var err error
epoch, err = time.Parse(time.RFC3339, body.Epoch)
epoch, err := time.Parse(time.RFC3339, body.Epoch)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid epoch: "+err.Error())
return
}
jobID := h.mgr.Download(epoch)
id := datasets.DatasetID{Epoch: epoch.UTC()}
if body.Subset != nil {
id.Subset = *body.Subset
}
jobID := h.mgr.Download(id)
writeJSON(w, http.StatusAccepted, map[string]string{"job_id": jobID})
}
// deleteDataset handles DELETE /api/v1/admin/datasets/{epoch}.
// deleteDataset handles DELETE /api/v1/admin/datasets/{name}.
//
// {name} is the dataset filename (DatasetID.Filename()) as returned by GET.
func (h *Handler) deleteDataset(w http.ResponseWriter, r *http.Request) {
rawEpoch := r.PathValue("epoch")
epoch, err := time.Parse(time.RFC3339, rawEpoch)
name := r.PathValue("name")
stored, err := h.mgr.ListEpochs()
if err != nil {
writeError(w, http.StatusBadRequest, "invalid epoch: "+err.Error())
return
}
if err := h.mgr.RemoveEpoch(epoch); err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
w.WriteHeader(http.StatusNoContent)
for _, id := range stored {
if id.Filename() == name {
if err := h.mgr.Remove(id); err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
w.WriteHeader(http.StatusNoContent)
return
}
}
writeError(w, http.StatusNotFound, "dataset not found")
}
// listJobs handles GET /api/v1/admin/jobs.
@ -158,24 +214,59 @@ func (h *Handler) cancelJob(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
}
// status handles GET /api/v1/admin/status — a consolidated dashboard view.
func (h *Handler) status(w http.ResponseWriter, _ *http.Request) {
jobs := h.mgr.ListJobs()
stored, _ := h.mgr.ListEpochs()
loaded := h.mgr.LoadedDatasets()
counts := map[string]int{}
for _, j := range jobs {
counts[string(j.Status)]++
}
var mem runtime.MemStats
runtime.ReadMemStats(&mem)
resp := struct {
Source string `json:"source"`
Uptime string `json:"uptime"`
Goroutines int `json:"goroutines"`
MemoryMB uint64 `json:"memory_mb"`
JobsByStatus map[string]int `json:"jobs_by_status"`
Stored int `json:"stored_datasets"`
Loaded int `json:"loaded_datasets"`
}{
Source: h.mgr.Source(),
Uptime: time.Since(h.start).Round(time.Second).String(),
Goroutines: runtime.NumGoroutine(),
MemoryMB: mem.Alloc / 1024 / 1024,
JobsByStatus: counts,
Stored: len(stored),
Loaded: len(loaded),
}
writeJSON(w, http.StatusOK, resp)
}
type jobDTO struct {
ID string `json:"id"`
Source string `json:"source"`
Epoch string `json:"epoch"`
Status string `json:"status"`
StartedAt string `json:"started_at"`
EndedAt string `json:"ended_at,omitempty"`
Err string `json:"error,omitempty"`
Total int `json:"total_units"`
Done int `json:"done_units"`
Bytes int64 `json:"bytes"`
ID string `json:"id"`
Source string `json:"source"`
Dataset string `json:"dataset"`
Epoch string `json:"epoch"`
Status string `json:"status"`
StartedAt string `json:"started_at"`
EndedAt string `json:"ended_at,omitempty"`
Err string `json:"error,omitempty"`
Total int `json:"total_units"`
Done int `json:"done_units"`
Bytes int64 `json:"bytes"`
}
func toDTO(j datasets.JobInfo) jobDTO {
dto := jobDTO{
ID: j.ID,
Source: j.Source,
Epoch: j.Epoch.UTC().Format(time.RFC3339),
Dataset: j.Dataset.Filename(),
Epoch: j.Dataset.Epoch.UTC().Format(time.RFC3339),
Status: string(j.Status),
StartedAt: j.StartedAt.UTC().Format(time.RFC3339),
Err: j.Err,
@ -189,18 +280,5 @@ func toDTO(j datasets.JobInfo) jobDTO {
return dto
}
func writeJSON(w http.ResponseWriter, status int, body any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(body)
}
func writeError(w http.ResponseWriter, status int, description string) {
writeJSON(w, status, map[string]any{
"error": map[string]string{
"type": http.StatusText(status),
"description": description,
},
})
}
var writeJSON = httpjson.Write
var writeError = httpjson.Error

View file

@ -0,0 +1,63 @@
package async
import (
"encoding/json"
"net/http"
"predictor-refactored/internal/api/httpjson"
"predictor-refactored/internal/api/v2"
)
// Handler implements the /api/v1/predictions{,/{id}} endpoints.
type Handler struct {
mgr *Manager
}
// NewHandler wires a handler.
func NewHandler(mgr *Manager) *Handler { return &Handler{mgr: mgr} }
// Register installs the async routes on mux.
func (h *Handler) Register(mux *http.ServeMux) {
mux.HandleFunc("POST /api/v1/predictions", h.create)
mux.HandleFunc("GET /api/v1/predictions/{id}", h.get)
mux.HandleFunc("DELETE /api/v1/predictions/{id}", h.cancel)
}
func (h *Handler) create(w http.ResponseWriter, r *http.Request) {
var req v2.PredictionRequest
dec := json.NewDecoder(r.Body)
dec.DisallowUnknownFields()
if err := dec.Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid body: "+err.Error())
return
}
info, accepted := h.mgr.Enqueue(req)
if !accepted {
writeJSON(w, http.StatusServiceUnavailable, info)
return
}
w.Header().Set("Location", "/api/v1/predictions/"+info.ID)
writeJSON(w, http.StatusAccepted, info)
}
func (h *Handler) get(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
info, ok := h.mgr.Get(id)
if !ok {
writeError(w, http.StatusNotFound, "prediction job not found")
return
}
writeJSON(w, http.StatusOK, info)
}
func (h *Handler) cancel(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
if !h.mgr.Cancel(id) {
writeError(w, http.StatusConflict, "job not found or already terminal")
return
}
w.WriteHeader(http.StatusNoContent)
}
var writeJSON = httpjson.Write
var writeError = httpjson.Error

View file

@ -0,0 +1,276 @@
// Package async implements the asynchronous prediction endpoints
// (/api/v1/predictions{,/{id}}) and the worker pool that executes them.
//
// Each enqueued request is assigned a job ID; the result is held in
// memory for a configurable TTL after completion.
package async
import (
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"go.uber.org/zap"
"predictor-refactored/internal/api/v2"
"predictor-refactored/internal/datasets"
"predictor-refactored/internal/elevation"
"predictor-refactored/internal/metrics"
)
// Status is the lifecycle state of a prediction job.
type Status string
const (
StatusPending Status = "pending"
StatusRunning Status = "running"
StatusComplete Status = "complete"
StatusFailed Status = "failed"
StatusCancelled Status = "cancelled"
)
// JobInfo is the externally-visible snapshot of one prediction job.
type JobInfo struct {
ID string `json:"id"`
Status Status `json:"status"`
CreatedAt time.Time `json:"created_at"`
StartedAt *time.Time `json:"started_at,omitempty"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
Error string `json:"error,omitempty"`
Result *v2.PredictionResponse `json:"result,omitempty"`
}
type job struct {
id string
req v2.PredictionRequest
createdAt time.Time
mu sync.Mutex
status Status
startedAt time.Time
completedAt time.Time
errStr string
result *v2.PredictionResponse
cancel chan struct{}
}
func (j *job) snapshot() JobInfo {
j.mu.Lock()
defer j.mu.Unlock()
info := JobInfo{
ID: j.id,
Status: j.status,
CreatedAt: j.createdAt,
Error: j.errStr,
Result: j.result,
}
if !j.startedAt.IsZero() {
t := j.startedAt
info.StartedAt = &t
}
if !j.completedAt.IsZero() {
t := j.completedAt
info.CompletedAt = &t
}
return info
}
// Manager runs a fixed pool of workers to execute prediction jobs and
// retains their results for the configured TTL.
type Manager struct {
mgr *datasets.Manager
elev *elevation.Dataset
metrics metrics.Sink
log *zap.Logger
queue chan *job
ttl time.Duration
jobsMu sync.RWMutex
jobs map[string]*job
inflight atomic.Int64
closed chan struct{}
wg sync.WaitGroup
}
// Config controls Manager construction.
type Config struct {
// Workers is the maximum concurrent prediction executions.
Workers int
// QueueSize bounds the number of jobs waiting to start.
QueueSize int
// ResultTTL is how long completed/failed jobs are retained in memory.
ResultTTL time.Duration
}
// New constructs a Manager with the given config and starts the workers.
func New(cfg Config, mgr *datasets.Manager, elev *elevation.Dataset, sink metrics.Sink, log *zap.Logger) *Manager {
if cfg.Workers <= 0 {
cfg.Workers = 4
}
if cfg.QueueSize <= 0 {
cfg.QueueSize = 64
}
if cfg.ResultTTL <= 0 {
cfg.ResultTTL = time.Hour
}
if sink == nil {
sink = metrics.Noop()
}
if log == nil {
log = zap.NewNop()
}
m := &Manager{
mgr: mgr, elev: elev, metrics: sink, log: log,
queue: make(chan *job, cfg.QueueSize),
jobs: make(map[string]*job),
ttl: cfg.ResultTTL,
closed: make(chan struct{}),
}
for range cfg.Workers {
m.wg.Add(1)
go m.worker()
}
m.wg.Add(1)
go m.evictor()
return m
}
// Enqueue creates a new job from req and returns its snapshot.
// Returns false when the queue is full.
func (m *Manager) Enqueue(req v2.PredictionRequest) (JobInfo, bool) {
j := &job{
id: uuid.New().String(),
req: req,
createdAt: time.Now().UTC(),
status: StatusPending,
cancel: make(chan struct{}),
}
m.jobsMu.Lock()
m.jobs[j.id] = j
m.jobsMu.Unlock()
select {
case m.queue <- j:
return j.snapshot(), true
default:
// Queue full — mark the job failed and return it.
j.mu.Lock()
j.status = StatusFailed
j.errStr = "prediction queue full"
j.completedAt = time.Now().UTC()
j.mu.Unlock()
return j.snapshot(), false
}
}
// Get returns a job's snapshot.
func (m *Manager) Get(id string) (JobInfo, bool) {
m.jobsMu.RLock()
j, ok := m.jobs[id]
m.jobsMu.RUnlock()
if !ok {
return JobInfo{}, false
}
return j.snapshot(), true
}
// Cancel marks a not-yet-started job as cancelled. Returns false when the
// job is unknown or already terminal.
func (m *Manager) Cancel(id string) bool {
m.jobsMu.RLock()
j, ok := m.jobs[id]
m.jobsMu.RUnlock()
if !ok {
return false
}
j.mu.Lock()
terminal := j.status == StatusComplete || j.status == StatusFailed || j.status == StatusCancelled
if terminal {
j.mu.Unlock()
return false
}
j.status = StatusCancelled
j.completedAt = time.Now().UTC()
j.mu.Unlock()
close(j.cancel)
return true
}
// Inflight returns the count of running jobs.
func (m *Manager) Inflight() int64 { return m.inflight.Load() }
// Close shuts down workers and the evictor.
func (m *Manager) Close() {
close(m.closed)
close(m.queue)
m.wg.Wait()
}
func (m *Manager) worker() {
defer m.wg.Done()
for j := range m.queue {
// Check cancellation before starting.
j.mu.Lock()
cancelled := j.status == StatusCancelled
j.mu.Unlock()
if cancelled {
continue
}
m.inflight.Add(1)
j.mu.Lock()
j.status = StatusRunning
j.startedAt = time.Now().UTC()
j.mu.Unlock()
resp, err := v2.Run(m.mgr, m.elev, j.req)
j.mu.Lock()
j.completedAt = time.Now().UTC()
if err != nil {
j.status = StatusFailed
j.errStr = err.Error()
} else {
j.status = StatusComplete
j.result = resp
}
j.mu.Unlock()
m.inflight.Add(-1)
if err == nil {
m.metrics.Prediction("async", j.completedAt.Sub(j.startedAt), nil)
} else {
m.metrics.Prediction("async", j.completedAt.Sub(j.startedAt), err)
}
}
}
func (m *Manager) evictor() {
defer m.wg.Done()
ticker := time.NewTicker(m.ttl / 4)
defer ticker.Stop()
for {
select {
case <-m.closed:
return
case <-ticker.C:
m.evictExpired()
}
}
}
func (m *Manager) evictExpired() {
now := time.Now().UTC()
m.jobsMu.Lock()
defer m.jobsMu.Unlock()
for id, j := range m.jobs {
j.mu.Lock()
expired := !j.completedAt.IsZero() && now.Sub(j.completedAt) > m.ttl
j.mu.Unlock()
if expired {
delete(m.jobs, id)
}
}
}

View file

@ -0,0 +1,27 @@
// Package httpjson holds the tiny JSON response helpers shared across
// the admin, v2, and async handlers.
package httpjson
import (
"encoding/json"
"net/http"
)
// Write writes body as JSON with the given status code.
func Write(w http.ResponseWriter, status int, body any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(body)
}
// Error writes a standard error JSON body with the given status code.
//
// Shape: {"error": {"type": "...", "description": "..."}}
func Error(w http.ResponseWriter, status int, description string) {
Write(w, status, map[string]any{
"error": map[string]string{
"type": http.StatusText(status),
"description": description,
},
})
}

View file

@ -2,8 +2,8 @@
// (GET /api/v1/prediction). The request/response shapes match the original
// Cambridge University Spaceflight predictor for drop-in compatibility.
//
// Internally the handler builds an engine.Profile from query parameters and
// dispatches it through the same engine path as the new v2 endpoint.
// Internally the handler builds an engine.Profile from query parameters
// and dispatches it through the same engine path as the new v2 endpoint.
package tawhiri
import (
@ -18,11 +18,11 @@ import (
"predictor-refactored/internal/elevation"
"predictor-refactored/internal/engine"
"predictor-refactored/internal/metrics"
"predictor-refactored/internal/weather"
api "predictor-refactored/pkg/rest"
)
// Handler implements api.Handler (the ogen-generated interface for
// performPrediction and readinessCheck).
// Handler implements api.Handler (ogen-generated interface).
type Handler struct {
mgr *datasets.Manager
elev *elevation.Dataset
@ -41,111 +41,49 @@ func New(mgr *datasets.Manager, elev *elevation.Dataset, sink metrics.Sink, log
return &Handler{mgr: mgr, elev: elev, metrics: sink, log: log}
}
// Compile-time check that Handler satisfies api.Handler.
var _ api.Handler = (*Handler)(nil)
// PerformPrediction runs the Tawhiri-style prediction.
func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredictionParams) (*api.PredictionResponse, error) {
func (h *Handler) PerformPrediction(_ context.Context, params api.PerformPredictionParams) (*api.PredictionResponse, error) {
field := h.mgr.Active()
if field == nil {
return nil, newError(http.StatusServiceUnavailable, "no dataset loaded, service is starting up")
}
// Parameters with Tawhiri defaults.
profileKind := "standard_profile"
if v, ok := params.Profile.Get(); ok {
profileKind = string(v)
}
ascentRate := 5.0
if v, ok := params.AscentRate.Get(); ok {
ascentRate = v
}
burstAltitude := 28000.0
if v, ok := params.BurstAltitude.Get(); ok {
burstAltitude = v
}
descentRate := 5.0
if v, ok := params.DescentRate.Get(); ok {
descentRate = v
}
launchAlt := 0.0
if v, ok := params.LaunchAltitude.Get(); ok {
launchAlt = v
}
profileKind := optString(params.Profile, "standard_profile")
ascentRate := optFloat(params.AscentRate, 5.0)
burstAltitude := optFloat(params.BurstAltitude, 28000.0)
descentRate := optFloat(params.DescentRate, 5.0)
launchAlt := optFloat(params.LaunchAltitude, 0.0)
lng := params.LaunchLongitude
if lng < 0 {
lng += 360
}
launchTime := float64(params.LaunchDatetime.Unix())
warnings := &engine.Warnings{}
// Build the profile.
events := engine.NewEventSink()
var stageNames []string
var prof engine.Profile
switch profileKind {
case "standard_profile":
stageNames = []string{"ascent", "descent"}
prof = engine.Profile{
Direction: engine.Forward,
Stages: []*engine.Propagator{
{
Name: "ascent",
Step: 60,
Model: engine.Sum(
engine.ConstantRate(ascentRate),
engine.WindTransport(field, warnings),
),
Constraints: []engine.Constraint{engine.MaxAltitude{Limit: burstAltitude, On: engine.ActionStop}},
},
{
Name: "descent",
Step: 60,
Model: engine.Sum(
engine.ParachuteDescent(descentRate),
engine.WindTransport(field, warnings),
),
Constraints: descentConstraints(h.elev),
},
},
}
prof = standardProfile(field, h.elev, events, ascentRate, burstAltitude, descentRate)
case "float_profile":
floatAlt := 25000.0
if v, ok := params.FloatAltitude.Get(); ok {
floatAlt = v
}
floatAlt := optFloat(params.FloatAltitude, 25000.0)
stopTime := params.LaunchDatetime.Add(24 * time.Hour)
if v, ok := params.StopDatetime.Get(); ok {
stopTime = v
}
stageNames = []string{"ascent", "float"}
prof = engine.Profile{
Direction: engine.Forward,
Stages: []*engine.Propagator{
{
Name: "ascent",
Step: 60,
Model: engine.Sum(
engine.ConstantRate(ascentRate),
engine.WindTransport(field, warnings),
),
Constraints: []engine.Constraint{engine.MaxAltitude{Limit: floatAlt, On: engine.ActionStop}},
},
{
Name: "float",
Step: 60,
Model: engine.WindTransport(field, warnings),
Constraints: []engine.Constraint{engine.MaxTime{Limit: float64(stopTime.Unix()), On: engine.ActionStop}},
},
},
}
prof = floatProfile(field, events, ascentRate, floatAlt, stopTime)
default:
return nil, newError(http.StatusBadRequest, "unknown profile: "+profileKind)
}
started := time.Now().UTC()
results := prof.Run(launchTime, engine.State{Lat: params.LaunchLatitude, Lng: lng, Altitude: launchAlt})
results := prof.Run(launchTime, engine.State{Lat: params.LaunchLatitude, Lng: lng, Altitude: launchAlt}, events)
completed := time.Now().UTC()
h.metrics.Prediction(profileKind, completed.Sub(started), nil)
@ -161,30 +99,7 @@ func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredi
if i < len(stageNames) {
stageName = stageNames[i]
}
stageEnum := api.PredictionResponsePredictionItemStageAscent
switch stageName {
case "descent":
stageEnum = api.PredictionResponsePredictionItemStageDescent
case "float":
stageEnum = api.PredictionResponsePredictionItemStageFloat
}
traj := make([]api.PredictionResponsePredictionItemTrajectoryItem, 0, len(r.Points))
for _, pt := range r.Points {
ptLng := pt.Lng
if ptLng > 180 {
ptLng -= 360
}
traj = append(traj, api.PredictionResponsePredictionItemTrajectoryItem{
Datetime: time.Unix(int64(pt.Time), 0).UTC(),
Latitude: pt.Lat,
Longitude: ptLng,
Altitude: pt.Altitude,
})
}
resp.Prediction = append(resp.Prediction, api.PredictionResponsePredictionItem{
Stage: stageEnum,
Trajectory: traj,
})
resp.Prediction = append(resp.Prediction, buildPredictionItem(stageName, r))
}
resp.Request = api.NewOptPredictionResponseRequest(api.PredictionResponseRequest{
@ -195,7 +110,8 @@ func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredi
LaunchAltitude: params.LaunchAltitude,
})
if warns := warnings.ToMap(); len(warns) > 0 {
if ev := events.Snapshot(); len(ev) > 0 {
// Preserve the OpenAPI-defined Warnings shape (open object).
resp.Warnings = api.NewOptPredictionResponseWarnings(api.PredictionResponseWarnings{})
}
@ -207,13 +123,78 @@ func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredi
return resp, nil
}
// descentConstraints returns the descent termination set: TerrainContact if an
// elevation dataset is loaded, MinAltitude(0) otherwise.
func descentConstraints(elev *elevation.Dataset) []engine.Constraint {
// standardProfile constructs the ascent → descent profile.
func standardProfile(field weather.WindField, elev *elevation.Dataset, events *engine.EventSink, ascentRate, burstAltitude, descentRate float64) engine.Profile {
wind := engine.WindTransport(field, events)
descentTerm := []engine.Constraint{engine.Altitude{Op: engine.OpLessEqual, Limit: 0, On: engine.ActionStop}}
if elev != nil {
return []engine.Constraint{engine.TerrainContact{Provider: elev, On: engine.ActionStop}}
descentTerm = []engine.Constraint{engine.TerrainContact{Provider: elev, On: engine.ActionStop}}
}
return []engine.Constraint{engine.MinAltitude{Limit: 0, On: engine.ActionStop}}
return engine.Profile{
Direction: engine.Forward,
Stages: []*engine.Propagator{
{
Name: "ascent",
Step: 60,
Model: engine.Sum(engine.ConstantRate(ascentRate), wind),
Constraints: []engine.Constraint{engine.Altitude{Op: engine.OpGreaterEqual, Limit: burstAltitude, On: engine.ActionStop}},
},
{
Name: "descent",
Step: 60,
Model: engine.Sum(engine.ParachuteDescent(descentRate), wind),
Constraints: descentTerm,
},
},
}
}
// floatProfile constructs the ascent → float profile.
func floatProfile(field weather.WindField, events *engine.EventSink, ascentRate, floatAlt float64, stopTime time.Time) engine.Profile {
wind := engine.WindTransport(field, events)
return engine.Profile{
Direction: engine.Forward,
Stages: []*engine.Propagator{
{
Name: "ascent",
Step: 60,
Model: engine.Sum(engine.ConstantRate(ascentRate), wind),
Constraints: []engine.Constraint{engine.Altitude{Op: engine.OpGreaterEqual, Limit: floatAlt, On: engine.ActionStop}},
},
{
Name: "float",
Step: 60,
Model: wind,
Constraints: []engine.Constraint{engine.Time{Op: engine.OpGreater, Limit: float64(stopTime.Unix()), On: engine.ActionStop}},
},
},
}
}
func buildPredictionItem(stageName string, r engine.Result) api.PredictionResponsePredictionItem {
var stageEnum api.PredictionResponsePredictionItemStage
switch stageName {
case "descent":
stageEnum = api.PredictionResponsePredictionItemStageDescent
case "float":
stageEnum = api.PredictionResponsePredictionItemStageFloat
default:
stageEnum = api.PredictionResponsePredictionItemStageAscent
}
traj := make([]api.PredictionResponsePredictionItemTrajectoryItem, 0, len(r.Points))
for _, pt := range r.Points {
ptLng := pt.Lng
if ptLng > 180 {
ptLng -= 360
}
traj = append(traj, api.PredictionResponsePredictionItemTrajectoryItem{
Datetime: time.Unix(int64(pt.Time), 0).UTC(),
Latitude: pt.Lat,
Longitude: ptLng,
Altitude: pt.Altitude,
})
}
return api.PredictionResponsePredictionItem{Stage: stageEnum, Trajectory: traj}
}
// ReadinessCheck reports whether a dataset is currently loaded.
@ -250,3 +231,21 @@ func newError(status int, description string) *api.ErrorStatusCode {
},
}
}
// optString returns the option's value if set, else fallback.
func optString[T ~string](o interface {
Get() (T, bool)
}, fallback string) string {
if v, ok := o.Get(); ok {
return string(v)
}
return fallback
}
// optFloat returns the option's float64 value if set, else fallback.
func optFloat(o api.OptFloat64, fallback float64) float64 {
if v, ok := o.Get(); ok {
return v
}
return fallback
}

View file

@ -15,6 +15,7 @@ import (
"go.uber.org/zap"
"predictor-refactored/internal/api/admin"
"predictor-refactored/internal/api/async"
"predictor-refactored/internal/api/middleware"
"predictor-refactored/internal/api/tawhiri"
v2 "predictor-refactored/internal/api/v2"
@ -33,12 +34,13 @@ type Server struct {
// Deps are the runtime dependencies the API layer needs.
type Deps struct {
Manager *datasets.Manager
Elevation *elevation.Dataset
Metrics metrics.Sink
Manager *datasets.Manager
Elevation *elevation.Dataset
Metrics metrics.Sink
MetricsHandler http.Handler // optional; mounted at MetricsPath when non-nil
MetricsPath string
Log *zap.Logger
AsyncManager *async.Manager // optional; mounts /api/v1/predictions when non-nil
Log *zap.Logger
}
// New wires the HTTP server. The returned Server is not yet started.
@ -68,6 +70,12 @@ func New(port int, d Deps) (*Server, error) {
adminH := admin.New(d.Manager, d.Log)
adminH.Register(mux)
// Async prediction endpoints (optional).
if d.AsyncManager != nil {
asyncH := async.NewHandler(d.AsyncManager)
asyncH.Register(mux)
}
// Metrics endpoint.
if d.MetricsHandler != nil && d.MetricsPath != "" {
mux.Handle(d.MetricsPath, d.MetricsHandler)

View file

@ -8,6 +8,7 @@ import (
"go.uber.org/zap"
"predictor-refactored/internal/api/httpjson"
"predictor-refactored/internal/datasets"
"predictor-refactored/internal/elevation"
"predictor-refactored/internal/engine"
@ -46,85 +47,109 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
writeError(w, http.StatusBadRequest, "invalid request body: "+err.Error())
return
}
if err := validateRequest(req); err != nil {
writeError(w, http.StatusBadRequest, err.Error())
return
}
field := h.mgr.Active()
if field == nil {
writeError(w, http.StatusServiceUnavailable, "no dataset loaded, service is starting up")
resp, err := Run(h.mgr, h.elev, req)
if err != nil {
if perr, ok := err.(*PredictionError); ok {
writeError(w, perr.Status, perr.Description)
return
}
writeError(w, http.StatusInternalServerError, err.Error())
return
}
h.metrics.Prediction("v2", resp.CompletedAt.Sub(resp.StartedAt), nil)
h.log.Info("v2 prediction complete",
zap.Int("stages", len(resp.Stages)),
zap.Duration("elapsed", resp.CompletedAt.Sub(resp.StartedAt)))
writeJSON(w, http.StatusOK, resp)
}
// PredictionError carries an HTTP status alongside the message so async
// callers can map the failure back to a useful HTTP response.
type PredictionError struct {
Status int
Description string
}
func (e *PredictionError) Error() string { return e.Description }
// Run executes a PredictionRequest against the manager's active wind field.
// Shared between the sync /api/v2/prediction handler and the async
// /api/v1/predictions worker.
func Run(mgr *datasets.Manager, elev *elevation.Dataset, req PredictionRequest) (*PredictionResponse, error) {
field := mgr.Active()
if field == nil {
return nil, &PredictionError{Status: http.StatusServiceUnavailable, Description: "no dataset loaded, service is starting up"}
}
// Normalize longitude to [0, 360) for internal use.
lng := req.Launch.Longitude
if lng < 0 {
lng += 360
}
warnings := &engine.Warnings{}
var terrain engine.TerrainProvider
if h.elev != nil {
terrain = h.elev
events := engine.NewEventSink()
deps := engine.BuildDeps{Wind: field, Events: events}
if elev != nil {
deps.Terrain = elev
}
prof, err := buildProfile(req, field, terrain, warnings)
prof, err := buildProfile(req, deps)
if err != nil {
writeError(w, http.StatusBadRequest, err.Error())
return
return nil, &PredictionError{Status: http.StatusBadRequest, Description: err.Error()}
}
started := time.Now().UTC()
results := prof.Run(float64(req.Launch.Time.Unix()), engine.State{
Lat: req.Launch.Latitude,
Lng: lng,
Altitude: req.Launch.Altitude,
})
Lat: req.Launch.Latitude, Lng: lng, Altitude: req.Launch.Altitude,
}, events)
completed := time.Now().UTC()
h.metrics.Prediction("v2", completed.Sub(started), nil)
resp := PredictionResponse{
resp := &PredictionResponse{
Stages: make([]StageResult, 0, len(results)),
Events: events.Snapshot(),
StartedAt: started,
CompletedAt: completed,
Dataset: DatasetInfo{
Source: field.Source(),
Epoch: field.Epoch(),
},
Dataset: DatasetInfo{Source: field.Source(), Epoch: field.Epoch()},
}
for _, r := range results {
stage := StageResult{
Name: r.Propagator,
Outcome: outcomeString(r.Outcome),
}
if r.Constraint != nil {
stage.Constraint = r.Constraint.Name()
}
stage.Trajectory = make([]TrajectoryPoint, len(r.Points))
for i, pt := range r.Points {
ptLng := pt.Lng
if ptLng > 180 {
ptLng -= 360
}
stage.Trajectory[i] = TrajectoryPoint{
Time: time.Unix(int64(pt.Time), 0).UTC(),
Latitude: pt.Lat,
Longitude: ptLng,
Altitude: pt.Altitude,
}
}
resp.Stages = append(resp.Stages, stage)
}
if warns := warnings.ToMap(); len(warns) > 0 {
resp.Warnings = warns
resp.Stages = append(resp.Stages, toStageResult(r))
}
return resp, nil
}
h.log.Info("v2 prediction complete",
zap.Int("stages", len(results)),
zap.Duration("elapsed", completed.Sub(started)))
writeJSON(w, http.StatusOK, resp)
func toStageResult(r engine.Result) StageResult {
stage := StageResult{
Name: r.Propagator,
Outcome: r.Outcome.String(),
Events: r.Events,
}
if r.Constraint != nil {
stage.Constraint = r.ConstraintName
stage.Termination = &TerminationInfo{
ViolationTime: time.Unix(int64(r.ViolationTime), 0).UTC(),
ViolationState: r.ViolationState,
RefinedTime: time.Unix(int64(r.RefinedTime), 0).UTC(),
RefinedState: r.RefinedState,
}
}
stage.Trajectory = make([]TrajectoryPoint, len(r.Points))
for i, pt := range r.Points {
ptLng := pt.Lng
if ptLng > 180 {
ptLng -= 360
}
stage.Trajectory[i] = TrajectoryPoint{
Time: time.Unix(int64(pt.Time), 0).UTC(),
Latitude: pt.Lat,
Longitude: ptLng,
Altitude: pt.Altitude,
}
}
return stage
}
func validateRequest(req PredictionRequest) error {
@ -148,26 +173,5 @@ func validateRequest(req PredictionRequest) error {
return nil
}
func outcomeString(o engine.Outcome) string {
switch o {
case engine.OutcomeStopped:
return "stopped"
case engine.OutcomeFallback:
return "fallback"
default:
return "continued"
}
}
func writeError(w http.ResponseWriter, status int, description string) {
writeJSON(w, status, ErrorResponse{Error: ErrorBody{
Type: http.StatusText(status),
Description: description,
}})
}
func writeJSON(w http.ResponseWriter, status int, body any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(body)
}
var writeJSON = httpjson.Write
var writeError = httpjson.Error

View file

@ -4,14 +4,11 @@ import (
"fmt"
"predictor-refactored/internal/engine"
"predictor-refactored/internal/weather"
)
// buildProfile translates a PredictionRequest into an engine.Profile.
//
// elev may be nil when no terrain dataset is loaded; TerrainContact constraints
// will return an error in that case.
func buildProfile(req PredictionRequest, field weather.WindField, elev engine.TerrainProvider, warnings *engine.Warnings) (engine.Profile, error) {
// buildProfile translates a PredictionRequest into an engine.Profile via
// the engine registry.
func buildProfile(req PredictionRequest, deps engine.BuildDeps) (engine.Profile, error) {
if len(req.Profile) == 0 {
return engine.Profile{}, fmt.Errorf("profile must contain at least one stage")
}
@ -37,24 +34,27 @@ func buildProfile(req PredictionRequest, field weather.WindField, elev engine.Te
props := make([]*engine.Propagator, len(req.Profile))
for i, stage := range req.Profile {
model, err := buildModel(stage.Model, field, warnings)
if err != nil {
return engine.Profile{}, fmt.Errorf("stage %q: %w", stage.Name, err)
if stage.Name == "" {
return engine.Profile{}, fmt.Errorf("stage %d: name is required", i)
}
constraints, err := buildConstraints(stage.Constraints, elev)
built, err := engine.BuildModel(stage.Model, deps)
if err != nil {
return engine.Profile{}, fmt.Errorf("stage %q model: %w", stage.Name, err)
}
constraints, err := buildConstraintList(stage.Constraints, deps)
if err != nil {
return engine.Profile{}, fmt.Errorf("stage %q: %w", stage.Name, err)
}
props[i] = &engine.Propagator{
Name: stage.Name,
Step: step,
Model: model,
Model: built.Model,
BuildModel: built.Build,
Constraints: constraints,
Tolerance: tol,
}
}
// Wire fallbacks once all stages exist.
for i, stage := range req.Profile {
if stage.FallbackIndex == nil {
continue
@ -66,80 +66,22 @@ func buildProfile(req PredictionRequest, field weather.WindField, elev engine.Te
props[i].Fallback = props[idx]
}
return engine.Profile{Stages: props, Direction: dir}, nil
}
func buildModel(spec ModelSpec, field weather.WindField, warnings *engine.Warnings) (engine.Model, error) {
var base engine.Model
switch spec.Type {
case "constant_rate":
base = engine.ConstantRate(spec.Rate)
case "parachute_descent":
if spec.SeaLevelRate <= 0 {
return nil, fmt.Errorf("parachute_descent requires positive sea_level_rate")
}
base = engine.ParachuteDescent(spec.SeaLevelRate)
case "piecewise":
segs := make([]engine.RateSegment, len(spec.Segments))
for i, s := range spec.Segments {
segs[i] = engine.RateSegment{Until: s.Until, Rate: s.Rate}
}
base = engine.Piecewise(segs)
case "wind":
if field == nil {
return nil, fmt.Errorf("wind model requires a loaded dataset")
}
return engine.WindTransport(field, warnings), nil
default:
return nil, fmt.Errorf("unknown model type %q", spec.Type)
globals, err := buildConstraintList(req.Globals, deps)
if err != nil {
return engine.Profile{}, fmt.Errorf("globals: %w", err)
}
if spec.IncludeWind {
if field == nil {
return nil, fmt.Errorf("include_wind requires a loaded dataset")
}
return engine.Sum(base, engine.WindTransport(field, warnings)), nil
}
return base, nil
return engine.Profile{Stages: props, Direction: dir, Globals: globals}, nil
}
func buildConstraints(specs []ConstraintSpec, elev engine.TerrainProvider) ([]engine.Constraint, error) {
func buildConstraintList(specs []engine.ConstraintSpec, deps engine.BuildDeps) ([]engine.Constraint, error) {
out := make([]engine.Constraint, 0, len(specs))
for _, spec := range specs {
action, err := parseAction(spec.Action)
for i, spec := range specs {
c, err := engine.BuildConstraint(spec, deps)
if err != nil {
return nil, err
}
var c engine.Constraint
switch spec.Type {
case "max_altitude":
c = engine.MaxAltitude{Limit: spec.Limit, On: action}
case "min_altitude":
c = engine.MinAltitude{Limit: spec.Limit, On: action}
case "max_time":
c = engine.MaxTime{Limit: spec.Limit, On: action}
case "terrain_contact":
if elev == nil {
return nil, fmt.Errorf("terrain_contact requires an elevation dataset")
}
c = engine.TerrainContact{Provider: elev, On: action}
default:
return nil, fmt.Errorf("unknown constraint type %q", spec.Type)
return nil, fmt.Errorf("constraint[%d]: %w", i, err)
}
out = append(out, c)
}
return out, nil
}
func parseAction(s string) (engine.Action, error) {
switch s {
case "", "stop":
return engine.ActionStop, nil
case "fallback":
return engine.ActionFallback, nil
case "clip":
return engine.ActionClip, nil
default:
return 0, fmt.Errorf("unknown constraint action %q", s)
}
}

View file

@ -1,18 +1,25 @@
// Package v2 implements the new primary prediction endpoint, which accepts a
// user-defined profile (chain of propagators with optional constraints) and
// returns the resulting trajectory.
// Package v2 implements the profile-driven prediction endpoint.
//
// Endpoint: POST /api/v2/prediction
//
// The request schema is built on the engine package's ConstraintSpec and
// ModelSpec, so adding new model or constraint types in the engine requires
// no changes here — they become available automatically via the registry.
package v2
import "time"
import (
"time"
// PredictionRequest is the request body for POST /api/v2/prediction.
"predictor-refactored/internal/engine"
)
// PredictionRequest is the body of POST /api/v2/prediction.
type PredictionRequest struct {
Launch Launch `json:"launch"`
Profile []Stage `json:"profile"`
Options Options `json:"options,omitempty"`
Direction string `json:"direction,omitempty"` // "forward" (default) or "reverse"
Launch Launch `json:"launch"`
Profile []StageSpec `json:"profile"`
Globals []engine.ConstraintSpec `json:"globals,omitempty"`
Options Options `json:"options,omitempty"`
Direction string `json:"direction,omitempty"` // "forward" (default) or "reverse"
}
// Launch is the initial state of the balloon (or, for reverse predictions,
@ -24,68 +31,47 @@ type Launch struct {
Altitude float64 `json:"altitude"`
}
// Stage is one entry in the propagator chain.
type Stage struct {
Name string `json:"name"`
Model ModelSpec `json:"model"`
Constraints []ConstraintSpec `json:"constraints,omitempty"`
// FallbackIndex, when set, points to another stage in the same profile to
// transfer to on ActionFallback constraints. Optional.
// StageSpec is one entry in the propagator chain.
type StageSpec struct {
Name string `json:"name"`
Model engine.ModelSpec `json:"model"`
Constraints []engine.ConstraintSpec `json:"constraints,omitempty"`
// FallbackIndex, when set, points to another stage in the same profile
// to transfer to on ActionFallback constraints.
FallbackIndex *int `json:"fallback_index,omitempty"`
}
// ModelSpec describes the per-stage propagation model.
type ModelSpec struct {
// Type selects the model: "constant_rate", "parachute_descent", "piecewise", "wind".
Type string `json:"type"`
// Rate (m/s) for constant_rate.
Rate float64 `json:"rate,omitempty"`
// SeaLevelRate (m/s, positive) for parachute_descent.
SeaLevelRate float64 `json:"sea_level_rate,omitempty"`
// Segments for piecewise.
Segments []PiecewiseSegment `json:"segments,omitempty"`
// IncludeWind sums a WindTransport model into the resulting derivative,
// allowing the same stage to model both vertical motion and wind drift.
IncludeWind bool `json:"include_wind"`
}
// PiecewiseSegment is one entry in a piecewise rate schedule.
type PiecewiseSegment struct {
Until float64 `json:"until"` // UNIX seconds; segment applies for t < Until
Rate float64 `json:"rate"` // m/s
}
// ConstraintSpec describes one constraint attached to a stage.
type ConstraintSpec struct {
// Type: "max_altitude", "min_altitude", "max_time", "terrain_contact".
Type string `json:"type"`
// Limit is interpreted per Type: metres for altitude, UNIX seconds for time.
Limit float64 `json:"limit,omitempty"`
// Action: "stop" (default), "fallback", "clip".
Action string `json:"action,omitempty"`
}
// Options tweaks the integrator behaviour.
// Options tweaks integrator behaviour.
type Options struct {
StepSeconds float64 `json:"step_seconds,omitempty"`
Tolerance float64 `json:"tolerance,omitempty"`
}
// PredictionResponse is the response body for POST /api/v2/prediction.
// PredictionResponse is the body of a successful POST response.
type PredictionResponse struct {
Stages []StageResult `json:"stages"`
Warnings map[string]any `json:"warnings,omitempty"`
Dataset DatasetInfo `json:"dataset"`
StartedAt time.Time `json:"started_at"`
CompletedAt time.Time `json:"completed_at"`
Stages []StageResult `json:"stages"`
Events []engine.EventSummary `json:"events,omitempty"`
Dataset DatasetInfo `json:"dataset"`
StartedAt time.Time `json:"started_at"`
CompletedAt time.Time `json:"completed_at"`
}
// StageResult is the outcome of one stage.
type StageResult struct {
Name string `json:"name"`
Outcome string `json:"outcome"` // "stopped" | "fallback" | "continued"
Constraint string `json:"constraint,omitempty"`
Trajectory []TrajectoryPoint `json:"trajectory"`
Name string `json:"name"`
Outcome string `json:"outcome"`
Constraint string `json:"constraint,omitempty"`
Termination *TerminationInfo `json:"termination,omitempty"`
Events []engine.EventSummary `json:"events,omitempty"`
Trajectory []TrajectoryPoint `json:"trajectory"`
}
// TerminationInfo exposes the violation+refinement detail from the engine.
type TerminationInfo struct {
ViolationTime time.Time `json:"violation_time"`
ViolationState engine.State `json:"violation_state"`
RefinedTime time.Time `json:"refined_time"`
RefinedState engine.State `json:"refined_state"`
}
// TrajectoryPoint is one sampled point of the trajectory.
@ -96,13 +82,13 @@ type TrajectoryPoint struct {
Altitude float64 `json:"altitude"`
}
// DatasetInfo identifies the dataset the prediction was computed against.
// DatasetInfo identifies the wind dataset used.
type DatasetInfo struct {
Source string `json:"source"`
Epoch time.Time `json:"epoch"`
}
// ErrorResponse is the JSON error shape used by both v2 and admin endpoints.
// ErrorResponse is the JSON error shape.
type ErrorResponse struct {
Error ErrorBody `json:"error"`
}

View file

@ -28,6 +28,12 @@ type Config struct {
// HTTPConfig configures the HTTP server.
type HTTPConfig struct {
Port int `yaml:"port"`
// AsyncWorkers caps concurrent prediction executions for the async endpoint.
AsyncWorkers int `yaml:"async_workers"`
// AsyncQueueSize bounds the async pending queue.
AsyncQueueSize int `yaml:"async_queue_size"`
// AsyncResultTTL is how long completed async results are retained.
AsyncResultTTL time.Duration `yaml:"async_result_ttl"`
}
// DataConfig configures dataset and elevation storage.
@ -60,11 +66,16 @@ type LogConfig struct {
// Defaults returns a Config with reasonable default values.
func Defaults() Config {
return Config{
HTTP: HTTPConfig{Port: 8080},
HTTP: HTTPConfig{
Port: 8080,
AsyncWorkers: 4,
AsyncQueueSize: 64,
AsyncResultTTL: time.Hour,
},
Data: DataConfig{
Dir: "/tmp/predictor-data",
ElevationPath: "/srv/ruaumoko-dataset",
Source: "noaa-gfs-0p50",
Source: "gfs-0p50-3h",
},
Download: DownloadConfig{
Parallel: 8,

View file

@ -0,0 +1,151 @@
// Package gefs implements datasets.Source for NOAA GEFS (Global Ensemble
// Forecast System) forecasts.
//
// Each ensemble member is treated as its own dataset, selected via
// DatasetID.Subset.Members. The download skeleton (HTTP, idx parsing,
// parallel blit) lives in internal/datasets/grib; this package only
// supplies GEFS-specific URL templating and member resolution.
package gefs
import (
"context"
"fmt"
"net/http"
"time"
"go.uber.org/zap"
"predictor-refactored/internal/datasets"
"predictor-refactored/internal/datasets/grib"
"predictor-refactored/internal/weather"
wgfs "predictor-refactored/internal/weather/gfs"
)
// Source is the GEFS implementation of datasets.Source.
type Source struct {
Variant *wgfs.Variant
Parallel int
Client *http.Client
Log *zap.Logger
}
// NewSource returns a default Source over variant. If variant is nil,
// GEFS 0.5° 3-hour is used.
func NewSource(variant *wgfs.Variant, log *zap.Logger) *Source {
if variant == nil {
variant = wgfs.GEFS0p50_3h
}
return &Source{
Variant: variant,
Parallel: 8,
Client: &http.Client{Timeout: 2 * time.Minute},
Log: log,
}
}
func (s *Source) ID() string { return s.Variant.ID }
func (s *Source) downloader() *grib.Downloader {
return &grib.Downloader{
Variant: s.Variant,
URLs: s.url,
Parallel: s.Parallel,
Client: s.Client,
Log: s.Log,
}
}
// url generates the GEFS URL for (date, runHour, member, step, levelSet).
func (s *Source) url(date string, runHour, member, step int, ls wgfs.LevelSet) string {
if ls == wgfs.LevelSetB {
return wgfs.GefsGribURLB(date, runHour, member, step, s.Variant.ResToken)
}
return wgfs.GefsGribURL(date, runHour, member, step, s.Variant.ResToken)
}
// LatestEpoch HEAD-checks the control member's final forecast hour.
func (s *Source) LatestEpoch(ctx context.Context) (time.Time, error) {
now := time.Now().UTC()
hour := now.Hour() - (now.Hour() % 6)
current := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC)
client := s.Client
if client == nil {
client = &http.Client{Timeout: 2 * time.Minute}
}
log := s.Log
if log == nil {
log = zap.NewNop()
}
for range 8 {
date := current.Format("20060102")
url := wgfs.GefsGribURL(date, current.Hour(), 0, s.Variant.MaxHour, s.Variant.ResToken) + ".idx"
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
if err == nil {
resp, err := client.Do(req)
if err == nil {
resp.Body.Close()
if resp.StatusCode == http.StatusOK {
log.Info("latest GEFS run discovered",
zap.Time("run", current),
zap.String("verified_url", url))
return current, nil
}
}
}
current = current.Add(-6 * time.Hour)
}
return time.Time{}, fmt.Errorf("no recent GEFS run found")
}
// Coverage returns the extent of id.
func (s *Source) Coverage(id datasets.DatasetID) datasets.Coverage {
v := s.Variant
cov := datasets.Coverage{
Region: datasets.Region{MinLat: -90, MaxLat: 90, MinLng: 0, MaxLng: 360},
StartTime: id.Epoch,
EndTime: id.Epoch.Add(time.Duration(v.MaxHour) * time.Hour),
}
if r := id.Subset.Region; r != nil {
cov.Region = *r
}
if h := id.Subset.HourRange; h != nil {
cov.StartTime = id.Epoch.Add(time.Duration(h.MinHour) * time.Hour)
cov.EndTime = id.Epoch.Add(time.Duration(h.MaxHour) * time.Hour)
}
return cov
}
// Open loads a stored GEFS dataset as a WindField.
func (s *Source) Open(_ context.Context, id datasets.DatasetID, store datasets.Storage) (weather.WindField, error) {
if !store.Exists(id) {
return nil, fmt.Errorf("dataset %s not found", id.Filename())
}
file, err := wgfs.Open(store.Path(id), s.Variant, id.Epoch.UTC())
if err != nil {
return nil, err
}
return wgfs.NewWind(file), nil
}
// memberOf extracts the single member index encoded by id.Subset.Members.
func memberOf(id datasets.DatasetID) (int, error) {
if len(id.Subset.Members) != 1 {
return 0, fmt.Errorf("gefs dataset id must specify exactly one member (got %v)", id.Subset.Members)
}
m := id.Subset.Members[0]
if m < 0 || m >= wgfs.GEFSMembers {
return 0, fmt.Errorf("gefs member %d out of range [0, %d)", m, wgfs.GEFSMembers)
}
return m, nil
}
// Download fetches one ensemble member's dataset.
func (s *Source) Download(ctx context.Context, id datasets.DatasetID, store datasets.Storage, prog datasets.ProgressSink, throttle datasets.Throttle) error {
member, err := memberOf(id)
if err != nil {
return err
}
return s.downloader().Run(ctx, id, member, store, prog, throttle)
}

View file

@ -1,85 +1,96 @@
// Package gfs implements datasets.Source for NOAA GFS 0.5-degree forecasts.
// Package gfs implements datasets.Source for NOAA GFS forecasts.
//
// The package serves multiple GFS variants (0.5° 3-hour, 0.25° 3-hour,
// 0.25° 1-hour); the variant is selected at construction time. The
// download skeleton (HTTP, idx parsing, parallel blit) lives in
// internal/datasets/grib; this package only supplies URL templating and
// the Source-interface plumbing.
package gfs
import (
"context"
"errors"
"fmt"
"io"
"math"
"net/http"
"os"
"sync"
"time"
"github.com/nilsmagnus/grib/griblib"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"predictor-refactored/internal/datasets"
"predictor-refactored/internal/datasets/grib"
"predictor-refactored/internal/weather"
wgfs "predictor-refactored/internal/weather/gfs"
)
// Source is the GFS implementation of datasets.Source.
type Source struct {
Parallel int // max concurrent step downloads
Client *http.Client // optional; defaults to a 2-minute-timeout client
Variant *wgfs.Variant
Parallel int
Client *http.Client
Log *zap.Logger
}
// NewSource returns a default Source.
func NewSource(log *zap.Logger) *Source {
// NewSource returns a default Source over variant. If variant is nil,
// GFS 0.5° 3-hour is used (the historical Tawhiri default).
func NewSource(variant *wgfs.Variant, log *zap.Logger) *Source {
if variant == nil {
variant = wgfs.GFS0p50_3h
}
return &Source{
Variant: variant,
Parallel: 8,
Client: &http.Client{Timeout: 2 * time.Minute},
Log: log,
}
}
// ID returns the source identifier.
func (s *Source) ID() string { return "noaa-gfs-0p50" }
// ID returns the variant's ID.
func (s *Source) ID() string { return s.Variant.ID }
func (s *Source) log() *zap.Logger {
if s.Log == nil {
return zap.NewNop()
func (s *Source) downloader() *grib.Downloader {
return &grib.Downloader{
Variant: s.Variant,
URLs: s.url,
Parallel: s.Parallel,
Client: s.Client,
Log: s.Log,
}
return s.Log
}
func (s *Source) client() *http.Client {
if s.Client == nil {
return &http.Client{Timeout: 2 * time.Minute}
// url generates the GFS URL for one (date, runHour, _, step, levelSet).
// member is unused for GFS.
func (s *Source) url(date string, runHour, _, step int, ls wgfs.LevelSet) string {
if ls == wgfs.LevelSetB {
return s.Variant.GribURLB(date, runHour, step)
}
return s.Client
return s.Variant.GribURL(date, runHour, step)
}
func (s *Source) parallel() int {
if s.Parallel <= 0 {
return 8
}
return s.Parallel
}
// LatestEpoch returns the most recent run NOAA has finished publishing,
// determined by HEAD-ing the .idx for the final forecast hour. Walks back
// up to 8 runs (48 hours) before giving up.
// LatestEpoch returns the most recent run NOAA has finished publishing.
func (s *Source) LatestEpoch(ctx context.Context) (time.Time, error) {
now := time.Now().UTC()
hour := now.Hour() - (now.Hour() % 6)
current := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC)
client := s.Client
if client == nil {
client = &http.Client{Timeout: 2 * time.Minute}
}
log := s.Log
if log == nil {
log = zap.NewNop()
}
for range 8 {
date := current.Format("20060102")
url := wgfs.GribURL(date, current.Hour(), wgfs.MaxHour) + ".idx"
url := s.Variant.GribURL(date, current.Hour(), s.Variant.MaxHour) + ".idx"
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
if err == nil {
resp, err := s.client().Do(req)
resp, err := client.Do(req)
if err == nil {
resp.Body.Close()
if resp.StatusCode == http.StatusOK {
s.log().Info("latest GFS run discovered",
log.Info("latest run discovered",
zap.String("variant", s.Variant.ID),
zap.Time("run", current),
zap.String("verified_url", url))
return current, nil
@ -88,343 +99,40 @@ func (s *Source) LatestEpoch(ctx context.Context) (time.Time, error) {
}
current = current.Add(-6 * time.Hour)
}
return time.Time{}, fmt.Errorf("no recent GFS run found (checked 8 runs)")
return time.Time{}, fmt.Errorf("no recent %s run found (checked 8 runs)", s.Variant.ID)
}
// Coverage returns the geographic and temporal extent of id.
func (s *Source) Coverage(id datasets.DatasetID) datasets.Coverage {
v := s.Variant
cov := datasets.Coverage{
Region: datasets.Region{MinLat: -90, MaxLat: 90, MinLng: 0, MaxLng: 360},
StartTime: id.Epoch,
EndTime: id.Epoch.Add(time.Duration(v.MaxHour) * time.Hour),
}
if r := id.Subset.Region; r != nil {
cov.Region = *r
}
if h := id.Subset.HourRange; h != nil {
cov.StartTime = id.Epoch.Add(time.Duration(h.MinHour) * time.Hour)
cov.EndTime = id.Epoch.Add(time.Duration(h.MaxHour) * time.Hour)
}
return cov
}
// Open loads a stored dataset as a WindField.
func (s *Source) Open(_ context.Context, epoch time.Time, store datasets.Storage) (weather.WindField, error) {
if !store.Exists(epoch) {
return nil, fmt.Errorf("epoch %s not found", epoch.Format(time.RFC3339))
func (s *Source) Open(_ context.Context, id datasets.DatasetID, store datasets.Storage) (weather.WindField, error) {
if !store.Exists(id) {
return nil, fmt.Errorf("dataset %s not found", id.Filename())
}
file, err := wgfs.Open(store.Path(epoch), epoch.UTC())
file, err := wgfs.Open(store.Path(id), s.Variant, id.Epoch.UTC())
if err != nil {
return nil, err
}
return wgfs.NewWind(file), nil
}
// neededVariables is the GRIB variable set we extract.
var neededVariables = map[string]bool{"HGT": true, "UGRD": true, "VGRD": true}
// Download fetches the full dataset for epoch in parallel, resuming any
// previously-completed work units. Honours ctx cancellation and prog
// (which may be nil).
func (s *Source) Download(ctx context.Context, epoch time.Time, store datasets.Storage, prog datasets.ProgressSink, throttle datasets.Throttle) error {
if prog == nil {
prog = noopSink{}
}
handle, err := store.BeginWrite(epoch)
if err != nil {
return fmt.Errorf("begin write: %w", err)
}
manifest := handle.Manifest()
// Open or create the temp file. If a previous attempt left a partial
// file of the right size, reuse it (resume); otherwise Create.
file, err := openOrCreateCube(handle.Path())
if err != nil {
_ = handle.Abort()
return err
}
date := epoch.UTC().Format("20060102")
runHour := epoch.UTC().Hour()
steps := wgfs.Hours()
totalUnits := len(steps) * 2
prog.SetTotal(totalUnits)
// Pre-count already-done units so progress is accurate on resume.
for _, u := range manifest.Units() {
_ = u
prog.StepComplete()
}
start := time.Now()
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(s.parallel())
// fileMu serialises concurrent BlitGribData calls because the underlying
// mmap is shared and SetVal isn't atomic.
var fileMu sync.Mutex
for _, step := range steps {
hourIdx := wgfs.HourIndex(step)
if hourIdx < 0 {
continue
}
for _, ls := range []wgfs.LevelSet{wgfs.LevelSetA, wgfs.LevelSetB} {
unit := unitKey(step, ls)
if manifest.Has(unit) {
continue
}
g.Go(func() error {
var url string
switch ls {
case wgfs.LevelSetA:
url = wgfs.GribURL(date, runHour, step)
case wgfs.LevelSetB:
url = wgfs.GribURLB(date, runHour, step)
}
if err := s.downloadAndBlit(ctx, file, &fileMu, url, hourIdx, ls, prog, throttle); err != nil {
return fmt.Errorf("step %d %s: %w", step, levelSetLabel(ls), err)
}
if err := manifest.Mark(unit); err != nil {
return fmt.Errorf("mark unit: %w", err)
}
prog.StepComplete()
return nil
})
}
}
if err := g.Wait(); err != nil {
_ = file.Close()
// Don't Abort on context cancellation — preserve progress for resume.
if errors.Is(err, context.Canceled) {
return err
}
// Other errors: abort if no progress was made; otherwise leave for resume.
if len(manifest.Units()) == 0 {
_ = handle.Abort()
}
return err
}
if err := file.Flush(); err != nil {
_ = file.Close()
return fmt.Errorf("flush: %w", err)
}
if err := file.Close(); err != nil {
return fmt.Errorf("close: %w", err)
}
if err := handle.Commit(); err != nil {
return fmt.Errorf("commit: %w", err)
}
s.log().Info("download complete",
zap.Time("epoch", epoch),
zap.Duration("elapsed", time.Since(start)))
return nil
// Download fetches the dataset for id. GFS ignores Subset.Members.
func (s *Source) Download(ctx context.Context, id datasets.DatasetID, store datasets.Storage, prog datasets.ProgressSink, throttle datasets.Throttle) error {
return s.downloader().Run(ctx, id, 0, store, prog, throttle)
}
// openOrCreateCube returns a writable cube file at path, creating it if the
// file does not exist or has the wrong size.
func openOrCreateCube(path string) (*wgfs.File, error) {
info, err := os.Stat(path)
if err == nil && info.Size() == wgfs.DatasetSize {
return wgfs.OpenWritable(path)
}
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("stat cube: %w", err)
}
// Wrong-size or missing — truncate-create.
return wgfs.Create(path)
}
// downloadAndBlit fetches and decodes one (URL, level-set) chunk and writes
// it into the dataset.
func (s *Source) downloadAndBlit(
ctx context.Context,
file *wgfs.File,
fileMu *sync.Mutex,
baseURL string,
hourIdx int,
ls wgfs.LevelSet,
prog datasets.ProgressSink,
throttle datasets.Throttle,
) error {
idxBody, err := s.httpGet(ctx, baseURL+".idx", throttle, prog)
if err != nil {
return fmt.Errorf("idx: %w", err)
}
entries := ParseIdx(idxBody)
filtered := FilterIdx(entries, neededVariables)
var relevant []IdxEntry
for _, e := range filtered {
set, ok := wgfs.PressureLevelSet(e.LevelMB)
if ok && set == ls {
relevant = append(relevant, e)
}
}
if len(relevant) == 0 {
return nil
}
ranges := EntriesToRanges(relevant)
tmp, err := os.CreateTemp("", "gfs-msg-*.tmp")
if err != nil {
return fmt.Errorf("temp: %w", err)
}
tmpPath := tmp.Name()
defer os.Remove(tmpPath)
for _, r := range ranges {
body, err := s.httpGetRange(ctx, baseURL, r.Start, r.End, throttle, prog)
if err != nil {
tmp.Close()
return fmt.Errorf("range %d-%d: %w", r.Start, r.End, err)
}
if _, err := tmp.Write(body); err != nil {
tmp.Close()
return fmt.Errorf("write tmp: %w", err)
}
}
if err := tmp.Close(); err != nil {
return err
}
f, err := os.Open(tmpPath)
if err != nil {
return err
}
messages, err := griblib.ReadMessages(f)
f.Close()
if err != nil {
return fmt.Errorf("read grib: %w", err)
}
for _, msg := range messages {
if msg.Section4.ProductDefinitionTemplateNumber != 0 {
continue
}
p := msg.Section4.ProductDefinitionTemplate
varIdx := wgfs.VariableIndex(int(p.ParameterCategory), int(p.ParameterNumber))
if varIdx < 0 {
continue
}
if p.FirstSurface.Type != 100 { // isobaric only
continue
}
pressureMB := int(math.Round(float64(p.FirstSurface.Value) / 100.0))
levelIdx := wgfs.PressureIndex(pressureMB)
if levelIdx < 0 {
continue
}
data := msg.Data()
fileMu.Lock()
err := file.BlitGribData(hourIdx, levelIdx, varIdx, data)
fileMu.Unlock()
if err != nil {
return fmt.Errorf("blit: %w", err)
}
}
return nil
}
// httpGet downloads a URL body with 3 retries and optional throttling.
func (s *Source) httpGet(ctx context.Context, url string, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) {
var lastErr error
for attempt := range 3 {
if attempt > 0 {
select {
case <-time.After(time.Duration(attempt*2) * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := s.client().Do(req)
if err != nil {
lastErr = err
continue
}
body, err := readThrottled(ctx, resp.Body, throttle, prog)
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url)
continue
}
if err != nil {
lastErr = err
continue
}
return body, nil
}
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
}
// httpGetRange downloads an inclusive byte range with 3 retries and throttling.
func (s *Source) httpGetRange(ctx context.Context, url string, start, end int64, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) {
var lastErr error
for attempt := range 3 {
if attempt > 0 {
select {
case <-time.After(time.Duration(attempt*2) * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
resp, err := s.client().Do(req)
if err != nil {
lastErr = err
continue
}
body, err := readThrottled(ctx, resp.Body, throttle, prog)
resp.Body.Close()
if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("HTTP %d for range %d-%d of %s", resp.StatusCode, start, end, url)
continue
}
if err != nil {
lastErr = err
continue
}
return body, nil
}
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
}
// readThrottled reads r into memory, consulting throttle (if non-nil) before
// each chunk and reporting bytes to prog.
func readThrottled(ctx context.Context, r io.Reader, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) {
buf := make([]byte, 0, 64*1024)
chunk := make([]byte, 32*1024)
for {
if throttle != nil {
if err := throttle.Wait(ctx, len(chunk)); err != nil {
return nil, err
}
}
n, err := r.Read(chunk)
if n > 0 {
buf = append(buf, chunk[:n]...)
prog.Bytes(int64(n))
}
if errors.Is(err, io.EOF) {
return buf, nil
}
if err != nil {
return nil, err
}
}
}
func unitKey(step int, ls wgfs.LevelSet) string {
return fmt.Sprintf("step%03d-%s", step, levelSetLabel(ls))
}
func levelSetLabel(ls wgfs.LevelSet) string {
if ls == wgfs.LevelSetB {
return "B"
}
return "A"
}
// noopSink discards progress events.
type noopSink struct{}
func (noopSink) SetTotal(int) {}
func (noopSink) StepComplete() {}
func (noopSink) Bytes(int64) {}

View file

@ -0,0 +1,369 @@
package grib
import (
"context"
"errors"
"fmt"
"io"
"math"
"net/http"
"os"
"sync"
"time"
"github.com/nilsmagnus/grib/griblib"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"predictor-refactored/internal/datasets"
wgfs "predictor-refactored/internal/weather/gfs"
)
// URLFunc returns the GRIB URL for one (date, runHour, member, step, levelSet).
// Sources that don't have members (GFS) ignore the member argument.
type URLFunc func(date string, runHour, member, step int, ls wgfs.LevelSet) string
// Downloader is the generic GRIB-cube downloader.
//
// A Source plugs in its variant, URL templating, and member-resolution
// logic; the Downloader runs the parallel idx fetch, byte-range download,
// GRIB decode, and blit loop with manifest-based resume.
type Downloader struct {
Variant *wgfs.Variant
URLs URLFunc
Parallel int
Client *http.Client
Log *zap.Logger
}
func (d *Downloader) log() *zap.Logger {
if d.Log == nil {
return zap.NewNop()
}
return d.Log
}
func (d *Downloader) client() *http.Client {
if d.Client == nil {
return &http.Client{Timeout: 2 * time.Minute}
}
return d.Client
}
func (d *Downloader) parallel() int {
if d.Parallel <= 0 {
return 8
}
return d.Parallel
}
// neededVariables is the GRIB variable set every source extracts.
var neededVariables = map[string]bool{"HGT": true, "UGRD": true, "VGRD": true}
// Run downloads the dataset for id, member into store. The caller may
// pass member=0 for non-ensemble sources.
func (d *Downloader) Run(ctx context.Context, id datasets.DatasetID, member int, store datasets.Storage, prog datasets.ProgressSink, throttle datasets.Throttle) error {
if prog == nil {
prog = noopSink{}
}
handle, err := store.BeginWrite(id)
if err != nil {
return fmt.Errorf("begin write: %w", err)
}
manifest := handle.Manifest()
file, err := openOrCreateCube(handle.Path(), d.Variant)
if err != nil {
_ = handle.Abort()
return err
}
epoch := id.Epoch.UTC()
date := epoch.Format("20060102")
runHour := epoch.Hour()
steps := d.Variant.Hours()
if hr := id.Subset.HourRange; hr != nil {
filtered := steps[:0]
for _, step := range steps {
if step >= hr.MinHour && step <= hr.MaxHour {
filtered = append(filtered, step)
}
}
steps = filtered
}
prog.SetTotal(len(steps) * 2)
for range manifest.Units() {
prog.StepComplete()
}
start := time.Now()
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(d.parallel())
var fileMu sync.Mutex
for _, step := range steps {
hourIdx := d.Variant.HourIndex(step)
if hourIdx < 0 {
continue
}
for _, ls := range []wgfs.LevelSet{wgfs.LevelSetA, wgfs.LevelSetB} {
unit := unitKey(step, ls)
if manifest.Has(unit) {
continue
}
g.Go(func() error {
url := d.URLs(date, runHour, member, step, ls)
if err := d.downloadAndBlit(ctx, file, &fileMu, url, hourIdx, ls, prog, throttle); err != nil {
return fmt.Errorf("step %d %s: %w", step, levelSetLabel(ls), err)
}
if err := manifest.Mark(unit); err != nil {
return fmt.Errorf("mark unit: %w", err)
}
prog.StepComplete()
return nil
})
}
}
if err := g.Wait(); err != nil {
_ = file.Close()
if errors.Is(err, context.Canceled) {
return err
}
if len(manifest.Units()) == 0 {
_ = handle.Abort()
}
return err
}
if err := file.Flush(); err != nil {
_ = file.Close()
return fmt.Errorf("flush: %w", err)
}
if err := file.Close(); err != nil {
return fmt.Errorf("close: %w", err)
}
if err := handle.Commit(); err != nil {
return fmt.Errorf("commit: %w", err)
}
d.log().Info("download complete",
zap.String("variant", d.Variant.ID),
zap.Time("epoch", epoch),
zap.Duration("elapsed", time.Since(start)))
return nil
}
// openOrCreateCube opens an existing cube at path if it matches variant's
// expected size, else truncate-creates a new one.
func openOrCreateCube(path string, variant *wgfs.Variant) (*wgfs.File, error) {
info, err := os.Stat(path)
if err == nil && info.Size() == variant.DatasetSize() {
return wgfs.OpenWritable(path, variant)
}
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("stat cube: %w", err)
}
return wgfs.Create(path, variant)
}
// downloadAndBlit fetches and decodes one (URL, level-set) chunk.
func (d *Downloader) downloadAndBlit(
ctx context.Context,
file *wgfs.File,
fileMu *sync.Mutex,
baseURL string,
hourIdx int,
ls wgfs.LevelSet,
prog datasets.ProgressSink,
throttle datasets.Throttle,
) error {
idxBody, err := d.httpGet(ctx, baseURL+".idx", throttle, prog)
if err != nil {
return fmt.Errorf("idx: %w", err)
}
entries := ParseIdx(idxBody)
filtered := FilterIdx(entries, neededVariables)
var relevant []IdxEntry
for _, e := range filtered {
set, ok := d.Variant.PressureLevelSet(e.LevelMB)
if ok && set == ls {
relevant = append(relevant, e)
}
}
if len(relevant) == 0 {
return nil
}
ranges := EntriesToRanges(relevant)
tmp, err := os.CreateTemp("", "grib-msg-*.tmp")
if err != nil {
return fmt.Errorf("temp: %w", err)
}
tmpPath := tmp.Name()
defer os.Remove(tmpPath)
for _, r := range ranges {
body, err := d.httpGetRange(ctx, baseURL, r.Start, r.End, throttle, prog)
if err != nil {
tmp.Close()
return fmt.Errorf("range: %w", err)
}
if _, err := tmp.Write(body); err != nil {
tmp.Close()
return fmt.Errorf("write tmp: %w", err)
}
}
if err := tmp.Close(); err != nil {
return err
}
f, err := os.Open(tmpPath)
if err != nil {
return err
}
messages, err := griblib.ReadMessages(f)
f.Close()
if err != nil {
return fmt.Errorf("read grib: %w", err)
}
for _, msg := range messages {
if msg.Section4.ProductDefinitionTemplateNumber != 0 {
continue
}
p := msg.Section4.ProductDefinitionTemplate
varIdx := d.Variant.VariableIndex(int(p.ParameterCategory), int(p.ParameterNumber))
if varIdx < 0 {
continue
}
if p.FirstSurface.Type != 100 {
continue
}
pressureMB := int(math.Round(float64(p.FirstSurface.Value) / 100.0))
levelIdx := d.Variant.PressureIndex(pressureMB)
if levelIdx < 0 {
continue
}
data := msg.Data()
fileMu.Lock()
err := file.BlitGribData(hourIdx, levelIdx, varIdx, data)
fileMu.Unlock()
if err != nil {
return fmt.Errorf("blit: %w", err)
}
}
return nil
}
func (d *Downloader) httpGet(ctx context.Context, url string, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) {
var lastErr error
for attempt := range 3 {
if attempt > 0 {
select {
case <-time.After(time.Duration(attempt*2) * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := d.client().Do(req)
if err != nil {
lastErr = err
continue
}
body, err := readThrottled(ctx, resp.Body, throttle, prog)
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url)
continue
}
if err != nil {
lastErr = err
continue
}
return body, nil
}
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
}
func (d *Downloader) httpGetRange(ctx context.Context, url string, start, end int64, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) {
var lastErr error
for attempt := range 3 {
if attempt > 0 {
select {
case <-time.After(time.Duration(attempt*2) * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
resp, err := d.client().Do(req)
if err != nil {
lastErr = err
continue
}
body, err := readThrottled(ctx, resp.Body, throttle, prog)
resp.Body.Close()
if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("HTTP %d for range", resp.StatusCode)
continue
}
if err != nil {
lastErr = err
continue
}
return body, nil
}
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
}
func readThrottled(ctx context.Context, r io.Reader, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) {
buf := make([]byte, 0, 64*1024)
chunk := make([]byte, 32*1024)
for {
if throttle != nil {
if err := throttle.Wait(ctx, len(chunk)); err != nil {
return nil, err
}
}
n, err := r.Read(chunk)
if n > 0 {
buf = append(buf, chunk[:n]...)
prog.Bytes(int64(n))
}
if errors.Is(err, io.EOF) {
return buf, nil
}
if err != nil {
return nil, err
}
}
}
func unitKey(step int, ls wgfs.LevelSet) string {
return fmt.Sprintf("step%03d-%s", step, levelSetLabel(ls))
}
func levelSetLabel(ls wgfs.LevelSet) string {
if ls == wgfs.LevelSetB {
return "B"
}
return "A"
}
type noopSink struct{}
func (noopSink) SetTotal(int) {}
func (noopSink) StepComplete() {}
func (noopSink) Bytes(int64) {}

View file

@ -1,4 +1,8 @@
package gfs
// Package grib contains the GRIB-cube download skeleton shared by every
// NOAA source (GFS, GEFS, future families). It exposes the .idx parser,
// HTTP helpers, and a parallel download loop; source-specific URL
// templating is injected by the caller.
package grib
import (
"fmt"

View file

@ -1,4 +1,4 @@
package gfs
package grib
import "testing"

View file

@ -27,23 +27,22 @@ const (
// JobInfo is the externally-visible snapshot of a download job.
type JobInfo struct {
ID string
Source string
Epoch time.Time
Status JobStatus
StartedAt time.Time
EndedAt *time.Time
Err string
Total int
Done int
Bytes int64
ID string
Source string
Dataset DatasetID
Status JobStatus
StartedAt time.Time
EndedAt *time.Time
Err string
Total int
Done int
Bytes int64
}
// jobEntry is the Manager's mutable record for one job.
type jobEntry struct {
id string
source string
epoch time.Time
dataset DatasetID
startedAt time.Time
cancel context.CancelFunc
@ -60,7 +59,7 @@ type jobEntry struct {
func (e *jobEntry) snapshot() JobInfo {
e.mu.Lock()
info := JobInfo{
ID: e.id, Source: e.source, Epoch: e.epoch,
ID: e.id, Source: e.source, Dataset: e.dataset,
StartedAt: e.startedAt, Status: e.status, Err: e.errStr,
}
if !e.endedAt.IsZero() {
@ -74,14 +73,20 @@ func (e *jobEntry) snapshot() JobInfo {
return info
}
// jobProgress is the ProgressSink wired into a jobEntry.
type jobProgress struct{ e *jobEntry }
func (p jobProgress) SetTotal(n int) { p.e.total.Store(int64(n)) }
func (p jobProgress) StepComplete() { p.e.done.Add(1) }
func (p jobProgress) Bytes(n int64) { p.e.bytes.Add(n) }
// Manager coordinates dataset downloads and exposes the active WindField.
// loadedDataset bundles a loaded WindField with its identity and coverage.
type loadedDataset struct {
ID DatasetID
Field weather.WindField
Coverage Coverage
}
// Manager coordinates dataset downloads and exposes the active WindFields.
type Manager struct {
src Source
store Storage
@ -89,18 +94,15 @@ type Manager struct {
log *zap.Logger
activeMu sync.RWMutex
active weather.WindField
active []loadedDataset
jobsMu sync.RWMutex
jobs map[string]*jobEntry
// inFlight maps an epoch's RFC3339 representation to its jobID, enforcing
// single-flight per epoch.
inFlight sync.Map
inFlight sync.Map // key: dataset filename, value: jobID
}
// New returns a Manager wiring source, store, and an optional throttle.
// A nil log uses zap.NewNop().
// New wires a Manager.
func New(src Source, store Storage, throttle Throttle, log *zap.Logger) *Manager {
if log == nil {
log = zap.NewNop()
@ -119,18 +121,65 @@ func New(src Source, store Storage, throttle Throttle, log *zap.Logger) *Manager
// Source returns the underlying source ID.
func (m *Manager) Source() string { return m.src.ID() }
// Active returns the currently-loaded WindField, or nil.
// Active returns the currently-loaded global WindField (the dataset with
// IsGlobal subset, most recently loaded). Returns nil if no global
// dataset is loaded; in cluster setups with only regional subsets, callers
// should use SelectFor.
func (m *Manager) Active() weather.WindField {
m.activeMu.RLock()
defer m.activeMu.RUnlock()
return m.active
for _, d := range m.active {
if d.ID.Subset.IsGlobal() {
return d.Field
}
}
if len(m.active) > 0 {
return m.active[0].Field
}
return nil
}
// Ready reports whether a dataset is currently loaded.
// Ready reports whether at least one dataset is loaded.
func (m *Manager) Ready() bool { return m.Active() != nil }
// ListEpochs returns all stored dataset epochs, newest first.
func (m *Manager) ListEpochs() ([]time.Time, error) { return m.store.List() }
// SelectFor returns a loaded WindField whose coverage contains (t, lat, lng).
// Returns nil when no loaded dataset covers the query.
func (m *Manager) SelectFor(t time.Time, lat, lng float64) weather.WindField {
m.activeMu.RLock()
defer m.activeMu.RUnlock()
for _, d := range m.active {
if d.Coverage.Covers(t, lat, lng) {
return d.Field
}
}
// Fallback: any global dataset is permissive about region.
for _, d := range m.active {
if d.ID.Subset.IsGlobal() {
return d.Field
}
}
return nil
}
// LoadedDatasets returns snapshots of every currently-loaded dataset.
func (m *Manager) LoadedDatasets() []LoadedDatasetInfo {
m.activeMu.RLock()
defer m.activeMu.RUnlock()
out := make([]LoadedDatasetInfo, 0, len(m.active))
for _, d := range m.active {
out = append(out, LoadedDatasetInfo{ID: d.ID, Coverage: d.Coverage})
}
return out
}
// LoadedDatasetInfo is a serializable snapshot of one active dataset.
type LoadedDatasetInfo struct {
ID DatasetID
Coverage Coverage
}
// ListEpochs returns all stored datasets, newest first.
func (m *Manager) ListEpochs() ([]DatasetID, error) { return m.store.List() }
// ListJobs returns snapshots of every job recorded since startup.
func (m *Manager) ListJobs() []JobInfo {
@ -143,7 +192,7 @@ func (m *Manager) ListJobs() []JobInfo {
return out
}
// GetJob returns the snapshot for a job, or false if id is unknown.
// GetJob returns the snapshot for a job.
func (m *Manager) GetJob(id string) (JobInfo, bool) {
m.jobsMu.RLock()
e, ok := m.jobs[id]
@ -154,8 +203,7 @@ func (m *Manager) GetJob(id string) (JobInfo, bool) {
return e.snapshot(), true
}
// CancelJob cancels a running job. Returns false if id is unknown or the
// job is already terminal.
// CancelJob cancels a running job.
func (m *Manager) CancelJob(id string) bool {
m.jobsMu.RLock()
e, ok := m.jobs[id]
@ -173,28 +221,31 @@ func (m *Manager) CancelJob(id string) bool {
return true
}
// RemoveEpoch deletes a stored dataset. If epoch is currently active, the
// active field is cleared.
func (m *Manager) RemoveEpoch(epoch time.Time) error {
epoch = epoch.UTC()
if active := m.Active(); active != nil && active.Epoch().Equal(epoch) {
m.activeMu.Lock()
m.active = nil
m.activeMu.Unlock()
// Remove deletes a stored dataset. If the dataset is currently loaded,
// it is unloaded first.
func (m *Manager) Remove(id DatasetID) error {
m.activeMu.Lock()
out := m.active[:0]
var removed *loadedDataset
for i := range m.active {
d := m.active[i]
if d.ID.Equals(id) {
removed = &d
continue
}
out = append(out, d)
}
return m.store.Remove(epoch)
m.active = out
m.activeMu.Unlock()
if removed != nil {
closeField(removed.Field, m.log)
}
return m.store.Remove(id)
}
// Download starts (or resumes) a download job for epoch in the background.
// Returns the JobID. If a job for the same epoch is already running, its
// existing JobID is returned.
//
// If the dataset is already present on disk, a synthetic completed JobInfo
// is recorded and its JobID returned.
func (m *Manager) Download(epoch time.Time) string {
epoch = epoch.UTC()
key := epoch.Format(time.RFC3339)
// Download starts (or resumes) a download job for id in the background.
func (m *Manager) Download(id DatasetID) string {
key := id.Filename()
if existing, ok := m.inFlight.Load(key); ok {
return existing.(string)
}
@ -209,7 +260,7 @@ func (m *Manager) Download(epoch time.Time) string {
e := &jobEntry{
id: jobID,
source: m.src.ID(),
epoch: epoch,
dataset: id,
startedAt: now,
status: JobPending,
cancel: cancel,
@ -218,8 +269,7 @@ func (m *Manager) Download(epoch time.Time) string {
m.jobs[jobID] = e
m.jobsMu.Unlock()
if m.store.Exists(epoch) {
// Skip the download but still record the job for traceability.
if m.store.Exists(id) {
go m.completeShortCircuit(ctx, e)
return jobID
}
@ -227,46 +277,54 @@ func (m *Manager) Download(epoch time.Time) string {
return jobID
}
// LoadEpoch swaps the active WindField to epoch's stored dataset.
func (m *Manager) LoadEpoch(ctx context.Context, epoch time.Time) error {
epoch = epoch.UTC()
if !m.store.Exists(epoch) {
return fmt.Errorf("epoch %s not present on disk", epoch.Format(time.RFC3339))
// Load swaps in id's stored dataset, making it available to predictions.
func (m *Manager) Load(ctx context.Context, id DatasetID) error {
if !m.store.Exists(id) {
return fmt.Errorf("dataset %s not present on disk", id.Filename())
}
field, err := m.src.Open(ctx, epoch, m.store)
field, err := m.src.Open(ctx, id, m.store)
if err != nil {
return fmt.Errorf("open epoch: %w", err)
return fmt.Errorf("open dataset: %w", err)
}
m.swapActive(field)
cov := m.src.Coverage(id)
m.activeMu.Lock()
// Replace any previously-loaded dataset with the same ID.
for i := range m.active {
if m.active[i].ID.Equals(id) {
closeField(m.active[i].Field, m.log)
m.active[i] = loadedDataset{ID: id, Field: field, Coverage: cov}
m.activeMu.Unlock()
return nil
}
}
m.active = append(m.active, loadedDataset{ID: id, Field: field, Coverage: cov})
m.activeMu.Unlock()
m.log.Info("loaded dataset",
zap.Time("epoch", epoch),
zap.String("filename", id.Filename()),
zap.String("source", m.src.ID()))
return nil
}
// Refresh ensures the most recent upstream dataset is downloaded and active.
//
// If the freshest stored dataset is newer than retentionTTL old, no upstream
// check is performed. Otherwise the source's LatestEpoch is consulted; if it
// is newer than the active dataset, a download is started and on completion
// the new dataset becomes active.
// Refresh ensures the freshest global dataset is downloaded and active.
//
// Returns the JobID started, or empty string when nothing was scheduled.
func (m *Manager) Refresh(ctx context.Context, freshnessTTL time.Duration) (string, error) {
if active := m.Active(); active != nil && time.Since(active.Epoch()) < freshnessTTL {
if a := m.activeGlobal(); a != nil && time.Since(a.ID.Epoch) < freshnessTTL {
return "", nil
}
// Try loading the freshest existing dataset before going to the network.
if epochs, err := m.store.List(); err == nil {
for _, e := range epochs {
if time.Since(e) > freshnessTTL {
if datasets, err := m.store.List(); err == nil {
for _, id := range datasets {
if !id.Subset.IsGlobal() {
continue
}
if active := m.Active(); active != nil && active.Epoch().Equal(e) {
if time.Since(id.Epoch) > freshnessTTL {
continue
}
if a := m.activeGlobal(); a != nil && a.ID.Equals(id) {
return "", nil
}
if err := m.LoadEpoch(ctx, e); err == nil {
if err := m.Load(ctx, id); err == nil {
return "", nil
}
}
@ -276,37 +334,50 @@ func (m *Manager) Refresh(ctx context.Context, freshnessTTL time.Duration) (stri
if err != nil {
return "", fmt.Errorf("latest epoch: %w", err)
}
if active := m.Active(); active != nil && !latest.After(active.Epoch()) {
id := DatasetID{Epoch: latest}
if a := m.activeGlobal(); a != nil && !latest.After(a.ID.Epoch) {
return "", nil
}
jobID := m.Download(latest)
// Spawn a watcher that loads the dataset on successful completion.
go func() {
for {
info, ok := m.GetJob(jobID)
if !ok {
return
}
switch info.Status {
case JobComplete:
if err := m.LoadEpoch(context.Background(), latest); err != nil {
m.log.Error("load after download", zap.Error(err))
}
return
case JobFailed, JobCancelled:
return
}
time.Sleep(2 * time.Second)
}
}()
jobID := m.Download(id)
go m.loadAfterCompletion(jobID, id)
return jobID, nil
}
// runDownload executes one Source.Download invocation and records its outcome.
// activeGlobal returns the currently-loaded global dataset, if any.
func (m *Manager) activeGlobal() *loadedDataset {
m.activeMu.RLock()
defer m.activeMu.RUnlock()
for i := range m.active {
if m.active[i].ID.Subset.IsGlobal() {
d := m.active[i]
return &d
}
}
return nil
}
func (m *Manager) loadAfterCompletion(jobID string, id DatasetID) {
for {
info, ok := m.GetJob(jobID)
if !ok {
return
}
switch info.Status {
case JobComplete:
if err := m.Load(context.Background(), id); err != nil {
m.log.Error("load after download", zap.Error(err))
}
return
case JobFailed, JobCancelled:
return
}
time.Sleep(2 * time.Second)
}
}
func (m *Manager) runDownload(ctx context.Context, e *jobEntry) {
defer m.inFlight.Delete(e.epoch.Format(time.RFC3339))
defer m.inFlight.Delete(e.dataset.Filename())
e.mu.Lock()
e.status = JobRunning
@ -314,9 +385,9 @@ func (m *Manager) runDownload(ctx context.Context, e *jobEntry) {
m.log.Info("download started",
zap.String("job", e.id),
zap.Time("epoch", e.epoch))
zap.String("dataset", e.dataset.Filename()))
err := m.src.Download(ctx, e.epoch, m.store, jobProgress{e: e}, m.throttle)
err := m.src.Download(ctx, e.dataset, m.store, jobProgress{e: e}, m.throttle)
now := time.Now().UTC()
e.mu.Lock()
@ -339,10 +410,9 @@ func (m *Manager) runDownload(ctx context.Context, e *jobEntry) {
zap.NamedError("err", err))
}
// completeShortCircuit records a job as complete without performing any work.
func (m *Manager) completeShortCircuit(ctx context.Context, e *jobEntry) {
_ = ctx
defer m.inFlight.Delete(e.epoch.Format(time.RFC3339))
defer m.inFlight.Delete(e.dataset.Filename())
now := time.Now().UTC()
e.mu.Lock()
e.status = JobComplete
@ -350,20 +420,6 @@ func (m *Manager) completeShortCircuit(ctx context.Context, e *jobEntry) {
e.mu.Unlock()
}
// swapActive replaces the active field and closes the previous one if it
// implements io.Closer.
func (m *Manager) swapActive(f weather.WindField) {
m.activeMu.Lock()
old := m.active
m.active = f
m.activeMu.Unlock()
if c, ok := old.(interface{ Close() error }); ok && c != nil {
if err := c.Close(); err != nil {
m.log.Warn("close old dataset", zap.Error(err))
}
}
}
// Close releases all resources, cancelling any in-flight jobs.
func (m *Manager) Close() error {
m.jobsMu.Lock()
@ -373,11 +429,18 @@ func (m *Manager) Close() error {
m.jobsMu.Unlock()
m.activeMu.Lock()
active := m.active
for _, d := range m.active {
closeField(d.Field, m.log)
}
m.active = nil
m.activeMu.Unlock()
if c, ok := active.(interface{ Close() error }); ok && c != nil {
return c.Close()
}
return nil
}
func closeField(f weather.WindField, log *zap.Logger) {
if c, ok := f.(interface{ Close() error }); ok && c != nil {
if err := c.Close(); err != nil && log != nil {
log.Warn("close dataset", zap.Error(err))
}
}
}

View file

@ -14,15 +14,16 @@ import (
//
// Layout under Root:
//
// <epoch>.bin — committed dataset (binary cube)
// <epoch>.bin.downloading — in-progress dataset
// <epoch>.bin.manifest.json — manifest of completed work units
// <filename>.bin — committed dataset
// <filename>.bin.downloading — in-progress dataset
// <filename>.bin.manifest.json — completed work units
//
// The .bin suffix exists to differentiate from sidecars in directory listings;
// epoch is formatted as "20060102T150405Z" (UTC).
// where <filename> is DatasetID.Filename() — typically
// "20060102T150405Z" for the global subset or
// "20060102T150405Z_r-10.10.-30.30_h0.72" for a subset.
type LocalStore struct {
Root string
Source string // source ID, recorded for safety but currently advisory
Source string
Extension string // default ".bin"
}
@ -37,8 +38,6 @@ func NewLocalStore(root, sourceID string) (*LocalStore, error) {
// SourceID returns the source ID this store is configured for.
func (s *LocalStore) SourceID() string { return s.Source }
const epochFormat = "20060102T150405Z"
func (s *LocalStore) ext() string {
if s.Extension == "" {
return ".bin"
@ -46,32 +45,32 @@ func (s *LocalStore) ext() string {
return s.Extension
}
// Path returns the canonical path for an epoch's committed dataset file.
func (s *LocalStore) Path(epoch time.Time) string {
return filepath.Join(s.Root, epoch.UTC().Format(epochFormat)+s.ext())
// Path returns the canonical path for id's committed dataset.
func (s *LocalStore) Path(id DatasetID) string {
return filepath.Join(s.Root, id.Filename()+s.ext())
}
func (s *LocalStore) tempPath(epoch time.Time) string {
return s.Path(epoch) + ".downloading"
func (s *LocalStore) tempPath(id DatasetID) string {
return s.Path(id) + ".downloading"
}
func (s *LocalStore) manifestPath(epoch time.Time) string {
return s.Path(epoch) + ".manifest.json"
func (s *LocalStore) manifestPath(id DatasetID) string {
return s.Path(id) + ".manifest.json"
}
// Exists reports whether a committed dataset for epoch is present.
func (s *LocalStore) Exists(epoch time.Time) bool {
info, err := os.Stat(s.Path(epoch))
// Exists reports whether a committed dataset for id is present.
func (s *LocalStore) Exists(id DatasetID) bool {
info, err := os.Stat(s.Path(id))
return err == nil && !info.IsDir()
}
// List returns all committed epochs, newest first.
func (s *LocalStore) List() ([]time.Time, error) {
// List returns all committed dataset IDs, newest first.
func (s *LocalStore) List() ([]DatasetID, error) {
entries, err := os.ReadDir(s.Root)
if err != nil {
return nil, fmt.Errorf("read store: %w", err)
}
var out []time.Time
var out []DatasetID
ext := s.ext()
for _, e := range entries {
if e.IsDir() {
@ -82,24 +81,47 @@ func (s *LocalStore) List() ([]time.Time, error) {
continue
}
stem := strings.TrimSuffix(name, ext)
// skip in-progress files (their stem already has .bin.downloading...)
// Skip in-progress files (their stem ends in .downloading or .manifest)
if strings.Contains(stem, ".") {
continue
}
t, err := time.Parse(epochFormat, stem)
if err != nil {
id, ok := parseFilename(stem)
if !ok {
continue
}
out = append(out, t.UTC())
out = append(out, id)
}
sort.Slice(out, func(i, j int) bool { return out[i].After(out[j]) })
sort.Slice(out, func(i, j int) bool {
if !out[i].Epoch.Equal(out[j].Epoch) {
return out[i].Epoch.After(out[j].Epoch)
}
return out[i].Subset.Key() < out[j].Subset.Key()
})
return out, nil
}
// Remove deletes the committed dataset and any sidecar files for epoch.
func (s *LocalStore) Remove(epoch time.Time) error {
// parseFilename inverts DatasetID.Filename(). The subset portion is not
// fully reversible (Key encoding is one-way for floats), so List returns
// IDs whose Subset is zero — the storage layer treats names as opaque
// identifiers. Callers wanting structured subset metadata should keep an
// out-of-band record.
func parseFilename(stem string) (DatasetID, bool) {
parts := strings.SplitN(stem, "_", 2)
epoch, err := time.Parse("20060102T150405Z", parts[0])
if err != nil {
return DatasetID{}, false
}
id := DatasetID{Epoch: epoch.UTC()}
// Subset key is opaque on disk; we don't reconstruct its parameters
// here. Admin callers track subset specs separately when they need
// the structured form.
return id, true
}
// Remove deletes the committed dataset and any sidecar files for id.
func (s *LocalStore) Remove(id DatasetID) error {
var errs []error
for _, p := range []string{s.Path(epoch), s.tempPath(epoch), s.manifestPath(epoch)} {
for _, p := range []string{s.Path(id), s.tempPath(id), s.manifestPath(id)} {
if err := os.Remove(p); err != nil && !errors.Is(err, os.ErrNotExist) {
errs = append(errs, err)
}
@ -110,55 +132,46 @@ func (s *LocalStore) Remove(epoch time.Time) error {
return nil
}
// BeginWrite opens or resumes a TempHandle for epoch.
//
// If a partial download is already present, its file and manifest are reused
// so the new download picks up where the previous one stopped.
func (s *LocalStore) BeginWrite(epoch time.Time) (TempHandle, error) {
man, err := LoadManifest(s.manifestPath(epoch))
// BeginWrite opens or resumes a TempHandle for id.
func (s *LocalStore) BeginWrite(id DatasetID) (TempHandle, error) {
man, err := LoadManifest(s.manifestPath(id))
if err != nil {
return nil, err
}
return &localHandle{
store: s,
epoch: epoch,
manifest: man,
}, nil
return &localHandle{store: s, id: id, manifest: man}, nil
}
type localHandle struct {
store *LocalStore
epoch time.Time
id DatasetID
manifest *Manifest
closed bool
}
func (h *localHandle) Path() string { return h.store.tempPath(h.epoch) }
func (h *localHandle) Path() string { return h.store.tempPath(h.id) }
func (h *localHandle) Manifest() *Manifest { return h.manifest }
// Commit promotes the temp file to its final path and removes the manifest.
func (h *localHandle) Commit() error {
if h.closed {
return nil
}
h.closed = true
if err := os.Rename(h.store.tempPath(h.epoch), h.store.Path(h.epoch)); err != nil {
if err := os.Rename(h.store.tempPath(h.id), h.store.Path(h.id)); err != nil {
return fmt.Errorf("commit rename: %w", err)
}
if err := os.Remove(h.store.manifestPath(h.epoch)); err != nil && !errors.Is(err, os.ErrNotExist) {
if err := os.Remove(h.store.manifestPath(h.id)); err != nil && !errors.Is(err, os.ErrNotExist) {
return fmt.Errorf("commit remove manifest: %w", err)
}
return nil
}
// Abort removes the in-progress file and manifest.
func (h *localHandle) Abort() error {
if h.closed {
return nil
}
h.closed = true
var firstErr error
for _, p := range []string{h.store.tempPath(h.epoch), h.store.manifestPath(h.epoch)} {
for _, p := range []string{h.store.tempPath(h.id), h.store.manifestPath(h.id)} {
if err := os.Remove(p); err != nil && !errors.Is(err, os.ErrNotExist) && firstErr == nil {
firstErr = err
}

View file

@ -2,7 +2,6 @@ package datasets
import (
"os"
"path/filepath"
"testing"
"time"
)
@ -14,8 +13,8 @@ func TestLocalStoreBeginWriteResume(t *testing.T) {
t.Fatalf("NewLocalStore: %v", err)
}
epoch := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
h, err := store.BeginWrite(epoch)
id := DatasetID{Epoch: time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)}
h, err := store.BeginWrite(id)
if err != nil {
t.Fatalf("BeginWrite: %v", err)
}
@ -27,7 +26,7 @@ func TestLocalStoreBeginWriteResume(t *testing.T) {
}
// Re-open should see the previous manifest entry.
h2, err := store.BeginWrite(epoch)
h2, err := store.BeginWrite(id)
if err != nil {
t.Fatalf("BeginWrite resume: %v", err)
}
@ -35,48 +34,59 @@ func TestLocalStoreBeginWriteResume(t *testing.T) {
t.Errorf("resumed manifest missing step000-A; units = %v", h2.Manifest().Units())
}
// Commit promotes the temp file and removes the manifest.
if err := h2.Commit(); err != nil {
t.Fatalf("Commit: %v", err)
}
if !store.Exists(epoch) {
if !store.Exists(id) {
t.Errorf("Exists after commit returned false")
}
if _, err := os.Stat(filepath.Join(dir, store.manifestPath(epoch))); !os.IsNotExist(err) {
if _, err := os.Stat(store.manifestPath(id)); !os.IsNotExist(err) {
t.Errorf("manifest should be removed, got err=%v", err)
}
// Listing finds the committed epoch.
epochs, err := store.List()
stored, err := store.List()
if err != nil {
t.Fatalf("List: %v", err)
}
if len(epochs) != 1 || !epochs[0].Equal(epoch) {
t.Errorf("List = %v, want [%v]", epochs, epoch)
if len(stored) != 1 || !stored[0].Epoch.Equal(id.Epoch) {
t.Errorf("List = %v, want one item with epoch %v", stored, id.Epoch)
}
// Remove cleans up.
if err := store.Remove(epoch); err != nil {
if err := store.Remove(id); err != nil {
t.Fatalf("Remove: %v", err)
}
if store.Exists(epoch) {
if store.Exists(id) {
t.Errorf("Exists after remove returned true")
}
}
func TestLocalStoreAbort(t *testing.T) {
func TestLocalStoreSubsetPath(t *testing.T) {
dir := t.TempDir()
store, _ := NewLocalStore(dir, "gfs-test")
epoch := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
h, _ := store.BeginWrite(epoch)
os.WriteFile(h.Path(), []byte("x"), 0o644)
h.Manifest().Mark("step000-A")
if err := h.Abort(); err != nil {
t.Fatalf("Abort: %v", err)
regional := DatasetID{
Epoch: epoch,
Subset: SubsetSpec{
Region: &Region{MinLat: -10, MaxLat: 10, MinLng: 0, MaxLng: 30},
HourRange: &HourRange{MinHour: 0, MaxHour: 72},
},
}
if _, err := os.Stat(h.Path()); !os.IsNotExist(err) {
t.Errorf("temp file should be removed after abort, got %v", err)
global := DatasetID{Epoch: epoch}
if store.Path(global) == store.Path(regional) {
t.Errorf("global and regional should have distinct paths")
}
}
func TestSubsetSpecCoverage(t *testing.T) {
r := Region{MinLat: -10, MaxLat: 10, MinLng: 350, MaxLng: 10} // wraps antimeridian
s := SubsetSpec{Region: &r}
if !s.IncludesLatLng(0, 0) {
t.Errorf("(0,0) should be inside antimeridian region")
}
if !s.IncludesLatLng(0, 359) {
t.Errorf("(0,359) should be inside antimeridian region")
}
if s.IncludesLatLng(0, 180) {
t.Errorf("(0,180) should be outside antimeridian region")
}
}

156
internal/datasets/subset.go Normal file
View file

@ -0,0 +1,156 @@
package datasets
import (
"fmt"
"slices"
"strings"
"time"
)
// SubsetSpec describes which portion of a dataset to download.
//
// A zero-value SubsetSpec means "the full dataset". The Region and
// HourRange fields independently restrict what is fetched and stored.
type SubsetSpec struct {
// Region restricts the geographic extent. nil means global.
Region *Region `json:"region,omitempty"`
// HourRange restricts the forecast horizon. nil means the source's
// full horizon (e.g. 0..192h for GFS 0.5°).
HourRange *HourRange `json:"hour_range,omitempty"`
// Members restricts ensemble members for sources that support them (GEFS).
// nil means all available members.
Members []int `json:"members,omitempty"`
}
// Region is an axis-aligned geographic bounding box.
//
// Longitudes are in [0, 360); a box crossing the antimeridian has
// MinLng > MaxLng.
type Region struct {
MinLat float64 `json:"min_lat"`
MaxLat float64 `json:"max_lat"`
MinLng float64 `json:"min_lng"`
MaxLng float64 `json:"max_lng"`
}
// HourRange is an inclusive forecast-hour range.
type HourRange struct {
MinHour int `json:"min_hour"`
MaxHour int `json:"max_hour"`
}
// IsGlobal reports whether the spec selects the entire dataset.
func (s SubsetSpec) IsGlobal() bool {
return s.Region == nil && s.HourRange == nil && len(s.Members) == 0
}
// IncludesLatLng reports whether (lat, lng) lies inside the spec's Region,
// or the spec has no Region.
func (s SubsetSpec) IncludesLatLng(lat, lng float64) bool {
if s.Region == nil {
return true
}
r := s.Region
if lat < r.MinLat || lat > r.MaxLat {
return false
}
if r.MinLng <= r.MaxLng {
return lng >= r.MinLng && lng <= r.MaxLng
}
// Wraps the antimeridian.
return lng >= r.MinLng || lng <= r.MaxLng
}
// IncludesHour reports whether the forecast hour is in range.
func (s SubsetSpec) IncludesHour(h int) bool {
if s.HourRange == nil {
return true
}
return h >= s.HourRange.MinHour && h <= s.HourRange.MaxHour
}
// IncludesMember reports whether the ensemble member is in range.
func (s SubsetSpec) IncludesMember(m int) bool {
if len(s.Members) == 0 {
return true
}
return slices.Contains(s.Members, m)
}
// Key returns a deterministic short identifier for the spec. The empty
// string represents the global subset.
func (s SubsetSpec) Key() string {
if s.IsGlobal() {
return ""
}
var b strings.Builder
if s.Region != nil {
fmt.Fprintf(&b, "r%g.%g.%g.%g", s.Region.MinLat, s.Region.MaxLat, s.Region.MinLng, s.Region.MaxLng)
}
if s.HourRange != nil {
if b.Len() > 0 {
b.WriteByte('_')
}
fmt.Fprintf(&b, "h%d.%d", s.HourRange.MinHour, s.HourRange.MaxHour)
}
if len(s.Members) > 0 {
if b.Len() > 0 {
b.WriteByte('_')
}
fmt.Fprintf(&b, "m")
for i, m := range s.Members {
if i > 0 {
b.WriteByte('.')
}
fmt.Fprintf(&b, "%d", m)
}
}
return b.String()
}
// DatasetID identifies one storable dataset.
type DatasetID struct {
Epoch time.Time
Subset SubsetSpec
}
// Equals reports whether two DatasetIDs refer to the same dataset.
// DatasetID is not comparable with == because SubsetSpec contains slices.
func (id DatasetID) Equals(other DatasetID) bool {
return id.Epoch.Equal(other.Epoch) && id.Subset.Key() == other.Subset.Key()
}
// Filename returns the canonical filename stem for the dataset. The
// extension is appended by the Storage implementation.
func (id DatasetID) Filename() string {
stem := id.Epoch.UTC().Format("20060102T150405Z")
if k := id.Subset.Key(); k != "" {
return stem + "_" + k
}
return stem
}
// Coverage is the spatial and temporal extent of a loaded dataset, used by
// the Manager to select which dataset can serve a given query.
type Coverage struct {
Region Region `json:"region"`
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
}
// Covers reports whether (t, lat, lng) lies inside the coverage.
func (c Coverage) Covers(t time.Time, lat, lng float64) bool {
if t.Before(c.StartTime) || t.After(c.EndTime) {
return false
}
r := c.Region
if lat < r.MinLat || lat > r.MaxLat {
return false
}
if r.MinLng <= r.MaxLng {
return lng >= r.MinLng && lng <= r.MaxLng
}
return lng >= r.MinLng || lng <= r.MaxLng
}

View file

@ -11,87 +11,75 @@ import (
//
// Implementations download dataset files in a transactional, resumable
// manner and load them as weather.WindField. A Source must be safe for
// concurrent use across multiple Manager calls.
// concurrent use across many Manager calls.
type Source interface {
// ID is a stable identifier, e.g. "noaa-gfs-0p50".
// ID is a stable identifier, e.g. "gfs-0p50-3h".
ID() string
// LatestEpoch returns the most recent dataset epoch this source can provide.
LatestEpoch(ctx context.Context) (time.Time, error)
// Download fetches the dataset for epoch into store. Sources must honour
// any partial progress recorded in store's manifest and skip
// already-completed work, so re-invocation after a crash resumes cleanly.
// Download fetches the dataset identified by id into store. Sources
// must honour any partial progress recorded in store's manifest and
// skip already-completed work so re-invocation after a crash resumes
// cleanly.
//
// prog receives progress events; nil is acceptable.
// throttle, if non-nil, is consulted before each network read for
// bandwidth limiting; nil means no throttling.
Download(ctx context.Context, epoch time.Time, store Storage, prog ProgressSink, throttle Throttle) error
Download(ctx context.Context, id DatasetID, store Storage, prog ProgressSink, throttle Throttle) error
// Open loads epoch's stored dataset and returns it as a WindField.
Open(ctx context.Context, epoch time.Time, store Storage) (weather.WindField, error)
// Open loads id's stored dataset and returns it as a WindField.
Open(ctx context.Context, id DatasetID, store Storage) (weather.WindField, error)
// Coverage returns the geographical/temporal extent of a downloaded
// dataset. Used by the Manager to decide which loaded dataset can
// serve a given prediction query.
Coverage(id DatasetID) Coverage
}
// Storage abstracts the on-disk location of dataset files and their manifests.
//
// Atomicity: only datasets promoted via TempHandle.Commit appear in Exists or
// List. Aborted or in-progress downloads are invisible to readers.
// Atomicity: only datasets promoted via TempHandle.Commit appear in Exists
// or List. Aborted or in-progress downloads are invisible to readers.
type Storage interface {
// SourceID identifies the data source these files belong to. Mixing
// sources in one Storage is not supported.
// SourceID identifies the data source these files belong to.
SourceID() string
// Path returns the canonical local path for epoch's dataset. The path
// is valid even when the dataset has not been written.
Path(epoch time.Time) string
// Path returns the canonical local path for id's dataset.
Path(id DatasetID) string
// Exists reports whether a committed dataset for epoch is present.
Exists(epoch time.Time) bool
// Exists reports whether a committed dataset for id is present.
Exists(id DatasetID) bool
// List returns all committed epochs available, newest first.
List() ([]time.Time, error)
// List returns all committed dataset IDs available, newest first.
List() ([]DatasetID, error)
// Remove deletes the dataset and any sidecar manifest for epoch.
Remove(epoch time.Time) error
// Remove deletes the dataset and any sidecar manifest for id.
Remove(id DatasetID) error
// BeginWrite opens (or resumes) a transactional handle for downloading
// epoch's dataset. Callers must Commit or Abort the returned handle.
BeginWrite(epoch time.Time) (TempHandle, error)
// id's dataset.
BeginWrite(id DatasetID) (TempHandle, error)
}
// TempHandle is the storage state for one in-progress download.
type TempHandle interface {
// Path returns the path of the in-progress file. Sources write directly here.
Path() string
// Manifest is the tracker of completed work units for resume support.
Manifest() *Manifest
// Commit promotes the temp file to its canonical location and removes
// the manifest. Subsequent calls are no-ops.
Commit() error
// Abort discards the temp file and manifest. Subsequent calls are no-ops.
Abort() error
}
// ProgressSink receives progress events during a download.
//
// All methods are safe to call concurrently.
type ProgressSink interface {
// SetTotal sets the total number of work units this download expects.
// May be called multiple times if discovery happens incrementally.
SetTotal(n int)
// StepComplete records one work unit as completed.
StepComplete()
// Bytes records n bytes received from the network.
Bytes(n int64)
}
// Throttle is an optional bandwidth limiter consulted by sources before
// each network read.
type Throttle interface {
// Wait blocks until n bytes can be consumed from the budget,
// or returns ctx's error if the context is cancelled while waiting.
Wait(ctx context.Context, n int) error
}

View file

@ -1,40 +1,42 @@
package engine
// MaxAltitude triggers when altitude rises above Limit (in metres).
// Used as the burst condition for ascent stages.
type MaxAltitude struct {
import (
"fmt"
"math"
)
// Altitude triggers when the balloon altitude satisfies Op against Limit.
//
// Examples:
//
// Altitude{Op: OpGreaterEqual, Limit: 30000} — burst at 30 km
// Altitude{Op: OpLessEqual, Limit: 0} — sea-level descent termination
type Altitude struct {
Op Operator
Limit float64
On Action
}
func (c MaxAltitude) Name() string { return "max_altitude" }
func (c MaxAltitude) Violated(_ float64, s State) bool { return s.Altitude >= c.Limit }
func (c MaxAltitude) Action() Action { return c.On }
func (c Altitude) Name() string {
return fmt.Sprintf("altitude %s %g", c.Op, c.Limit)
}
func (c Altitude) Violated(_ float64, s State) bool { return c.Op.Test(s.Altitude, c.Limit) }
func (c Altitude) Action() Action { return c.On }
// MinAltitude triggers when altitude falls at or below Limit (in metres).
// With Limit=0 this is the "sea level" terminator.
type MinAltitude struct {
// Time triggers when the integration time t (UNIX seconds) satisfies Op
// against Limit.
type Time struct {
Op Operator
Limit float64
On Action
}
func (c MinAltitude) Name() string { return "min_altitude" }
func (c MinAltitude) Violated(_ float64, s State) bool { return s.Altitude <= c.Limit }
func (c MinAltitude) Action() Action { return c.On }
func (c Time) Name() string { return fmt.Sprintf("time %s %g", c.Op, c.Limit) }
func (c Time) Violated(t float64, _ State) bool { return c.Op.Test(t, c.Limit) }
func (c Time) Action() Action { return c.On }
// MaxTime triggers when t exceeds Limit (UNIX seconds). Used as a stop
// condition for float profiles.
type MaxTime struct {
Limit float64
On Action
}
func (c MaxTime) Name() string { return "max_time" }
func (c MaxTime) Violated(t float64, _ State) bool { return t > c.Limit }
func (c MaxTime) Action() Action { return c.On }
// TerrainContact triggers when altitude has dropped at or below ground level.
// Equivalent to Tawhiri's elevation termination.
// TerrainContact triggers when the ground elevation exceeds the balloon's
// altitude — i.e. the balloon has hit the ground.
type TerrainContact struct {
Provider TerrainProvider
On Action
@ -45,3 +47,103 @@ func (c TerrainContact) Violated(_ float64, s State) bool {
return c.Provider.Elevation(s.Lat, s.Lng) > s.Altitude
}
func (c TerrainContact) Action() Action { return c.On }
// PolygonMode selects whether Polygon fires when the balloon is inside or
// outside the configured polygon.
type PolygonMode int
const (
// PolygonInside fires when (lat, lng) lies inside the polygon — useful
// for "must not enter restricted airspace".
PolygonInside PolygonMode = iota
// PolygonOutside fires when (lat, lng) lies outside the polygon —
// useful for "must remain over the test range".
PolygonOutside
)
// PolygonVertex is one vertex of a geographic polygon. Latitudes are in
// degrees [-90, 90]; longitudes in degrees [0, 360) or [-180, 180]
// (callers normalise — see Polygon.Violated).
type PolygonVertex struct {
Lat float64
Lng float64
}
// Polygon is a constraint over a geographic polygon. The polygon is
// considered closed (last vertex connects to the first) and is interpreted
// in plate-carrée (rectangular lat/lng) coordinates with longitude
// wrap-around handling.
//
// Edges crossing the 180/-180 antimeridian are split via longitude
// normalisation against the polygon's centroid: callers that need
// great-circle accuracy should clip their polygon along the antimeridian
// before submitting.
type Polygon struct {
Vertices []PolygonVertex
Mode PolygonMode
On Action
// Label, if set, is returned by Name. Defaults to "polygon_inside" or
// "polygon_outside" based on Mode.
Label string
}
func (c Polygon) Name() string {
if c.Label != "" {
return c.Label
}
if c.Mode == PolygonOutside {
return "polygon_outside"
}
return "polygon_inside"
}
func (c Polygon) Action() Action { return c.On }
// Violated reports whether the state satisfies the polygon-containment rule.
func (c Polygon) Violated(_ float64, s State) bool {
if len(c.Vertices) < 3 {
return false
}
in := pointInPolygon(s.Lat, s.Lng, c.Vertices)
if c.Mode == PolygonInside {
return in
}
return !in
}
// pointInPolygon implements the ray-casting algorithm in lat/lng space.
//
// All vertices and the query point are normalised to within 180° of
// verts[0] before testing, so a polygon spanning the antimeridian is
// handled correctly as long as the polygon itself spans no more than 180°
// in longitude.
func pointInPolygon(lat, lng float64, verts []PolygonVertex) bool {
if len(verts) == 0 {
return false
}
ref := verts[0].Lng
qx := normLng(lng, ref)
inside := false
n := len(verts)
for i, j := 0, n-1; i < n; j, i = i, i+1 {
yi, yj := verts[i].Lat, verts[j].Lat
xi := normLng(verts[i].Lng, ref)
xj := normLng(verts[j].Lng, ref)
if (yi > lat) != (yj > lat) {
xIntersect := (xj-xi)*(lat-yi)/(yj-yi) + xi
if qx < xIntersect {
inside = !inside
}
}
}
return inside
}
// normLng rewrites v so that it lies within 180° of ref. With ref=10 and
// v=350, normLng returns -10.
func normLng(v, ref float64) float64 {
diff := math.Mod(v-ref+540, 360) - 180
return ref + diff
}

View file

@ -8,8 +8,7 @@ import (
"predictor-refactored/internal/weather"
)
// noWind is a WindField that always returns zero wind. Lets us test
// integration of vertical-only profiles deterministically.
// noWind is a WindField that always returns zero wind.
type noWind struct{ epoch time.Time }
func (n noWind) Wind(_ float64, _, _, _ float64) (weather.Sample, error) {
@ -31,19 +30,23 @@ func TestConstantAscentToBurst(t *testing.T) {
Name: "ascent",
Step: 60,
Model: Sum(ConstantRate(rate), WindTransport(noWind{}, nil)),
Constraints: []Constraint{MaxAltitude{Limit: burst, On: ActionStop}},
Constraints: []Constraint{Altitude{Op: OpGreaterEqual, Limit: burst, On: ActionStop}},
}
prof := Profile{Stages: []*Propagator{ascend}, Direction: Forward}
results := prof.Run(0, State{Lat: 0, Lng: 0, Altitude: 0})
results := prof.Run(0, State{Lat: 0, Lng: 0, Altitude: 0}, NewEventSink())
if len(results) != 1 || results[0].Outcome != OutcomeStopped {
t.Fatalf("expected one stopped stage, got %+v", results)
}
if results[0].ConstraintName == "" {
t.Errorf("ConstraintName not populated")
}
if results[0].RefinedState.Altitude == 0 {
t.Errorf("RefinedState not populated")
}
last := results[0].Points[len(results[0].Points)-1]
// Refinement tolerance is 0.01 in parameter space over a 60s step, so the
// returned point sits within ±0.6s × rate ≈ ±3m of the boundary.
if math.Abs(last.Altitude-burst) > 5 {
t.Errorf("burst altitude = %v, want within 5m of %v", last.Altitude, burst)
}
@ -67,12 +70,12 @@ func TestProfileWithFallback(t *testing.T) {
Name: "ascent",
Step: 60,
Model: ConstantRate(rate),
Constraints: []Constraint{MaxAltitude{Limit: burst, On: ActionFallback}},
Constraints: []Constraint{Altitude{Op: OpGreaterEqual, Limit: burst, On: ActionFallback}},
Fallback: descent,
}
prof := Profile{Stages: []*Propagator{ascend}, Direction: Forward}
results := prof.Run(0, State{Altitude: 0})
results := prof.Run(0, State{Altitude: 0}, NewEventSink())
if len(results) != 2 {
t.Fatalf("expected 2 results (ascent then descent fallback), got %d", len(results))
@ -91,16 +94,14 @@ func TestProfileWithFallback(t *testing.T) {
}
func TestReverseDirection(t *testing.T) {
// Start at altitude 100m with downward rate; integrating reverse should
// give increasing altitude.
desc := &Propagator{
Name: "rewind",
Step: 1,
Model: ConstantRate(-1), // forward: alt decreases at 1 m/s
Constraints: []Constraint{MaxAltitude{Limit: 200, On: ActionStop}},
Model: ConstantRate(-1),
Constraints: []Constraint{Altitude{Op: OpGreaterEqual, Limit: 200, On: ActionStop}},
}
prof := Profile{Stages: []*Propagator{desc}, Direction: Reverse}
results := prof.Run(0, State{Altitude: 100})
results := prof.Run(0, State{Altitude: 100}, NewEventSink())
last := results[0].Points[len(results[0].Points)-1]
if math.Abs(last.Altitude-200) > 1 {
@ -129,6 +130,33 @@ func TestPiecewiseRate(t *testing.T) {
}
}
func TestPiecewiseReferenceResolution(t *testing.T) {
// Build via the registry with propagator_start segments.
spec := ModelSpec{
Type: "piecewise",
Segments: []PiecewiseSegmentSpec{
{Until: 100, Rate: 5, Reference: "propagator_start"},
{Until: 200, Rate: 3, Reference: "propagator_start"},
},
}
built, err := BuildModel(spec, BuildDeps{})
if err != nil {
t.Fatalf("BuildModel: %v", err)
}
if built.Build == nil {
t.Fatalf("expected lazy build for propagator_start references")
}
ctx := StageContext{ProfileStart: 1000, PropagatorStart: 5000}
m := built.Build(ctx)
// Until=100 from propagator_start=5000 → absolute 5100.
if r := m(5050, State{}); r.Altitude != 5 {
t.Errorf("rate at t=5050 = %v, want 5", r.Altitude)
}
if r := m(5150, State{}); r.Altitude != 3 {
t.Errorf("rate at t=5150 = %v, want 3", r.Altitude)
}
}
// fixedWind returns a constant wind sample.
type fixedWind struct{ u, v float64 }
@ -139,12 +167,8 @@ func (fixedWind) Epoch() time.Time { return time.Unix(0, 0) }
func (fixedWind) Source() string { return "test-fixed" }
func TestWindTransportUnitConversion(t *testing.T) {
// Pure eastward wind of 10 m/s at the equator at sea level.
// Expected dlng/dt = (180/pi) * 10 / (6371009 * cos(0)) ≈ 0.00008991 deg/s.
// Expected dlat/dt = 0.
wind := WindTransport(fixedWind{u: 10, v: 0}, nil)
d := wind(0, State{Lat: 0, Lng: 0, Altitude: 0})
wantLng := (180.0 / math.Pi) * 10.0 / 6371009.0
if math.Abs(d.Lng-wantLng) > 1e-12 {
t.Errorf("dlng = %v, want %v", d.Lng, wantLng)
@ -153,7 +177,6 @@ func TestWindTransportUnitConversion(t *testing.T) {
t.Errorf("dlat = %v, want 0 for u=10 v=0", d.Lat)
}
// Pure northward at 60° latitude: dlat = (180/pi) * v / R, dlng = 0.
wind2 := WindTransport(fixedWind{u: 0, v: 5}, nil)
d = wind2(0, State{Lat: 60, Lng: 0, Altitude: 0})
wantLat := (180.0 / math.Pi) * 5.0 / 6371009.0
@ -162,8 +185,28 @@ func TestWindTransportUnitConversion(t *testing.T) {
}
}
// aboveModelWind reports AboveModel on every sample. Used to verify event emission.
type aboveModelWind struct{}
func (aboveModelWind) Wind(_ float64, _, _, _ float64) (weather.Sample, error) {
return weather.Sample{AboveModel: true}, nil
}
func (aboveModelWind) Epoch() time.Time { return time.Unix(0, 0) }
func (aboveModelWind) Source() string { return "above" }
func TestWindTransportEmitsAboveModel(t *testing.T) {
sink := NewEventSink()
wind := WindTransport(aboveModelWind{}, sink)
for range 3 {
_ = wind(0, State{})
}
events := sink.Snapshot()
if len(events) != 1 || events[0].Type != "above_model" || events[0].Count != 3 {
t.Errorf("expected one above_model event with count=3, got %+v", events)
}
}
func TestStateAddWrapsLongitude(t *testing.T) {
// Demonstrates state algebra used by the integrator and refinement.
s := stateAdd(State{Lat: 0, Lng: 350, Altitude: 0}, 1, State{Lng: 20})
if math.Abs(s.Lng-10) > 1e-9 {
t.Errorf("addState wrap: lng = %v, want 10", s.Lng)
@ -174,3 +217,39 @@ func TestStateAddWrapsLongitude(t *testing.T) {
t.Errorf("lerpState lng wrap: %v, want 0 or 360", mid.Lng)
}
}
func TestPolygonInside(t *testing.T) {
// Unit square at the equator.
square := []PolygonVertex{
{Lat: -1, Lng: -1},
{Lat: -1, Lng: 1},
{Lat: 1, Lng: 1},
{Lat: 1, Lng: -1},
}
c := Polygon{Vertices: square, Mode: PolygonInside, On: ActionStop}
if !c.Violated(0, State{Lat: 0, Lng: 0}) {
t.Errorf("origin should be inside the square")
}
if c.Violated(0, State{Lat: 5, Lng: 0}) {
t.Errorf("(5, 0) should be outside the square")
}
}
func TestPolygonOutsideAntimeridian(t *testing.T) {
// A polygon centred near the antimeridian, spanning lng 170..-170
// (i.e. lng 170..190 in [0, 360) form).
poly := []PolygonVertex{
{Lat: -10, Lng: 170},
{Lat: -10, Lng: 190},
{Lat: 10, Lng: 190},
{Lat: 10, Lng: 170},
}
c := Polygon{Vertices: poly, Mode: PolygonInside, On: ActionStop}
// A point at the antimeridian.
if !c.Violated(0, State{Lat: 0, Lng: 180}) {
t.Errorf("(0, 180) should be inside the antimeridian polygon")
}
if c.Violated(0, State{Lat: 0, Lng: 0}) {
t.Errorf("(0, 0) should be outside")
}
}

89
internal/engine/events.go Normal file
View file

@ -0,0 +1,89 @@
package engine
import "sync"
// Event is a non-fatal observation made during integration.
//
// Events generalise the warnings counter from the original Tawhiri port:
// any model or constraint can emit them, the EventSink aggregates by Type,
// and each Result carries a summary slice for the API to surface.
type Event struct {
Type string // short identifier, e.g. "above_model"
Time float64 // UNIX seconds when the event was emitted
State State
Message string
}
// EventSummary is the per-type aggregation of repeated emissions.
type EventSummary struct {
Type string `json:"type"`
Count int64 `json:"count"`
FirstTime float64 `json:"first_time"`
LastTime float64 `json:"last_time"`
FirstState State `json:"first_state"`
LastState State `json:"last_state"`
Message string `json:"message"`
}
// EventSink collects events from models and the integrator, aggregating
// duplicate types into a single EventSummary. Safe for concurrent use.
type EventSink struct {
mu sync.Mutex
summaries map[string]*EventSummary
}
// NewEventSink returns an empty sink.
func NewEventSink() *EventSink { return &EventSink{summaries: make(map[string]*EventSummary)} }
// Emit records one occurrence of typ at (t, s) with the provided message.
// Subsequent emits with the same typ update LastTime/LastState and Count.
func (s *EventSink) Emit(typ string, t float64, state State, message string) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
sum, ok := s.summaries[typ]
if !ok {
s.summaries[typ] = &EventSummary{
Type: typ, Count: 1,
FirstTime: t, LastTime: t,
FirstState: state, LastState: state,
Message: message,
}
return
}
sum.Count++
sum.LastTime = t
sum.LastState = state
if sum.Message == "" && message != "" {
sum.Message = message
}
}
// Snapshot returns a stable copy of every summary in deterministic order
// (sorted by Type).
func (s *EventSink) Snapshot() []EventSummary {
if s == nil {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
out := make([]EventSummary, 0, len(s.summaries))
for _, sum := range s.summaries {
out = append(out, *sum)
}
sortEventSummaries(out)
return out
}
func sortEventSummaries(s []EventSummary) {
// Insertion sort: usually one or two entries.
for i := 1; i < len(s); i++ {
j := i
for j > 0 && s[j-1].Type > s[j].Type {
s[j-1], s[j] = s[j], s[j-1]
j--
}
}
}

View file

@ -3,14 +3,13 @@ package engine
import (
"math"
"sort"
"sync/atomic"
"predictor-refactored/internal/weather"
)
// Sum composes models by summing their derivatives at each evaluation point.
//
// Useful for combining e.g. a vertical-rate model with a horizontal wind model
// Useful for combining a vertical-rate model with a horizontal wind model
// into a single propagator. Equivalent to Tawhiri's LinearModel.
func Sum(models ...Model) Model {
if len(models) == 1 {
@ -29,18 +28,16 @@ func Sum(models ...Model) Model {
}
// ConstantRate returns a model with a constant vertical velocity (m/s).
// A positive rate is upward (ascent); a negative rate is downward.
// Positive rates are upward.
func ConstantRate(rate float64) Model {
return func(_ float64, _ State) State {
return State{Altitude: rate}
}
return func(_ float64, _ State) State { return State{Altitude: rate} }
}
// ParachuteDescent returns a model where vertical velocity grows with altitude
// because thinner air provides less drag.
// ParachuteDescent returns a model where vertical velocity grows with
// altitude because thinner air provides less drag. seaLevelRate is the
// descent speed at sea level (m/s, positive).
//
// seaLevelRate is the descent speed at sea level (m/s, positive number).
// The terminal velocity at altitude is computed as
// Terminal velocity at altitude is computed as
//
// v = -k / sqrt(rho(alt)), k = seaLevelRate * 1.1045,
//
@ -52,9 +49,9 @@ func ParachuteDescent(seaLevelRate float64) Model {
}
}
// nasaDensity returns air density (kg/m^3) for the given altitude in metres,
// using the NASA simple atmosphere model. See
// https://www.grc.nasa.gov/WWW/K-12/airplane/atmosmet.html.
// nasaDensity returns air density (kg/m^3) for an altitude in metres,
// using the NASA simple atmosphere model.
// See https://www.grc.nasa.gov/WWW/K-12/airplane/atmosmet.html.
func nasaDensity(alt float64) float64 {
var temp, pressure float64
switch {
@ -71,22 +68,17 @@ func nasaDensity(alt float64) float64 {
return pressure / (0.2869 * (temp + 273.1))
}
// RateSegment is one entry in a Piecewise rate schedule.
// RateSegment is one entry in a Piecewise rate schedule. Until is the UNIX
// timestamp at which this segment ends — the model emits the segment's
// Rate for all t < Until. The final segment's Rate is held indefinitely.
type RateSegment struct {
// Until is the UNIX timestamp at which this segment ends.
// The model applies the segment's Rate for all t < Until.
Until float64
// Rate is the vertical velocity (m/s) during the segment. Positive is up.
Rate float64
Rate float64
}
// Piecewise returns a model that produces a piecewise-constant vertical rate
// over a sequence of time intervals.
//
// Segments are searched by their Until field; the first segment whose Until
// exceeds t supplies the active rate. For t at or after the last Until, the
// final segment's Rate is held indefinitely. Input is sorted ascending by
// Until on construction.
// Piecewise returns a model that produces a piecewise-constant vertical
// rate over a sequence of intervals. The input is sorted ascending by
// Until on construction; later segments shadow earlier ones.
func Piecewise(segments []RateSegment) Model {
if len(segments) == 0 {
return ConstantRate(0)
@ -104,33 +96,13 @@ func Piecewise(segments []RateSegment) Model {
}
}
// Warnings aggregates non-fatal conditions encountered during integration.
type Warnings struct {
// AltitudeTooHigh counts evaluations where the wind sampler reported
// that altitude was above the highest pressure level of the dataset.
AltitudeTooHigh atomic.Int64
}
// ToMap returns warnings as a map suitable for JSON output. Only counters
// that have fired are included.
func (w *Warnings) ToMap() map[string]any {
out := make(map[string]any)
if n := w.AltitudeTooHigh.Load(); n > 0 {
out["altitude_too_high"] = map[string]any{
"count": n,
"description": "altitude exceeded the highest pressure level of the wind dataset; samples were extrapolated",
}
}
return out
}
// WindTransport returns a model that moves laterally at the wind velocity
// sampled from field. The vertical component of the returned derivative is
// zero. Wind units are converted from m/s to deg/s on Earth's surface.
// sampled from field. Vertical component is zero. Wind components in m/s
// are converted to deg/s on Earth's surface using R = 6371009 m.
//
// If warnings is non-nil, the AltitudeTooHigh counter is incremented for any
// sample where the wind field reported altitude above the model top.
func WindTransport(field weather.WindField, warnings *Warnings) Model {
// If events is non-nil, an "above_model" event is emitted whenever the
// wind field reports altitude above the highest pressure level.
func WindTransport(field weather.WindField, events *EventSink) Model {
const earthR = 6371009.0
const piOver180 = math.Pi / 180.0
const degPerRad = 180.0 / math.Pi
@ -139,8 +111,9 @@ func WindTransport(field weather.WindField, warnings *Warnings) Model {
if err != nil {
return State{}
}
if sample.AboveModel && warnings != nil {
warnings.AltitudeTooHigh.Add(1)
if sample.AboveModel && events != nil {
events.Emit("above_model", t, s,
"altitude exceeded the highest pressure level of the wind dataset; samples extrapolated")
}
r := earthR + s.Altitude
return State{

View file

@ -0,0 +1,69 @@
package engine
import "fmt"
// Operator is a scalar comparison used by generalised constraints like
// Altitude and Time. A constraint fires when its Operator.Test(value, limit)
// returns true.
type Operator int
const (
OpLess Operator = iota // value < limit
OpLessEqual // value ≤ limit
OpGreater // value > limit
OpGreaterEqual // value ≥ limit
OpEqual // value == limit
)
// Test evaluates op(value, limit).
func (o Operator) Test(value, limit float64) bool {
switch o {
case OpLess:
return value < limit
case OpLessEqual:
return value <= limit
case OpGreater:
return value > limit
case OpGreaterEqual:
return value >= limit
case OpEqual:
return value == limit
}
return false
}
// String returns the symbol "<", "<=", ">", ">=", "==".
func (o Operator) String() string {
switch o {
case OpLess:
return "<"
case OpLessEqual:
return "<="
case OpGreater:
return ">"
case OpGreaterEqual:
return ">="
case OpEqual:
return "=="
}
return "?"
}
// ParseOperator maps a textual operator to its Operator constant.
// Accepts "<", "<=", "le", ">", ">=", "ge", "==", "eq".
func ParseOperator(s string) (Operator, error) {
switch s {
case "<", "lt":
return OpLess, nil
case "<=", "le":
return OpLessEqual, nil
case ">", "gt":
return OpGreater, nil
case ">=", "ge":
return OpGreaterEqual, nil
case "==", "eq":
return OpEqual, nil
default:
return 0, fmt.Errorf("unknown operator %q", s)
}
}

View file

@ -3,21 +3,26 @@ package engine
// Profile is an ordered chain of propagators executed sequentially. Each
// propagator picks up where the previous one finished.
type Profile struct {
// Stages are run in order. For Direction=Reverse they are still iterated
// from index 0 onwards, but each propagator integrates with negative dt.
// Stages are run in order. For Direction=Reverse they are still
// iterated from index 0 onwards but each propagator integrates with
// negative dt.
Stages []*Propagator
// Direction controls the sign of dt across the whole profile.
// Direction controls the sign of dt across the profile.
Direction Direction
// Globals are constraints evaluated alongside each stage's local Constraints.
// Useful for profile-wide bounds like "stop after N hours total".
// Globals are constraints evaluated alongside each stage's local
// Constraints. Useful for profile-wide bounds like "stop after N hours".
Globals []Constraint
}
// Run executes the profile from the given launch point. Returns one Result
// per executed stage, including any Fallback chains that were activated.
func (p *Profile) Run(t0 float64, launch State) []Result {
// Run executes the profile from the given launch point. Returns one
// Result per executed stage, including any Fallback chains that were
// activated. The supplied EventSink is shared across stages and aggregates
// non-fatal observations.
//
// events may be nil; pass NewEventSink() to capture observations.
func (p *Profile) Run(t0 float64, launch State, events *EventSink) []Result {
if p.Direction == 0 {
p.Direction = Forward
}
@ -27,28 +32,36 @@ func (p *Profile) Run(t0 float64, launch State) []Result {
for i := 0; i < len(p.Stages); i++ {
stage := p.Stages[i]
res := stage.run(t, s, p.Direction, p.Globals)
ctx := StageContext{
ProfileStart: t0,
PropagatorStart: t,
Launch: launch,
PropagatorState: s,
Direction: p.Direction,
}
res := stage.run(ctx, t, s, p.Globals, events)
results = append(results, res)
last := res.Points[len(res.Points)-1]
t = last.Time
s = State{Lat: last.Lat, Lng: last.Lng, Altitude: last.Altitude}
// Follow Fallback chains until none remains. Each fallback consumes
// from the same point the previous stage stopped at.
// Follow Fallback chains until none remains.
for res.Outcome == OutcomeFallback && stage.Fallback != nil {
stage = stage.Fallback
res = stage.run(t, s, p.Direction, p.Globals)
ctx = StageContext{
ProfileStart: t0,
PropagatorStart: t,
Launch: launch,
PropagatorState: s,
Direction: p.Direction,
}
res = stage.run(ctx, t, s, p.Globals, events)
results = append(results, res)
last = res.Points[len(res.Points)-1]
t = last.Time
s = State{Lat: last.Lat, Lng: last.Lng, Altitude: last.Altitude}
}
// If a propagator's stop fired (not a fallback), end the profile.
if res.Outcome == OutcomeStopped {
continue
}
}
return results

View file

@ -7,71 +7,58 @@ import (
// Propagator advances state under one Model, checking a set of Constraints
// after every integration step.
//
// When a constraint fires, the propagator binary-search refines the violation
// point and emits it as its final trajectory point. The Action of the
// triggering constraint controls what the surrounding Profile does next:
// stop the profile, transfer to Fallback, or clip and continue.
// When a constraint fires, the propagator binary-search refines the
// violation point and emits it as its final trajectory point. The Action of
// the triggering constraint controls what the surrounding Profile does
// next: stop the profile, transfer to Fallback, or clip and continue.
type Propagator struct {
// Name identifies the propagator in trajectory metadata.
// Name identifies the propagator in trajectory metadata. Optional —
// callers using sequential profile chains may leave it empty.
Name string
// Step is the magnitude of the integration step in seconds (always positive).
// The Profile flips its sign for Reverse direction.
Step float64
// Model produces the per-second time derivative of state.
Model Model
// Model is the per-second derivative function used for integration.
// One of Model or BuildModel must be non-nil. If both are set, BuildModel
// takes precedence (it is invoked once per stage with a StageContext).
Model Model
BuildModel func(ctx StageContext) Model
// Constraints are evaluated after each step. Any fired constraint stops
// the propagator at the refined point; the first one in this slice wins
// on ties.
Constraints []Constraint
// Constraints are evaluated after each step. The first violation wins.
Constraints []Constraint
BuildConstraints func(ctx StageContext) []Constraint
// Fallback is the propagator to switch to when a constraint with
// ActionFallback fires. Optional.
Fallback *Propagator
// Tolerance is the binary-search refinement tolerance in parameter space
// (default 0.01, matching Tawhiri).
// Tolerance is the binary-search refinement tolerance in parameter
// space (default 0.01, matching Tawhiri).
Tolerance float64
}
// Outcome describes how a propagator's run ended.
type Outcome int
const (
// OutcomeStopped means a Constraint with ActionStop fired and the profile
// should end here.
OutcomeStopped Outcome = iota
// OutcomeFallback means a Constraint with ActionFallback fired and the
// profile should transfer to the propagator's Fallback chain.
OutcomeFallback
// OutcomeContinued means no constraint fired before the time horizon was
// reached. In practice this is only seen when a propagator runs unbounded,
// which means the profile is misconfigured.
OutcomeContinued
)
// Result is the output of running one propagator.
type Result struct {
Propagator string
Points []TrajectoryPoint
Outcome Outcome
// Constraint is the constraint that fired, or nil if Outcome == OutcomeContinued.
Constraint Constraint
}
// run integrates the model from (t0, s0) in direction dir, returning a Result.
// globals are constraints injected by the Profile and checked alongside the
// propagator's local Constraints.
func (p *Propagator) run(t0 float64, s0 State, dir Direction, globals []Constraint) Result {
dt := p.Step * float64(dir)
// propagator's local Constraints. events receives non-fatal observations.
func (p *Propagator) run(ctx StageContext, t0 float64, s0 State, globals []Constraint, events *EventSink) Result {
dt := p.Step * float64(ctx.Direction)
tol := p.Tolerance
if tol == 0 {
tol = 0.01
}
deriv := numerics.Deriv[State](func(t float64, s State) State { return p.Model(t, s) })
model := p.Model
if p.BuildModel != nil {
model = p.BuildModel(ctx)
}
constraints := p.Constraints
if p.BuildConstraints != nil {
constraints = p.BuildConstraints(ctx)
}
deriv := numerics.Deriv[State](func(t float64, s State) State { return model(t, s) })
add := numerics.VecAdd[State](stateAdd)
lerp := numerics.VecLerp[State](stateLerp)
@ -90,39 +77,50 @@ func (p *Propagator) run(t0 float64, s0 State, dir Direction, globals []Constrai
s2 := numerics.RK4Step(t, s, dt, deriv, add)
t2 := t + dt
if c, fired := firstFiring(p.Constraints, globals, t2, s2); fired {
trig := numerics.Trigger[State](func(tt float64, ss State) bool { return c.Violated(tt, ss) })
t3, s3 := numerics.RefineTrigger(t, s, t2, s2, trig, lerp, tol)
switch c.Action() {
case ActionClip:
s3 = clipToConstraint(c, s3)
out.Points = append(out.Points, TrajectoryPoint{
Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude,
})
t, s = t3, s3
continue
case ActionFallback:
out.Points = append(out.Points, TrajectoryPoint{
Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude,
})
out.Outcome = OutcomeFallback
out.Constraint = c
return out
default: // ActionStop
out.Points = append(out.Points, TrajectoryPoint{
Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude,
})
out.Outcome = OutcomeStopped
out.Constraint = c
return out
}
c, fired := firstFiring(constraints, globals, t2, s2)
if !fired {
t, s = t2, s2
out.Points = append(out.Points, TrajectoryPoint{
Time: t, Lat: s.Lat, Lng: s.Lng, Altitude: s.Altitude,
})
continue
}
t, s = t2, s2
out.Points = append(out.Points, TrajectoryPoint{
Time: t, Lat: s.Lat, Lng: s.Lng, Altitude: s.Altitude,
})
// Record the unrefined violation.
out.ViolationTime = t2
out.ViolationState = s2
trig := numerics.Trigger[State](func(tt float64, ss State) bool { return c.Violated(tt, ss) })
t3, s3 := numerics.RefineTrigger(t, s, t2, s2, trig, lerp, tol)
out.RefinedTime = t3
out.RefinedState = s3
out.Constraint = c
out.ConstraintName = c.Name()
switch c.Action() {
case ActionClip:
s3 = clipToConstraint(c, s3)
out.RefinedState = s3
out.Points = append(out.Points, TrajectoryPoint{
Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude,
})
t, s = t3, s3
continue
case ActionFallback:
out.Points = append(out.Points, TrajectoryPoint{
Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude,
})
out.Outcome = OutcomeFallback
out.Events = events.Snapshot()
return out
default: // ActionStop
out.Points = append(out.Points, TrajectoryPoint{
Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude,
})
out.Outcome = OutcomeStopped
out.Events = events.Snapshot()
return out
}
}
}
@ -142,15 +140,12 @@ func firstFiring(local, globals []Constraint, t float64, s State) (Constraint, b
return nil, false
}
// clipToConstraint adjusts s so that the given constraint is exactly satisfied
// (not violated). Implemented for constraints with a well-defined boundary;
// others fall through unchanged.
// clipToConstraint adjusts s so that the given constraint is exactly
// satisfied (not violated). Defined only for constraints with a
// well-defined coordinate boundary; others fall through unchanged.
func clipToConstraint(c Constraint, s State) State {
switch v := c.(type) {
case MaxAltitude:
s.Altitude = v.Limit
case MinAltitude:
s.Altitude = v.Limit
if alt, ok := c.(Altitude); ok {
s.Altitude = alt.Limit
}
return s
}

287
internal/engine/registry.go Normal file
View file

@ -0,0 +1,287 @@
package engine
import (
"fmt"
"sync"
"predictor-refactored/internal/weather"
)
// ConstraintSpec is the source-agnostic JSON-shape used to declare a
// constraint. The Type field is the registry key; remaining fields are
// extracted by the registered factory.
type ConstraintSpec struct {
Type string `json:"type"`
Action string `json:"action,omitempty"`
// Op is the comparison operator for scalar constraints (altitude, time).
Op string `json:"op,omitempty"`
Limit float64 `json:"limit,omitempty"`
// Vertices and Mode are used by the polygon constraint.
Vertices []PolygonVertex `json:"vertices,omitempty"`
Mode string `json:"mode,omitempty"`
// Label is an optional human-readable identifier surfaced via Name().
Label string `json:"label,omitempty"`
}
// ModelSpec is the source-agnostic JSON shape used to declare a model.
type ModelSpec struct {
Type string `json:"type"`
// Rate (m/s) for constant_rate.
Rate float64 `json:"rate,omitempty"`
// SeaLevelRate (m/s, positive) for parachute_descent.
SeaLevelRate float64 `json:"sea_level_rate,omitempty"`
// Segments for piecewise.
Segments []PiecewiseSegmentSpec `json:"segments,omitempty"`
// IncludeWind sums a WindTransport model into the resulting derivative.
IncludeWind bool `json:"include_wind,omitempty"`
}
// PiecewiseSegmentSpec is one entry in a piecewise rate schedule.
//
// Reference selects how the Until field is interpreted:
//
// - "absolute" (default): UNIX seconds.
// - "profile_start": seconds since the profile's launch time.
// - "propagator_start": seconds since this propagator began running.
type PiecewiseSegmentSpec struct {
Until float64 `json:"until"`
Rate float64 `json:"rate"`
Reference string `json:"reference,omitempty"`
}
// BuildDeps bundle the runtime dependencies factories may consult.
type BuildDeps struct {
Wind weather.WindField
Terrain TerrainProvider
Events *EventSink
}
// ConstraintFactory builds one Constraint from a spec.
type ConstraintFactory func(spec ConstraintSpec, deps BuildDeps) (Constraint, error)
// ModelFactory builds one model from a spec. The returned Built is held by
// a Propagator; if Build is set, it is invoked lazily by the profile
// runner before every stage so it can capture per-stage start times.
type ModelFactory func(spec ModelSpec, deps BuildDeps) (BuiltModel, error)
// BuiltModel is either an eager Model, a lazy Build, or both. The profile
// runner prefers Build when present.
type BuiltModel struct {
Model Model
Build func(ctx StageContext) Model
}
var (
regMu sync.RWMutex
constraintFactories = map[string]ConstraintFactory{}
modelFactories = map[string]ModelFactory{}
)
// RegisterConstraint installs a factory for typeName. Subsequent calls
// overwrite the previous factory.
func RegisterConstraint(typeName string, f ConstraintFactory) {
regMu.Lock()
defer regMu.Unlock()
constraintFactories[typeName] = f
}
// RegisterModel installs a model factory.
func RegisterModel(typeName string, f ModelFactory) {
regMu.Lock()
defer regMu.Unlock()
modelFactories[typeName] = f
}
// BuildConstraint dispatches spec to its registered factory.
func BuildConstraint(spec ConstraintSpec, deps BuildDeps) (Constraint, error) {
regMu.RLock()
f, ok := constraintFactories[spec.Type]
regMu.RUnlock()
if !ok {
return nil, fmt.Errorf("unknown constraint type %q", spec.Type)
}
return f(spec, deps)
}
// BuildModel dispatches spec to its registered factory.
func BuildModel(spec ModelSpec, deps BuildDeps) (BuiltModel, error) {
regMu.RLock()
f, ok := modelFactories[spec.Type]
regMu.RUnlock()
if !ok {
return BuiltModel{}, fmt.Errorf("unknown model type %q", spec.Type)
}
return f(spec, deps)
}
// RegisteredConstraints returns the names of every registered constraint type.
func RegisteredConstraints() []string {
regMu.RLock()
defer regMu.RUnlock()
out := make([]string, 0, len(constraintFactories))
for k := range constraintFactories {
out = append(out, k)
}
return out
}
// RegisteredModels returns the names of every registered model type.
func RegisteredModels() []string {
regMu.RLock()
defer regMu.RUnlock()
out := make([]string, 0, len(modelFactories))
for k := range modelFactories {
out = append(out, k)
}
return out
}
// --- Built-in registrations ------------------------------------------------
func init() {
RegisterConstraint("altitude", buildAltitude)
RegisterConstraint("time", buildTime)
RegisterConstraint("terrain_contact", buildTerrainContact)
RegisterConstraint("polygon", buildPolygon)
RegisterModel("constant_rate", buildConstantRate)
RegisterModel("parachute_descent", buildParachuteDescent)
RegisterModel("piecewise", buildPiecewise)
RegisterModel("wind", buildWind)
}
func buildAltitude(spec ConstraintSpec, _ BuildDeps) (Constraint, error) {
op, err := ParseOperator(spec.Op)
if err != nil {
return nil, fmt.Errorf("altitude: %w", err)
}
act, err := ParseAction(spec.Action)
if err != nil {
return nil, fmt.Errorf("altitude: %w", err)
}
return Altitude{Op: op, Limit: spec.Limit, On: act}, nil
}
func buildTime(spec ConstraintSpec, _ BuildDeps) (Constraint, error) {
op, err := ParseOperator(spec.Op)
if err != nil {
return nil, fmt.Errorf("time: %w", err)
}
act, err := ParseAction(spec.Action)
if err != nil {
return nil, fmt.Errorf("time: %w", err)
}
return Time{Op: op, Limit: spec.Limit, On: act}, nil
}
func buildTerrainContact(spec ConstraintSpec, deps BuildDeps) (Constraint, error) {
if deps.Terrain == nil {
return nil, fmt.Errorf("terrain_contact requires a terrain provider")
}
act, err := ParseAction(spec.Action)
if err != nil {
return nil, fmt.Errorf("terrain_contact: %w", err)
}
return TerrainContact{Provider: deps.Terrain, On: act}, nil
}
func buildPolygon(spec ConstraintSpec, _ BuildDeps) (Constraint, error) {
if len(spec.Vertices) < 3 {
return nil, fmt.Errorf("polygon requires at least 3 vertices")
}
act, err := ParseAction(spec.Action)
if err != nil {
return nil, fmt.Errorf("polygon: %w", err)
}
mode := PolygonInside
switch spec.Mode {
case "", "inside":
mode = PolygonInside
case "outside":
mode = PolygonOutside
default:
return nil, fmt.Errorf("polygon: unknown mode %q", spec.Mode)
}
return Polygon{Vertices: spec.Vertices, Mode: mode, On: act, Label: spec.Label}, nil
}
func buildConstantRate(spec ModelSpec, _ BuildDeps) (BuiltModel, error) {
return BuiltModel{Model: ConstantRate(spec.Rate)}, nil
}
func buildParachuteDescent(spec ModelSpec, _ BuildDeps) (BuiltModel, error) {
if spec.SeaLevelRate <= 0 {
return BuiltModel{}, fmt.Errorf("parachute_descent requires positive sea_level_rate")
}
return BuiltModel{Model: ParachuteDescent(spec.SeaLevelRate)}, nil
}
func buildWind(_ ModelSpec, deps BuildDeps) (BuiltModel, error) {
if deps.Wind == nil {
return BuiltModel{}, fmt.Errorf("wind model requires a loaded wind field")
}
return BuiltModel{Model: WindTransport(deps.Wind, deps.Events)}, nil
}
func buildPiecewise(spec ModelSpec, deps BuildDeps) (BuiltModel, error) {
needsCtx := false
for _, seg := range spec.Segments {
if seg.Reference == "propagator_start" {
needsCtx = true
break
}
}
if !needsCtx {
// Eager build: resolve any "profile_start" relative segments using
// the launch time we know at build time only when we have one.
// Without context, treat profile_start the same as absolute (the
// caller is expected to pre-resolve), and absolute as absolute.
segs := make([]RateSegment, 0, len(spec.Segments))
for _, s := range spec.Segments {
if s.Reference == "profile_start" {
return BuiltModel{}, fmt.Errorf("piecewise: profile_start reference requires a stage context — supply via lazy build")
}
segs = append(segs, RateSegment{Until: s.Until, Rate: s.Rate})
}
base := Piecewise(segs)
return BuiltModel{Model: maybeAddWind(base, spec.IncludeWind, deps)}, nil
}
// Lazy build — captures spec into a closure.
return BuiltModel{
Build: func(ctx StageContext) Model {
segs := resolveSegments(spec.Segments, ctx)
base := Piecewise(segs)
return maybeAddWind(base, spec.IncludeWind, deps)
},
}, nil
}
// resolveSegments converts spec segments to engine.RateSegment using the
// stage context to resolve relative references.
func resolveSegments(in []PiecewiseSegmentSpec, ctx StageContext) []RateSegment {
out := make([]RateSegment, 0, len(in))
for _, s := range in {
var until float64
switch s.Reference {
case "", "absolute":
until = s.Until
case "profile_start":
until = ctx.ProfileStart + s.Until
case "propagator_start":
until = ctx.PropagatorStart + s.Until
}
out = append(out, RateSegment{Until: until, Rate: s.Rate})
}
return out
}
// maybeAddWind sums a WindTransport model into base when the spec asks for it.
func maybeAddWind(base Model, includeWind bool, deps BuildDeps) Model {
if !includeWind {
return base
}
if deps.Wind == nil {
return base
}
return Sum(base, WindTransport(deps.Wind, deps.Events))
}

View file

@ -1,27 +1,27 @@
// Package engine is the trajectory calculation engine. It composes
// propagators (model-driven integrators) into profiles (ordered chains) and
// runs them over a wind field.
// propagators (model-driven integrators) into profiles (ordered chains)
// over a wind field.
//
// The engine has no direct dependency on any specific data source: wind data
// is consumed through weather.WindField and terrain data through any type
// satisfying TerrainProvider.
// The engine has no direct dependency on any specific data source: wind
// data is consumed through weather.WindField and terrain data through
// any type satisfying TerrainProvider.
package engine
// State holds the spatial state of the balloon. When returned by a Model
// the same struct is interpreted as the per-second time derivative of state.
// the same struct is interpreted as the per-second time derivative.
type State struct {
// Lat is degrees latitude in [-90, 90] (or deg/s when returned as a derivative).
Lat float64
// Lng is degrees longitude in [0, 360) (or deg/s as a derivative).
Lng float64
// Altitude is metres above mean sea level (or m/s as a derivative).
Altitude float64
// Lat is degrees latitude in [-90, 90].
Lat float64 `json:"lat"`
// Lng is degrees longitude in [0, 360).
Lng float64 `json:"lng"`
// Altitude is metres above mean sea level.
Altitude float64 `json:"altitude"`
}
// Model returns the time derivative of state at (t, s).
//
// The derivative is direction-independent; the integrator applies the sign
// of dt for reverse propagation.
// The derivative is direction-independent; the integrator applies the
// sign of dt for reverse propagation.
type Model func(t float64, s State) State
// TrajectoryPoint is one sampled point of an integration result.
@ -32,9 +32,7 @@ type TrajectoryPoint struct {
Altitude float64
}
// Direction is the time direction of integration. Forward (+1) integrates
// from launch to landing; Reverse (-1) integrates from a known landing back
// to a candidate launch point.
// Direction is the time direction of integration.
type Direction int8
const (
@ -42,28 +40,39 @@ const (
Reverse Direction = -1
)
// Action describes what the profile runner should do when a Constraint
// reports a violation.
// Action is what the profile runner does on a constraint violation.
type Action int
const (
// ActionStop ends the current propagator at the (refined) violation point.
// This matches the only behaviour available in the reference Tawhiri solver.
// ActionStop ends the current propagator at the refined violation point.
ActionStop Action = iota
// ActionFallback ends the current propagator and starts its Fallback
// propagator from the violation point. Useful for "if max altitude is
// reached during ascent, switch to descent" profiles.
// propagator from the refined violation point.
ActionFallback
// ActionClip clips the violated coordinate to the boundary and continues
// integration. Useful for soft constraints such as "max altitude floor".
// integration.
ActionClip
)
// Constraint reports when integration should stop, branch, or clip.
//
// A constraint is direction-agnostic: it reads state and decides. The profile
// runner is responsible for refining the trigger point via binary search and
// dispatching the configured Action.
// ParseAction maps "stop" | "fallback" | "clip" to an Action.
func ParseAction(s string) (Action, error) {
switch s {
case "", "stop":
return ActionStop, nil
case "fallback":
return ActionFallback, nil
case "clip":
return ActionClip, nil
default:
return 0, errUnknownAction(s)
}
}
type errUnknownAction string
func (e errUnknownAction) Error() string { return "unknown constraint action " + string(e) }
// Constraint defines a stopping, branching, or clipping condition.
type Constraint interface {
// Name identifies the constraint in logs and result metadata.
Name() string
@ -74,7 +83,79 @@ type Constraint interface {
}
// TerrainProvider returns ground elevation in metres at a coordinate.
// Implementations must be safe for concurrent use.
type TerrainProvider interface {
Elevation(lat, lng float64) float64
}
// StageContext is provided to Propagator.BuildModel and BuildConstraints by
// the profile runner immediately before each stage executes.
type StageContext struct {
// ProfileStart is the UNIX timestamp of the profile's initial launch.
ProfileStart float64
// PropagatorStart is the UNIX timestamp at which this propagator begins
// running — equal to ProfileStart for the first stage; the end-time of
// the previous stage thereafter.
PropagatorStart float64
// Launch is the profile's initial state.
Launch State
// PropagatorState is the state at which this propagator begins.
PropagatorState State
// Direction is the integration direction the profile is configured with.
Direction Direction
}
// Outcome describes how a propagator's run ended.
type Outcome int
const (
// OutcomeStopped means a Constraint with ActionStop fired.
OutcomeStopped Outcome = iota
// OutcomeFallback means a Constraint with ActionFallback fired.
OutcomeFallback
// OutcomeContinued means the propagator finished without a constraint
// firing — only seen when a propagator is misconfigured to run unbounded.
OutcomeContinued
)
// String renders the outcome as a stable string for API serialisation.
func (o Outcome) String() string {
switch o {
case OutcomeStopped:
return "stopped"
case OutcomeFallback:
return "fallback"
default:
return "continued"
}
}
// Result is the output of running one propagator.
type Result struct {
// Propagator is the propagator's Name.
Propagator string
// Points is the emitted trajectory.
Points []TrajectoryPoint
// Outcome describes how the propagator terminated.
Outcome Outcome
// Constraint is the constraint that fired, or nil if Outcome is OutcomeContinued.
Constraint Constraint
// ConstraintName captures Constraint.Name() at fire time so callers can
// serialise the result after the Constraint has been garbage collected.
ConstraintName string
// ViolationTime / ViolationState describe the first integration step at
// which the constraint reported a violation, before binary-search refinement.
ViolationTime float64
ViolationState State
// RefinedTime / RefinedState describe the refined violation point that
// appears as the propagator's last trajectory point.
RefinedTime float64
RefinedState State
// Events is the aggregated set of non-fatal observations from this stage.
Events []EventSummary
}

View file

@ -1,34 +1,28 @@
package gfs
import "fmt"
// Cross-variant constants. Per-variant geometry (latitudes, longitudes,
// pressure levels, hour step, max hour, URL token) lives on the Variant
// type; see variant.go.
// Dataset shape: (hour, pressure_level, variable, latitude, longitude).
// Matches the cube layout used by the reference Tawhiri implementation.
const (
NumHours = 65 // 0, 3, 6, ..., 192 hours forecast
NumLevels = 47 // pressure levels
NumVariables = 3 // geopotential height, U-wind, V-wind
NumLatitudes = 361 // -90.0 to +90.0 inclusive in 0.5° steps
NumLongitudes = 720 // 0.0 to 359.5 in 0.5° steps
// NumVariables is the number of dataset variables: HGT, UGRD, VGRD.
NumVariables = 3
// ElementSize is the cell size in bytes (float32).
ElementSize = 4
HourStep = 3
MaxHour = 192
Resolution = 0.5
LatStart = -90.0
LonStart = 0.0
// LatStart is the first latitude in the cube (south to north).
LatStart = -90.0
// LonStart is the first longitude in the cube (0..360 east).
LonStart = 0.0
// Variable indices within the cube's 3rd axis.
VarHeight = 0
VarWindU = 1
VarWindV = 2
ElementSize = 4 // float32
// DatasetSize is the canonical file size: every grid cell × element size.
DatasetSize int64 = int64(NumHours) * int64(NumLevels) * int64(NumVariables) *
int64(NumLatitudes) * int64(NumLongitudes) * int64(ElementSize)
)
// LevelSet identifies which GRIB file (primary/secondary) carries a level.
// LevelSet identifies which GRIB file (primary or secondary) carries a
// pressure level.
type LevelSet int
const (
@ -36,106 +30,5 @@ const (
LevelSetB // pgrb2b — secondary file
)
// Pressures lists the 47 pressure levels (hPa) in dataset index order,
// descending from surface to top of atmosphere.
var Pressures = [NumLevels]int{
1000, 975, 950, 925, 900, 875, 850, 825, 800, 775,
750, 725, 700, 675, 650, 625, 600, 575, 550, 525,
500, 475, 450, 425, 400, 375, 350, 325, 300, 275,
250, 225, 200, 175, 150, 125, 100, 70, 50, 30,
20, 10, 7, 5, 3, 2, 1,
}
// PressuresPgrb2 lists the levels carried by the primary GRIB file.
var PressuresPgrb2 = []int{
10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 350, 400,
450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 925,
950, 975, 1000,
}
// PressuresPgrb2b lists the levels carried by the secondary GRIB file.
var PressuresPgrb2b = []int{
1, 2, 3, 5, 7, 125, 175, 225, 275, 325, 375, 425,
475, 525, 575, 625, 675, 725, 775, 825, 875,
}
var pressureIndex map[int]int
var pressureLevelSet map[int]LevelSet
func init() {
pressureIndex = make(map[int]int, NumLevels)
for i, p := range Pressures {
pressureIndex[p] = i
}
pressureLevelSet = make(map[int]LevelSet, NumLevels)
for _, p := range PressuresPgrb2 {
pressureLevelSet[p] = LevelSetA
}
for _, p := range PressuresPgrb2b {
pressureLevelSet[p] = LevelSetB
}
}
// PressureIndex returns the dataset index for a pressure level in hPa,
// or -1 when the level is unknown.
func PressureIndex(hPa int) int {
idx, ok := pressureIndex[hPa]
if !ok {
return -1
}
return idx
}
// PressureLevelSet returns the GRIB file set carrying a pressure level.
func PressureLevelSet(hPa int) (LevelSet, bool) {
ls, ok := pressureLevelSet[hPa]
return ls, ok
}
// HourIndex returns the dataset time index for a forecast hour, or -1 when
// the hour is outside the range or not a multiple of HourStep.
func HourIndex(hour int) int {
if hour < 0 || hour > MaxHour || hour%HourStep != 0 {
return -1
}
return hour / HourStep
}
// Hours returns the full list of forecast hours, [0, 3, 6, ..., MaxHour].
func Hours() []int {
out := make([]int, 0, NumHours)
for h := 0; h <= MaxHour; h += HourStep {
out = append(out, h)
}
return out
}
// VariableIndex maps a GRIB (category, number) pair to a dataset variable
// index, returning -1 for parameters this dataset does not store.
func VariableIndex(parameterCategory, parameterNumber int) int {
switch {
case parameterCategory == 3 && parameterNumber == 5:
return VarHeight
case parameterCategory == 2 && parameterNumber == 2:
return VarWindU
case parameterCategory == 2 && parameterNumber == 3:
return VarWindV
default:
return -1
}
}
// S3 URL configuration for NOAA GFS data on the public S3 mirror.
// S3BaseURL is the public NOAA S3 mirror.
const S3BaseURL = "https://noaa-gfs-bdp-pds.s3.amazonaws.com"
// GribURL returns the S3 URL for a primary (pgrb2) GRIB file.
func GribURL(date string, runHour, forecastStep int) string {
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.0p50.f%03d",
S3BaseURL, date, runHour, runHour, forecastStep)
}
// GribURLB returns the S3 URL for a secondary (pgrb2b) GRIB file.
func GribURLB(date string, runHour, forecastStep int) string {
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2b.0p50.f%03d",
S3BaseURL, date, runHour, runHour, forecastStep)
}

View file

@ -11,8 +11,10 @@ import (
)
// File is an mmap-backed wind dataset file. The layout is a flat C-order
// row-major array of float32 values, shape (hour, level, variable, lat, lng).
// row-major float32 array, shape (hour, level, variable, lat, lng), with
// the per-axis sizes coming from Variant.
type File struct {
variant *Variant
mm mmap.MMap
file *os.File
writable bool
@ -20,8 +22,11 @@ type File struct {
Epoch time.Time
}
// Variant returns the Variant the file was created with.
func (d *File) Variant() *Variant { return d.variant }
// Open opens an existing dataset file for reading.
func Open(path string, epoch time.Time) (*File, error) {
func Open(path string, variant *Variant, epoch time.Time) (*File, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open dataset: %w", err)
@ -31,39 +36,40 @@ func Open(path string, epoch time.Time) (*File, error) {
f.Close()
return nil, fmt.Errorf("stat dataset: %w", err)
}
if info.Size() != DatasetSize {
if info.Size() != variant.DatasetSize() {
f.Close()
return nil, fmt.Errorf("dataset should be %d bytes (was %d)", DatasetSize, info.Size())
return nil, fmt.Errorf("dataset should be %d bytes (was %d)", variant.DatasetSize(), info.Size())
}
mm, err := mmap.Map(f, mmap.RDONLY, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap dataset: %w", err)
}
return &File{mm: mm, file: f, writable: false, Epoch: epoch}, nil
return &File{variant: variant, mm: mm, file: f, writable: false, Epoch: epoch}, nil
}
// Create creates a new dataset file of the canonical size, mmap'd read-write.
func Create(path string) (*File, error) {
// Create creates a new dataset file sized for variant, mmap'd read-write.
func Create(path string, variant *Variant) (*File, error) {
f, err := os.Create(path)
if err != nil {
return nil, fmt.Errorf("create dataset: %w", err)
}
if err := f.Truncate(DatasetSize); err != nil {
size := variant.DatasetSize()
if err := f.Truncate(size); err != nil {
f.Close()
return nil, fmt.Errorf("truncate dataset: %w", err)
}
mm, err := mmap.MapRegion(f, int(DatasetSize), mmap.RDWR, 0, 0)
mm, err := mmap.MapRegion(f, int(size), mmap.RDWR, 0, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap dataset: %w", err)
}
return &File{mm: mm, file: f, writable: true}, nil
return &File{variant: variant, mm: mm, file: f, writable: true}, nil
}
// OpenWritable opens an existing dataset file for read-write access.
// Used when resuming a partial download.
func OpenWritable(path string) (*File, error) {
// OpenWritable opens an existing dataset file for read-write access. Used
// when resuming a partial download.
func OpenWritable(path string, variant *Variant) (*File, error) {
f, err := os.OpenFile(path, os.O_RDWR, 0o644)
if err != nil {
return nil, fmt.Errorf("open dataset rw: %w", err)
@ -73,51 +79,55 @@ func OpenWritable(path string) (*File, error) {
f.Close()
return nil, fmt.Errorf("stat dataset: %w", err)
}
if info.Size() != DatasetSize {
if info.Size() != variant.DatasetSize() {
f.Close()
return nil, fmt.Errorf("dataset should be %d bytes (was %d)", DatasetSize, info.Size())
return nil, fmt.Errorf("dataset should be %d bytes (was %d)", variant.DatasetSize(), info.Size())
}
mm, err := mmap.MapRegion(f, int(DatasetSize), mmap.RDWR, 0, 0)
mm, err := mmap.MapRegion(f, int(info.Size()), mmap.RDWR, 0, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap dataset: %w", err)
}
return &File{mm: mm, file: f, writable: true}, nil
return &File{variant: variant, mm: mm, file: f, writable: true}, nil
}
// offset returns the byte offset of the [hour][level][variable][lat][lng] cell.
func offset(hour, level, variable, lat, lng int) int64 {
func (d *File) offset(hour, level, variable, lat, lng int) int64 {
v := d.variant
idx := int64(hour)
idx = idx*int64(NumLevels) + int64(level)
idx = idx*int64(v.NumLevels()) + int64(level)
idx = idx*int64(NumVariables) + int64(variable)
idx = idx*int64(NumLatitudes) + int64(lat)
idx = idx*int64(NumLongitudes) + int64(lng)
idx = idx*int64(v.NumLatitudes()) + int64(lat)
idx = idx*int64(v.NumLongitudes()) + int64(lng)
return idx * int64(ElementSize)
}
// Val reads one cell as a float32.
func (d *File) Val(hour, level, variable, lat, lng int) float32 {
off := offset(hour, level, variable, lat, lng)
off := d.offset(hour, level, variable, lat, lng)
return math.Float32frombits(binary.LittleEndian.Uint32(d.mm[off : off+4]))
}
// SetVal writes one cell. Only valid on writable files.
func (d *File) SetVal(hour, level, variable, lat, lng int, val float32) {
off := offset(hour, level, variable, lat, lng)
off := d.offset(hour, level, variable, lat, lng)
binary.LittleEndian.PutUint32(d.mm[off:off+4], math.Float32bits(val))
}
// BlitGribData copies one decoded GRIB grid into the dataset, flipping the
// latitude axis from GRIB's north-to-south scan order to our south-to-north
// storage order. gribData must be 361*720 = 259920 float64 values.
// storage order.
func (d *File) BlitGribData(hourIdx, levelIdx, varIdx int, gribData []float64) error {
expected := NumLatitudes * NumLongitudes
v := d.variant
expected := v.NumLatitudes() * v.NumLongitudes()
if len(gribData) != expected {
return fmt.Errorf("grib data has %d values, expected %d", len(gribData), expected)
}
for lat := range NumLatitudes {
for lng := range NumLongitudes {
gribIdx := (360-lat)*NumLongitudes + lng
lats := v.NumLatitudes()
lngs := v.NumLongitudes()
for lat := range lats {
for lng := range lngs {
gribIdx := (lats-1-lat)*lngs + lng
d.SetVal(hourIdx, levelIdx, varIdx, lat, lng, float32(gribData[gribIdx]))
}
}

View file

@ -0,0 +1,68 @@
package gfs
import "fmt"
// Family is the dataset family ("gfs" or "gefs"). Variants of different
// families have different URL layouts but share the cube format.
type Family int
const (
FamilyGFS Family = iota
FamilyGEFS
)
func (f Family) String() string {
switch f {
case FamilyGEFS:
return "gefs"
default:
return "gfs"
}
}
// HasMember reports whether the family requires a member index in URLs.
func (f Family) HasMember() bool { return f == FamilyGEFS }
// GEFS variant constants.
//
// The 21-member ensemble is gec00 (control) + gep01..gep20 (perturbations).
// NOAA publishes more members today but 21 matches the historical Tawhiri
// configuration and is what the phase 2 spec calls for.
const GEFSMembers = 21
// GefsMemberName returns the file-name token for a GEFS member.
// member=0 → "gec00", member=1..20 → "gep01".."gep20".
func GefsMemberName(member int) string {
if member == 0 {
return "gec00"
}
return fmt.Sprintf("gep%02d", member)
}
// GEFS S3 mirror.
const GEFSS3BaseURL = "https://noaa-gefs-pds.s3.amazonaws.com"
// GefsGribURL returns the S3 URL for a GEFS primary GRIB file.
func GefsGribURL(date string, runHour, member, forecastStep int, resToken string) string {
return fmt.Sprintf("%s/gefs.%s/%02d/atmos/pgrb2ap5/%s.t%02dz.pgrb2a.%s.f%03d",
GEFSS3BaseURL, date, runHour, GefsMemberName(member), runHour, resToken, forecastStep)
}
// GefsGribURLB returns the S3 URL for a GEFS secondary GRIB file.
func GefsGribURLB(date string, runHour, member, forecastStep int, resToken string) string {
return fmt.Sprintf("%s/gefs.%s/%02d/atmos/pgrb2bp5/%s.t%02dz.pgrb2b.%s.f%03d",
GEFSS3BaseURL, date, runHour, GefsMemberName(member), runHour, resToken, forecastStep)
}
// GEFS variants — 0.5° resolution, 3-hour cadence, 192h horizon.
var GEFS0p50_3h = &Variant{
ID: "gefs-0p50-3h",
Family: FamilyGEFS,
ResToken: "0p50",
Resolution: 0.5,
HourStep: 3,
MaxHour: 192,
Pressures: GFS0p50_3h.Pressures,
PressuresPgrb2: GFS0p50_3h.PressuresPgrb2,
PressuresPgrb2b: GFS0p50_3h.PressuresPgrb2b,
}

View file

@ -0,0 +1,191 @@
package gfs
import "fmt"
// Variant describes one configuration of a NOAA dataset family (GFS or GEFS).
//
// The dataset cube is a 5-D float32 array with shape
// (NumHours, NumLevels, NumVariables, NumLatitudes, NumLongitudes) where
// NumVariables and ElementSize are fixed across all GFS variants but the
// other dimensions depend on the resolution and forecast cadence.
type Variant struct {
// ID is a stable identifier ("gfs-0p50-3h", "gefs-0p50-3h", ...).
ID string
// Family identifies the dataset family the variant belongs to.
Family Family
// Resolution token used in NOAA URLs ("0p50", "0p25").
ResToken string
// Grid step in degrees (0.5, 0.25). 180 / Resolution + 1 latitudes and
// 360 / Resolution longitudes.
Resolution float64
HourStep int // hours between forecast steps
MaxHour int // largest forecast hour (inclusive)
// Pressures lists every pressure level in dataset index order, descending.
Pressures []int
// PressuresPgrb2 / PressuresPgrb2b split the pressures between the two
// downloaded GRIB files. Their union must equal Pressures.
PressuresPgrb2 []int
PressuresPgrb2b []int
pressureIndex map[int]int
pressureLevelSet map[int]LevelSet
}
// NumHours returns MaxHour/HourStep + 1.
func (v *Variant) NumHours() int { return v.MaxHour/v.HourStep + 1 }
// NumLevels returns len(Pressures).
func (v *Variant) NumLevels() int { return len(v.Pressures) }
// NumLatitudes returns 180/Resolution + 1.
func (v *Variant) NumLatitudes() int { return int(180.0/v.Resolution) + 1 }
// NumLongitudes returns 360/Resolution.
func (v *Variant) NumLongitudes() int { return int(360.0 / v.Resolution) }
// DatasetSize returns the canonical file size in bytes.
func (v *Variant) DatasetSize() int64 {
return int64(v.NumHours()) * int64(v.NumLevels()) * int64(NumVariables) *
int64(v.NumLatitudes()) * int64(v.NumLongitudes()) * int64(ElementSize)
}
// Hours returns the full list of forecast hours [0, HourStep, ..., MaxHour].
func (v *Variant) Hours() []int {
out := make([]int, 0, v.NumHours())
for h := 0; h <= v.MaxHour; h += v.HourStep {
out = append(out, h)
}
return out
}
// HourIndex returns the dataset time index for an hour, or -1 if invalid.
func (v *Variant) HourIndex(hour int) int {
if hour < 0 || hour > v.MaxHour || hour%v.HourStep != 0 {
return -1
}
return hour / v.HourStep
}
// PressureIndex returns the dataset index for a pressure level in hPa,
// or -1 when the level is unknown to this variant.
func (v *Variant) PressureIndex(hPa int) int {
v.indexLazyInit()
if i, ok := v.pressureIndex[hPa]; ok {
return i
}
return -1
}
// PressureLevelSet returns the GRIB file set carrying a pressure level.
func (v *Variant) PressureLevelSet(hPa int) (LevelSet, bool) {
v.indexLazyInit()
ls, ok := v.pressureLevelSet[hPa]
return ls, ok
}
// VariableIndex maps a GRIB (category, number) pair to a dataset variable index.
func (v *Variant) VariableIndex(parameterCategory, parameterNumber int) int {
switch {
case parameterCategory == 3 && parameterNumber == 5:
return VarHeight
case parameterCategory == 2 && parameterNumber == 2:
return VarWindU
case parameterCategory == 2 && parameterNumber == 3:
return VarWindV
default:
return -1
}
}
// GribURL returns the S3 URL for the primary (pgrb2) GRIB file.
func (v *Variant) GribURL(date string, runHour, forecastStep int) string {
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.%s.f%03d",
S3BaseURL, date, runHour, runHour, v.ResToken, forecastStep)
}
// GribURLB returns the S3 URL for the secondary (pgrb2b) GRIB file.
func (v *Variant) GribURLB(date string, runHour, forecastStep int) string {
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2b.%s.f%03d",
S3BaseURL, date, runHour, runHour, v.ResToken, forecastStep)
}
func (v *Variant) indexLazyInit() {
if v.pressureIndex != nil {
return
}
v.pressureIndex = make(map[int]int, len(v.Pressures))
for i, p := range v.Pressures {
v.pressureIndex[p] = i
}
v.pressureLevelSet = make(map[int]LevelSet, len(v.Pressures))
for _, p := range v.PressuresPgrb2 {
v.pressureLevelSet[p] = LevelSetA
}
for _, p := range v.PressuresPgrb2b {
v.pressureLevelSet[p] = LevelSetB
}
}
// Standard variants -- these mirror what NOAA publishes today.
//
// GFS0p50_3h is the historical Tawhiri default: 0.5° resolution, 3-hour
// forecast cadence, 0..192h horizon, 47 pressure levels split across the
// primary and secondary GRIB files.
//
// GFS0p25_3h mirrors the same 3-hour cadence at 0.25° resolution (the
// horizon is larger in practice but we keep 192h for parity with 0p50).
//
// GFS0p25_1h targets the 1-hourly portion NOAA publishes out to 120h.
var (
GFS0p50_3h = &Variant{
ID: "gfs-0p50-3h",
ResToken: "0p50",
Resolution: 0.5,
HourStep: 3,
MaxHour: 192,
Pressures: []int{1000, 975, 950, 925, 900, 875, 850, 825, 800, 775, 750, 725, 700, 675, 650, 625, 600, 575, 550, 525, 500, 475, 450, 425, 400, 375, 350, 325, 300, 275, 250, 225, 200, 175, 150, 125, 100, 70, 50, 30, 20, 10, 7, 5, 3, 2, 1},
PressuresPgrb2: []int{10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 350, 400, 450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 925, 950, 975, 1000},
PressuresPgrb2b: []int{1, 2, 3, 5, 7, 125, 175, 225, 275, 325, 375, 425, 475, 525, 575, 625, 675, 725, 775, 825, 875},
}
GFS0p25_3h = &Variant{
ID: "gfs-0p25-3h",
ResToken: "0p25",
Resolution: 0.25,
HourStep: 3,
MaxHour: 192,
Pressures: GFS0p50_3h.Pressures,
PressuresPgrb2: GFS0p50_3h.PressuresPgrb2,
PressuresPgrb2b: GFS0p50_3h.PressuresPgrb2b,
}
GFS0p25_1h = &Variant{
ID: "gfs-0p25-1h",
ResToken: "0p25",
Resolution: 0.25,
HourStep: 1,
MaxHour: 120,
Pressures: GFS0p50_3h.Pressures,
PressuresPgrb2: GFS0p50_3h.PressuresPgrb2,
PressuresPgrb2b: GFS0p50_3h.PressuresPgrb2b,
}
)
// VariantByID returns one of the predefined variants by its ID.
func VariantByID(id string) (*Variant, error) {
switch id {
case GFS0p50_3h.ID:
return GFS0p50_3h, nil
case GFS0p25_3h.ID:
return GFS0p25_3h, nil
case GFS0p25_1h.ID:
return GFS0p25_1h, nil
case GEFS0p50_3h.ID:
return GEFS0p50_3h, nil
default:
return nil, fmt.Errorf("unknown variant %q", id)
}
}

View file

@ -10,45 +10,49 @@ import (
// Wind is a WindField backed by a GFS dataset file.
type Wind struct {
file *File
hourAxis numerics.Axis
latAxis numerics.Axis
lngAxis numerics.Axis
}
// NewWind returns a Wind backed by file.
// NewWind returns a Wind backed by file. The axes are constructed from the
// file's variant geometry.
func NewWind(file *File) *Wind {
return &Wind{file: file}
v := file.variant
return &Wind{
file: file,
hourAxis: numerics.Axis{
Left: 0,
Step: float64(v.HourStep),
N: v.NumHours(),
Name: "hour",
},
latAxis: numerics.Axis{
Left: LatStart,
Step: v.Resolution,
N: v.NumLatitudes(),
Name: "lat",
},
lngAxis: numerics.Axis{
Left: LonStart,
Step: v.Resolution,
N: v.NumLongitudes(),
Wrap: true,
Name: "lng",
},
}
}
// Epoch returns the forecast run time of the underlying file.
func (w *Wind) Epoch() time.Time { return w.file.Epoch }
// Source returns the source identifier "noaa-gfs-0p50".
func (w *Wind) Source() string { return "noaa-gfs-0p50" }
// Source returns the variant ID (e.g. "gfs-0p50-3h").
func (w *Wind) Source() string { return w.file.variant.ID }
// Close releases the underlying file's resources.
func (w *Wind) Close() error { return w.file.Close() }
// Grid axes for the GFS 0.5-degree dataset.
var (
hourAxis = numerics.Axis{
Left: 0,
Step: float64(HourStep),
N: NumHours,
Name: "hour",
}
latAxis = numerics.Axis{
Left: LatStart,
Step: Resolution,
N: NumLatitudes,
Name: "lat",
}
lngAxis = numerics.Axis{
Left: LonStart,
Step: Resolution,
N: NumLongitudes,
Wrap: true,
Name: "lng",
}
)
// Wind samples the field at the given UNIX time, geographic coordinate, and
// altitude. Vertical interpolation matches Tawhiri: locate the two pressure
// levels whose interpolated geopotential heights bracket alt, then linearly
@ -56,15 +60,15 @@ var (
func (w *Wind) Wind(t, lat, lng, alt float64) (weather.Sample, error) {
hours := (t - float64(w.file.Epoch.Unix())) / 3600.0
bh, err := hourAxis.Locate(hours)
bh, err := w.hourAxis.Locate(hours)
if err != nil {
return weather.Sample{}, err
}
bla, err := latAxis.Locate(lat)
bla, err := w.latAxis.Locate(lat)
if err != nil {
return weather.Sample{}, err
}
bln, err := lngAxis.Locate(lng)
bln, err := w.lngAxis.Locate(lng)
if err != nil {
return weather.Sample{}, err
}
@ -76,7 +80,7 @@ func (w *Wind) Wind(t, lat, lng, alt float64) (weather.Sample, error) {
}
}
levelIdx := numerics.Bisect(0, NumLevels-2, alt, func(level int) float64 {
levelIdx := numerics.Bisect(0, w.file.variant.NumLevels()-2, alt, func(level int) float64 {
return numerics.EvalTrilinear(bs, height(level))
})