feat: polish & windviz & deploy

This commit is contained in:
Anatoly Antonov 2026-05-30 06:29:39 +09:00
parent 81b8e763bd
commit 465ad00f7b
78 changed files with 20622 additions and 2154 deletions

View file

@ -1,284 +0,0 @@
// Package admin implements dataset-management HTTP endpoints used by the
// stratoflights operator console.
//
// Endpoints:
//
// GET /api/v1/admin/datasets list stored datasets
// POST /api/v1/admin/datasets trigger a download
// DELETE /api/v1/admin/datasets/{name} delete a stored dataset by filename
// GET /api/v1/admin/jobs list all jobs
// GET /api/v1/admin/jobs/{id} fetch one job
// DELETE /api/v1/admin/jobs/{id} cancel a running job
// GET /api/v1/admin/status service status summary
package admin
import (
"context"
"encoding/json"
"net/http"
"runtime"
"time"
"go.uber.org/zap"
"predictor-refactored/internal/api/httpjson"
"predictor-refactored/internal/datasets"
)
// Handler serves all /api/v1/admin/* endpoints.
type Handler struct {
mgr *datasets.Manager
start time.Time
log *zap.Logger
}
// New wires an admin handler.
func New(mgr *datasets.Manager, log *zap.Logger) *Handler {
if log == nil {
log = zap.NewNop()
}
return &Handler{mgr: mgr, start: time.Now().UTC(), log: log}
}
// Register installs admin routes on mux.
func (h *Handler) Register(mux *http.ServeMux) {
mux.HandleFunc("GET /api/v1/admin/datasets", h.listDatasets)
mux.HandleFunc("POST /api/v1/admin/datasets", h.triggerDownload)
mux.HandleFunc("DELETE /api/v1/admin/datasets/{name}", h.deleteDataset)
mux.HandleFunc("GET /api/v1/admin/jobs", h.listJobs)
mux.HandleFunc("GET /api/v1/admin/jobs/{id}", h.getJob)
mux.HandleFunc("DELETE /api/v1/admin/jobs/{id}", h.cancelJob)
mux.HandleFunc("GET /api/v1/admin/status", h.status)
}
// datasetDTO is the JSON shape of one stored dataset.
type datasetDTO struct {
Filename string `json:"filename"`
Epoch string `json:"epoch"`
Subset *subsetDTO `json:"subset,omitempty"`
Coverage *coverageDTO `json:"coverage,omitempty"`
Loaded bool `json:"loaded"`
}
type subsetDTO struct {
Region *datasets.Region `json:"region,omitempty"`
HourRange *datasets.HourRange `json:"hour_range,omitempty"`
Members []int `json:"members,omitempty"`
}
type coverageDTO struct {
Region datasets.Region `json:"region"`
StartTime string `json:"start_time"`
EndTime string `json:"end_time"`
}
// listDatasets handles GET /api/v1/admin/datasets.
func (h *Handler) listDatasets(w http.ResponseWriter, _ *http.Request) {
stored, err := h.mgr.ListEpochs()
if err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
loaded := h.mgr.LoadedDatasets()
loadedByName := make(map[string]datasets.LoadedDatasetInfo, len(loaded))
for _, ld := range loaded {
loadedByName[ld.ID.Filename()] = ld
}
out := struct {
Source string `json:"source"`
Datasets []datasetDTO `json:"datasets"`
}{Source: h.mgr.Source(), Datasets: make([]datasetDTO, 0, len(stored))}
for _, id := range stored {
dto := datasetDTO{
Filename: id.Filename(),
Epoch: id.Epoch.UTC().Format(time.RFC3339),
}
if !id.Subset.IsGlobal() {
dto.Subset = &subsetDTO{
Region: id.Subset.Region,
HourRange: id.Subset.HourRange,
Members: id.Subset.Members,
}
}
if ld, ok := loadedByName[id.Filename()]; ok {
dto.Loaded = true
dto.Coverage = &coverageDTO{
Region: ld.Coverage.Region,
StartTime: ld.Coverage.StartTime.UTC().Format(time.RFC3339),
EndTime: ld.Coverage.EndTime.UTC().Format(time.RFC3339),
}
}
out.Datasets = append(out.Datasets, dto)
}
writeJSON(w, http.StatusOK, out)
}
// triggerDownload handles POST /api/v1/admin/datasets.
//
// Body:
// {"latest": true} — refresh the latest global dataset
// {"epoch": "2026-03-28T06:00:00Z", "subset": {...}} — explicit dataset
func (h *Handler) triggerDownload(w http.ResponseWriter, r *http.Request) {
var body struct {
Epoch string `json:"epoch,omitempty"`
Latest bool `json:"latest,omitempty"`
Subset *datasets.SubsetSpec `json:"subset,omitempty"`
}
if err := json.NewDecoder(r.Body).Decode(&body); err != nil {
writeError(w, http.StatusBadRequest, "invalid body: "+err.Error())
return
}
if !body.Latest && body.Epoch == "" {
writeError(w, http.StatusBadRequest, "specify either epoch or latest=true")
return
}
if body.Latest {
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
defer cancel()
jobID, err := h.mgr.Refresh(ctx, 0)
if err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
writeJSON(w, http.StatusAccepted, map[string]string{"job_id": jobID})
return
}
epoch, err := time.Parse(time.RFC3339, body.Epoch)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid epoch: "+err.Error())
return
}
id := datasets.DatasetID{Epoch: epoch.UTC()}
if body.Subset != nil {
id.Subset = *body.Subset
}
jobID := h.mgr.Download(id)
writeJSON(w, http.StatusAccepted, map[string]string{"job_id": jobID})
}
// deleteDataset handles DELETE /api/v1/admin/datasets/{name}.
//
// {name} is the dataset filename (DatasetID.Filename()) as returned by GET.
func (h *Handler) deleteDataset(w http.ResponseWriter, r *http.Request) {
name := r.PathValue("name")
stored, err := h.mgr.ListEpochs()
if err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
for _, id := range stored {
if id.Filename() == name {
if err := h.mgr.Remove(id); err != nil {
writeError(w, http.StatusInternalServerError, err.Error())
return
}
w.WriteHeader(http.StatusNoContent)
return
}
}
writeError(w, http.StatusNotFound, "dataset not found")
}
// listJobs handles GET /api/v1/admin/jobs.
func (h *Handler) listJobs(w http.ResponseWriter, _ *http.Request) {
jobs := h.mgr.ListJobs()
out := make([]jobDTO, 0, len(jobs))
for _, j := range jobs {
out = append(out, toDTO(j))
}
writeJSON(w, http.StatusOK, out)
}
// getJob handles GET /api/v1/admin/jobs/{id}.
func (h *Handler) getJob(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
job, ok := h.mgr.GetJob(id)
if !ok {
writeError(w, http.StatusNotFound, "job not found")
return
}
writeJSON(w, http.StatusOK, toDTO(job))
}
// cancelJob handles DELETE /api/v1/admin/jobs/{id}.
func (h *Handler) cancelJob(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
if !h.mgr.CancelJob(id) {
writeError(w, http.StatusConflict, "job not found or already terminal")
return
}
w.WriteHeader(http.StatusNoContent)
}
// status handles GET /api/v1/admin/status — a consolidated dashboard view.
func (h *Handler) status(w http.ResponseWriter, _ *http.Request) {
jobs := h.mgr.ListJobs()
stored, _ := h.mgr.ListEpochs()
loaded := h.mgr.LoadedDatasets()
counts := map[string]int{}
for _, j := range jobs {
counts[string(j.Status)]++
}
var mem runtime.MemStats
runtime.ReadMemStats(&mem)
resp := struct {
Source string `json:"source"`
Uptime string `json:"uptime"`
Goroutines int `json:"goroutines"`
MemoryMB uint64 `json:"memory_mb"`
JobsByStatus map[string]int `json:"jobs_by_status"`
Stored int `json:"stored_datasets"`
Loaded int `json:"loaded_datasets"`
}{
Source: h.mgr.Source(),
Uptime: time.Since(h.start).Round(time.Second).String(),
Goroutines: runtime.NumGoroutine(),
MemoryMB: mem.Alloc / 1024 / 1024,
JobsByStatus: counts,
Stored: len(stored),
Loaded: len(loaded),
}
writeJSON(w, http.StatusOK, resp)
}
type jobDTO struct {
ID string `json:"id"`
Source string `json:"source"`
Dataset string `json:"dataset"`
Epoch string `json:"epoch"`
Status string `json:"status"`
StartedAt string `json:"started_at"`
EndedAt string `json:"ended_at,omitempty"`
Err string `json:"error,omitempty"`
Total int `json:"total_units"`
Done int `json:"done_units"`
Bytes int64 `json:"bytes"`
}
func toDTO(j datasets.JobInfo) jobDTO {
dto := jobDTO{
ID: j.ID,
Source: j.Source,
Dataset: j.Dataset.Filename(),
Epoch: j.Dataset.Epoch.UTC().Format(time.RFC3339),
Status: string(j.Status),
StartedAt: j.StartedAt.UTC().Format(time.RFC3339),
Err: j.Err,
Total: j.Total,
Done: j.Done,
Bytes: j.Bytes,
}
if j.EndedAt != nil {
dto.EndedAt = j.EndedAt.UTC().Format(time.RFC3339)
}
return dto
}
var writeJSON = httpjson.Write
var writeError = httpjson.Error

View file

@ -1,63 +0,0 @@
package async
import (
"encoding/json"
"net/http"
"predictor-refactored/internal/api/httpjson"
"predictor-refactored/internal/api/v2"
)
// Handler implements the /api/v1/predictions{,/{id}} endpoints.
type Handler struct {
mgr *Manager
}
// NewHandler wires a handler.
func NewHandler(mgr *Manager) *Handler { return &Handler{mgr: mgr} }
// Register installs the async routes on mux.
func (h *Handler) Register(mux *http.ServeMux) {
mux.HandleFunc("POST /api/v1/predictions", h.create)
mux.HandleFunc("GET /api/v1/predictions/{id}", h.get)
mux.HandleFunc("DELETE /api/v1/predictions/{id}", h.cancel)
}
func (h *Handler) create(w http.ResponseWriter, r *http.Request) {
var req v2.PredictionRequest
dec := json.NewDecoder(r.Body)
dec.DisallowUnknownFields()
if err := dec.Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid body: "+err.Error())
return
}
info, accepted := h.mgr.Enqueue(req)
if !accepted {
writeJSON(w, http.StatusServiceUnavailable, info)
return
}
w.Header().Set("Location", "/api/v1/predictions/"+info.ID)
writeJSON(w, http.StatusAccepted, info)
}
func (h *Handler) get(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
info, ok := h.mgr.Get(id)
if !ok {
writeError(w, http.StatusNotFound, "prediction job not found")
return
}
writeJSON(w, http.StatusOK, info)
}
func (h *Handler) cancel(w http.ResponseWriter, r *http.Request) {
id := r.PathValue("id")
if !h.mgr.Cancel(id) {
writeError(w, http.StatusConflict, "job not found or already terminal")
return
}
w.WriteHeader(http.StatusNoContent)
}
var writeJSON = httpjson.Write
var writeError = httpjson.Error

View file

@ -1,11 +1,15 @@
// Package async implements the asynchronous prediction endpoints
// (/api/v1/predictions{,/{id}}) and the worker pool that executes them.
// Package async runs profile-driven predictions on a bounded worker pool and
// retains their results in memory for a configurable TTL. It is the engine
// behind the asynchronous prediction endpoints; the HTTP surface itself is
// the ogen-generated server in the parent package.
//
// Each enqueued request is assigned a job ID; the result is held in
// memory for a configurable TTL after completion.
// The package is decoupled from the request/response wire types: a RunFunc is
// injected at construction, so this file imports only the generated API types
// it stores and returns.
package async
import (
"fmt"
"sync"
"sync/atomic"
"time"
@ -13,12 +17,13 @@ import (
"github.com/google/uuid"
"go.uber.org/zap"
"predictor-refactored/internal/api/v2"
"predictor-refactored/internal/datasets"
"predictor-refactored/internal/elevation"
"predictor-refactored/internal/metrics"
apirest "predictor-refactored/pkg/rest"
)
// RunFunc executes one prediction synchronously.
type RunFunc func(req *apirest.PredictionV2Request) (*apirest.PredictionV2Response, error)
// Status is the lifecycle state of a prediction job.
type Status string
@ -30,20 +35,20 @@ const (
StatusCancelled Status = "cancelled"
)
// JobInfo is the externally-visible snapshot of one prediction job.
// JobInfo is a snapshot of one prediction job.
type JobInfo struct {
ID string `json:"id"`
Status Status `json:"status"`
CreatedAt time.Time `json:"created_at"`
StartedAt *time.Time `json:"started_at,omitempty"`
CompletedAt *time.Time `json:"completed_at,omitempty"`
Error string `json:"error,omitempty"`
Result *v2.PredictionResponse `json:"result,omitempty"`
ID string
Status Status
CreatedAt time.Time
StartedAt *time.Time
CompletedAt *time.Time
Error string
Result *apirest.PredictionV2Response
}
type job struct {
id string
req v2.PredictionRequest
req *apirest.PredictionV2Request
createdAt time.Time
mu sync.Mutex
@ -51,19 +56,15 @@ type job struct {
startedAt time.Time
completedAt time.Time
errStr string
result *v2.PredictionResponse
cancel chan struct{}
result *apirest.PredictionV2Response
}
func (j *job) snapshot() JobInfo {
j.mu.Lock()
defer j.mu.Unlock()
info := JobInfo{
ID: j.id,
Status: j.status,
CreatedAt: j.createdAt,
Error: j.errStr,
Result: j.result,
ID: j.id, Status: j.status, CreatedAt: j.createdAt,
Error: j.errStr, Result: j.result,
}
if !j.startedAt.IsZero() {
t := j.startedAt
@ -76,16 +77,14 @@ func (j *job) snapshot() JobInfo {
return info
}
// Manager runs a fixed pool of workers to execute prediction jobs and
// retains their results for the configured TTL.
// Manager runs a fixed pool of workers and retains job results for a TTL.
type Manager struct {
mgr *datasets.Manager
elev *elevation.Dataset
run RunFunc
metrics metrics.Sink
log *zap.Logger
queue chan *job
ttl time.Duration
queue chan *job
ttl time.Duration
jobsMu sync.RWMutex
jobs map[string]*job
@ -97,16 +96,14 @@ type Manager struct {
// Config controls Manager construction.
type Config struct {
// Workers is the maximum concurrent prediction executions.
Workers int
// QueueSize bounds the number of jobs waiting to start.
QueueSize int
// ResultTTL is how long completed/failed jobs are retained in memory.
ResultTTL time.Duration
Workers int // max concurrent executions
QueueSize int // pending-queue bound
ResultTTL time.Duration // retention of terminal jobs
}
// New constructs a Manager with the given config and starts the workers.
func New(cfg Config, mgr *datasets.Manager, elev *elevation.Dataset, sink metrics.Sink, log *zap.Logger) *Manager {
// New constructs a Manager and starts its workers. run executes one
// prediction; sink and log may be nil.
func New(cfg Config, run RunFunc, sink metrics.Sink, log *zap.Logger) *Manager {
if cfg.Workers <= 0 {
cfg.Workers = 4
}
@ -123,7 +120,7 @@ func New(cfg Config, mgr *datasets.Manager, elev *elevation.Dataset, sink metric
log = zap.NewNop()
}
m := &Manager{
mgr: mgr, elev: elev, metrics: sink, log: log,
run: run, metrics: sink, log: log,
queue: make(chan *job, cfg.QueueSize),
jobs: make(map[string]*job),
ttl: cfg.ResultTTL,
@ -138,15 +135,14 @@ func New(cfg Config, mgr *datasets.Manager, elev *elevation.Dataset, sink metric
return m
}
// Enqueue creates a new job from req and returns its snapshot.
// Returns false when the queue is full.
func (m *Manager) Enqueue(req v2.PredictionRequest) (JobInfo, bool) {
// Enqueue creates a job from req and returns its snapshot. The bool is false
// when the queue is full (the returned job is marked failed).
func (m *Manager) Enqueue(req *apirest.PredictionV2Request) (JobInfo, bool) {
j := &job{
id: uuid.New().String(),
req: req,
createdAt: time.Now().UTC(),
status: StatusPending,
cancel: make(chan struct{}),
}
m.jobsMu.Lock()
m.jobs[j.id] = j
@ -156,7 +152,6 @@ func (m *Manager) Enqueue(req v2.PredictionRequest) (JobInfo, bool) {
case m.queue <- j:
return j.snapshot(), true
default:
// Queue full — mark the job failed and return it.
j.mu.Lock()
j.status = StatusFailed
j.errStr = "prediction queue full"
@ -177,8 +172,11 @@ func (m *Manager) Get(id string) (JobInfo, bool) {
return j.snapshot(), true
}
// Cancel marks a not-yet-started job as cancelled. Returns false when the
// job is unknown or already terminal.
// Cancel marks a still-queued job cancelled. Returns false when the job is
// unknown or already running/terminal — a running prediction cannot be
// interrupted (the worker would otherwise overwrite the cancelled status with
// its result), so callers get an honest "too late" rather than a 204 that the
// worker silently undoes.
func (m *Manager) Cancel(id string) bool {
m.jobsMu.RLock()
j, ok := m.jobs[id]
@ -187,22 +185,19 @@ func (m *Manager) Cancel(id string) bool {
return false
}
j.mu.Lock()
terminal := j.status == StatusComplete || j.status == StatusFailed || j.status == StatusCancelled
if terminal {
j.mu.Unlock()
defer j.mu.Unlock()
if j.status != StatusPending {
return false
}
j.status = StatusCancelled
j.completedAt = time.Now().UTC()
j.mu.Unlock()
close(j.cancel)
return true
}
// Inflight returns the count of running jobs.
// Inflight returns the number of running jobs.
func (m *Manager) Inflight() int64 { return m.inflight.Load() }
// Close shuts down workers and the evictor.
// Close stops the workers and the evictor.
func (m *Manager) Close() {
close(m.closed)
close(m.queue)
@ -212,41 +207,49 @@ func (m *Manager) Close() {
func (m *Manager) worker() {
defer m.wg.Done()
for j := range m.queue {
// Check cancellation before starting.
j.mu.Lock()
cancelled := j.status == StatusCancelled
if !cancelled {
j.status = StatusRunning
j.startedAt = time.Now().UTC()
}
j.mu.Unlock()
if cancelled {
continue
}
m.inflight.Add(1)
j.mu.Lock()
j.status = StatusRunning
j.startedAt = time.Now().UTC()
j.mu.Unlock()
resp, err := v2.Run(m.mgr, m.elev, j.req)
j.mu.Lock()
j.completedAt = time.Now().UTC()
if err != nil {
j.status = StatusFailed
j.errStr = err.Error()
} else {
j.status = StatusComplete
j.result = resp
}
j.mu.Unlock()
m.inflight.Add(-1)
if err == nil {
m.metrics.Prediction("async", j.completedAt.Sub(j.startedAt), nil)
} else {
m.metrics.Prediction("async", j.completedAt.Sub(j.startedAt), err)
}
m.execute(j)
}
}
// execute runs one job, recovering from a panic in the injected RunFunc so a
// single bad prediction can't leak the inflight counter or kill the worker.
func (m *Manager) execute(j *job) {
m.inflight.Add(1)
defer m.inflight.Add(-1)
resp, err := func() (resp *apirest.PredictionV2Response, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("prediction panicked: %v", r)
}
}()
return m.run(j.req)
}()
j.mu.Lock()
j.completedAt = time.Now().UTC()
if err != nil {
j.status = StatusFailed
j.errStr = err.Error()
} else {
j.status = StatusComplete
j.result = resp
}
dur := j.completedAt.Sub(j.startedAt)
j.mu.Unlock()
m.metrics.Prediction("async", dur, err)
}
func (m *Manager) evictor() {
defer m.wg.Done()
ticker := time.NewTicker(m.ttl / 4)

189
internal/api/datasets.go Normal file
View file

@ -0,0 +1,189 @@
package api
import (
"context"
"net/http"
"runtime"
"time"
"predictor-refactored/internal/datasets"
apirest "predictor-refactored/pkg/rest"
)
// ListDatasets implements GET /api/v1/admin/datasets.
func (h *Handler) ListDatasets(_ context.Context) (*apirest.DatasetList, error) {
stored, err := h.mgr.ListEpochs()
if err != nil {
return nil, apiError(http.StatusInternalServerError, err.Error())
}
loaded := make(map[string]datasets.LoadedDatasetInfo)
for _, ld := range h.mgr.LoadedDatasets() {
loaded[ld.ID.Filename()] = ld
}
out := &apirest.DatasetList{Source: h.mgr.Source(), Datasets: make([]apirest.DatasetEntry, 0, len(stored))}
for _, id := range stored {
entry := apirest.DatasetEntry{
Filename: id.Filename(),
Epoch: id.Epoch.UTC(),
}
if !id.Subset.IsGlobal() {
entry.Subset = apirest.NewOptSubsetSpec(subsetToAPI(id.Subset))
}
if ld, ok := loaded[id.Filename()]; ok {
entry.Loaded = true
entry.Coverage = apirest.NewOptCoverage(coverageToAPI(ld.Coverage))
}
out.Datasets = append(out.Datasets, entry)
}
return out, nil
}
// TriggerDatasetDownload implements POST /api/v1/admin/datasets.
func (h *Handler) TriggerDatasetDownload(ctx context.Context, req *apirest.DownloadRequest) (*apirest.DownloadAccepted, error) {
if req.Latest.Or(false) {
dctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
jobID, err := h.mgr.Refresh(dctx, 0)
if err != nil {
return nil, apiError(http.StatusInternalServerError, err.Error())
}
return &apirest.DownloadAccepted{JobID: jobID}, nil
}
epoch, ok := req.Epoch.Get()
if !ok {
return nil, apiError(http.StatusBadRequest, "specify either epoch or latest=true")
}
id := datasets.DatasetID{Epoch: epoch.UTC()}
if s, ok := req.Subset.Get(); ok {
id.Subset = subsetFromAPI(s)
}
return &apirest.DownloadAccepted{JobID: h.mgr.Download(id)}, nil
}
// DeleteDataset implements DELETE /api/v1/admin/datasets/{name}.
func (h *Handler) DeleteDataset(_ context.Context, params apirest.DeleteDatasetParams) error {
stored, err := h.mgr.ListEpochs()
if err != nil {
return apiError(http.StatusInternalServerError, err.Error())
}
for _, id := range stored {
if id.Filename() == params.Name {
if err := h.mgr.Remove(id); err != nil {
return apiError(http.StatusInternalServerError, err.Error())
}
return nil
}
}
return apiError(http.StatusNotFound, "dataset not found")
}
// ListDatasetJobs implements GET /api/v1/admin/jobs.
func (h *Handler) ListDatasetJobs(_ context.Context) ([]apirest.DownloadJob, error) {
jobs := h.mgr.ListJobs()
out := make([]apirest.DownloadJob, 0, len(jobs))
for _, j := range jobs {
out = append(out, downloadJobToAPI(j))
}
return out, nil
}
// GetDatasetJob implements GET /api/v1/admin/jobs/{id}.
func (h *Handler) GetDatasetJob(_ context.Context, params apirest.GetDatasetJobParams) (*apirest.DownloadJob, error) {
j, ok := h.mgr.GetJob(params.ID)
if !ok {
return nil, apiError(http.StatusNotFound, "job not found")
}
dto := downloadJobToAPI(j)
return &dto, nil
}
// CancelDatasetJob implements DELETE /api/v1/admin/jobs/{id}.
func (h *Handler) CancelDatasetJob(_ context.Context, params apirest.CancelDatasetJobParams) error {
if !h.mgr.CancelJob(params.ID) {
return apiError(http.StatusConflict, "job not found or already terminal")
}
return nil
}
// GetServiceStatus implements GET /api/v1/admin/status.
func (h *Handler) GetServiceStatus(_ context.Context) (*apirest.StatusResponse, error) {
jobs := h.mgr.ListJobs()
stored, _ := h.mgr.ListEpochs()
loaded := h.mgr.LoadedDatasets()
byStatus := apirest.StatusResponseJobsByStatus{}
for _, j := range jobs {
byStatus[string(j.Status)]++
}
var mem runtime.MemStats
runtime.ReadMemStats(&mem)
return &apirest.StatusResponse{
Source: h.mgr.Source(),
Uptime: time.Since(h.started).Round(time.Second).String(),
Goroutines: runtime.NumGoroutine(),
MemoryMB: int64(mem.Alloc / 1024 / 1024),
JobsByStatus: byStatus,
StoredDatasets: len(stored),
LoadedDatasets: len(loaded),
}, nil
}
// --- dataset mapping helpers ----------------------------------------------
func downloadJobToAPI(j datasets.JobInfo) apirest.DownloadJob {
dto := apirest.DownloadJob{
ID: j.ID,
Source: j.Source,
Dataset: j.Dataset.Filename(),
Epoch: j.Dataset.Epoch.UTC(),
Status: apirest.DownloadJobStatus(j.Status),
StartedAt: j.StartedAt.UTC(),
TotalUnits: j.Total,
DoneUnits: j.Done,
Bytes: j.Bytes,
}
if j.EndedAt != nil {
dto.EndedAt = apirest.NewOptDateTime(j.EndedAt.UTC())
}
if j.Err != "" {
dto.Error = apirest.NewOptString(j.Err)
}
return dto
}
func subsetToAPI(s datasets.SubsetSpec) apirest.SubsetSpec {
out := apirest.SubsetSpec{Members: s.Members}
if s.Region != nil {
out.Region = apirest.NewOptRegion(regionToAPI(*s.Region))
}
if s.HourRange != nil {
out.HourRange = apirest.NewOptHourRange(apirest.HourRange{MinHour: s.HourRange.MinHour, MaxHour: s.HourRange.MaxHour})
}
return out
}
func subsetFromAPI(s apirest.SubsetSpec) datasets.SubsetSpec {
out := datasets.SubsetSpec{Members: s.Members}
if r, ok := s.Region.Get(); ok {
out.Region = &datasets.Region{MinLat: r.MinLat, MaxLat: r.MaxLat, MinLng: r.MinLng, MaxLng: r.MaxLng}
}
if hr, ok := s.HourRange.Get(); ok {
out.HourRange = &datasets.HourRange{MinHour: hr.MinHour, MaxHour: hr.MaxHour}
}
return out
}
func regionToAPI(r datasets.Region) apirest.Region {
return apirest.Region{MinLat: r.MinLat, MaxLat: r.MaxLat, MinLng: r.MinLng, MaxLng: r.MaxLng}
}
func coverageToAPI(c datasets.Coverage) apirest.Coverage {
return apirest.Coverage{
Region: regionToAPI(c.Region),
StartTime: c.StartTime.UTC(),
EndTime: c.EndTime.UTC(),
}
}

48
internal/api/docs/docs.go Normal file
View file

@ -0,0 +1,48 @@
// Package docs serves the human-facing API documentation: the OpenAPI
// document and a ReDoc rendering of it. The spec is embedded in the binary
// (see package apispec) so the documentation needs no external files or a
// separate server.
package docs
import (
"net/http"
apispec "predictor-refactored/api"
)
// redocHTML renders the embedded spec with ReDoc loaded from a CDN.
const redocHTML = `<!DOCTYPE html>
<html>
<head>
<title>stratoflights-predictor API</title>
<meta charset="utf-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1">
<style>body { margin: 0; padding: 0; }</style>
</head>
<body>
<redoc spec-url="/openapi.yaml"></redoc>
<script src="https://cdn.redoc.ly/redoc/latest/bundles/redoc.standalone.js"></script>
</body>
</html>`
// Handler serves the documentation endpoints.
type Handler struct{}
// New returns a docs Handler.
func New() *Handler { return &Handler{} }
// Register installs GET /docs and GET /openapi.yaml on mux.
func (h *Handler) Register(mux *http.ServeMux) {
mux.HandleFunc("GET /openapi.yaml", h.spec)
mux.HandleFunc("GET /docs", h.redoc)
}
func (h *Handler) spec(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/yaml")
_, _ = w.Write(apispec.Spec)
}
func (h *Handler) redoc(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
_, _ = w.Write([]byte(redocHTML))
}

70
internal/api/handler.go Normal file
View file

@ -0,0 +1,70 @@
package api
import (
"context"
"errors"
"net/http"
"time"
"go.uber.org/zap"
"predictor-refactored/internal/api/async"
"predictor-refactored/internal/datasets"
"predictor-refactored/internal/elevation"
"predictor-refactored/internal/engine"
"predictor-refactored/internal/metrics"
"predictor-refactored/internal/windviz"
apirest "predictor-refactored/pkg/rest"
)
// Handler implements the ogen-generated apirest.Handler interface for every
// operation in the OpenAPI spec. Operation methods are grouped by concern
// across prediction.go, datasets.go, and wind.go.
type Handler struct {
mgr *datasets.Manager
elev *elevation.Dataset
async *async.Manager
metrics metrics.Sink
cache *windviz.Cache
started time.Time
log *zap.Logger
}
var _ apirest.Handler = (*Handler)(nil)
// terrain returns the elevation dataset as an engine.TerrainProvider, or an
// untyped nil interface when no elevation dataset is loaded. Returning the
// concrete nil *elevation.Dataset directly would produce a non-nil interface
// wrapping a nil pointer, which then panics on first use — so the nil check
// must happen here, on the concrete type.
func (h *Handler) terrain() engine.TerrainProvider {
if h.elev == nil {
return nil
}
return h.elev
}
// NewError converts an error returned by a handler into the spec's default
// error response. Handlers return *apirest.DefaultErrorStatusCode (via the
// apiError helper) to control the status code; anything else is a 500.
func (h *Handler) NewError(_ context.Context, err error) *apirest.DefaultErrorStatusCode {
var coded *apirest.DefaultErrorStatusCode
if errors.As(err, &coded) {
return coded
}
h.log.Error("unhandled handler error", zap.Error(err))
return apiError(http.StatusInternalServerError, err.Error())
}
// apiError builds a coded error response carrying an HTTP status.
func apiError(status int, description string) *apirest.DefaultErrorStatusCode {
return &apirest.DefaultErrorStatusCode{
StatusCode: status,
Response: apirest.Error{
Error: apirest.ErrorError{
Type: http.StatusText(status),
Description: description,
},
},
}
}

View file

@ -1,27 +0,0 @@
// Package httpjson holds the tiny JSON response helpers shared across
// the admin, v2, and async handlers.
package httpjson
import (
"encoding/json"
"net/http"
)
// Write writes body as JSON with the given status code.
func Write(w http.ResponseWriter, status int, body any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(body)
}
// Error writes a standard error JSON body with the given status code.
//
// Shape: {"error": {"type": "...", "description": "..."}}
func Error(w http.ResponseWriter, status int, description string) {
Write(w, status, map[string]any{
"error": map[string]string{
"type": http.StatusText(status),
"description": description,
},
})
}

217
internal/api/mapping.go Normal file
View file

@ -0,0 +1,217 @@
package api
import (
"fmt"
"time"
"predictor-refactored/internal/api/async"
"predictor-refactored/internal/engine"
apirest "predictor-refactored/pkg/rest"
)
// normalizeLng folds a longitude into [0, 360) for internal use.
func normalizeLng(lng float64) float64 {
if lng < 0 {
return lng + 360
}
return lng
}
// signedLng converts an internal [0, 360) longitude back to [-180, 180).
func signedLng(lng float64) float64 {
if lng > 180 {
return lng - 360
}
return lng
}
// buildProfile translates an API prediction request into an engine profile
// using the engine's model/constraint registry.
// maxProfileStages bounds the propagator chain length to keep a single
// request's work bounded.
const maxProfileStages = 32
func buildProfile(req *apirest.PredictionV2Request, deps engine.BuildDeps) (engine.Profile, error) {
if len(req.Profile) == 0 {
return engine.Profile{}, fmt.Errorf("profile must contain at least one stage")
}
if len(req.Profile) > maxProfileStages {
return engine.Profile{}, fmt.Errorf("profile has %d stages; maximum is %d", len(req.Profile), maxProfileStages)
}
step := 60.0
tol := 0.01
if o, ok := req.Options.Get(); ok {
step = o.StepSeconds.Or(step)
tol = o.Tolerance.Or(tol)
}
if step <= 0 || step > 3600 {
return engine.Profile{}, fmt.Errorf("options.step_seconds must be in (0, 3600], got %g", step)
}
if tol <= 0 || tol >= 1 {
return engine.Profile{}, fmt.Errorf("options.tolerance must be in (0, 1), got %g", tol)
}
dir := engine.Forward
if req.Direction.Or(apirest.PredictionV2RequestDirectionForward) == apirest.PredictionV2RequestDirectionReverse {
dir = engine.Reverse
}
props := make([]*engine.Propagator, len(req.Profile))
for i, stage := range req.Profile {
if stage.Name == "" {
return engine.Profile{}, fmt.Errorf("stage %d: name is required", i)
}
built, err := engine.BuildModel(toEngineModelSpec(stage.Model), deps)
if err != nil {
return engine.Profile{}, fmt.Errorf("stage %q model: %w", stage.Name, err)
}
constraints, err := toEngineConstraints(stage.Constraints, deps)
if err != nil {
return engine.Profile{}, fmt.Errorf("stage %q: %w", stage.Name, err)
}
props[i] = &engine.Propagator{
Name: stage.Name,
Step: step,
Model: built.Model,
BuildModel: built.Build,
Constraints: constraints,
Tolerance: tol,
}
}
for i, stage := range req.Profile {
idx, ok := stage.FallbackIndex.Get()
if !ok {
continue
}
if idx < 0 || idx >= len(props) {
return engine.Profile{}, fmt.Errorf("stage %q: fallback_index %d out of range", stage.Name, idx)
}
props[i].Fallback = props[idx]
}
globals, err := toEngineConstraints(req.Globals, deps)
if err != nil {
return engine.Profile{}, fmt.Errorf("globals: %w", err)
}
return engine.Profile{Stages: props, Direction: dir, Globals: globals}, nil
}
func toEngineModelSpec(m apirest.ModelSpec) engine.ModelSpec {
out := engine.ModelSpec{
Type: string(m.Type),
Rate: m.Rate.Or(0),
SeaLevelRate: m.SeaLevelRate.Or(0),
IncludeWind: m.IncludeWind.Or(false),
}
for _, s := range m.Segments {
out.Segments = append(out.Segments, engine.PiecewiseSegmentSpec{
Until: s.Until,
Rate: s.Rate,
Reference: string(s.Reference.Or(apirest.PiecewiseSegmentReferenceAbsolute)),
})
}
return out
}
func toEngineConstraints(specs []apirest.ConstraintSpec, deps engine.BuildDeps) ([]engine.Constraint, error) {
out := make([]engine.Constraint, 0, len(specs))
for i, s := range specs {
c, err := engine.BuildConstraint(toEngineConstraintSpec(s), deps)
if err != nil {
return nil, fmt.Errorf("constraint[%d]: %w", i, err)
}
out = append(out, c)
}
return out, nil
}
func toEngineConstraintSpec(c apirest.ConstraintSpec) engine.ConstraintSpec {
spec := engine.ConstraintSpec{
Type: string(c.Type),
Op: string(c.Op.Or("")),
Limit: c.Limit.Or(0),
Action: string(c.Action.Or(apirest.ConstraintSpecActionStop)),
Mode: string(c.Mode.Or("")),
Label: c.Label.Or(""),
}
for _, v := range c.Vertices {
spec.Vertices = append(spec.Vertices, engine.PolygonVertex{Lat: v.Lat, Lng: v.Lng})
}
return spec
}
// stageResultToAPI maps one engine stage result to the API representation.
func stageResultToAPI(r engine.Result) apirest.StageResult {
out := apirest.StageResult{
Name: r.Propagator,
Outcome: apirest.StageResultOutcome(r.Outcome.String()),
Events: eventsToAPI(r.Events),
}
if r.Constraint != nil {
out.Constraint = apirest.NewOptString(r.ConstraintName)
out.Termination = apirest.NewOptTerminationInfo(apirest.TerminationInfo{
ViolationTime: time.Unix(int64(r.ViolationTime), 0).UTC(),
ViolationState: geoStateToAPI(r.ViolationState),
RefinedTime: time.Unix(int64(r.RefinedTime), 0).UTC(),
RefinedState: geoStateToAPI(r.RefinedState),
})
}
n := r.Path.Len()
out.Trajectory = make([]apirest.TrajectoryPoint, n)
for i := range n {
t, p := r.Path.At(i)
out.Trajectory[i] = apirest.TrajectoryPoint{
Time: time.Unix(int64(t), 0).UTC(),
Latitude: p.Lat,
Longitude: signedLng(p.Lng),
Altitude: p.Altitude,
}
}
return out
}
func geoStateToAPI(s engine.State) apirest.GeoState {
return apirest.GeoState{Lat: s.Lat, Lng: signedLng(s.Lng), Altitude: s.Altitude}
}
func eventsToAPI(in []engine.EventSummary) []apirest.EventSummary {
if len(in) == 0 {
return nil
}
out := make([]apirest.EventSummary, 0, len(in))
for _, e := range in {
out = append(out, apirest.EventSummary{
Type: e.Type,
Count: e.Count,
FirstTime: apirest.NewOptFloat64(e.FirstTime),
LastTime: apirest.NewOptFloat64(e.LastTime),
FirstState: apirest.NewOptGeoState(geoStateToAPI(e.FirstState)),
LastState: apirest.NewOptGeoState(geoStateToAPI(e.LastState)),
Message: apirest.NewOptString(e.Message),
})
}
return out
}
// asyncJobToAPI maps an async job snapshot to the API PredictionJob.
func asyncJobToAPI(info async.JobInfo) *apirest.PredictionJob {
job := &apirest.PredictionJob{
ID: info.ID,
Status: apirest.PredictionJobStatus(info.Status),
CreatedAt: info.CreatedAt,
}
if info.StartedAt != nil {
job.StartedAt = apirest.NewOptDateTime(*info.StartedAt)
}
if info.CompletedAt != nil {
job.CompletedAt = apirest.NewOptDateTime(*info.CompletedAt)
}
if info.Error != "" {
job.Error = apirest.NewOptString(info.Error)
}
if info.Result != nil {
job.Result = apirest.NewOptPredictionV2Response(*info.Result)
}
return job
}

View file

@ -1,51 +1,34 @@
// Package middleware contains HTTP and ogen middleware used by the API layer.
package middleware
import (
"net/http"
"time"
"github.com/ogen-go/ogen/middleware"
"go.uber.org/zap"
)
// OgenLogging is an ogen middleware that logs request duration and outcome.
// statusCoder is implemented by ogen's *...StatusCode error wrappers.
type statusCoder interface{ GetStatusCode() int }
// OgenLogging is an ogen middleware that logs each operation's duration and
// outcome. Handler errors carrying a 4xx/5xx-class status are logged at the
// appropriate level: client errors (and expected 503s during startup) at
// warn without a stacktrace, server errors at error.
func OgenLogging(log *zap.Logger) middleware.Middleware {
return func(req middleware.Request, next func(req middleware.Request) (middleware.Response, error)) (middleware.Response, error) {
lg := log.With(zap.String("op", req.OperationID))
start := time.Now()
resp, err := next(req)
dur := time.Since(start)
if err != nil {
lg.Error("request failed", zap.Duration("duration", dur), zap.Error(err))
lg := log.With(zap.String("operation", req.OperationID), zap.Duration("duration", time.Since(start)))
if err == nil {
lg.Info("request completed")
return resp, err
}
if sc, ok := err.(statusCoder); ok && sc.GetStatusCode() < 500 {
lg.Warn("request rejected", zap.Int("status", sc.GetStatusCode()), zap.NamedError("reason", err))
} else {
lg.Info("request completed", zap.Duration("duration", dur))
lg.Error("request failed", zap.Error(err))
}
return resp, err
}
}
// statusRecorder captures the response status for HTTPLogging.
type statusRecorder struct {
http.ResponseWriter
status int
}
func (r *statusRecorder) WriteHeader(code int) {
r.status = code
r.ResponseWriter.WriteHeader(code)
}
// HTTPLogging wraps the given http.Handler with a per-request log line.
func HTTPLogging(log *zap.Logger, next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
rec := &statusRecorder{ResponseWriter: w, status: 200}
next.ServeHTTP(rec, r)
log.Info("http",
zap.String("method", r.Method),
zap.String("path", r.URL.Path),
zap.Int("status", rec.status),
zap.Duration("duration", time.Since(start)))
})
}

239
internal/api/prediction.go Normal file
View file

@ -0,0 +1,239 @@
package api
import (
"context"
"net/http"
"time"
"predictor-refactored/internal/engine"
"predictor-refactored/internal/weather"
apirest "predictor-refactored/pkg/rest"
)
// ReadinessCheck implements GET /ready.
func (h *Handler) ReadinessCheck(_ context.Context) (*apirest.ReadinessResponse, error) {
resp := &apirest.ReadinessResponse{}
if field := h.mgr.Active(); field != nil {
resp.Status = apirest.ReadinessResponseStatusOk
resp.DatasetTime = apirest.NewOptDateTime(field.Epoch())
} else {
resp.Status = apirest.ReadinessResponseStatusNotReady
resp.ErrorMessage = apirest.NewOptString("no dataset loaded")
}
return resp, nil
}
// PerformPredictionV2 implements POST /api/v2/prediction.
func (h *Handler) PerformPredictionV2(_ context.Context, req *apirest.PredictionV2Request) (*apirest.PredictionV2Response, error) {
resp, err := h.runPredictionV2(req)
if err == nil {
h.metrics.Prediction("v2", resp.CompletedAt.Sub(resp.StartedAt), nil)
}
return resp, err
}
// CreatePredictionJob implements POST /api/v1/predictions.
func (h *Handler) CreatePredictionJob(_ context.Context, req *apirest.PredictionV2Request) (*apirest.PredictionJob, error) {
info, accepted := h.async.Enqueue(req)
if !accepted {
return nil, apiError(http.StatusServiceUnavailable, info.Error)
}
return asyncJobToAPI(info), nil
}
// GetPredictionJob implements GET /api/v1/predictions/{id}.
func (h *Handler) GetPredictionJob(_ context.Context, params apirest.GetPredictionJobParams) (*apirest.PredictionJob, error) {
info, ok := h.async.Get(params.ID)
if !ok {
return nil, apiError(http.StatusNotFound, "prediction job not found")
}
return asyncJobToAPI(info), nil
}
// CancelPredictionJob implements DELETE /api/v1/predictions/{id}.
func (h *Handler) CancelPredictionJob(_ context.Context, params apirest.CancelPredictionJobParams) error {
if !h.async.Cancel(params.ID) {
return apiError(http.StatusConflict, "job not found or already terminal")
}
return nil
}
// runPredictionV2 is the synchronous prediction core, shared by the v2
// endpoint and the async worker pool.
func (h *Handler) runPredictionV2(req *apirest.PredictionV2Request) (*apirest.PredictionV2Response, error) {
// Validate the request shape before checking dataset availability, so a
// malformed request is a 400 regardless of startup state.
lat := req.Launch.Latitude
rawLng := req.Launch.Longitude
alt := req.Launch.Altitude.Or(0)
if lat < -90 || lat > 90 {
return nil, apiError(http.StatusBadRequest, "launch.latitude must be in [-90, 90]")
}
if rawLng < -180 || rawLng >= 360 {
return nil, apiError(http.StatusBadRequest, "launch.longitude must be in [-180, 360)")
}
lng := normalizeLng(rawLng)
field := h.mgr.Active()
if field == nil {
return nil, apiError(http.StatusServiceUnavailable, "no dataset loaded, service is starting up")
}
events := engine.NewEventSink()
deps := engine.BuildDeps{Wind: field, Events: events, Terrain: h.terrain()}
prof, err := buildProfile(req, deps)
if err != nil {
return nil, apiError(http.StatusBadRequest, err.Error())
}
started := time.Now().UTC()
results := prof.Run(float64(req.Launch.Time.Unix()), engine.State{Lat: lat, Lng: lng, Altitude: alt}, events)
completed := time.Now().UTC()
resp := &apirest.PredictionV2Response{
Stages: make([]apirest.StageResult, 0, len(results)),
Events: eventsToAPI(events.Snapshot()),
Dataset: apirest.DatasetInfo{Source: field.Source(), Epoch: field.Epoch()},
StartedAt: started,
CompletedAt: completed,
}
for _, r := range results {
resp.Stages = append(resp.Stages, stageResultToAPI(r))
}
return resp, nil
}
// PerformPrediction implements GET /api/v1/prediction (Tawhiri-compatible).
func (h *Handler) PerformPrediction(_ context.Context, params apirest.PerformPredictionParams) (*apirest.PredictionResponse, error) {
field := h.mgr.Active()
if field == nil {
return nil, apiError(http.StatusServiceUnavailable, "no dataset loaded, service is starting up")
}
profileKind := "standard_profile"
if p, ok := params.Profile.Get(); ok {
profileKind = string(p)
}
ascentRate := params.AscentRate.Or(5)
descentRate := params.DescentRate.Or(5)
launchAlt := params.LaunchAltitude.Or(0)
lng := normalizeLng(params.LaunchLongitude)
launchTime := float64(params.LaunchDatetime.Unix())
events := engine.NewEventSink()
var stageNames []string
var prof engine.Profile
switch profileKind {
case "standard_profile":
stageNames = []string{"ascent", "descent"}
prof = standardProfile(field, h.terrain(), events, ascentRate, params.BurstAltitude.Or(28000), descentRate)
case "float_profile":
stopTime := params.LaunchDatetime.Add(24 * time.Hour)
if v, ok := params.StopDatetime.Get(); ok {
stopTime = v
}
stageNames = []string{"ascent", "float"}
prof = floatProfile(field, events, ascentRate, params.FloatAltitude.Or(25000), stopTime)
default:
return nil, apiError(http.StatusBadRequest, "unknown profile: "+profileKind)
}
started := time.Now().UTC()
results := prof.Run(launchTime, engine.State{Lat: params.LaunchLatitude, Lng: lng, Altitude: launchAlt}, events)
completed := time.Now().UTC()
h.metrics.Prediction(profileKind, completed.Sub(started), nil)
resp := &apirest.PredictionResponse{
Metadata: apirest.PredictionResponseMetadata{StartDatetime: started, CompleteDatetime: completed},
}
for i, r := range results {
name := "ascent"
if i < len(stageNames) {
name = stageNames[i]
}
resp.Prediction = append(resp.Prediction, tawhiriItem(name, r))
}
resp.Request = apirest.NewOptPredictionResponseRequest(apirest.PredictionResponseRequest{
Dataset: apirest.NewOptString(field.Epoch().Format("2006-01-02T15:04:05Z")),
LaunchLatitude: apirest.NewOptFloat64(params.LaunchLatitude),
LaunchLongitude: apirest.NewOptFloat64(params.LaunchLongitude),
LaunchDatetime: apirest.NewOptString(params.LaunchDatetime.Format(time.RFC3339)),
LaunchAltitude: params.LaunchAltitude,
})
if ev := events.Snapshot(); len(ev) > 0 {
resp.Warnings = apirest.NewOptPredictionResponseWarnings(apirest.PredictionResponseWarnings{})
}
return resp, nil
}
// standardProfile builds the Tawhiri ascent → descent chain.
func standardProfile(field weather.WindField, elev engine.TerrainProvider, events *engine.EventSink, ascentRate, burst, descentRate float64) engine.Profile {
wind := engine.WindTransport(field, events)
descentTerm := []engine.Constraint{engine.Altitude{Op: engine.OpLessEqual, Limit: 0, On: engine.ActionStop}}
if elev != nil {
descentTerm = []engine.Constraint{engine.TerrainContact{Provider: elev, On: engine.ActionStop}}
}
return engine.Profile{
Direction: engine.Forward,
Stages: []*engine.Propagator{
{
Name: "ascent",
Step: 60,
Model: engine.Sum(engine.ConstantRate(ascentRate), wind),
Constraints: []engine.Constraint{engine.Altitude{Op: engine.OpGreaterEqual, Limit: burst, On: engine.ActionStop}},
},
{
Name: "descent",
Step: 60,
Model: engine.Sum(engine.ParachuteDescent(descentRate), wind),
Constraints: descentTerm,
},
},
}
}
// floatProfile builds the Tawhiri ascent → float chain.
func floatProfile(field weather.WindField, events *engine.EventSink, ascentRate, floatAlt float64, stopTime time.Time) engine.Profile {
wind := engine.WindTransport(field, events)
return engine.Profile{
Direction: engine.Forward,
Stages: []*engine.Propagator{
{
Name: "ascent",
Step: 60,
Model: engine.Sum(engine.ConstantRate(ascentRate), wind),
Constraints: []engine.Constraint{engine.Altitude{Op: engine.OpGreaterEqual, Limit: floatAlt, On: engine.ActionStop}},
},
{
Name: "float",
Step: 60,
Model: wind,
Constraints: []engine.Constraint{engine.Time{Op: engine.OpGreater, Limit: float64(stopTime.Unix()), On: engine.ActionStop}},
},
},
}
}
// tawhiriItem maps one engine stage result to a v1 prediction item.
func tawhiriItem(name string, r engine.Result) apirest.PredictionResponsePredictionItem {
stage := apirest.PredictionResponsePredictionItemStageAscent
switch name {
case "descent":
stage = apirest.PredictionResponsePredictionItemStageDescent
case "float":
stage = apirest.PredictionResponsePredictionItemStageFloat
}
n := r.Path.Len()
traj := make([]apirest.TawhiriPoint, 0, n)
for i := range n {
t, p := r.Path.At(i)
traj = append(traj, apirest.TawhiriPoint{
Datetime: time.Unix(int64(t), 0).UTC(),
Latitude: p.Lat,
Longitude: signedLng(p.Lng),
Altitude: p.Altitude,
})
}
return apirest.PredictionResponsePredictionItem{Stage: stage, Trajectory: traj}
}

View file

@ -1,251 +0,0 @@
// Package tawhiri implements the legacy Tawhiri-compatible HTTP endpoint
// (GET /api/v1/prediction). The request/response shapes match the original
// Cambridge University Spaceflight predictor for drop-in compatibility.
//
// Internally the handler builds an engine.Profile from query parameters
// and dispatches it through the same engine path as the new v2 endpoint.
package tawhiri
import (
"context"
"errors"
"net/http"
"time"
"go.uber.org/zap"
"predictor-refactored/internal/datasets"
"predictor-refactored/internal/elevation"
"predictor-refactored/internal/engine"
"predictor-refactored/internal/metrics"
"predictor-refactored/internal/weather"
api "predictor-refactored/pkg/rest"
)
// Handler implements api.Handler (ogen-generated interface).
type Handler struct {
mgr *datasets.Manager
elev *elevation.Dataset
metrics metrics.Sink
log *zap.Logger
}
// New wires a Handler.
func New(mgr *datasets.Manager, elev *elevation.Dataset, sink metrics.Sink, log *zap.Logger) *Handler {
if log == nil {
log = zap.NewNop()
}
if sink == nil {
sink = metrics.Noop()
}
return &Handler{mgr: mgr, elev: elev, metrics: sink, log: log}
}
var _ api.Handler = (*Handler)(nil)
// PerformPrediction runs the Tawhiri-style prediction.
func (h *Handler) PerformPrediction(_ context.Context, params api.PerformPredictionParams) (*api.PredictionResponse, error) {
field := h.mgr.Active()
if field == nil {
return nil, newError(http.StatusServiceUnavailable, "no dataset loaded, service is starting up")
}
profileKind := optString(params.Profile, "standard_profile")
ascentRate := optFloat(params.AscentRate, 5.0)
burstAltitude := optFloat(params.BurstAltitude, 28000.0)
descentRate := optFloat(params.DescentRate, 5.0)
launchAlt := optFloat(params.LaunchAltitude, 0.0)
lng := params.LaunchLongitude
if lng < 0 {
lng += 360
}
launchTime := float64(params.LaunchDatetime.Unix())
events := engine.NewEventSink()
var stageNames []string
var prof engine.Profile
switch profileKind {
case "standard_profile":
stageNames = []string{"ascent", "descent"}
prof = standardProfile(field, h.elev, events, ascentRate, burstAltitude, descentRate)
case "float_profile":
floatAlt := optFloat(params.FloatAltitude, 25000.0)
stopTime := params.LaunchDatetime.Add(24 * time.Hour)
if v, ok := params.StopDatetime.Get(); ok {
stopTime = v
}
stageNames = []string{"ascent", "float"}
prof = floatProfile(field, events, ascentRate, floatAlt, stopTime)
default:
return nil, newError(http.StatusBadRequest, "unknown profile: "+profileKind)
}
started := time.Now().UTC()
results := prof.Run(launchTime, engine.State{Lat: params.LaunchLatitude, Lng: lng, Altitude: launchAlt}, events)
completed := time.Now().UTC()
h.metrics.Prediction(profileKind, completed.Sub(started), nil)
resp := &api.PredictionResponse{
Metadata: api.PredictionResponseMetadata{
StartDatetime: started,
CompleteDatetime: completed,
},
}
for i, r := range results {
stageName := "ascent"
if i < len(stageNames) {
stageName = stageNames[i]
}
resp.Prediction = append(resp.Prediction, buildPredictionItem(stageName, r))
}
resp.Request = api.NewOptPredictionResponseRequest(api.PredictionResponseRequest{
Dataset: api.NewOptString(field.Epoch().Format("2006-01-02T15:04:05Z")),
LaunchLatitude: api.NewOptFloat64(params.LaunchLatitude),
LaunchLongitude: api.NewOptFloat64(params.LaunchLongitude),
LaunchDatetime: api.NewOptString(params.LaunchDatetime.Format(time.RFC3339)),
LaunchAltitude: params.LaunchAltitude,
})
if ev := events.Snapshot(); len(ev) > 0 {
// Preserve the OpenAPI-defined Warnings shape (open object).
resp.Warnings = api.NewOptPredictionResponseWarnings(api.PredictionResponseWarnings{})
}
h.log.Info("prediction complete",
zap.String("profile", profileKind),
zap.Int("stages", len(results)),
zap.Duration("elapsed", completed.Sub(started)))
return resp, nil
}
// standardProfile constructs the ascent → descent profile.
func standardProfile(field weather.WindField, elev *elevation.Dataset, events *engine.EventSink, ascentRate, burstAltitude, descentRate float64) engine.Profile {
wind := engine.WindTransport(field, events)
descentTerm := []engine.Constraint{engine.Altitude{Op: engine.OpLessEqual, Limit: 0, On: engine.ActionStop}}
if elev != nil {
descentTerm = []engine.Constraint{engine.TerrainContact{Provider: elev, On: engine.ActionStop}}
}
return engine.Profile{
Direction: engine.Forward,
Stages: []*engine.Propagator{
{
Name: "ascent",
Step: 60,
Model: engine.Sum(engine.ConstantRate(ascentRate), wind),
Constraints: []engine.Constraint{engine.Altitude{Op: engine.OpGreaterEqual, Limit: burstAltitude, On: engine.ActionStop}},
},
{
Name: "descent",
Step: 60,
Model: engine.Sum(engine.ParachuteDescent(descentRate), wind),
Constraints: descentTerm,
},
},
}
}
// floatProfile constructs the ascent → float profile.
func floatProfile(field weather.WindField, events *engine.EventSink, ascentRate, floatAlt float64, stopTime time.Time) engine.Profile {
wind := engine.WindTransport(field, events)
return engine.Profile{
Direction: engine.Forward,
Stages: []*engine.Propagator{
{
Name: "ascent",
Step: 60,
Model: engine.Sum(engine.ConstantRate(ascentRate), wind),
Constraints: []engine.Constraint{engine.Altitude{Op: engine.OpGreaterEqual, Limit: floatAlt, On: engine.ActionStop}},
},
{
Name: "float",
Step: 60,
Model: wind,
Constraints: []engine.Constraint{engine.Time{Op: engine.OpGreater, Limit: float64(stopTime.Unix()), On: engine.ActionStop}},
},
},
}
}
func buildPredictionItem(stageName string, r engine.Result) api.PredictionResponsePredictionItem {
var stageEnum api.PredictionResponsePredictionItemStage
switch stageName {
case "descent":
stageEnum = api.PredictionResponsePredictionItemStageDescent
case "float":
stageEnum = api.PredictionResponsePredictionItemStageFloat
default:
stageEnum = api.PredictionResponsePredictionItemStageAscent
}
traj := make([]api.PredictionResponsePredictionItemTrajectoryItem, 0, len(r.Points))
for _, pt := range r.Points {
ptLng := pt.Lng
if ptLng > 180 {
ptLng -= 360
}
traj = append(traj, api.PredictionResponsePredictionItemTrajectoryItem{
Datetime: time.Unix(int64(pt.Time), 0).UTC(),
Latitude: pt.Lat,
Longitude: ptLng,
Altitude: pt.Altitude,
})
}
return api.PredictionResponsePredictionItem{Stage: stageEnum, Trajectory: traj}
}
// ReadinessCheck reports whether a dataset is currently loaded.
func (h *Handler) ReadinessCheck(_ context.Context) (*api.ReadinessResponse, error) {
resp := &api.ReadinessResponse{}
if field := h.mgr.Active(); field != nil {
resp.Status = api.ReadinessResponseStatusOk
resp.DatasetTime = api.NewOptDateTime(field.Epoch())
} else {
resp.Status = api.ReadinessResponseStatusNotReady
resp.ErrorMessage = api.NewOptString("no dataset loaded")
}
return resp, nil
}
// NewError implements the ogen Handler interface for unhandled errors.
func (h *Handler) NewError(_ context.Context, err error) *api.ErrorStatusCode {
var statusErr *api.ErrorStatusCode
if errors.As(err, &statusErr) {
return statusErr
}
h.log.Error("unhandled error", zap.Error(err))
return newError(http.StatusInternalServerError, err.Error())
}
func newError(status int, description string) *api.ErrorStatusCode {
return &api.ErrorStatusCode{
StatusCode: status,
Response: api.Error{
Error: api.ErrorError{
Type: http.StatusText(status),
Description: description,
},
},
}
}
// optString returns the option's value if set, else fallback.
func optString[T ~string](o interface {
Get() (T, bool)
}, fallback string) string {
if v, ok := o.Get(); ok {
return string(v)
}
return fallback
}
// optFloat returns the option's float64 value if set, else fallback.
func optFloat(o api.OptFloat64, fallback float64) float64 {
if v, ok := o.Get(); ok {
return v
}
return fallback
}

View file

@ -1,9 +1,8 @@
// Package api wires together every HTTP-facing component of the service:
//
// - Tawhiri-compatible v1 endpoints generated from the OpenAPI spec (ogen);
// - The new v2 prediction endpoint;
// - Dataset and job admin endpoints under /api/v1/admin/;
// - Optional Prometheus-format metrics endpoint.
// Package api is the HTTP surface of the service. Every REST operation is
// defined in the OpenAPI spec (api/rest/predictor.swagger.yml) and served by
// the ogen-generated server in pkg/rest; this package implements the
// generated Handler interface and wires the server together with the
// non-OpenAPI endpoints (Prometheus metrics, ReDoc docs).
package api
import (
@ -14,22 +13,22 @@ import (
"go.uber.org/zap"
"predictor-refactored/internal/api/admin"
"predictor-refactored/internal/api/async"
"predictor-refactored/internal/api/docs"
"predictor-refactored/internal/api/middleware"
"predictor-refactored/internal/api/tawhiri"
v2 "predictor-refactored/internal/api/v2"
"predictor-refactored/internal/datasets"
"predictor-refactored/internal/elevation"
"predictor-refactored/internal/metrics"
"predictor-refactored/internal/windviz"
apirest "predictor-refactored/pkg/rest"
)
// Server is the top-level HTTP server.
type Server struct {
port int
mux *http.ServeMux
log *zap.Logger
port int
mux *http.ServeMux
async *async.Manager
log *zap.Logger
}
// Deps are the runtime dependencies the API layer needs.
@ -39,8 +38,14 @@ type Deps struct {
Metrics metrics.Sink
MetricsHandler http.Handler // optional; mounted at MetricsPath when non-nil
MetricsPath string
AsyncManager *async.Manager // optional; mounts /api/v1/predictions when non-nil
Log *zap.Logger
EnableWind bool
WindCache *windviz.Cache // optional; created if nil and EnableWind
AsyncWorkers int
AsyncQueueSize int
AsyncResultTTL time.Duration
Log *zap.Logger
}
// New wires the HTTP server. The returned Server is not yet started.
@ -51,53 +56,55 @@ func New(port int, d Deps) (*Server, error) {
if d.Metrics == nil {
d.Metrics = metrics.Noop()
}
if d.EnableWind && d.WindCache == nil {
d.WindCache = windviz.NewCache(64, 10*time.Minute)
}
mux := http.NewServeMux()
h := &Handler{
mgr: d.Manager,
elev: d.Elevation,
metrics: d.Metrics,
cache: d.WindCache,
started: time.Now().UTC(),
log: d.Log,
}
// The async worker pool runs the same prediction core as the synchronous
// endpoint; inject it so async stays decoupled from the wire types.
h.async = async.New(async.Config{
Workers: d.AsyncWorkers,
QueueSize: d.AsyncQueueSize,
ResultTTL: d.AsyncResultTTL,
}, h.runPredictionV2, d.Metrics, d.Log)
// ogen-generated server handles the Tawhiri-compat surface
// (GET /api/v1/prediction and GET /ready).
tw := tawhiri.New(d.Manager, d.Elevation, d.Metrics, d.Log)
ogenSrv, err := apirest.NewServer(tw, apirest.WithMiddleware(middleware.OgenLogging(d.Log)))
ogenSrv, err := apirest.NewServer(h, apirest.WithMiddleware(middleware.OgenLogging(d.Log)))
if err != nil {
return nil, fmt.Errorf("create ogen server: %w", err)
}
// New primary prediction endpoint.
v2h := v2.New(d.Manager, d.Elevation, d.Metrics, d.Log)
mux.Handle("/api/v2/prediction", v2h)
// Admin endpoints.
adminH := admin.New(d.Manager, d.Log)
adminH.Register(mux)
// Async prediction endpoints (optional).
if d.AsyncManager != nil {
asyncH := async.NewHandler(d.AsyncManager)
asyncH.Register(mux)
}
// Metrics endpoint.
mux := http.NewServeMux()
// Liveness: always 200 while the process is up, independent of whether a
// dataset is loaded. Container/orchestrator health checks use this; the
// readiness of the data plane is /ready (an OpenAPI operation).
mux.HandleFunc("GET /health", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"status":"alive"}`))
})
docs.New().Register(mux)
if d.MetricsHandler != nil && d.MetricsPath != "" {
mux.Handle(d.MetricsPath, d.MetricsHandler)
}
// Fallback to the ogen-generated routes (v1 + ready) for anything else.
// The ogen server owns every OpenAPI route; mount it last as the catch-all.
mux.Handle("/", ogenSrv)
return &Server{
port: port,
mux: mux,
log: d.Log,
}, nil
return &Server{port: port, mux: mux, async: h.async, log: d.Log}, nil
}
// Run starts the HTTP server and blocks until it returns.
//
// The handler chain is: CORS → request logger → mux.
// Run starts the HTTP server and blocks until ctx is cancelled or the server
// fails. The handler chain is CORS → mux (ogen routes + docs + metrics).
func (s *Server) Run(ctx context.Context) error {
srv := &http.Server{
Addr: fmt.Sprintf(":%d", s.port),
Handler: middleware.CORS(middleware.HTTPLogging(s.log, s.mux)),
Handler: middleware.CORS(s.mux),
}
errCh := make(chan error, 1)
@ -115,3 +122,10 @@ func (s *Server) Run(ctx context.Context) error {
return srv.Shutdown(shutdownCtx)
}
}
// Close releases background resources (the async worker pool).
func (s *Server) Close() {
if s.async != nil {
s.async.Close()
}
}

View file

@ -1,177 +0,0 @@
package v2
import (
"encoding/json"
"fmt"
"net/http"
"time"
"go.uber.org/zap"
"predictor-refactored/internal/api/httpjson"
"predictor-refactored/internal/datasets"
"predictor-refactored/internal/elevation"
"predictor-refactored/internal/engine"
"predictor-refactored/internal/metrics"
)
// Handler serves POST /api/v2/prediction.
type Handler struct {
mgr *datasets.Manager
elev *elevation.Dataset
metrics metrics.Sink
log *zap.Logger
}
// New wires a v2 Handler.
func New(mgr *datasets.Manager, elev *elevation.Dataset, sink metrics.Sink, log *zap.Logger) *Handler {
if log == nil {
log = zap.NewNop()
}
if sink == nil {
sink = metrics.Noop()
}
return &Handler{mgr: mgr, elev: elev, metrics: sink, log: log}
}
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodPost {
writeError(w, http.StatusMethodNotAllowed, "use POST")
return
}
var req PredictionRequest
dec := json.NewDecoder(r.Body)
dec.DisallowUnknownFields()
if err := dec.Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid request body: "+err.Error())
return
}
if err := validateRequest(req); err != nil {
writeError(w, http.StatusBadRequest, err.Error())
return
}
resp, err := Run(h.mgr, h.elev, req)
if err != nil {
if perr, ok := err.(*PredictionError); ok {
writeError(w, perr.Status, perr.Description)
return
}
writeError(w, http.StatusInternalServerError, err.Error())
return
}
h.metrics.Prediction("v2", resp.CompletedAt.Sub(resp.StartedAt), nil)
h.log.Info("v2 prediction complete",
zap.Int("stages", len(resp.Stages)),
zap.Duration("elapsed", resp.CompletedAt.Sub(resp.StartedAt)))
writeJSON(w, http.StatusOK, resp)
}
// PredictionError carries an HTTP status alongside the message so async
// callers can map the failure back to a useful HTTP response.
type PredictionError struct {
Status int
Description string
}
func (e *PredictionError) Error() string { return e.Description }
// Run executes a PredictionRequest against the manager's active wind field.
// Shared between the sync /api/v2/prediction handler and the async
// /api/v1/predictions worker.
func Run(mgr *datasets.Manager, elev *elevation.Dataset, req PredictionRequest) (*PredictionResponse, error) {
field := mgr.Active()
if field == nil {
return nil, &PredictionError{Status: http.StatusServiceUnavailable, Description: "no dataset loaded, service is starting up"}
}
lng := req.Launch.Longitude
if lng < 0 {
lng += 360
}
events := engine.NewEventSink()
deps := engine.BuildDeps{Wind: field, Events: events}
if elev != nil {
deps.Terrain = elev
}
prof, err := buildProfile(req, deps)
if err != nil {
return nil, &PredictionError{Status: http.StatusBadRequest, Description: err.Error()}
}
started := time.Now().UTC()
results := prof.Run(float64(req.Launch.Time.Unix()), engine.State{
Lat: req.Launch.Latitude, Lng: lng, Altitude: req.Launch.Altitude,
}, events)
completed := time.Now().UTC()
resp := &PredictionResponse{
Stages: make([]StageResult, 0, len(results)),
Events: events.Snapshot(),
StartedAt: started,
CompletedAt: completed,
Dataset: DatasetInfo{Source: field.Source(), Epoch: field.Epoch()},
}
for _, r := range results {
resp.Stages = append(resp.Stages, toStageResult(r))
}
return resp, nil
}
func toStageResult(r engine.Result) StageResult {
stage := StageResult{
Name: r.Propagator,
Outcome: r.Outcome.String(),
Events: r.Events,
}
if r.Constraint != nil {
stage.Constraint = r.ConstraintName
stage.Termination = &TerminationInfo{
ViolationTime: time.Unix(int64(r.ViolationTime), 0).UTC(),
ViolationState: r.ViolationState,
RefinedTime: time.Unix(int64(r.RefinedTime), 0).UTC(),
RefinedState: r.RefinedState,
}
}
stage.Trajectory = make([]TrajectoryPoint, len(r.Points))
for i, pt := range r.Points {
ptLng := pt.Lng
if ptLng > 180 {
ptLng -= 360
}
stage.Trajectory[i] = TrajectoryPoint{
Time: time.Unix(int64(pt.Time), 0).UTC(),
Latitude: pt.Lat,
Longitude: ptLng,
Altitude: pt.Altitude,
}
}
return stage
}
func validateRequest(req PredictionRequest) error {
if req.Launch.Latitude < -90 || req.Launch.Latitude > 90 {
return fmt.Errorf("launch.latitude must be in [-90, 90]")
}
if req.Launch.Longitude < -180 || req.Launch.Longitude >= 360 {
return fmt.Errorf("launch.longitude must be in [-180, 360)")
}
if len(req.Profile) == 0 {
return fmt.Errorf("profile must contain at least one stage")
}
for i, s := range req.Profile {
if s.Name == "" {
return fmt.Errorf("profile[%d].name is required", i)
}
if s.Model.Type == "" {
return fmt.Errorf("profile[%d].model.type is required", i)
}
}
return nil
}
var writeJSON = httpjson.Write
var writeError = httpjson.Error

View file

@ -1,87 +0,0 @@
package v2
import (
"fmt"
"predictor-refactored/internal/engine"
)
// buildProfile translates a PredictionRequest into an engine.Profile via
// the engine registry.
func buildProfile(req PredictionRequest, deps engine.BuildDeps) (engine.Profile, error) {
if len(req.Profile) == 0 {
return engine.Profile{}, fmt.Errorf("profile must contain at least one stage")
}
step := req.Options.StepSeconds
if step == 0 {
step = 60
}
tol := req.Options.Tolerance
if tol == 0 {
tol = 0.01
}
dir := engine.Forward
switch req.Direction {
case "", "forward":
dir = engine.Forward
case "reverse":
dir = engine.Reverse
default:
return engine.Profile{}, fmt.Errorf("unknown direction %q", req.Direction)
}
props := make([]*engine.Propagator, len(req.Profile))
for i, stage := range req.Profile {
if stage.Name == "" {
return engine.Profile{}, fmt.Errorf("stage %d: name is required", i)
}
built, err := engine.BuildModel(stage.Model, deps)
if err != nil {
return engine.Profile{}, fmt.Errorf("stage %q model: %w", stage.Name, err)
}
constraints, err := buildConstraintList(stage.Constraints, deps)
if err != nil {
return engine.Profile{}, fmt.Errorf("stage %q: %w", stage.Name, err)
}
props[i] = &engine.Propagator{
Name: stage.Name,
Step: step,
Model: built.Model,
BuildModel: built.Build,
Constraints: constraints,
Tolerance: tol,
}
}
for i, stage := range req.Profile {
if stage.FallbackIndex == nil {
continue
}
idx := *stage.FallbackIndex
if idx < 0 || idx >= len(props) {
return engine.Profile{}, fmt.Errorf("stage %q: fallback_index %d out of range", stage.Name, idx)
}
props[i].Fallback = props[idx]
}
globals, err := buildConstraintList(req.Globals, deps)
if err != nil {
return engine.Profile{}, fmt.Errorf("globals: %w", err)
}
return engine.Profile{Stages: props, Direction: dir, Globals: globals}, nil
}
func buildConstraintList(specs []engine.ConstraintSpec, deps engine.BuildDeps) ([]engine.Constraint, error) {
out := make([]engine.Constraint, 0, len(specs))
for i, spec := range specs {
c, err := engine.BuildConstraint(spec, deps)
if err != nil {
return nil, fmt.Errorf("constraint[%d]: %w", i, err)
}
out = append(out, c)
}
return out, nil
}

View file

@ -1,100 +0,0 @@
// Package v2 implements the profile-driven prediction endpoint.
//
// Endpoint: POST /api/v2/prediction
//
// The request schema is built on the engine package's ConstraintSpec and
// ModelSpec, so adding new model or constraint types in the engine requires
// no changes here — they become available automatically via the registry.
package v2
import (
"time"
"predictor-refactored/internal/engine"
)
// PredictionRequest is the body of POST /api/v2/prediction.
type PredictionRequest struct {
Launch Launch `json:"launch"`
Profile []StageSpec `json:"profile"`
Globals []engine.ConstraintSpec `json:"globals,omitempty"`
Options Options `json:"options,omitempty"`
Direction string `json:"direction,omitempty"` // "forward" (default) or "reverse"
}
// Launch is the initial state of the balloon (or, for reverse predictions,
// the known landing point).
type Launch struct {
Time time.Time `json:"time"`
Latitude float64 `json:"latitude"`
Longitude float64 `json:"longitude"`
Altitude float64 `json:"altitude"`
}
// StageSpec is one entry in the propagator chain.
type StageSpec struct {
Name string `json:"name"`
Model engine.ModelSpec `json:"model"`
Constraints []engine.ConstraintSpec `json:"constraints,omitempty"`
// FallbackIndex, when set, points to another stage in the same profile
// to transfer to on ActionFallback constraints.
FallbackIndex *int `json:"fallback_index,omitempty"`
}
// Options tweaks integrator behaviour.
type Options struct {
StepSeconds float64 `json:"step_seconds,omitempty"`
Tolerance float64 `json:"tolerance,omitempty"`
}
// PredictionResponse is the body of a successful POST response.
type PredictionResponse struct {
Stages []StageResult `json:"stages"`
Events []engine.EventSummary `json:"events,omitempty"`
Dataset DatasetInfo `json:"dataset"`
StartedAt time.Time `json:"started_at"`
CompletedAt time.Time `json:"completed_at"`
}
// StageResult is the outcome of one stage.
type StageResult struct {
Name string `json:"name"`
Outcome string `json:"outcome"`
Constraint string `json:"constraint,omitempty"`
Termination *TerminationInfo `json:"termination,omitempty"`
Events []engine.EventSummary `json:"events,omitempty"`
Trajectory []TrajectoryPoint `json:"trajectory"`
}
// TerminationInfo exposes the violation+refinement detail from the engine.
type TerminationInfo struct {
ViolationTime time.Time `json:"violation_time"`
ViolationState engine.State `json:"violation_state"`
RefinedTime time.Time `json:"refined_time"`
RefinedState engine.State `json:"refined_state"`
}
// TrajectoryPoint is one sampled point of the trajectory.
type TrajectoryPoint struct {
Time time.Time `json:"time"`
Latitude float64 `json:"latitude"`
Longitude float64 `json:"longitude"`
Altitude float64 `json:"altitude"`
}
// DatasetInfo identifies the wind dataset used.
type DatasetInfo struct {
Source string `json:"source"`
Epoch time.Time `json:"epoch"`
}
// ErrorResponse is the JSON error shape.
type ErrorResponse struct {
Error ErrorBody `json:"error"`
}
// ErrorBody is the error detail.
type ErrorBody struct {
Type string `json:"type"`
Description string `json:"description"`
}

92
internal/api/wind.go Normal file
View file

@ -0,0 +1,92 @@
package api
import (
"context"
"fmt"
"net/http"
"predictor-refactored/internal/windviz"
apirest "predictor-refactored/pkg/rest"
)
// GetWindMeta implements GET /api/v1/wind/meta.
func (h *Handler) GetWindMeta(_ context.Context) (*apirest.WindMeta, error) {
field := h.mgr.Active()
if field == nil {
return nil, apiError(http.StatusServiceUnavailable, "no dataset loaded")
}
return &apirest.WindMeta{
Source: field.Source(),
Epoch: field.Epoch().UTC(),
DefaultStep: 1.0,
MinStep: 0.25,
SuggestedAltitudes: []int{0, 1000, 5000, 10000, 15000, 20000, 30000},
Bbox: apirest.Region{MinLat: -90, MaxLat: 90, MinLng: 0, MaxLng: 360},
}, nil
}
// GetWindField implements GET /api/v1/wind/field.
func (h *Handler) GetWindField(_ context.Context, params apirest.GetWindFieldParams) ([]apirest.WindComponent, error) {
field := h.mgr.Active()
if field == nil {
return nil, apiError(http.StatusServiceUnavailable, "no dataset loaded")
}
when := field.Epoch()
if t, ok := params.Time.Get(); ok {
when = t
}
req := windviz.Request{
Time: float64(when.Unix()),
Altitude: params.Altitude.Or(0),
MinLat: params.MinLat.Or(0),
MaxLat: params.MaxLat.Or(0),
MinLng: params.MinLng.Or(0),
MaxLng: params.MaxLng.Or(0),
Step: params.Step.Or(0),
}
key := fmt.Sprintf("%s|%v|%.3f|%.3f|%.3f|%.3f|%.3f|%.3f",
field.Source(), req.Time, req.Altitude, req.MinLat, req.MaxLat, req.MinLng, req.MaxLng, req.Step)
if h.cache != nil {
if cached, ok := h.cache.Get(key); ok {
return windFieldToAPI(cached), nil
}
}
out, err := windviz.Rasterize(field, req)
if err != nil {
return nil, apiError(http.StatusBadRequest, err.Error())
}
if h.cache != nil {
h.cache.Put(key, out)
}
return windFieldToAPI(out), nil
}
// windFieldToAPI maps a rasterized field to the generated component slice.
func windFieldToAPI(f windviz.Field) []apirest.WindComponent {
out := make([]apirest.WindComponent, 0, len(f))
for _, c := range f {
out = append(out, apirest.WindComponent{
Header: apirest.WindHeader{
ParameterCategory: c.Header.ParameterCategory,
ParameterNumber: c.Header.ParameterNumber,
ParameterNumberName: apirest.NewOptString(c.Header.ParameterNumberName),
ParameterUnit: apirest.NewOptString(c.Header.ParameterUnit),
Nx: c.Header.Nx,
Ny: c.Header.Ny,
Lo1: c.Header.Lo1,
La1: c.Header.La1,
Lo2: c.Header.Lo2,
La2: c.Header.La2,
Dx: c.Header.Dx,
Dy: c.Header.Dy,
RefTime: c.Header.RefTime,
ForecastTime: c.Header.ForecastTime,
},
Data: c.Data,
})
}
return out
}

View file

@ -22,6 +22,7 @@ type Config struct {
Data DataConfig `yaml:"data"`
Download DownloadConfig `yaml:"download"`
Metrics MetricsConfig `yaml:"metrics"`
Wind WindConfig `yaml:"wind"`
Log LogConfig `yaml:"log"`
}
@ -40,16 +41,17 @@ type HTTPConfig struct {
type DataConfig struct {
Dir string `yaml:"dir"`
ElevationPath string `yaml:"elevation_path"`
// Source is the dataset source identifier; only "noaa-gfs-0p50" is supported today.
// Source is the dataset variant ID: gfs-0p50-3h (default), gfs-0p25-3h,
// gfs-0p25-1h, or gefs-0p50-3h. See weather/gfs.VariantByID.
Source string `yaml:"source"`
}
// DownloadConfig configures the dataset downloader.
type DownloadConfig struct {
Parallel int `yaml:"parallel"`
BandwidthBytesPerSecond int64 `yaml:"bandwidth_bytes_per_second"`
UpdateInterval time.Duration `yaml:"update_interval"`
FreshnessTTL time.Duration `yaml:"freshness_ttl"`
Parallel int `yaml:"parallel"`
BandwidthBytesPerSecond int64 `yaml:"bandwidth_bytes_per_second"`
UpdateInterval time.Duration `yaml:"update_interval"`
FreshnessTTL time.Duration `yaml:"freshness_ttl"`
}
// MetricsConfig configures the metrics endpoint.
@ -58,6 +60,13 @@ type MetricsConfig struct {
Path string `yaml:"path"`
}
// WindConfig configures the wind-visualization endpoints.
type WindConfig struct {
Enabled bool `yaml:"enabled"`
CacheSize int `yaml:"cache_size"`
CacheTTL time.Duration `yaml:"cache_ttl"`
}
// LogConfig configures logging.
type LogConfig struct {
Level string `yaml:"level"` // "debug", "info", "warn", "error"
@ -87,6 +96,11 @@ func Defaults() Config {
Enabled: true,
Path: "/metrics",
},
Wind: WindConfig{
Enabled: true,
CacheSize: 64,
CacheTTL: 10 * time.Minute,
},
Log: LogConfig{Level: "info"},
}
}
@ -106,16 +120,16 @@ func Load(args []string) (Config, error) {
var (
configPath = fs.String("config", os.Getenv("PREDICTOR_CONFIG_FILE"), "path to YAML config file")
// Flag-driven overrides. Empty / -1 means "not specified".
flagPort = fs.Int("port", -1, "HTTP listen port")
flagDataDir = fs.String("data-dir", "", "directory for dataset files")
flagElevation = fs.String("elevation", "", "path to ruaumoko elevation dataset")
flagParallel = fs.Int("download-parallel", -1, "max concurrent GRIB downloads")
flagBandwidth = fs.Int64("download-bandwidth", -1, "download bandwidth limit in bytes/sec (0 = unlimited)")
flagInterval = fs.Duration("update-interval", 0, "scheduler refresh interval")
flagTTL = fs.Duration("freshness-ttl", 0, "max age before a dataset is considered stale")
flagPort = fs.Int("port", -1, "HTTP listen port")
flagDataDir = fs.String("data-dir", "", "directory for dataset files")
flagElevation = fs.String("elevation", "", "path to ruaumoko elevation dataset")
flagParallel = fs.Int("download-parallel", -1, "max concurrent GRIB downloads")
flagBandwidth = fs.Int64("download-bandwidth", -1, "download bandwidth limit in bytes/sec (0 = unlimited)")
flagInterval = fs.Duration("update-interval", 0, "scheduler refresh interval")
flagTTL = fs.Duration("freshness-ttl", 0, "max age before a dataset is considered stale")
flagMetricsEnabled = fs.Bool("metrics", true, "enable Prometheus-compatible metrics endpoint")
flagMetricsPath = fs.String("metrics-path", "", "HTTP path for the metrics endpoint")
flagLogLevel = fs.String("log-level", "", "log level: debug|info|warn|error")
flagMetricsPath = fs.String("metrics-path", "", "HTTP path for the metrics endpoint")
flagLogLevel = fs.String("log-level", "", "log level: debug|info|warn|error")
)
if err := fs.Parse(args); err != nil {
return Config{}, fmt.Errorf("parse flags: %w", err)

View file

@ -0,0 +1,11 @@
//go:build !unix
package datasets
import "context"
// flockExclusive is a no-op on platforms without flock. The service targets
// Linux containers; this stub only keeps non-Unix builds compiling.
func flockExclusive(_ context.Context, _ string) (func(), error) {
return func() {}, nil
}

View file

@ -0,0 +1,50 @@
//go:build unix
package datasets
import (
"context"
"errors"
"fmt"
"os"
"syscall"
"time"
)
// lockPollInterval is how often a contended lock is retried. The lock is held
// for the duration of a dataset download (minutes), so sub-second acquisition
// latency is irrelevant.
const lockPollInterval = 150 * time.Millisecond
// flockExclusive acquires an exclusive flock on path, creating the lock file
// if needed, and blocks until it is held or ctx is cancelled.
//
// It uses non-blocking LOCK_NB attempts in a poll loop rather than a blocking
// flock in a goroutine: the file descriptor is only ever touched by this
// goroutine, so there is no race between a pending syscall and Close on
// cancellation.
func flockExclusive(ctx context.Context, path string) (func(), error) {
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDWR, 0o644)
if err != nil {
return nil, fmt.Errorf("open lock file: %w", err)
}
for {
err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB)
if err == nil {
return func() {
_ = syscall.Flock(int(f.Fd()), syscall.LOCK_UN)
_ = f.Close()
}, nil
}
if !errors.Is(err, syscall.EWOULDBLOCK) {
f.Close()
return nil, fmt.Errorf("flock: %w", err)
}
select {
case <-ctx.Done():
f.Close()
return nil, ctx.Err()
case <-time.After(lockPollInterval):
}
}
}

View file

@ -27,16 +27,16 @@ const (
// JobInfo is the externally-visible snapshot of a download job.
type JobInfo struct {
ID string
Source string
Dataset DatasetID
Status JobStatus
StartedAt time.Time
EndedAt *time.Time
Err string
Total int
Done int
Bytes int64
ID string
Source string
Dataset DatasetID
Status JobStatus
StartedAt time.Time
EndedAt *time.Time
Err string
Total int
Done int
Bytes int64
}
type jobEntry struct {
@ -75,9 +75,9 @@ func (e *jobEntry) snapshot() JobInfo {
type jobProgress struct{ e *jobEntry }
func (p jobProgress) SetTotal(n int) { p.e.total.Store(int64(n)) }
func (p jobProgress) StepComplete() { p.e.done.Add(1) }
func (p jobProgress) Bytes(n int64) { p.e.bytes.Add(n) }
func (p jobProgress) SetTotal(n int) { p.e.total.Store(int64(n)) }
func (p jobProgress) StepComplete() { p.e.done.Add(1) }
func (p jobProgress) Bytes(n int64) { p.e.bytes.Add(n) }
// loadedDataset bundles a loaded WindField with its identity and coverage.
type loadedDataset struct {
@ -387,7 +387,7 @@ func (m *Manager) runDownload(ctx context.Context, e *jobEntry) {
zap.String("job", e.id),
zap.String("dataset", e.dataset.Filename()))
err := m.src.Download(ctx, e.dataset, m.store, jobProgress{e: e}, m.throttle)
err := m.downloadLocked(ctx, e)
now := time.Now().UTC()
e.mu.Lock()
@ -410,6 +410,26 @@ func (m *Manager) runDownload(ctx context.Context, e *jobEntry) {
zap.NamedError("err", err))
}
// downloadLocked runs the source download while holding the storage's
// cross-process lock, so multiple replicas sharing a node-local dataset
// volume coordinate instead of each fetching ~9 GB. After acquiring the lock
// it re-checks existence: if another replica committed the dataset while this
// one waited, it skips the download and lets the caller load the committed file.
func (m *Manager) downloadLocked(ctx context.Context, e *jobEntry) error {
release, err := m.store.Lock(ctx)
if err != nil {
return fmt.Errorf("acquire download lock: %w", err)
}
defer release()
if m.store.Exists(e.dataset) {
m.log.Info("dataset committed by another instance while waiting; skipping download",
zap.String("dataset", e.dataset.Filename()))
return nil
}
return m.src.Download(ctx, e.dataset, m.store, jobProgress{e: e}, m.throttle)
}
func (m *Manager) completeShortCircuit(ctx context.Context, e *jobEntry) {
_ = ctx
defer m.inFlight.Delete(e.dataset.Filename())

View file

@ -1,6 +1,7 @@
package datasets
import (
"context"
"errors"
"fmt"
"os"
@ -132,6 +133,13 @@ func (s *LocalStore) Remove(id DatasetID) error {
return nil
}
// Lock acquires the storage-wide download lock (an exclusive flock on a
// sentinel file in the root), serialising downloads across processes that
// share this directory.
func (s *LocalStore) Lock(ctx context.Context) (func(), error) {
return flockExclusive(ctx, filepath.Join(s.Root, ".download.lock"))
}
// BeginWrite opens or resumes a TempHandle for id.
func (s *LocalStore) BeginWrite(id DatasetID) (TempHandle, error) {
man, err := LoadManifest(s.manifestPath(id))
@ -148,8 +156,8 @@ type localHandle struct {
closed bool
}
func (h *localHandle) Path() string { return h.store.tempPath(h.id) }
func (h *localHandle) Manifest() *Manifest { return h.manifest }
func (h *localHandle) Path() string { return h.store.tempPath(h.id) }
func (h *localHandle) Manifest() *Manifest { return h.manifest }
func (h *localHandle) Commit() error {
if h.closed {

View file

@ -1,11 +1,64 @@
package datasets
import (
"context"
"os"
"testing"
"time"
)
func TestLocalStoreLockSerializes(t *testing.T) {
dir := t.TempDir()
store, _ := NewLocalStore(dir, "gfs-test")
ctx := context.Background()
release, err := store.Lock(ctx)
if err != nil {
t.Fatalf("first Lock: %v", err)
}
// A second acquisition must block until the first releases.
got := make(chan struct{})
go func() {
r2, err := store.Lock(ctx)
if err == nil {
r2()
}
close(got)
}()
select {
case <-got:
t.Fatal("second Lock acquired while first was held")
case <-time.After(100 * time.Millisecond):
// expected: still blocked
}
release()
select {
case <-got:
// expected: acquired after release
case <-time.After(2 * time.Second):
t.Fatal("second Lock did not acquire after release")
}
}
func TestLocalStoreLockContextCancel(t *testing.T) {
dir := t.TempDir()
store, _ := NewLocalStore(dir, "gfs-test")
release, err := store.Lock(context.Background())
if err != nil {
t.Fatalf("Lock: %v", err)
}
defer release()
ctx, cancel := context.WithCancel(context.Background())
cancel()
if _, err := store.Lock(ctx); err == nil {
t.Error("expected Lock to fail on cancelled context while held elsewhere")
}
}
func TestLocalStoreBeginWriteResume(t *testing.T) {
dir := t.TempDir()
store, err := NewLocalStore(dir, "gfs-test")

View file

@ -61,6 +61,12 @@ type Storage interface {
// BeginWrite opens (or resumes) a transactional handle for downloading
// id's dataset.
BeginWrite(id DatasetID) (TempHandle, error)
// Lock acquires an exclusive, storage-wide lock that serialises downloads
// across every process sharing this storage (e.g. multiple replicas on a
// node that share a dataset volume). It blocks until the lock is held or
// ctx is cancelled. The returned function releases the lock.
Lock(ctx context.Context) (release func(), err error)
}
// TempHandle is the storage state for one in-progress download.

View file

@ -16,8 +16,8 @@ import (
const (
CellsPerDegree = 120
NumLats = 180*CellsPerDegree + 1 // 21601
NumLons = 360 * CellsPerDegree // 43200
DataSize = NumLats * NumLons * 2 // 1,866,326,400 bytes (~1.74 GiB)
NumLons = 360 * CellsPerDegree // 43200
DataSize = NumLats * NumLons * 2 // 1,866,326,400 bytes (~1.74 GiB)
)
// Dataset is a memory-mapped global elevation grid.

View file

@ -2,7 +2,8 @@ package engine
import (
"fmt"
"math"
"predictor-refactored/internal/numerics"
)
// Altitude triggers when the balloon altitude satisfies Op against Limit.
@ -31,9 +32,9 @@ type Time struct {
On Action
}
func (c Time) Name() string { return fmt.Sprintf("time %s %g", c.Op, c.Limit) }
func (c Time) Violated(t float64, _ State) bool { return c.Op.Test(t, c.Limit) }
func (c Time) Action() Action { return c.On }
func (c Time) Name() string { return fmt.Sprintf("time %s %g", c.Op, c.Limit) }
func (c Time) Violated(t float64, _ State) bool { return c.Op.Test(t, c.Limit) }
func (c Time) Action() Action { return c.On }
// TerrainContact triggers when the ground elevation exceeds the balloon's
// altitude — i.e. the balloon has hit the ground.
@ -69,23 +70,30 @@ type PolygonVertex struct {
Lng float64
}
// Polygon is a constraint over a geographic polygon. The polygon is
// considered closed (last vertex connects to the first) and is interpreted
// in plate-carrée (rectangular lat/lng) coordinates with longitude
// wrap-around handling.
//
// Edges crossing the 180/-180 antimeridian are split via longitude
// normalisation against the polygon's centroid: callers that need
// great-circle accuracy should clip their polygon along the antimeridian
// before submitting.
// Polygon is a constraint over a closed geographic polygon, evaluated in
// plate-carrée coordinates with antimeridian handling (see
// numerics.PointInPolygon). Build one with NewPolygon so the flattened
// vertex slices used by the hot path are precomputed.
type Polygon struct {
Vertices []PolygonVertex
Mode PolygonMode
On Action
// Label, if set, is returned by Name. Defaults to "polygon_inside" or
// "polygon_outside" based on Mode.
Label string
// Precomputed parallel vertex slices for numerics.PointInPolygon.
polyLat, polyLng []float64
}
// NewPolygon builds a Polygon, precomputing the flattened vertex slices.
func NewPolygon(verts []PolygonVertex, mode PolygonMode, on Action, label string) Polygon {
lat := make([]float64, len(verts))
lng := make([]float64, len(verts))
for i, v := range verts {
lat[i], lng[i] = v.Lat, v.Lng
}
return Polygon{Vertices: verts, Mode: mode, On: on, Label: label, polyLat: lat, polyLng: lng}
}
func (c Polygon) Name() string {
@ -101,49 +109,9 @@ func (c Polygon) Action() Action { return c.On }
// Violated reports whether the state satisfies the polygon-containment rule.
func (c Polygon) Violated(_ float64, s State) bool {
if len(c.Vertices) < 3 {
return false
}
in := pointInPolygon(s.Lat, s.Lng, c.Vertices)
in := numerics.PointInPolygon(s.Lat, s.Lng, c.polyLat, c.polyLng)
if c.Mode == PolygonInside {
return in
}
return !in
}
// pointInPolygon implements the ray-casting algorithm in lat/lng space.
//
// All vertices and the query point are normalised to within 180° of
// verts[0] before testing, so a polygon spanning the antimeridian is
// handled correctly as long as the polygon itself spans no more than 180°
// in longitude.
func pointInPolygon(lat, lng float64, verts []PolygonVertex) bool {
if len(verts) == 0 {
return false
}
ref := verts[0].Lng
qx := normLng(lng, ref)
inside := false
n := len(verts)
for i, j := 0, n-1; i < n; j, i = i, i+1 {
yi, yj := verts[i].Lat, verts[j].Lat
xi := normLng(verts[i].Lng, ref)
xj := normLng(verts[j].Lng, ref)
if (yi > lat) != (yj > lat) {
xIntersect := (xj-xi)*(lat-yi)/(yj-yi) + xi
if qx < xIntersect {
inside = !inside
}
}
}
return inside
}
// normLng rewrites v so that it lies within 180° of ref. With ref=10 and
// v=350, normLng returns -10.
func normLng(v, ref float64) float64 {
diff := math.Mod(v-ref+540, 360) - 180
return ref + diff
}

View file

@ -46,13 +46,13 @@ func TestConstantAscentToBurst(t *testing.T) {
t.Errorf("RefinedState not populated")
}
last := results[0].Points[len(results[0].Points)-1]
lastT, last := results[0].Path.Last()
if math.Abs(last.Altitude-burst) > 5 {
t.Errorf("burst altitude = %v, want within 5m of %v", last.Altitude, burst)
}
wantTime := burst / rate
if math.Abs(last.Time-wantTime) > 1 {
t.Errorf("burst time = %v, want within 1s of %v", last.Time, wantTime)
if math.Abs(lastT-wantTime) > 1 {
t.Errorf("burst time = %v, want within 1s of %v", lastT, wantTime)
}
}
@ -87,7 +87,7 @@ func TestProfileWithFallback(t *testing.T) {
t.Errorf("second outcome = %v, want OutcomeStopped", results[1].Outcome)
}
last := results[1].Points[len(results[1].Points)-1]
_, last := results[1].Path.Last()
if math.Abs(last.Altitude) > 5 {
t.Errorf("final altitude = %v, want within 5m of 0", last.Altitude)
}
@ -103,12 +103,12 @@ func TestReverseDirection(t *testing.T) {
prof := Profile{Stages: []*Propagator{desc}, Direction: Reverse}
results := prof.Run(0, State{Altitude: 100}, NewEventSink())
last := results[0].Points[len(results[0].Points)-1]
lastT, last := results[0].Path.Last()
if math.Abs(last.Altitude-200) > 1 {
t.Errorf("reverse final altitude = %v, want ~200", last.Altitude)
}
if last.Time >= 0 {
t.Errorf("reverse final time = %v, want < 0", last.Time)
if lastT >= 0 {
t.Errorf("reverse final time = %v, want < 0", lastT)
}
}
@ -206,15 +206,25 @@ func TestWindTransportEmitsAboveModel(t *testing.T) {
}
}
func TestStateAddWrapsLongitude(t *testing.T) {
s := stateAdd(State{Lat: 0, Lng: 350, Altitude: 0}, 1, State{Lng: 20})
if math.Abs(s.Lng-10) > 1e-9 {
t.Errorf("addState wrap: lng = %v, want 10", s.Lng)
func TestNoTerminatorStopsAtStepCap(t *testing.T) {
// A stage that ascends forever with no constraint must not loop endlessly;
// the integrator's step backstop stops it and records a max_steps event.
sink := NewEventSink()
prof := Profile{
Stages: []*Propagator{{Name: "runaway", Step: 60, Model: ConstantRate(5)}},
Direction: Forward,
}
results := prof.Run(0, State{}, sink)
mid := stateLerp(State{Lng: 350}, State{Lng: 10}, 0.5)
if math.Abs(mid.Lng-0) > 1e-9 && math.Abs(mid.Lng-360) > 1e-9 {
t.Errorf("lerpState lng wrap: %v, want 0 or 360", mid.Lng)
if results[0].Outcome != OutcomeContinued {
t.Errorf("outcome = %v, want OutcomeContinued (step cap)", results[0].Outcome)
}
if results[0].Path.Len() != DefaultMaxSteps+1 {
t.Errorf("path len = %d, want %d", results[0].Path.Len(), DefaultMaxSteps+1)
}
ev := sink.Snapshot()
if len(ev) != 1 || ev[0].Type != "max_steps" {
t.Errorf("expected a max_steps event, got %+v", ev)
}
}
@ -226,7 +236,7 @@ func TestPolygonInside(t *testing.T) {
{Lat: 1, Lng: 1},
{Lat: 1, Lng: -1},
}
c := Polygon{Vertices: square, Mode: PolygonInside, On: ActionStop}
c := NewPolygon(square, PolygonInside, ActionStop, "")
if !c.Violated(0, State{Lat: 0, Lng: 0}) {
t.Errorf("origin should be inside the square")
}
@ -244,7 +254,7 @@ func TestPolygonOutsideAntimeridian(t *testing.T) {
{Lat: 10, Lng: 190},
{Lat: 10, Lng: 170},
}
c := Polygon{Vertices: poly, Mode: PolygonInside, On: ActionStop}
c := NewPolygon(poly, PolygonInside, ActionStop, "")
// A point at the antimeridian.
if !c.Violated(0, State{Lat: 0, Lng: 180}) {
t.Errorf("(0, 180) should be inside the antimeridian polygon")

View file

@ -4,6 +4,7 @@ import (
"math"
"sort"
"predictor-refactored/internal/numerics"
"predictor-refactored/internal/weather"
)
@ -45,29 +46,10 @@ func ConstantRate(rate float64) Model {
func ParachuteDescent(seaLevelRate float64) Model {
k := seaLevelRate * 1.1045
return func(_ float64, s State) State {
return State{Altitude: -k / math.Sqrt(nasaDensity(s.Altitude))}
return State{Altitude: -k / math.Sqrt(numerics.NasaDensity(s.Altitude))}
}
}
// nasaDensity returns air density (kg/m^3) for an altitude in metres,
// using the NASA simple atmosphere model.
// See https://www.grc.nasa.gov/WWW/K-12/airplane/atmosmet.html.
func nasaDensity(alt float64) float64 {
var temp, pressure float64
switch {
case alt > 25000:
temp = -131.21 + 0.00299*alt
pressure = 2.488 * math.Pow((temp+273.1)/216.6, -11.388)
case alt > 11000:
temp = -56.46
pressure = 22.65 * math.Exp(1.73-0.000157*alt)
default:
temp = 15.04 - 0.00649*alt
pressure = 101.29 * math.Pow((temp+273.1)/288.08, 5.256)
}
return pressure / (0.2869 * (temp + 273.1))
}
// RateSegment is one entry in a Piecewise rate schedule. Until is the UNIX
// timestamp at which this segment ends — the model emits the segment's
// Rate for all t < Until. The final segment's Rate is held indefinitely.

View file

@ -30,39 +30,30 @@ func (p *Profile) Run(t0 float64, launch State, events *EventSink) []Result {
results := make([]Result, 0, len(p.Stages))
t, s := t0, launch
for i := 0; i < len(p.Stages); i++ {
stage := p.Stages[i]
ctx := StageContext{
ProfileStart: t0,
PropagatorStart: t,
Launch: launch,
PropagatorState: s,
Direction: p.Direction,
}
res := stage.run(ctx, t, s, p.Globals, events)
for _, stage := range p.Stages {
res := stage.run(p.context(t0, t, launch, s), t, s, p.Globals, events)
results = append(results, res)
last := res.Points[len(res.Points)-1]
t = last.Time
s = State{Lat: last.Lat, Lng: last.Lng, Altitude: last.Altitude}
t, s = res.Path.Last()
// Follow Fallback chains until none remains.
for res.Outcome == OutcomeFallback && stage.Fallback != nil {
stage = stage.Fallback
ctx = StageContext{
ProfileStart: t0,
PropagatorStart: t,
Launch: launch,
PropagatorState: s,
Direction: p.Direction,
}
res = stage.run(ctx, t, s, p.Globals, events)
res = stage.run(p.context(t0, t, launch, s), t, s, p.Globals, events)
results = append(results, res)
last = res.Points[len(res.Points)-1]
t = last.Time
s = State{Lat: last.Lat, Lng: last.Lng, Altitude: last.Altitude}
t, s = res.Path.Last()
}
}
return results
}
// context builds the StageContext for a stage starting at (tStart, sStart).
func (p *Profile) context(t0, tStart float64, launch, sStart State) StageContext {
return StageContext{
ProfileStart: t0,
PropagatorStart: tStart,
Launch: launch,
PropagatorState: sStart,
Direction: p.Direction,
}
}

View file

@ -1,8 +1,6 @@
package engine
import (
"predictor-refactored/internal/numerics"
)
import "predictor-refactored/internal/numerics"
// Propagator advances state under one Model, checking a set of Constraints
// after every integration step.
@ -11,9 +9,12 @@ import (
// violation point and emits it as its final trajectory point. The Action of
// the triggering constraint controls what the surrounding Profile does
// next: stop the profile, transfer to Fallback, or clip and continue.
//
// The per-step numerics (RK4 stepping, crossing refinement) are delegated to
// the numerics package; this type owns only the orchestration: constraint
// evaluation, action dispatch, and trajectory assembly.
type Propagator struct {
// Name identifies the propagator in trajectory metadata. Optional —
// callers using sequential profile chains may leave it empty.
// Name identifies the propagator in trajectory metadata. Optional.
Name string
// Step is the magnitude of the integration step in seconds (always positive).
@ -39,6 +40,18 @@ type Propagator struct {
Tolerance float64
}
// estimatedSteps is the initial Path capacity; a typical balloon stage is a
// few hundred 60-second steps.
const estimatedSteps = 256
// DefaultMaxSteps bounds the number of integration steps a single propagator
// may take. It is a safety backstop, not a physical limit: a profile whose
// constraints never fire (e.g. a stage with no effective terminator) would
// otherwise integrate forever and exhaust memory. At the default 60-second
// step this allows ~8 simulated years, far beyond any real flight, so it only
// ever trips on a misconfigured profile.
const DefaultMaxSteps = 1_000_000
// run integrates the model from (t0, s0) in direction dir, returning a Result.
// globals are constraints injected by the Profile and checked alongside the
// propagator's local Constraints. events receives non-fatal observations.
@ -58,70 +71,53 @@ func (p *Propagator) run(ctx StageContext, t0 float64, s0 State, globals []Const
constraints = p.BuildConstraints(ctx)
}
deriv := numerics.Deriv[State](func(t float64, s State) State { return model(t, s) })
add := numerics.VecAdd[State](stateAdd)
lerp := numerics.VecLerp[State](stateLerp)
field := numerics.Field(model)
out := Result{
Propagator: p.Name,
Outcome: OutcomeContinued,
Points: []TrajectoryPoint{{
Time: t0, Lat: s0.Lat, Lng: s0.Lng, Altitude: s0.Altitude,
}},
}
out := Result{Propagator: p.Name, Outcome: OutcomeContinued, Path: numerics.NewPath(estimatedSteps)}
out.Path.Append(t0, s0)
t := t0
s := s0
for {
s2 := numerics.RK4Step(t, s, dt, deriv, add)
t, s := t0, s0
for range DefaultMaxSteps {
s2 := numerics.RK4Step(t, s, dt, field)
t2 := t + dt
c, fired := firstFiring(constraints, globals, t2, s2)
if !fired {
t, s = t2, s2
out.Points = append(out.Points, TrajectoryPoint{
Time: t, Lat: s.Lat, Lng: s.Lng, Altitude: s.Altitude,
})
out.Path.Append(t, s)
continue
}
// Record the unrefined violation.
out.ViolationTime = t2
out.ViolationState = s2
out.ViolationTime, out.ViolationState = t2, s2
t3, s3 := numerics.RefineCrossing(t, s, t2, s2, c.Violated, tol)
out.Constraint, out.ConstraintName = c, c.Name()
trig := numerics.Trigger[State](func(tt float64, ss State) bool { return c.Violated(tt, ss) })
t3, s3 := numerics.RefineTrigger(t, s, t2, s2, trig, lerp, tol)
out.RefinedTime = t3
out.RefinedState = s3
out.Constraint = c
out.ConstraintName = c.Name()
switch c.Action() {
case ActionClip:
if c.Action() == ActionClip {
s3 = clipToConstraint(c, s3)
out.RefinedState = s3
out.Points = append(out.Points, TrajectoryPoint{
Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude,
})
out.RefinedTime, out.RefinedState = t3, s3
out.Path.Append(t3, s3)
t, s = t3, s3
continue
case ActionFallback:
out.Points = append(out.Points, TrajectoryPoint{
Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude,
})
out.Outcome = OutcomeFallback
out.Events = events.Snapshot()
return out
default: // ActionStop
out.Points = append(out.Points, TrajectoryPoint{
Time: t3, Lat: s3.Lat, Lng: s3.Lng, Altitude: s3.Altitude,
})
out.Outcome = OutcomeStopped
out.Events = events.Snapshot()
return out
}
out.RefinedTime, out.RefinedState = t3, s3
out.Path.Append(t3, s3)
if c.Action() == ActionFallback {
out.Outcome = OutcomeFallback
} else {
out.Outcome = OutcomeStopped
}
out.Events = events.Snapshot()
return out
}
// Step cap reached without any constraint firing — the profile has no
// effective terminator for this stage. Stop safely rather than loop forever.
events.Emit("max_steps", t, s,
"integration step limit reached without a constraint firing; check the stage's terminator")
out.Outcome = OutcomeContinued
out.Events = events.Snapshot()
return out
}
// firstFiring scans local then global constraints for the first one whose
@ -140,9 +136,9 @@ func firstFiring(local, globals []Constraint, t float64, s State) (Constraint, b
return nil, false
}
// clipToConstraint adjusts s so that the given constraint is exactly
// satisfied (not violated). Defined only for constraints with a
// well-defined coordinate boundary; others fall through unchanged.
// clipToConstraint adjusts s so the given constraint is exactly satisfied.
// Defined only for constraints with a well-defined coordinate boundary;
// others fall through unchanged.
func clipToConstraint(c Constraint, s State) State {
if alt, ok := c.(Altitude); ok {
s.Altitude = alt.Limit

View file

@ -72,7 +72,7 @@ type BuiltModel struct {
}
var (
regMu sync.RWMutex
regMu sync.RWMutex
constraintFactories = map[string]ConstraintFactory{}
modelFactories = map[string]ModelFactory{}
)
@ -202,7 +202,7 @@ func buildPolygon(spec ConstraintSpec, _ BuildDeps) (Constraint, error) {
default:
return nil, fmt.Errorf("polygon: unknown mode %q", spec.Mode)
}
return Polygon{Vertices: spec.Vertices, Mode: mode, On: act, Label: spec.Label}, nil
return NewPolygon(spec.Vertices, mode, act, spec.Label), nil
}
func buildConstantRate(spec ModelSpec, _ BuildDeps) (BuiltModel, error) {
@ -224,34 +224,19 @@ func buildWind(_ ModelSpec, deps BuildDeps) (BuiltModel, error) {
}
func buildPiecewise(spec ModelSpec, deps BuildDeps) (BuiltModel, error) {
needsCtx := false
for _, seg := range spec.Segments {
if seg.Reference == "propagator_start" {
needsCtx = true
break
for _, s := range spec.Segments {
switch s.Reference {
case "", "absolute", "profile_start", "propagator_start":
default:
return BuiltModel{}, fmt.Errorf("piecewise: unknown segment reference %q", s.Reference)
}
}
if !needsCtx {
// Eager build: resolve any "profile_start" relative segments using
// the launch time we know at build time only when we have one.
// Without context, treat profile_start the same as absolute (the
// caller is expected to pre-resolve), and absolute as absolute.
segs := make([]RateSegment, 0, len(spec.Segments))
for _, s := range spec.Segments {
if s.Reference == "profile_start" {
return BuiltModel{}, fmt.Errorf("piecewise: profile_start reference requires a stage context — supply via lazy build")
}
segs = append(segs, RateSegment{Until: s.Until, Rate: s.Rate})
}
base := Piecewise(segs)
return BuiltModel{Model: maybeAddWind(base, spec.IncludeWind, deps)}, nil
}
// Lazy build — captures spec into a closure.
// Always build lazily: the profile runner supplies a StageContext before
// each stage, which is what resolves absolute / profile-relative /
// propagator-relative segment times uniformly.
return BuiltModel{
Build: func(ctx StageContext) Model {
segs := resolveSegments(spec.Segments, ctx)
base := Piecewise(segs)
return maybeAddWind(base, spec.IncludeWind, deps)
return maybeAddWind(Piecewise(resolveSegments(spec.Segments, ctx)), spec.IncludeWind, deps)
},
}, nil
}

View file

@ -1,50 +0,0 @@
package engine
import "math"
// pymod returns a % b with Python semantics: the result has the sign of b,
// so for b > 0 the result is always in [0, b).
func pymod(a, b float64) float64 {
r := math.Mod(a, b)
if r < 0 {
r += b
}
return r
}
// stateAdd is the RK4 integrator's update operation y + k*dy, with longitude
// kept wrapped to [0, 360).
//
// Time is not stored in State — it is tracked separately by the integrator
// and passed to Model.
func stateAdd(y State, k float64, dy State) State {
return State{
Lat: y.Lat + k*dy.Lat,
Lng: pymod(y.Lng+k*dy.Lng, 360),
Altitude: y.Altitude + k*dy.Altitude,
}
}
// stateLerp computes the linear interpolation of two states by parameter l
// in [0, 1]. Longitude uses lngLerp so that wrap-around is handled.
func stateLerp(a, b State, l float64) State {
return State{
Lat: (1-l)*a.Lat + l*b.Lat,
Lng: lngLerp(a.Lng, b.Lng, l),
Altitude: (1-l)*a.Altitude + l*b.Altitude,
}
}
// lngLerp interpolates between two longitudes in [0, 360), choosing the
// shorter great-circle arc.
func lngLerp(a, b, l float64) float64 {
l2 := 1 - l
if a > b {
a, b = b, a
l, l2 = l2, l
}
if b-a < 180 {
return l2*a + l*b
}
return pymod(l2*(a+360)+l*b, 360)
}

View file

@ -2,21 +2,23 @@
// propagators (model-driven integrators) into profiles (ordered chains)
// over a wind field.
//
// The engine orchestrates the calculation; the numerically heavy work
// (RK4 stepping, crossing refinement, interpolation, atmosphere density,
// vector and polygon math) lives in the numerics package so it can be
// reimplemented in a faster language without touching this layer.
//
// The engine has no direct dependency on any specific data source: wind
// data is consumed through weather.WindField and terrain data through
// any type satisfying TerrainProvider.
package engine
// State holds the spatial state of the balloon. When returned by a Model
// the same struct is interpreted as the per-second time derivative.
type State struct {
// Lat is degrees latitude in [-90, 90].
Lat float64 `json:"lat"`
// Lng is degrees longitude in [0, 360).
Lng float64 `json:"lng"`
// Altitude is metres above mean sea level.
Altitude float64 `json:"altitude"`
}
import "predictor-refactored/internal/numerics"
// State is the spatial state of the balloon: latitude/longitude in degrees,
// altitude in metres. When returned by a Model the same struct is the
// per-second derivative. It is an alias of numerics.GeoVec so the engine and
// the numeric core share one hot-path value type without conversions.
type State = numerics.GeoVec
// Model returns the time derivative of state at (t, s).
//
@ -24,14 +26,6 @@ type State struct {
// sign of dt for reverse propagation.
type Model func(t float64, s State) State
// TrajectoryPoint is one sampled point of an integration result.
type TrajectoryPoint struct {
Time float64 // UNIX seconds
Lat float64
Lng float64
Altitude float64
}
// Direction is the time direction of integration.
type Direction int8
@ -134,8 +128,8 @@ type Result struct {
// Propagator is the propagator's Name.
Propagator string
// Points is the emitted trajectory.
Points []TrajectoryPoint
// Path is the emitted trajectory in struct-of-arrays form.
Path numerics.Path
// Outcome describes how the propagator terminated.
Outcome Outcome

View file

@ -31,6 +31,6 @@ func Noop() Sink { return noop{} }
type noop struct{}
func (noop) Prediction(string, time.Duration, error) {}
func (noop) Prediction(string, time.Duration, error) {}
func (noop) Download(string, time.Duration, string, int64) {}
func (noop) ActiveEpoch(time.Time) {}
func (noop) ActiveEpoch(time.Time) {}

View file

@ -0,0 +1,25 @@
package numerics
import "math"
// NasaDensity returns air density in kg/m^3 at the given altitude in metres,
// using the NASA piecewise standard-atmosphere model.
// See https://www.grc.nasa.gov/WWW/K-12/airplane/atmosmet.html.
//
// The model is split into three altitude bands (troposphere, lower
// stratosphere, upper stratosphere); density is pressure / (0.2869 * T_K).
func NasaDensity(alt float64) float64 {
var temp, pressure float64
switch {
case alt > 25000:
temp = -131.21 + 0.00299*alt
pressure = 2.488 * math.Pow((temp+273.1)/216.6, -11.388)
case alt > 11000:
temp = -56.46
pressure = 22.65 * math.Exp(1.73-0.000157*alt)
default:
temp = 15.04 - 0.00649*alt
pressure = 101.29 * math.Pow((temp+273.1)/288.08, 5.256)
}
return pressure / (0.2869 * (temp + 273.1))
}

View file

@ -0,0 +1,41 @@
package numerics
import "math"
// PointInPolygon reports whether (lat, lng) lies inside the closed polygon
// whose vertices are given as parallel latitude/longitude slices (degrees).
//
// The test is ray casting in plate-carrée space. Every longitude is
// normalised to within 180° of the first vertex before testing, so a polygon
// spanning the antimeridian is handled correctly as long as it spans no more
// than 180° in longitude. polyLat and polyLng must have equal length >= 3.
func PointInPolygon(lat, lng float64, polyLat, polyLng []float64) bool {
n := len(polyLat)
if n < 3 || len(polyLng) != n {
return false
}
ref := polyLng[0]
qx := NormalizeLng(lng, ref)
inside := false
for i, j := 0, n-1; i < n; j, i = i, i+1 {
yi, yj := polyLat[i], polyLat[j]
xi := NormalizeLng(polyLng[i], ref)
xj := NormalizeLng(polyLng[j], ref)
if (yi > lat) != (yj > lat) {
xIntersect := (xj-xi)*(lat-yi)/(yj-yi) + xi
if qx < xIntersect {
inside = !inside
}
}
}
return inside
}
// NormalizeLng rewrites v so that it lies within 180° of ref. For example,
// NormalizeLng(350, 10) returns -10. Used to make longitude comparisons
// continuous across the antimeridian.
func NormalizeLng(v, ref float64) float64 {
return ref + math.Mod(v-ref+540, 360) - 180
}

View file

@ -56,31 +56,74 @@ func (a Axis) Locate(value float64) (Bracket, error) {
return Bracket{Lo: lo, Hi: hi, Frac: pos - float64(lo)}, nil
}
// EvalTrilinear samples a 3D field via f at the eight corners defined by b3
// and returns the trilinearly interpolated value.
// TrilinearWeights returns the eight corner weights for a (axis0, axis1,
// axis2) bracket triple, in the canonical visiting order
//
// The corners are visited in the order (axis0 outer, axis2 inner), matching
// the Cython reference. With f(i,j,k) = a*i + b*j + c*k + d this returns
// a*pos0 + b*pos1 + c*pos2 + d exactly, modulo floating-point rounding.
func EvalTrilinear(b3 [3]Bracket, f func(i, j, k int) float64) float64 {
// (0,0,0) (0,0,1) (0,1,0) (0,1,1) (1,0,0) (1,0,1) (1,1,0) (1,1,1)
//
// where the bit triple selects Lo (0) or Hi (1) on each axis. The weights sum
// to 1. Pair this with Dot8 over corner values fetched in the same order.
func TrilinearWeights(b3 [3]Bracket) [8]float64 {
wa0, wa1 := 1-b3[0].Frac, b3[0].Frac
wb0, wb1 := 1-b3[1].Frac, b3[1].Frac
wc0, wc1 := 1-b3[2].Frac, b3[2].Frac
wa0wb0 := wa0 * wb0
wa0wb1 := wa0 * wb1
wa1wb0 := wa1 * wb0
wa1wb1 := wa1 * wb1
return [8]float64{
wa0wb0 * wc0,
wa0wb0 * wc1,
wa0wb1 * wc0,
wa0wb1 * wc1,
wa1wb0 * wc0,
wa1wb0 * wc1,
wa1wb1 * wc0,
wa1wb1 * wc1,
}
}
// Dot8 returns the multiply-accumulate sum w[0]*v[0] + ... + w[7]*v[7].
//
// The fixed length and straight-line accumulation are written so the Go
// compiler can keep the values in registers and a future hand-vectorised
// port can replace the body with a single SIMD MAC. The accumulation order
// is fixed (ascending index) so results are reproducible.
func Dot8(w, v *[8]float64) float64 {
acc := w[0] * v[0]
acc = w[1]*v[1] + acc
acc = w[2]*v[2] + acc
acc = w[3]*v[3] + acc
acc = w[4]*v[4] + acc
acc = w[5]*v[5] + acc
acc = w[6]*v[6] + acc
acc = w[7]*v[7] + acc
return acc
}
// EvalTrilinear samples a 3D field via f at the eight corners defined by b3
// and returns the trilinearly interpolated value.
//
// Corners are visited in the canonical order documented on TrilinearWeights.
// With f(i,j,k) = a*i + b*j + c*k + d this returns a*pos0 + b*pos1 + c*pos2
// + d, modulo floating-point rounding. For the hot path prefer precomputing
// weights once via TrilinearWeights and reducing with Dot8.
func EvalTrilinear(b3 [3]Bracket, f func(i, j, k int) float64) float64 {
w := TrilinearWeights(b3)
a0, a1 := b3[0].Lo, b3[0].Hi
bb0, bb1 := b3[1].Lo, b3[1].Hi
b0, b1 := b3[1].Lo, b3[1].Hi
c0, c1 := b3[2].Lo, b3[2].Hi
return wa0*wb0*wc0*f(a0, bb0, c0) +
wa0*wb0*wc1*f(a0, bb0, c1) +
wa0*wb1*wc0*f(a0, bb1, c0) +
wa0*wb1*wc1*f(a0, bb1, c1) +
wa1*wb0*wc0*f(a1, bb0, c0) +
wa1*wb0*wc1*f(a1, bb0, c1) +
wa1*wb1*wc0*f(a1, bb1, c0) +
wa1*wb1*wc1*f(a1, bb1, c1)
}
// Lerp returns (1-l)*a + l*b.
func Lerp(a, b, l float64) float64 {
return (1-l)*a + l*b
v := [8]float64{
f(a0, b0, c0),
f(a0, b0, c1),
f(a0, b1, c0),
f(a0, b1, c1),
f(a1, b0, c0),
f(a1, b0, c1),
f(a1, b1, c0),
f(a1, b1, c1),
}
return Dot8(&w, &v)
}

View file

@ -1,57 +1,46 @@
package numerics
// VecAdd computes y + k*dy on the domain state type S.
// Any coordinate-wrap or other domain-specific operation lives here.
type VecAdd[S any] func(y S, k float64, dy S) S
// Field returns the time derivative of a geographic state at (t, y).
// The derivative is direction-independent; the integrator applies the sign
// of dt for reverse-time integration.
type Field func(t float64, y GeoVec) GeoVec
// VecLerp computes (1-l)*a + l*b on the domain state type S.
type VecLerp[S any] func(a, b S, l float64) S
// Deriv computes the time derivative of state.
type Deriv[S any] func(t float64, y S) S
// Trigger reports whether a termination condition holds at (t, y).
type Trigger[S any] func(t float64, y S) bool
// Crossed reports whether a termination condition holds at (t, y).
type Crossed func(t float64, y GeoVec) bool
// RK4Step performs one classical Runge-Kutta-4 step from (t, y) with step dt.
// dt may be negative to integrate backwards in time.
func RK4Step[S any](t float64, y S, dt float64, deriv Deriv[S], add VecAdd[S]) S {
k1 := deriv(t, y)
k2 := deriv(t+dt/2, add(y, dt/2, k1))
k3 := deriv(t+dt/2, add(y, dt/2, k2))
k4 := deriv(t+dt, add(y, dt, k3))
// dt may be negative to integrate backwards in time. Longitude wrapping is
// applied at every intermediate add via GeoAdd, matching the reference
// integrator. The function performs no heap allocation.
func RK4Step(t float64, y GeoVec, dt float64, f Field) GeoVec {
half := dt / 2
k1 := f(t, y)
k2 := f(t+half, GeoAdd(y, half, k1))
k3 := f(t+half, GeoAdd(y, half, k2))
k4 := f(t+dt, GeoAdd(y, dt, k3))
y2 := y
y2 = add(y2, dt/6, k1)
y2 = add(y2, dt/3, k2)
y2 = add(y2, dt/3, k3)
y2 = add(y2, dt/6, k4)
y2 := GeoAdd(y, dt/6, k1)
y2 = GeoAdd(y2, dt/3, k2)
y2 = GeoAdd(y2, dt/3, k3)
y2 = GeoAdd(y2, dt/6, k4)
return y2
}
// RefineTrigger locates the trigger point between (t1, y1) (trigger not fired)
// and (t2, y2) (trigger fired) via binary search in the linear-interpolation
// parameter space, stopping when the parameter interval is narrower than tol.
// RefineCrossing locates a crossing between (t1, y1) (not crossed) and
// (t2, y2) (crossed) by binary search in the linear-interpolation parameter
// space, stopping when the parameter interval is narrower than tol.
//
// Returns the final midpoint sampled, matching the behavior of Tawhiri's
// solver.pyx (the returned point is *not* guaranteed to satisfy the trigger;
// for tol << 1 the difference is at most one tolerance-width either side).
func RefineTrigger[S any](
t1 float64, y1 S,
t2 float64, y2 S,
trigger Trigger[S],
lerp VecLerp[S],
tol float64,
) (float64, S) {
// It returns the final midpoint sampled, matching Tawhiri's solver.pyx: the
// returned point is not guaranteed to satisfy the predicate, but for tol << 1
// it is within one tolerance-width of the true crossing.
func RefineCrossing(t1 float64, y1 GeoVec, t2 float64, y2 GeoVec, crossed Crossed, tol float64) (float64, GeoVec) {
left, right := 0.0, 1.0
t3 := t2
y3 := y2
t3, y3 := t2, y2
for right-left > tol {
mid := (left + right) / 2
t3 = Lerp(t1, t2, mid)
y3 = lerp(y1, y2, mid)
if trigger(t3, y3) {
y3 = GeoLerp(y1, y2, mid)
if crossed(t3, y3) {
right = mid
} else {
left = mid
@ -59,3 +48,47 @@ func RefineTrigger[S any](
}
return t3, y3
}
// Path is a struct-of-arrays trajectory: parallel slices of time and the
// three state components. SoA layout keeps each component contiguous, which
// is friendlier to cache and to vectorised post-processing than a slice of
// point structs, and lets the integrator append with a single bounds check
// per component.
type Path struct {
T []float64
Lat []float64
Lng []float64
Altitude []float64
}
// NewPath returns a Path with capacity reserved for n points.
func NewPath(n int) Path {
return Path{
T: make([]float64, 0, n),
Lat: make([]float64, 0, n),
Lng: make([]float64, 0, n),
Altitude: make([]float64, 0, n),
}
}
// Len returns the number of points in the path.
func (p *Path) Len() int { return len(p.T) }
// Append adds one point to the path.
func (p *Path) Append(t float64, y GeoVec) {
p.T = append(p.T, t)
p.Lat = append(p.Lat, y.Lat)
p.Lng = append(p.Lng, y.Lng)
p.Altitude = append(p.Altitude, y.Altitude)
}
// Last returns the final (t, state) of the path. It panics on an empty path.
func (p *Path) Last() (float64, GeoVec) {
i := len(p.T) - 1
return p.T[i], GeoVec{Lat: p.Lat[i], Lng: p.Lng[i], Altitude: p.Altitude[i]}
}
// At returns the point at index i.
func (p *Path) At(i int) (float64, GeoVec) {
return p.T[i], GeoVec{Lat: p.Lat[i], Lng: p.Lng[i], Altitude: p.Altitude[i]}
}

View file

@ -5,57 +5,74 @@ import (
"testing"
)
// scalarAdd / scalarLerp let us drive RK4 on a plain float64.
func scalarAdd(y float64, k float64, dy float64) float64 { return y + k*dy }
func scalarLerpF(a, b float64, l float64) float64 { return Lerp(a, b, l) }
func TestRK4ExponentialDecay(t *testing.T) {
// dy/dt = -y → exact: y(t) = y0 * exp(-t).
deriv := func(_ float64, y float64) float64 { return -y }
// dAlt/dt = -Alt → exact: Alt(t) = Alt0 * exp(-t).
f := func(_ float64, y GeoVec) GeoVec { return GeoVec{Altitude: -y.Altitude} }
y := 1.0
tnow := 0.0
dt := 0.01
y := GeoVec{Altitude: 1}
tnow, dt := 0.0, 0.01
for range 100 {
y = RK4Step(tnow, y, dt, deriv, scalarAdd)
y = RK4Step(tnow, y, dt, f)
tnow += dt
}
want := math.Exp(-1.0)
if math.Abs(y-want) > 1e-8 {
t.Errorf("RK4 exp decay at t=1: got %v, want %v (diff %v)", y, want, y-want)
if math.Abs(y.Altitude-want) > 1e-8 {
t.Errorf("RK4 exp decay at t=1: got %v, want %v", y.Altitude, want)
}
}
func TestRK4ReverseTime(t *testing.T) {
// dy/dt = y → exact: y(t) = y0 * exp(t).
// Integrating from t=1 backwards with dt=-0.01 over 100 steps should give y0.
deriv := func(_ float64, y float64) float64 { return y }
// dAlt/dt = Alt → exact: Alt(t) = Alt0 * exp(t).
f := func(_ float64, y GeoVec) GeoVec { return GeoVec{Altitude: y.Altitude} }
y := math.E
tnow := 1.0
dt := -0.01
y := GeoVec{Altitude: math.E}
tnow, dt := 1.0, -0.01
for range 100 {
y = RK4Step(tnow, y, dt, deriv, scalarAdd)
y = RK4Step(tnow, y, dt, f)
tnow += dt
}
if math.Abs(y-1.0) > 1e-8 {
t.Errorf("RK4 reverse: got %v, want 1.0 (diff %v)", y, y-1.0)
if math.Abs(y.Altitude-1.0) > 1e-8 {
t.Errorf("RK4 reverse: got %v, want 1.0", y.Altitude)
}
}
func TestRefineTrigger(t *testing.T) {
// y crosses 0 at l=0.4 between y1=1 and y2=-1.5.
y1, y2 := 1.0, -1.5
t1, t2 := 0.0, 1.0
trig := func(_ float64, y float64) bool { return y <= 0 }
func TestRefineCrossing(t *testing.T) {
y1 := GeoVec{Altitude: 1}
y2 := GeoVec{Altitude: -1.5}
crossed := func(_ float64, y GeoVec) bool { return y.Altitude <= 0 }
tr, yr := RefineTrigger(t1, y1, t2, y2, trig, scalarLerpF, 0.001)
// The exact crossing is at l = 1/(1+1.5) = 0.4 → t = 0.4, y = 0.
tr, yr := RefineCrossing(0, y1, 1, y2, crossed, 0.001)
if math.Abs(tr-0.4) > 0.01 {
t.Errorf("Refined t = %v, want ~0.4", tr)
t.Errorf("refined t = %v, want ~0.4", tr)
}
if math.Abs(yr) > 0.01 {
t.Errorf("Refined y = %v, want ~0", yr)
if math.Abs(yr.Altitude) > 0.01 {
t.Errorf("refined alt = %v, want ~0", yr.Altitude)
}
}
func TestGeoAddWrapsLongitude(t *testing.T) {
y := GeoAdd(GeoVec{Lng: 350}, 1, GeoVec{Lng: 20})
if math.Abs(y.Lng-10) > 1e-9 {
t.Errorf("GeoAdd wrap: lng = %v, want 10", y.Lng)
}
}
func TestGeoLerpWrap(t *testing.T) {
mid := GeoLerp(GeoVec{Lng: 350}, GeoVec{Lng: 10}, 0.5)
if math.Abs(mid.Lng) > 1e-9 && math.Abs(mid.Lng-360) > 1e-9 {
t.Errorf("GeoLerp lng wrap: %v, want 0 or 360", mid.Lng)
}
}
func TestPathSoA(t *testing.T) {
p := NewPath(4)
p.Append(0, GeoVec{Lat: 1, Lng: 2, Altitude: 3})
p.Append(60, GeoVec{Lat: 4, Lng: 5, Altitude: 6})
if p.Len() != 2 {
t.Fatalf("len = %d, want 2", p.Len())
}
tt, last := p.Last()
if tt != 60 || last.Lat != 4 {
t.Errorf("last = %v, %+v", tt, last)
}
}

66
internal/numerics/vec.go Normal file
View file

@ -0,0 +1,66 @@
package numerics
import "math"
// GeoVec is a geographic state vector: latitude and longitude in degrees and
// altitude in metres. The same struct represents a per-second derivative,
// in which case the fields are deg/s and m/s.
//
// GeoVec is the hot-path state type for the integrator. It is a small value
// type (three float64) and is passed by value to stay allocation-free; a
// future SIMD/SoA batch integrator can lift these fields into parallel
// slices (see Path).
type GeoVec struct {
Lat float64 `json:"lat"`
Lng float64 `json:"lng"`
Altitude float64 `json:"altitude"`
}
// PyMod returns a mod b with Python semantics: the result carries the sign of
// b, so for b > 0 it always lies in [0, b).
func PyMod(a, b float64) float64 {
r := math.Mod(a, b)
if r < 0 {
r += b
}
return r
}
// GeoAdd returns y + k*dy with longitude wrapped to [0, 360). Latitude and
// altitude accumulate linearly. This is the integrator's state-update step.
func GeoAdd(y GeoVec, k float64, dy GeoVec) GeoVec {
return GeoVec{
Lat: y.Lat + k*dy.Lat,
Lng: PyMod(y.Lng+k*dy.Lng, 360),
Altitude: y.Altitude + k*dy.Altitude,
}
}
// GeoLerp linearly interpolates two geographic states by parameter l in
// [0, 1]. Longitude takes the shorter great-circle arc.
func GeoLerp(a, b GeoVec, l float64) GeoVec {
return GeoVec{
Lat: (1-l)*a.Lat + l*b.Lat,
Lng: LngLerp(a.Lng, b.Lng, l),
Altitude: (1-l)*a.Altitude + l*b.Altitude,
}
}
// LngLerp interpolates between two longitudes in [0, 360), choosing the
// shorter arc and wrapping the result back into range.
func LngLerp(a, b, l float64) float64 {
l2 := 1 - l
if a > b {
a, b = b, a
l, l2 = l2, l
}
if b-a < 180 {
return l2*a + l*b
}
return PyMod(l2*(a+360)+l*b, 360)
}
// Lerp returns (1-l)*a + l*b.
func Lerp(a, b, l float64) float64 {
return (1-l)*a + l*b
}

View file

@ -108,6 +108,14 @@ func (d *File) Val(hour, level, variable, lat, lng int) float32 {
return math.Float32frombits(binary.LittleEndian.Uint32(d.mm[off : off+4]))
}
// ValByElem reads the float32 at a precomputed flat element index (not a byte
// offset). The wind sampler uses this to read the eight interpolation corners
// after computing their flat indices once via cube strides.
func (d *File) ValByElem(elem int64) float32 {
off := elem * ElementSize
return math.Float32frombits(binary.LittleEndian.Uint32(d.mm[off : off+4]))
}
// SetVal writes one cell. Only valid on writable files.
func (d *File) SetVal(hour, level, variable, lat, lng int, val float32) {
off := d.offset(hour, level, variable, lat, lng)

View file

@ -8,39 +8,40 @@ import (
)
// Wind is a WindField backed by a GFS dataset file.
//
// The cube is addressed in flat element units with fixed strides so the
// sampler can compute the eight horizontal interpolation corners once and
// reach any (level, variable) by adding constant strides — avoiding the
// five-multiply offset computation per corner per evaluation.
type Wind struct {
file *File
hourAxis numerics.Axis
latAxis numerics.Axis
lngAxis numerics.Axis
hourStride int64 // elements between successive hours
levelStride int64 // elements between successive pressure levels
varStride int64 // elements between successive variables
latStride int64 // elements between successive latitudes
}
// NewWind returns a Wind backed by file. The axes are constructed from the
// file's variant geometry.
// NewWind returns a Wind backed by file. Axes and strides are derived from
// the file's variant geometry.
func NewWind(file *File) *Wind {
v := file.variant
nLat := v.NumLatitudes()
nLng := v.NumLongitudes()
nLev := v.NumLevels()
return &Wind{
file: file,
hourAxis: numerics.Axis{
Left: 0,
Step: float64(v.HourStep),
N: v.NumHours(),
Name: "hour",
},
latAxis: numerics.Axis{
Left: LatStart,
Step: v.Resolution,
N: v.NumLatitudes(),
Name: "lat",
},
lngAxis: numerics.Axis{
Left: LonStart,
Step: v.Resolution,
N: v.NumLongitudes(),
Wrap: true,
Name: "lng",
},
file: file,
hourAxis: numerics.Axis{Left: 0, Step: float64(v.HourStep), N: v.NumHours(), Name: "hour"},
latAxis: numerics.Axis{Left: LatStart, Step: v.Resolution, N: nLat, Name: "lat"},
lngAxis: numerics.Axis{Left: LonStart, Step: v.Resolution, N: nLng, Wrap: true, Name: "lng"},
hourStride: int64(nLev) * NumVariables * int64(nLat) * int64(nLng),
levelStride: NumVariables * int64(nLat) * int64(nLng),
varStride: int64(nLat) * int64(nLng),
latStride: int64(nLng),
}
}
@ -72,38 +73,53 @@ func (w *Wind) Wind(t, lat, lng, alt float64) (weather.Sample, error) {
if err != nil {
return weather.Sample{}, err
}
bs := [3]numerics.Bracket{bh, bla, bln}
height := func(level int) func(i, j, k int) float64 {
return func(i, j, k int) float64 {
return float64(w.file.Val(i, level, VarHeight, j, k))
weights := numerics.TrilinearWeights([3]numerics.Bracket{bh, bla, bln})
// Flat element index of each of the eight horizontal corners, at level 0
// variable 0, in the canonical TrilinearWeights order (hour outer, lng
// inner). Reaching a given (level, variable) corner only adds constant
// strides.
var base [8]int64
hours2 := [2]int64{int64(bh.Lo) * w.hourStride, int64(bh.Hi) * w.hourStride}
lats2 := [2]int64{int64(bla.Lo) * w.latStride, int64(bla.Hi) * w.latStride}
lngs2 := [2]int64{int64(bln.Lo), int64(bln.Hi)}
i := 0
for _, h := range hours2 {
for _, la := range lats2 {
for _, ln := range lngs2 {
base[i] = h + la + ln
i++
}
}
}
sample := func(level int, varIdx int64) float64 {
off := int64(level)*w.levelStride + varIdx*w.varStride
var vals [8]float64
for k := range 8 {
vals[k] = float64(w.file.ValByElem(base[k] + off))
}
return numerics.Dot8(&weights, &vals)
}
// Largest pressure level whose interpolated geopotential height is below alt.
levelIdx := numerics.Bisect(0, w.file.variant.NumLevels()-2, alt, func(level int) float64 {
return numerics.EvalTrilinear(bs, height(level))
return sample(level, VarHeight)
})
lowerHGT := numerics.EvalTrilinear(bs, height(levelIdx))
upperHGT := numerics.EvalTrilinear(bs, height(levelIdx+1))
lowerHGT := sample(levelIdx, VarHeight)
upperHGT := sample(levelIdx+1, VarHeight)
var altFrac float64
altFrac := 0.5
if lowerHGT != upperHGT {
altFrac = (upperHGT - alt) / (upperHGT - lowerHGT)
} else {
altFrac = 0.5
}
component := func(level, variable int) float64 {
return numerics.EvalTrilinear(bs, func(i, j, k int) float64 {
return float64(w.file.Val(i, level, variable, j, k))
})
}
lowerU := component(levelIdx, VarWindU)
upperU := component(levelIdx+1, VarWindU)
lowerV := component(levelIdx, VarWindV)
upperV := component(levelIdx+1, VarWindV)
lowerU := sample(levelIdx, VarWindU)
upperU := sample(levelIdx+1, VarWindU)
lowerV := sample(levelIdx, VarWindV)
upperV := sample(levelIdx+1, VarWindV)
return weather.Sample{
U: lowerU*altFrac + upperU*(1-altFrac),

View file

@ -0,0 +1,69 @@
package gfs
import (
"math"
"path/filepath"
"testing"
"time"
)
// testVariant is a tiny cube (2 hours × 3 levels × 3 lat × 4 lng) used to
// exercise the sampler without allocating a multi-gigabyte real dataset.
func testVariant() *Variant {
return &Variant{
ID: "gfs-test",
ResToken: "test",
Resolution: 90, // 180/90+1 = 3 lats, 360/90 = 4 lngs
HourStep: 3,
MaxHour: 3, // 2 hours
Pressures: []int{1000, 500, 100},
PressuresPgrb2: []int{1000, 500, 100},
PressuresPgrb2b: []int{},
}
}
func TestWindSampler(t *testing.T) {
v := testVariant()
path := filepath.Join(t.TempDir(), "cube.bin")
f, err := Create(path, v)
if err != nil {
t.Fatalf("Create: %v", err)
}
// HGT increases with level so the altitude bisection has a gradient;
// U and V are constant so interpolation must return them exactly.
for h := range v.NumHours() {
for lvl := range v.NumLevels() {
for la := range v.NumLatitudes() {
for ln := range v.NumLongitudes() {
f.SetVal(h, lvl, VarHeight, la, ln, float32(lvl*1000))
f.SetVal(h, lvl, VarWindU, la, ln, 7)
f.SetVal(h, lvl, VarWindV, la, ln, 3)
}
}
}
}
f.Flush()
f.Close()
epoch := time.Date(2026, 1, 1, 0, 0, 0, 0, time.UTC)
rf, err := Open(path, v, epoch)
if err != nil {
t.Fatalf("Open: %v", err)
}
defer rf.Close()
w := NewWind(rf)
// Query at the dataset epoch, equator, lng 45, altitude 500m (between
// level 0 @ 0m and level 1 @ 1000m).
s, err := w.Wind(float64(epoch.Unix()), 0, 45, 500)
if err != nil {
t.Fatalf("Wind: %v", err)
}
if math.Abs(s.U-7) > 1e-5 || math.Abs(s.V-3) > 1e-5 {
t.Errorf("constant wind not recovered: got U=%v V=%v, want 7,3", s.U, s.V)
}
if s.AboveModel {
t.Errorf("AboveModel should be false at altitude within model range")
}
}

63
internal/windviz/cache.go Normal file
View file

@ -0,0 +1,63 @@
package windviz
import (
"sync"
"time"
)
// Cache is a small bounded cache of rasterized fields keyed by request
// parameters and dataset epoch. It is safe for concurrent use.
//
// Visualization requests repeat heavily (a frontend re-fetches the same
// layer as users pan within a tile), so even a tiny cache removes most
// recomputation. Eviction is simplest-possible: when full, the whole map is
// cleared. Entries also expire after TTL.
type Cache struct {
mu sync.Mutex
entries map[string]cacheEntry
max int
ttl time.Duration
now func() time.Time
}
type cacheEntry struct {
field Field
expires time.Time
}
// NewCache returns a cache holding up to max entries for ttl each.
func NewCache(max int, ttl time.Duration) *Cache {
if max <= 0 {
max = 64
}
if ttl <= 0 {
ttl = 10 * time.Minute
}
return &Cache{
entries: make(map[string]cacheEntry, max),
max: max,
ttl: ttl,
now: time.Now,
}
}
// Get returns the cached field for key, if present and unexpired.
func (c *Cache) Get(key string) (Field, bool) {
c.mu.Lock()
defer c.mu.Unlock()
e, ok := c.entries[key]
if !ok || c.now().After(e.expires) {
return nil, false
}
return e.field, true
}
// Put stores field under key.
func (c *Cache) Put(key string, field Field) {
c.mu.Lock()
defer c.mu.Unlock()
if len(c.entries) >= c.max {
c.entries = make(map[string]cacheEntry, c.max)
}
c.entries[key] = cacheEntry{field: field, expires: c.now().Add(c.ttl)}
}

179
internal/windviz/windviz.go Normal file
View file

@ -0,0 +1,179 @@
// Package windviz rasterizes a weather.WindField into the JSON grid format
// consumed by browser velocity layers such as leaflet-velocity and
// wind-layer (the "gfs.json" / wind-js-server format).
//
// The module is decoupled from any specific dataset: it samples any
// weather.WindField on a regular latitude/longitude grid at a chosen time
// and altitude, downsampling by a configurable step to bound payload size.
package windviz
import (
"fmt"
"time"
"predictor-refactored/internal/weather"
)
// Request describes a wind-field rasterization.
type Request struct {
// Time is the forecast time to sample (UNIX seconds). Sampling outside
// the field's temporal coverage returns an error.
Time float64
// Altitude is the altitude in metres to sample at.
Altitude float64
// Bounding box in degrees. Latitudes in [-90, 90]; longitudes in
// [0, 360). For a global field use 0..360 (the rasterizer drops the
// duplicate 360° column).
MinLat, MaxLat float64
MinLng, MaxLng float64
// Step is the grid resolution in degrees (e.g. 1.0). Smaller is denser.
Step float64
}
// Component is one wind-js-server record: a header plus a flat data grid.
type Component struct {
Header Header `json:"header"`
Data []float64 `json:"data"`
}
// Header is the wind-js-server grid header. Field names and semantics match
// what leaflet-velocity / wind-layer expect.
type Header struct {
ParameterCategory int `json:"parameterCategory"`
ParameterNumber int `json:"parameterNumber"`
ParameterNumberName string `json:"parameterNumberName"`
ParameterUnit string `json:"parameterUnit"`
Nx int `json:"nx"`
Ny int `json:"ny"`
Lo1 float64 `json:"lo1"`
La1 float64 `json:"la1"`
Lo2 float64 `json:"lo2"`
La2 float64 `json:"la2"`
Dx float64 `json:"dx"`
Dy float64 `json:"dy"`
RefTime string `json:"refTime"`
ForecastTime int `json:"forecastTime"`
}
// Field is the two-component (U then V) payload. JSON-encoding a Field
// produces the array the velocity layers consume directly.
type Field []Component
const (
defaultStep = 1.0
minStep = 0.25 // clamp to bound output size
maxCells = 1 << 21
)
// Rasterize samples field over req and returns the U/V grid payload.
//
// Data is laid out in wind-js scan order: row 0 is the northernmost
// latitude (la1), each row runs west→east, longitudes increasing. Per-cell
// sampling errors (e.g. altitude outside the model) are written as 0 rather
// than failing the whole request; a time outside coverage is a hard error.
func Rasterize(field weather.WindField, req Request) (Field, error) {
step := req.Step
if step <= 0 {
step = defaultStep
}
if step < minStep {
step = minStep
}
minLat, maxLat := req.MinLat, req.MaxLat
minLng, maxLng := req.MinLng, req.MaxLng
if minLat == 0 && maxLat == 0 {
minLat, maxLat = -90, 90
}
if minLng == 0 && maxLng == 0 {
minLng, maxLng = 0, 360
}
if maxLat <= minLat {
return nil, fmt.Errorf("invalid bounding box latitude")
}
// Longitudes may arrive in either the [0, 360) or the [-180, 180]
// convention (the latter is what the rest of the API emits). Detect a
// full-globe span first, then fold a regional box's western edge into
// [0, 360); per-cell sampling re-folds via normLng so an eastern edge
// past 360° still reads the correct column.
lngSpan := maxLng - minLng
if lngSpan <= 0 {
return nil, fmt.Errorf("invalid bounding box longitude")
}
global := lngSpan >= 360-1e-9
var nx int
if global {
// Drop the duplicate wrap column so the layer tiles cleanly.
minLng = 0
nx = int(360/step + 0.5)
maxLng = float64(nx-1) * step
} else {
minLng = normLng(minLng)
maxLng = minLng + lngSpan
nx = int(lngSpan/step+0.5) + 1
}
ny := int((maxLat-minLat)/step+0.5) + 1
if nx < 1 || ny < 1 {
return nil, fmt.Errorf("empty grid")
}
if nx*ny > maxCells {
return nil, fmt.Errorf("grid too large (%d cells); increase step or shrink bbox", nx*ny)
}
u := make([]float64, nx*ny)
v := make([]float64, nx*ny)
// Row 0 = north (la1); rows descend in latitude.
for j := range ny {
lat := maxLat - float64(j)*step
for i := range nx {
lng := minLng + float64(i)*step
s, err := field.Wind(req.Time, lat, normLng(lng), req.Altitude)
idx := j*nx + i
if err != nil {
continue // leave as 0
}
u[idx] = s.U
v[idx] = s.V
}
}
refTime := time.Unix(int64(req.Time), 0).UTC().Format("2006-01-02T15:04:05.000Z")
mk := func(num int, name string, data []float64) Component {
return Component{
Header: Header{
ParameterCategory: 2,
ParameterNumber: num,
ParameterNumberName: name,
ParameterUnit: "m.s-1",
Nx: nx,
Ny: ny,
Lo1: minLng,
La1: maxLat,
Lo2: maxLng,
La2: minLat,
Dx: step,
Dy: step,
RefTime: refTime,
ForecastTime: 0,
},
Data: data,
}
}
return Field{
mk(2, "eastward_wind", u),
mk(3, "northward_wind", v),
}, nil
}
// normLng folds a longitude into [0, 360) for sampling.
func normLng(lng float64) float64 {
for lng < 0 {
lng += 360
}
for lng >= 360 {
lng -= 360
}
return lng
}

View file

@ -0,0 +1,96 @@
package windviz
import (
"testing"
"time"
"predictor-refactored/internal/weather"
)
// constWind is a WindField returning a fixed sample everywhere.
type constWind struct {
u, v float64
epoch time.Time
}
func (c constWind) Wind(_ float64, _, _, _ float64) (weather.Sample, error) {
return weather.Sample{U: c.u, V: c.v}, nil
}
func (c constWind) Epoch() time.Time { return c.epoch }
func (c constWind) Source() string { return "test" }
func TestRasterizeGlobalDropsDuplicateColumn(t *testing.T) {
f := constWind{u: 5, v: -3, epoch: time.Unix(0, 0)}
out, err := Rasterize(f, Request{MinLng: 0, MaxLng: 360, Step: 90})
if err != nil {
t.Fatalf("Rasterize: %v", err)
}
if len(out) != 2 {
t.Fatalf("expected 2 components, got %d", len(out))
}
u := out[0]
// 360/90 = 4 columns (no duplicate 360°); lat -90..90 step 90 = 3 rows.
if u.Header.Nx != 4 || u.Header.Ny != 3 {
t.Errorf("grid = %dx%d, want 4x3", u.Header.Nx, u.Header.Ny)
}
if len(u.Data) != 12 {
t.Errorf("data len = %d, want 12", len(u.Data))
}
if u.Header.La1 != 90 || u.Header.La2 != -90 {
t.Errorf("lat range = %v..%v, want 90..-90 (north first)", u.Header.La1, u.Header.La2)
}
if u.Header.Lo1 != 0 || u.Header.Lo2 != 270 {
t.Errorf("lng range = %v..%v, want 0..270", u.Header.Lo1, u.Header.Lo2)
}
for _, d := range u.Data {
if d != 5 {
t.Errorf("U data = %v, want 5", d)
break
}
}
if out[0].Header.ParameterNumber != 2 || out[1].Header.ParameterNumber != 3 {
t.Errorf("component order should be U(2) then V(3)")
}
}
func TestRasterizeSignedLongitudeConvention(t *testing.T) {
f := constWind{u: 1, v: 2, epoch: time.Unix(0, 0)}
// A [-180, 180] global request must be detected as global and tiled
// without a duplicate seam column, identical to a 0..360 request.
signed, err := Rasterize(f, Request{MinLng: -180, MaxLng: 180, Step: 90})
if err != nil {
t.Fatalf("signed-global Rasterize: %v", err)
}
if signed[0].Header.Nx != 4 {
t.Errorf("signed-global nx = %d, want 4 (no duplicate column)", signed[0].Header.Nx)
}
// A western-hemisphere box must not 400; its western edge folds into [0,360).
west, err := Rasterize(f, Request{MinLat: 10, MaxLat: 20, MinLng: -100, MaxLng: -50, Step: 10})
if err != nil {
t.Fatalf("western-box Rasterize: %v", err)
}
if west[0].Header.Lo1 != 260 {
t.Errorf("western-box lo1 = %v, want 260 (=-100 folded)", west[0].Header.Lo1)
}
}
func TestRasterizeStepClamp(t *testing.T) {
f := constWind{epoch: time.Unix(0, 0)}
// step below min gets clamped, not rejected.
if _, err := Rasterize(f, Request{MinLat: -1, MaxLat: 1, MinLng: 0, MaxLng: 2, Step: 0.01}); err != nil {
t.Fatalf("Rasterize with tiny step: %v", err)
}
}
func TestCacheRoundTrip(t *testing.T) {
c := NewCache(2, time.Minute)
if _, ok := c.Get("a"); ok {
t.Errorf("empty cache should miss")
}
c.Put("a", Field{})
if _, ok := c.Get("a"); !ok {
t.Errorf("cache should hit after put")
}
}