This commit is contained in:
Anatoly Antonov 2026-05-18 02:09:07 +09:00
parent c4f355a32e
commit 7a8d5d13fa
72 changed files with 4510 additions and 4104 deletions

158
internal/dataset/dataset.go Normal file
View file

@ -0,0 +1,158 @@
package dataset
import "fmt"
// Dataset shape constants.
// Shape: (65, 47, 3, 361, 720) = (hour, pressure_level, variable, latitude, longitude)
// This matches the reference predictor exactly.
const (
NumHours = 65 // 0, 3, 6, ..., 192
NumLevels = 47 // pressure levels
NumVariables = 3 // height, wind_u, wind_v
NumLatitudes = 361 // -90.0 to +90.0 in 0.5 degree steps
NumLongitudes = 720 // 0.0 to 359.5 in 0.5 degree steps
HourStep = 3 // hours between forecast time steps
MaxHour = 192 // maximum forecast hour
Resolution = 0.5 // grid resolution in degrees
LatStart = -90.0 // first latitude in the dataset
LonStart = 0.0 // first longitude in the dataset
// Variable indices within the dataset.
VarHeight = 0
VarWindU = 1
VarWindV = 2
ElementSize = 4 // float32 = 4 bytes
)
// DatasetSize is the total size of the dataset file in bytes.
// 65 * 47 * 3 * 361 * 720 * 4 = 9,528,667,200
const DatasetSize int64 = int64(NumHours) * int64(NumLevels) * int64(NumVariables) *
int64(NumLatitudes) * int64(NumLongitudes) * int64(ElementSize)
// LevelSet identifies which GRIB file set a pressure level belongs to.
type LevelSet int
const (
LevelSetA LevelSet = iota // pgrb2 (primary)
LevelSetB // pgrb2b (secondary)
)
// Pressures contains the 47 pressure levels in hPa, sorted descending.
// Index 0 = 1000 hPa (near surface), Index 46 = 1 hPa (high atmosphere).
var Pressures = [NumLevels]int{
1000, 975, 950, 925, 900, 875, 850, 825, 800, 775,
750, 725, 700, 675, 650, 625, 600, 575, 550, 525,
500, 475, 450, 425, 400, 375, 350, 325, 300, 275,
250, 225, 200, 175, 150, 125, 100, 70, 50, 30,
20, 10, 7, 5, 3, 2, 1,
}
// pressureIndex maps pressure in hPa to its index in the Pressures array.
var pressureIndex map[int]int
// pressureLevelSet maps pressure in hPa to its GRIB file set.
var pressureLevelSet map[int]LevelSet
func init() {
pressureIndex = make(map[int]int, NumLevels)
for i, p := range Pressures {
pressureIndex[p] = i
}
pressureLevelSet = make(map[int]LevelSet, NumLevels)
for _, p := range PressuresPgrb2 {
pressureLevelSet[p] = LevelSetA
}
for _, p := range PressuresPgrb2b {
pressureLevelSet[p] = LevelSetB
}
}
// PressuresPgrb2 contains levels found in the primary pgrb2 file (26 levels).
var PressuresPgrb2 = []int{
10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 350, 400,
450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 925,
950, 975, 1000,
}
// PressuresPgrb2b contains levels found in the secondary pgrb2b file (21 levels).
var PressuresPgrb2b = []int{
1, 2, 3, 5, 7, 125, 175, 225, 275, 325, 375, 425,
475, 525, 575, 625, 675, 725, 775, 825, 875,
}
// PressureIndex returns the dataset index for a given pressure level in hPa.
// Returns -1 if the level is not found.
func PressureIndex(hPa int) int {
idx, ok := pressureIndex[hPa]
if !ok {
return -1
}
return idx
}
// PressureLevelSet returns which GRIB file set a pressure level belongs to.
func PressureLevelSet(hPa int) (LevelSet, bool) {
ls, ok := pressureLevelSet[hPa]
return ls, ok
}
// HourIndex returns the dataset time index for a forecast hour.
// Returns -1 if the hour is invalid (not a multiple of HourStep or out of range).
func HourIndex(hour int) int {
if hour < 0 || hour > MaxHour || hour%HourStep != 0 {
return -1
}
return hour / HourStep
}
// Hours returns all forecast hours as a slice: [0, 3, 6, ..., 192].
func Hours() []int {
out := make([]int, 0, NumHours)
for h := 0; h <= MaxHour; h += HourStep {
out = append(out, h)
}
return out
}
// S3 URL configuration for NOAA GFS data.
const S3BaseURL = "https://noaa-gfs-bdp-pds.s3.amazonaws.com"
// GribURL returns the S3 URL for a primary (pgrb2) GRIB file.
func GribURL(date string, runHour, forecastStep int) string {
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.0p50.f%03d",
S3BaseURL, date, runHour, runHour, forecastStep)
}
// GribURLB returns the S3 URL for a secondary (pgrb2b) GRIB file.
func GribURLB(date string, runHour, forecastStep int) string {
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2b.0p50.f%03d",
S3BaseURL, date, runHour, runHour, forecastStep)
}
// GribFileName returns the local filename for a primary GRIB file.
func GribFileName(runHour, forecastStep int) string {
return fmt.Sprintf("gfs.t%02dz.pgrb2.0p50.f%03d", runHour, forecastStep)
}
// GribFileNameB returns the local filename for a secondary GRIB file.
func GribFileNameB(runHour, forecastStep int) string {
return fmt.Sprintf("gfs.t%02dz.pgrb2b.0p50.f%03d", runHour, forecastStep)
}
// VariableIndex returns the dataset variable index for a GRIB parameter.
// Returns -1 if the parameter is not recognized.
func VariableIndex(parameterCategory, parameterNumber int) int {
switch {
case parameterCategory == 3 && parameterNumber == 5:
return VarHeight // Geopotential Height
case parameterCategory == 2 && parameterNumber == 2:
return VarWindU // U-component of wind
case parameterCategory == 2 && parameterNumber == 3:
return VarWindV // V-component of wind
default:
return -1
}
}

View file

@ -0,0 +1,152 @@
package dataset
import (
"testing"
)
func TestDatasetShape(t *testing.T) {
if NumHours != 65 {
t.Errorf("NumHours = %d, want 65", NumHours)
}
if NumLevels != 47 {
t.Errorf("NumLevels = %d, want 47", NumLevels)
}
if NumVariables != 3 {
t.Errorf("NumVariables = %d, want 3", NumVariables)
}
if NumLatitudes != 361 {
t.Errorf("NumLatitudes = %d, want 361", NumLatitudes)
}
if NumLongitudes != 720 {
t.Errorf("NumLongitudes = %d, want 720", NumLongitudes)
}
}
func TestDatasetSize(t *testing.T) {
// 65 * 47 * 3 * 361 * 720 * 4 = 9,528,667,200
want := int64(9_528_667_200)
if DatasetSize != want {
t.Errorf("DatasetSize = %d, want %d", DatasetSize, want)
}
}
func TestPressureLevels(t *testing.T) {
if len(Pressures) != 47 {
t.Fatalf("len(Pressures) = %d, want 47", len(Pressures))
}
// First should be 1000 (highest pressure, near surface)
if Pressures[0] != 1000 {
t.Errorf("Pressures[0] = %d, want 1000", Pressures[0])
}
// Last should be 1 (lowest pressure, high atmosphere)
if Pressures[46] != 1 {
t.Errorf("Pressures[46] = %d, want 1", Pressures[46])
}
// Should be sorted descending
for i := 1; i < len(Pressures); i++ {
if Pressures[i] >= Pressures[i-1] {
t.Errorf("Pressures not descending at [%d]: %d >= %d", i, Pressures[i], Pressures[i-1])
}
}
// Total levels: 26 from pgrb2 + 21 from pgrb2b = 47
if len(PressuresPgrb2) != 26 {
t.Errorf("len(PressuresPgrb2) = %d, want 26", len(PressuresPgrb2))
}
if len(PressuresPgrb2b) != 21 {
t.Errorf("len(PressuresPgrb2b) = %d, want 21", len(PressuresPgrb2b))
}
}
func TestPressureIndex(t *testing.T) {
if PressureIndex(1000) != 0 {
t.Errorf("PressureIndex(1000) = %d, want 0", PressureIndex(1000))
}
if PressureIndex(1) != 46 {
t.Errorf("PressureIndex(1) = %d, want 46", PressureIndex(1))
}
if PressureIndex(500) != 20 {
t.Errorf("PressureIndex(500) = %d, want 20", PressureIndex(500))
}
if PressureIndex(9999) != -1 {
t.Errorf("PressureIndex(9999) = %d, want -1", PressureIndex(9999))
}
}
func TestPressureLevelSet(t *testing.T) {
// 1000 mb should be in pgrb2 (A)
ls, ok := PressureLevelSet(1000)
if !ok || ls != LevelSetA {
t.Errorf("PressureLevelSet(1000) = %v, %v; want A, true", ls, ok)
}
// 125 mb should be in pgrb2b (B)
ls, ok = PressureLevelSet(125)
if !ok || ls != LevelSetB {
t.Errorf("PressureLevelSet(125) = %v, %v; want B, true", ls, ok)
}
// 1, 2, 3, 5, 7 should be in pgrb2b (B)
for _, p := range []int{1, 2, 3, 5, 7} {
ls, ok := PressureLevelSet(p)
if !ok || ls != LevelSetB {
t.Errorf("PressureLevelSet(%d) = %v, %v; want B, true", p, ls, ok)
}
}
// Every pressure level should have a level set assignment
for _, p := range Pressures {
_, ok := PressureLevelSet(p)
if !ok {
t.Errorf("PressureLevelSet(%d) not found", p)
}
}
}
func TestHourIndex(t *testing.T) {
if HourIndex(0) != 0 {
t.Errorf("HourIndex(0) = %d, want 0", HourIndex(0))
}
if HourIndex(3) != 1 {
t.Errorf("HourIndex(3) = %d, want 1", HourIndex(3))
}
if HourIndex(192) != 64 {
t.Errorf("HourIndex(192) = %d, want 64", HourIndex(192))
}
if HourIndex(1) != -1 {
t.Errorf("HourIndex(1) = %d, want -1 (not multiple of 3)", HourIndex(1))
}
if HourIndex(195) != -1 {
t.Errorf("HourIndex(195) = %d, want -1 (out of range)", HourIndex(195))
}
}
func TestHours(t *testing.T) {
hours := Hours()
if len(hours) != NumHours {
t.Fatalf("len(Hours()) = %d, want %d", len(hours), NumHours)
}
if hours[0] != 0 {
t.Errorf("Hours()[0] = %d, want 0", hours[0])
}
if hours[len(hours)-1] != MaxHour {
t.Errorf("Hours()[last] = %d, want %d", hours[len(hours)-1], MaxHour)
}
}
func TestVariableIndex(t *testing.T) {
if VariableIndex(3, 5) != VarHeight {
t.Errorf("HGT: got %d, want %d", VariableIndex(3, 5), VarHeight)
}
if VariableIndex(2, 2) != VarWindU {
t.Errorf("UGRD: got %d, want %d", VariableIndex(2, 2), VarWindU)
}
if VariableIndex(2, 3) != VarWindV {
t.Errorf("VGRD: got %d, want %d", VariableIndex(2, 3), VarWindV)
}
if VariableIndex(0, 0) != -1 {
t.Errorf("unknown: got %d, want -1", VariableIndex(0, 0))
}
}

140
internal/dataset/file.go Normal file
View file

@ -0,0 +1,140 @@
package dataset
import (
"encoding/binary"
"fmt"
"math"
"os"
"time"
mmap "github.com/edsrzf/mmap-go"
)
// File represents an mmap-backed wind dataset file.
type File struct {
mm mmap.MMap
file *os.File
writable bool
DSTime time.Time // forecast run time (UTC)
}
// Open opens an existing dataset file for reading.
func Open(path string, dsTime time.Time) (*File, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open dataset: %w", err)
}
info, err := f.Stat()
if err != nil {
f.Close()
return nil, fmt.Errorf("stat dataset: %w", err)
}
if info.Size() != DatasetSize {
f.Close()
return nil, fmt.Errorf("dataset should be %d bytes (was %d)", DatasetSize, info.Size())
}
mm, err := mmap.Map(f, mmap.RDONLY, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap dataset: %w", err)
}
return &File{mm: mm, file: f, writable: false, DSTime: dsTime}, nil
}
// Create creates a new dataset file for writing.
// The file is truncated to the correct size and mmap'd read-write.
func Create(path string) (*File, error) {
f, err := os.Create(path)
if err != nil {
return nil, fmt.Errorf("create dataset: %w", err)
}
if err := f.Truncate(DatasetSize); err != nil {
f.Close()
return nil, fmt.Errorf("truncate dataset: %w", err)
}
mm, err := mmap.MapRegion(f, int(DatasetSize), mmap.RDWR, 0, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap dataset: %w", err)
}
return &File{mm: mm, file: f, writable: true}, nil
}
// offset computes the byte offset for element [hour][level][variable][lat][lon].
// Row-major C-order indexing matching the reference implementation:
// shape = (65, 47, 3, 361, 720)
func offset(hour, level, variable, lat, lon int) int64 {
idx := int64(hour)
idx = idx*int64(NumLevels) + int64(level)
idx = idx*int64(NumVariables) + int64(variable)
idx = idx*int64(NumLatitudes) + int64(lat)
idx = idx*int64(NumLongitudes) + int64(lon)
return idx * int64(ElementSize)
}
// Val reads a float32 value from the dataset at [hour][level][variable][lat][lon].
func (d *File) Val(hour, level, variable, lat, lon int) float32 {
off := offset(hour, level, variable, lat, lon)
bits := binary.LittleEndian.Uint32(d.mm[off : off+4])
return math.Float32frombits(bits)
}
// SetVal writes a float32 value to the dataset at [hour][level][variable][lat][lon].
// Only valid on writable (created) datasets.
func (d *File) SetVal(hour, level, variable, lat, lon int, val float32) {
off := offset(hour, level, variable, lat, lon)
binary.LittleEndian.PutUint32(d.mm[off:off+4], math.Float32bits(val))
}
// BlitGribData copies decoded GRIB grid data into the dataset at the given position.
// gribData is 361*720 float64 values in GRIB scan order (north-to-south, west-to-east).
// This function flips the latitude so that dataset index 0 = -90 (south) and 360 = +90 (north).
func (d *File) BlitGribData(hourIdx, levelIdx, varIdx int, gribData []float64) error {
expected := NumLatitudes * NumLongitudes
if len(gribData) != expected {
return fmt.Errorf("grib data has %d values, expected %d", len(gribData), expected)
}
for lat := 0; lat < NumLatitudes; lat++ {
for lon := 0; lon < NumLongitudes; lon++ {
// GRIB scans north-to-south: row 0 = 90N, row 360 = 90S
// Dataset stores south-to-north: index 0 = -90 (90S), index 360 = +90 (90N)
gribIdx := (360-lat)*NumLongitudes + lon
val := float32(gribData[gribIdx])
d.SetVal(hourIdx, levelIdx, varIdx, lat, lon, val)
}
}
return nil
}
// Flush flushes the mmap to disk.
func (d *File) Flush() error {
if d.mm != nil {
return d.mm.Flush()
}
return nil
}
// Close unmaps and closes the dataset file.
func (d *File) Close() error {
if d.mm != nil {
if err := d.mm.Unmap(); err != nil {
d.file.Close()
return fmt.Errorf("unmap: %w", err)
}
d.mm = nil
}
if d.file != nil {
err := d.file.Close()
d.file = nil
return err
}
return nil
}

View file

@ -0,0 +1,58 @@
package downloader
import (
"os"
"strconv"
"time"
)
// Config holds downloader configuration, loaded from environment variables.
type Config struct {
// DataDir is the directory for storing dataset files and temporary GRIB data.
DataDir string
// Parallel is the maximum number of concurrent GRIB downloads.
Parallel int
// UpdateInterval is how often the scheduler checks for new forecast data.
UpdateInterval time.Duration
// DatasetTTL is how long a dataset is considered fresh before a new one is needed.
DatasetTTL time.Duration
}
// DefaultConfig returns the default configuration.
func DefaultConfig() *Config {
return &Config{
DataDir: "/tmp/predictor-data",
Parallel: 8,
UpdateInterval: 6 * time.Hour,
DatasetTTL: 48 * time.Hour,
}
}
// LoadConfig loads configuration from environment variables, falling back to defaults.
func LoadConfig() *Config {
cfg := DefaultConfig()
if v := os.Getenv("PREDICTOR_DATA_DIR"); v != "" {
cfg.DataDir = v
}
if v := os.Getenv("PREDICTOR_DOWNLOAD_PARALLEL"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
cfg.Parallel = n
}
}
if v := os.Getenv("PREDICTOR_UPDATE_INTERVAL"); v != "" {
if d, err := time.ParseDuration(v); err == nil {
cfg.UpdateInterval = d
}
}
if v := os.Getenv("PREDICTOR_DATASET_TTL"); v != "" {
if d, err := time.ParseDuration(v); err == nil {
cfg.DatasetTTL = d
}
}
return cfg
}

View file

@ -0,0 +1,441 @@
package downloader
import (
"context"
"fmt"
"io"
"math"
"net/http"
"os"
"path/filepath"
"sync/atomic"
"time"
"predictor-refactored/internal/dataset"
"github.com/nilsmagnus/grib/griblib"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
)
// Downloader handles fetching GFS forecast data from S3 and assembling dataset files.
type Downloader struct {
cfg *Config
client *http.Client
log *zap.Logger
}
// NewDownloader creates a new Downloader.
func NewDownloader(cfg *Config, log *zap.Logger) *Downloader {
return &Downloader{
cfg: cfg,
client: &http.Client{
Timeout: 2 * time.Minute,
},
log: log,
}
}
// neededVariables is the set of GRIB variable names we need.
var neededVariables = map[string]bool{
"HGT": true,
"UGRD": true,
"VGRD": true,
}
// FindLatestRun finds the most recent available GFS model run on S3.
// It checks the last forecast step of each run to confirm availability.
func (d *Downloader) FindLatestRun(ctx context.Context) (time.Time, error) {
now := time.Now().UTC()
hour := now.Hour() - (now.Hour() % 6)
current := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC)
for i := 0; i < 8; i++ {
date := current.Format("20060102")
url := dataset.GribURL(date, current.Hour(), dataset.MaxHour) + ".idx"
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
if err != nil {
current = current.Add(-6 * time.Hour)
continue
}
resp, err := d.client.Do(req)
if err == nil {
resp.Body.Close()
if resp.StatusCode == http.StatusOK {
d.log.Info("found latest model run",
zap.Time("run", current),
zap.String("verified_url", url))
return current, nil
}
}
current = current.Add(-6 * time.Hour)
}
return time.Time{}, fmt.Errorf("no recent GFS forecast found (checked 8 runs)")
}
// progress tracks download progress across concurrent goroutines.
type progress struct {
bytesDownloaded atomic.Int64
stepsCompleted atomic.Int64
totalSteps int64
startTime time.Time
log *zap.Logger
}
func newProgress(totalSteps int, log *zap.Logger) *progress {
return &progress{
totalSteps: int64(totalSteps),
startTime: time.Now(),
log: log,
}
}
func (p *progress) addBytes(n int64) {
p.bytesDownloaded.Add(n)
}
func (p *progress) completeStep() {
done := p.stepsCompleted.Add(1)
total := p.totalSteps
bytes := p.bytesDownloaded.Load()
elapsed := time.Since(p.startTime).Seconds()
pct := float64(done) / float64(total) * 100
mbDownloaded := float64(bytes) / (1024 * 1024)
mbPerSec := 0.0
if elapsed > 0 {
mbPerSec = mbDownloaded / elapsed
}
// Estimate remaining
eta := ""
if done > 0 && done < total {
secsPerStep := elapsed / float64(done)
remaining := secsPerStep * float64(total-done)
if remaining > 60 {
eta = fmt.Sprintf("%.0fm%02.0fs", math.Floor(remaining/60), math.Mod(remaining, 60))
} else {
eta = fmt.Sprintf("%.0fs", remaining)
}
}
p.log.Info("download progress",
zap.String("progress", fmt.Sprintf("%d/%d", done, total)),
zap.String("percent", fmt.Sprintf("%.1f%%", pct)),
zap.String("downloaded", fmt.Sprintf("%.1f MB", mbDownloaded)),
zap.String("speed", fmt.Sprintf("%.1f MB/s", mbPerSec)),
zap.String("eta", eta))
}
// Download downloads a complete forecast and assembles a dataset file.
// Returns the path to the completed dataset file.
func (d *Downloader) Download(ctx context.Context, run time.Time) (string, error) {
date := run.Format("20060102")
runHour := run.Hour()
finalPath := filepath.Join(d.cfg.DataDir, run.Format("2006010215"))
tempPath := finalPath + ".downloading"
// Check if final dataset already exists
if info, err := os.Stat(finalPath); err == nil && info.Size() == dataset.DatasetSize {
d.log.Info("dataset already exists", zap.String("path", finalPath))
return finalPath, nil
}
steps := dataset.Hours()
totalSteps := len(steps) * 2 // pgrb2 + pgrb2b per step
prog := newProgress(totalSteps, d.log)
d.log.Info("starting dataset download",
zap.Time("run", run),
zap.Int("total_steps", totalSteps),
zap.String("temp_path", tempPath))
// Create the dataset file
ds, err := dataset.Create(tempPath)
if err != nil {
return "", fmt.Errorf("create dataset: %w", err)
}
defer ds.Close()
// Process each forecast step with bounded concurrency
g, ctx := errgroup.WithContext(ctx)
sem := make(chan struct{}, d.cfg.Parallel)
for _, step := range steps {
step := step
hourIdx := dataset.HourIndex(step)
if hourIdx < 0 {
continue
}
// Download pgrb2 (level set A)
sem <- struct{}{}
g.Go(func() error {
defer func() { <-sem }()
url := dataset.GribURL(date, runHour, step)
err := d.downloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetA, prog)
if err != nil {
return fmt.Errorf("step %d pgrb2: %w", step, err)
}
prog.completeStep()
return nil
})
// Download pgrb2b (level set B)
sem <- struct{}{}
g.Go(func() error {
defer func() { <-sem }()
url := dataset.GribURLB(date, runHour, step)
err := d.downloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetB, prog)
if err != nil {
return fmt.Errorf("step %d pgrb2b: %w", step, err)
}
prog.completeStep()
return nil
})
}
if err := g.Wait(); err != nil {
os.Remove(tempPath)
return "", err
}
elapsed := time.Since(prog.startTime)
totalMB := float64(prog.bytesDownloaded.Load()) / (1024 * 1024)
d.log.Info("download complete, flushing to disk",
zap.String("downloaded", fmt.Sprintf("%.1f MB", totalMB)),
zap.Duration("elapsed", elapsed),
zap.String("avg_speed", fmt.Sprintf("%.1f MB/s", totalMB/elapsed.Seconds())))
// Flush to disk
if err := ds.Flush(); err != nil {
os.Remove(tempPath)
return "", fmt.Errorf("flush dataset: %w", err)
}
// Close before rename
ds.Close()
// Atomic rename
if err := os.Rename(tempPath, finalPath); err != nil {
os.Remove(tempPath)
return "", fmt.Errorf("rename dataset: %w", err)
}
d.log.Info("dataset ready", zap.String("path", finalPath))
return finalPath, nil
}
// DownloadAndBlit downloads needed GRIB fields from a URL and writes them into the dataset.
func (d *Downloader) DownloadAndBlit(ctx context.Context, ds *dataset.File, baseURL string, hourIdx int, levelSet dataset.LevelSet) error {
return d.downloadAndBlit(ctx, ds, baseURL, hourIdx, levelSet, nil)
}
// downloadAndBlit is the internal implementation with optional progress tracking.
func (d *Downloader) downloadAndBlit(ctx context.Context, ds *dataset.File, baseURL string, hourIdx int, levelSet dataset.LevelSet, prog *progress) error {
// 1. Download .idx
idxURL := baseURL + ".idx"
idxBody, err := d.httpGet(ctx, idxURL)
if err != nil {
return fmt.Errorf("download idx: %w", err)
}
// 2. Parse and filter
entries := ParseIdx(idxBody)
filtered := FilterIdx(entries, neededVariables)
// Further filter to only levels in this level set
var relevant []IdxEntry
for _, e := range filtered {
ls, ok := dataset.PressureLevelSet(e.LevelMB)
if ok && ls == levelSet {
relevant = append(relevant, e)
}
}
if len(relevant) == 0 {
d.log.Warn("no relevant entries found in idx",
zap.String("url", idxURL),
zap.Int("total_entries", len(entries)),
zap.Int("filtered", len(filtered)))
return nil
}
// 3. Download byte ranges and write to temp file
ranges := EntriesToRanges(relevant)
tmpFile, err := d.downloadRangesToTempFile(ctx, baseURL, ranges, prog)
if err != nil {
return fmt.Errorf("download ranges: %w", err)
}
defer os.Remove(tmpFile)
// 4. Read GRIB messages from temp file
f, err := os.Open(tmpFile)
if err != nil {
return fmt.Errorf("open temp grib: %w", err)
}
messages, err := griblib.ReadMessages(f)
f.Close()
if err != nil {
return fmt.Errorf("read grib messages: %w", err)
}
// 5. Decode and blit each message into the dataset
for _, msg := range messages {
if msg.Section4.ProductDefinitionTemplateNumber != 0 {
continue
}
product := msg.Section4.ProductDefinitionTemplate
varIdx := dataset.VariableIndex(int(product.ParameterCategory), int(product.ParameterNumber))
if varIdx < 0 {
continue
}
if product.FirstSurface.Type != 100 { // isobaric surface
continue
}
pressurePa := float64(product.FirstSurface.Value)
pressureMB := int(math.Round(pressurePa / 100.0))
levelIdx := dataset.PressureIndex(pressureMB)
if levelIdx < 0 {
continue
}
data := msg.Data()
if err := ds.BlitGribData(hourIdx, levelIdx, varIdx, data); err != nil {
d.log.Warn("blit failed",
zap.Int("var", varIdx),
zap.Int("level_mb", pressureMB),
zap.Error(err))
continue
}
}
return nil
}
// downloadRangesToTempFile downloads multiple byte ranges from a URL,
// concatenating them into a single temp file (valid concatenated GRIB messages).
func (d *Downloader) downloadRangesToTempFile(ctx context.Context, baseURL string, ranges []ByteRange, prog *progress) (string, error) {
tmpFile, err := os.CreateTemp(d.cfg.DataDir, "grib-*.tmp")
if err != nil {
return "", fmt.Errorf("create temp file: %w", err)
}
tmpPath := tmpFile.Name()
for _, r := range ranges {
data, err := d.httpGetRange(ctx, baseURL, r.Start, r.End)
if err != nil {
tmpFile.Close()
os.Remove(tmpPath)
return "", fmt.Errorf("download range %d-%d: %w", r.Start, r.End, err)
}
if _, err := tmpFile.Write(data); err != nil {
tmpFile.Close()
os.Remove(tmpPath)
return "", fmt.Errorf("write temp: %w", err)
}
if prog != nil {
prog.addBytes(int64(len(data)))
}
}
if err := tmpFile.Close(); err != nil {
os.Remove(tmpPath)
return "", err
}
return tmpPath, nil
}
// httpGet downloads a URL and returns the body bytes.
func (d *Downloader) httpGet(ctx context.Context, url string) ([]byte, error) {
var lastErr error
for attempt := 0; attempt < 3; attempt++ {
if attempt > 0 {
select {
case <-time.After(time.Duration(attempt*2) * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := d.client.Do(req)
if err != nil {
lastErr = err
continue
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url)
continue
}
if err != nil {
lastErr = err
continue
}
return body, nil
}
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
}
// httpGetRange downloads a byte range from a URL.
func (d *Downloader) httpGetRange(ctx context.Context, url string, start, end int64) ([]byte, error) {
var lastErr error
for attempt := 0; attempt < 3; attempt++ {
if attempt > 0 {
select {
case <-time.After(time.Duration(attempt*2) * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
resp, err := d.client.Do(req)
if err != nil {
lastErr = err
continue
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("HTTP %d for range %d-%d of %s", resp.StatusCode, start, end, url)
continue
}
if err != nil {
lastErr = err
continue
}
return body, nil
}
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
}

157
internal/downloader/idx.go Normal file
View file

@ -0,0 +1,157 @@
package downloader
import (
"fmt"
"strconv"
"strings"
)
// IdxEntry represents a single parsed line from a GRIB .idx file.
// Example line: "15:1207405:d=2024010100:HGT:1000 mb:0 hour fcst:"
type IdxEntry struct {
Index int
Offset int64
Variable string // "HGT", "UGRD", "VGRD", etc.
LevelMB int // pressure level in mb (0 if not a pressure level)
Hour int // forecast hour
EndOffset int64 // byte after this message (from next entry's offset, or -1 if last)
}
// Length returns the byte length of this GRIB message, or -1 if unknown.
func (e *IdxEntry) Length() int64 {
if e.EndOffset <= 0 {
return -1
}
return e.EndOffset - e.Offset
}
// ParseIdx parses a .idx file body and returns all entries.
// Lines that can't be parsed are silently skipped.
func ParseIdx(body []byte) []IdxEntry {
lines := strings.Split(string(body), "\n")
var entries []IdxEntry
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.Split(line, ":")
if len(parts) < 7 {
continue
}
idx, err := strconv.Atoi(parts[0])
if err != nil {
continue
}
offset, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
continue
}
variable := parts[3]
levelStr := parts[4]
hourStr := parts[5]
levelMB := parseLevelMB(levelStr)
hour := parseHour(hourStr)
entries = append(entries, IdxEntry{
Index: idx,
Offset: offset,
Variable: variable,
LevelMB: levelMB,
Hour: hour,
EndOffset: -1, // filled in below
})
}
// Fill in EndOffset from the next entry's Offset.
for i := 0; i < len(entries)-1; i++ {
entries[i].EndOffset = entries[i+1].Offset
}
return entries
}
// FilterIdx returns entries matching the given variables at pressure levels.
// Only entries with a recognized pressure level (levelMB > 0) are returned.
func FilterIdx(entries []IdxEntry, variables map[string]bool) []IdxEntry {
var filtered []IdxEntry
for _, e := range entries {
if !variables[e.Variable] {
continue
}
if e.LevelMB <= 0 {
continue
}
// Must have a known length (not the last entry) or be handled specially
if e.Length() <= 0 {
continue
}
filtered = append(filtered, e)
}
return filtered
}
// parseLevelMB parses a level string like "1000 mb" and returns the pressure in mb.
// Returns 0 if not a pressure level.
func parseLevelMB(s string) int {
s = strings.TrimSpace(s)
if !strings.HasSuffix(s, " mb") {
return 0
}
numStr := strings.TrimSuffix(s, " mb")
n, err := strconv.Atoi(numStr)
if err != nil {
return 0
}
return n
}
// parseHour parses a forecast hour string like "0 hour fcst" or "anl".
// Returns -1 if it can't be parsed.
func parseHour(s string) int {
s = strings.TrimSpace(s)
if s == "anl" {
return 0
}
s = strings.TrimSuffix(s, " hour fcst")
n, err := strconv.Atoi(s)
if err != nil {
return -1
}
return n
}
// GroupByRange groups idx entries into byte ranges suitable for HTTP Range downloads.
// Each range covers one contiguous GRIB message.
type ByteRange struct {
Start int64
End int64 // inclusive
Entry IdxEntry
}
// EntriesToRanges converts filtered idx entries to byte ranges.
func EntriesToRanges(entries []IdxEntry) []ByteRange {
ranges := make([]ByteRange, 0, len(entries))
for _, e := range entries {
if e.Length() <= 0 {
continue
}
ranges = append(ranges, ByteRange{
Start: e.Offset,
End: e.EndOffset - 1, // inclusive
Entry: e,
})
}
return ranges
}
// FormatRange returns an HTTP Range header value for a byte range.
func (r ByteRange) FormatRange() string {
return fmt.Sprintf("bytes=%d-%d", r.Start, r.End)
}

View file

@ -0,0 +1,110 @@
package downloader
import (
"testing"
)
const sampleIdx = `1:0:d=2024010100:HGT:1000 mb:0 hour fcst:
2:289012:d=2024010100:HGT:975 mb:0 hour fcst:
3:541876:d=2024010100:TMP:1000 mb:0 hour fcst:
4:789012:d=2024010100:UGRD:1000 mb:0 hour fcst:
5:1045678:d=2024010100:VGRD:1000 mb:0 hour fcst:
6:1298765:d=2024010100:UGRD:975 mb:0 hour fcst:
7:1567890:d=2024010100:UGRD:2 m above ground:0 hour fcst:
8:1812345:d=2024010100:VGRD:975 mb:0 hour fcst:
9:2098765:d=2024010100:HGT:500 mb:3 hour fcst:
`
func TestParseIdx(t *testing.T) {
entries := ParseIdx([]byte(sampleIdx))
if len(entries) != 9 {
t.Fatalf("expected 9 entries, got %d", len(entries))
}
// Check first entry
e := entries[0]
if e.Index != 1 || e.Offset != 0 || e.Variable != "HGT" || e.LevelMB != 1000 || e.Hour != 0 {
t.Errorf("entry 0: got %+v", e)
}
if e.EndOffset != 289012 {
t.Errorf("entry 0 EndOffset: got %d, want 289012", e.EndOffset)
}
// Check "2 m above ground" is not a pressure level
e = entries[6] // UGRD at "2 m above ground"
if e.LevelMB != 0 {
t.Errorf("non-pressure level should have LevelMB=0, got %d", e.LevelMB)
}
// Last entry should have EndOffset = -1
last := entries[len(entries)-1]
if last.EndOffset != -1 {
t.Errorf("last entry EndOffset: got %d, want -1", last.EndOffset)
}
}
func TestFilterIdx(t *testing.T) {
entries := ParseIdx([]byte(sampleIdx))
filtered := FilterIdx(entries, neededVariables)
// Should include HGT/UGRD/VGRD at pressure levels, exclude TMP and "above ground"
// And exclude last entry (no EndOffset)
for _, e := range filtered {
if !neededVariables[e.Variable] {
t.Errorf("unexpected variable %s", e.Variable)
}
if e.LevelMB <= 0 {
t.Errorf("non-pressure level included: %+v", e)
}
if e.Length() <= 0 {
t.Errorf("entry with unknown length included: %+v", e)
}
}
// Count expected: HGT@1000, HGT@975, UGRD@1000, VGRD@1000, UGRD@975, VGRD@975 = 6
// But HGT@500 at 3hr fcst is the last entry (no EndOffset), so excluded
if len(filtered) != 6 {
t.Errorf("expected 6 filtered entries, got %d", len(filtered))
for _, e := range filtered {
t.Logf(" %s %d mb (offset %d, len %d)", e.Variable, e.LevelMB, e.Offset, e.Length())
}
}
}
func TestParseLevelMB(t *testing.T) {
tests := []struct {
input string
want int
}{
{"1000 mb", 1000},
{"975 mb", 975},
{"1 mb", 1},
{"2 m above ground", 0},
{"surface", 0},
{"tropopause", 0},
}
for _, tt := range tests {
got := parseLevelMB(tt.input)
if got != tt.want {
t.Errorf("parseLevelMB(%q) = %d, want %d", tt.input, got, tt.want)
}
}
}
func TestParseHour(t *testing.T) {
tests := []struct {
input string
want int
}{
{"0 hour fcst", 0},
{"3 hour fcst", 3},
{"192 hour fcst", 192},
{"anl", 0},
}
for _, tt := range tests {
got := parseHour(tt.input)
if got != tt.want {
t.Errorf("parseHour(%q) = %d, want %d", tt.input, got, tt.want)
}
}
}

View file

@ -0,0 +1,113 @@
package elevation
import (
"encoding/binary"
"fmt"
"math"
"os"
mmap "github.com/edsrzf/mmap-go"
)
// Dataset provides global elevation lookup, compatible with ruaumoko.
// Binary format: int16 little-endian elevation values in metres, row-major (lat, lon).
// Latitude axis: -90 to +90 (south to north), Longitude axis: 0 to 360 (wraps).
// Resolution: 120 cells per degree (30 arc-seconds).
const (
CellsPerDegree = 120
NumLats = 180*CellsPerDegree + 1 // 21601
NumLons = 360 * CellsPerDegree // 43200
DataSize = NumLats * NumLons * 2 // 1,866,326,400 bytes (~1.74 GiB)
)
// Dataset is a memory-mapped global elevation grid.
type Dataset struct {
mm mmap.MMap
file *os.File
}
// Open opens an existing elevation dataset file.
func Open(path string) (*Dataset, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open elevation: %w", err)
}
info, err := f.Stat()
if err != nil {
f.Close()
return nil, fmt.Errorf("stat elevation: %w", err)
}
if info.Size() != DataSize {
f.Close()
return nil, fmt.Errorf("elevation dataset should be %d bytes (was %d)", DataSize, info.Size())
}
mm, err := mmap.Map(f, mmap.RDONLY, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap elevation: %w", err)
}
return &Dataset{mm: mm, file: f}, nil
}
// getCell reads the int16 elevation at grid indices (latIdx, lngIdx).
func (d *Dataset) getCell(latIdx, lngIdx int) int16 {
// Clamp latitude
if latIdx < 0 {
latIdx = 0
}
if latIdx >= NumLats {
latIdx = NumLats - 1
}
// Wrap longitude
lngIdx = lngIdx % NumLons
if lngIdx < 0 {
lngIdx += NumLons
}
off := (latIdx*NumLons + lngIdx) * 2
return int16(binary.LittleEndian.Uint16(d.mm[off : off+2]))
}
// Get returns the interpolated elevation in metres at the given coordinates.
// lat: -90 to +90, lng: 0 to 360 (or -180 to 180, will be normalised).
func (d *Dataset) Get(lat, lng float64) float64 {
// Normalise longitude to [0, 360)
if lng < 0 {
lng += 360
}
// Convert to cell coordinates
latCell := (lat + 90.0) * CellsPerDegree
lngCell := lng * CellsPerDegree
lat0 := int(math.Floor(latCell))
lng0 := int(math.Floor(lngCell))
latFrac := latCell - float64(lat0)
lngFrac := lngCell - float64(lng0)
// Bilinear interpolation
v00 := float64(d.getCell(lat0, lng0))
v10 := float64(d.getCell(lat0+1, lng0))
v01 := float64(d.getCell(lat0, lng0+1))
v11 := float64(d.getCell(lat0+1, lng0+1))
return (1-latFrac)*((1-lngFrac)*v00+lngFrac*v01) +
latFrac*((1-lngFrac)*v10+lngFrac*v11)
}
// Close unmaps and closes the dataset.
func (d *Dataset) Close() error {
if d.mm != nil {
d.mm.Unmap()
d.mm = nil
}
if d.file != nil {
err := d.file.Close()
d.file = nil
return err
}
return nil
}

View file

@ -1,23 +0,0 @@
package updater
import (
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
env "github.com/caarlos0/env/v11"
)
type Config struct {
Interval time.Duration `env:"INTERVAL" envDefault:"6h"`
Timeout time.Duration `env:"TIMEOUT" envDefault:"45m"`
}
func NewConfig() (*Config, error) {
cfg := &Config{}
if err := env.ParseWithOptions(cfg, env.Options{
PrefixTagName: "GSN_PREDICTOR_GRIB_UPDATER_",
}); err != nil {
return nil, errcodes.Wrap(err, "failed to parse GRIB updater config")
}
return cfg, nil
}

View file

@ -1,8 +0,0 @@
package updater
import "context"
// GribService defines the interface for GRIB operations needed by the updater job
type GribService interface {
Update(ctx context.Context) error
}

View file

@ -1,51 +0,0 @@
package updater
import (
"context"
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
"git.intra.yksa.space/gsn/predictor/internal/pkg/log"
"go.uber.org/zap"
)
type Job struct {
service GribService
config *Config
}
func New(service GribService, config *Config) *Job {
return &Job{
service: service,
config: config,
}
}
func (j *Job) GetInterval() time.Duration {
return j.config.Interval
}
func (j *Job) GetTimeout() time.Duration {
return j.config.Timeout
}
func (j *Job) GetCount() int {
return 1
}
func (j *Job) GetAsync() bool {
return false
}
func (j *Job) Execute(ctx context.Context) error {
log := log.Ctx(ctx)
log.Info("executing GRIB update job")
if err := j.service.Update(ctx); err != nil {
log.Error("GRIB update failed", zap.Error(err))
return errcodes.Wrap(err, "failed to update GRIB data")
}
log.Info("GRIB update completed successfully")
return nil
}

View file

@ -1,89 +0,0 @@
package ds
import (
"time"
api "git.intra.yksa.space/gsn/predictor/pkg/rest"
)
type PredictionParameters struct {
LaunchLatitude *float64
LaunchLongitude *float64
LaunchDatetime *time.Time
LaunchAltitude *float64
Profile *string
AscentRate *float64
BurstAltitude *float64
DescentRate *float64
FloatAltitude *float64
StopDatetime *time.Time
AscentCurve *string // base64
DescentCurve *string // base64
Interpolate *bool
Format *string
Dataset *time.Time
// Add other parameters as needed
}
type PredicitonResult struct {
Latitude *float64
Longitude *float64
Altitude *float64
Timestamp *time.Time
WindU *float64
WindV *float64
// Add other result fields as needed
}
// Converts flat ogen params to internal pointer-based model
func ConvertFlatPredictionParams(params api.PerformPredictionParams) *PredictionParameters {
out := &PredictionParameters{}
if v, ok := params.LaunchLatitude.Get(); ok {
out.LaunchLatitude = &v
}
if v, ok := params.LaunchLongitude.Get(); ok {
out.LaunchLongitude = &v
}
if v, ok := params.LaunchDatetime.Get(); ok {
out.LaunchDatetime = &v
}
if v, ok := params.LaunchAltitude.Get(); ok {
out.LaunchAltitude = &v
}
if v, ok := params.Profile.Get(); ok {
s := string(v)
out.Profile = &s
}
if v, ok := params.AscentRate.Get(); ok {
out.AscentRate = &v
}
if v, ok := params.BurstAltitude.Get(); ok {
out.BurstAltitude = &v
}
if v, ok := params.DescentRate.Get(); ok {
out.DescentRate = &v
}
if v, ok := params.FloatAltitude.Get(); ok {
out.FloatAltitude = &v
}
if v, ok := params.StopDatetime.Get(); ok {
out.StopDatetime = &v
}
if v, ok := params.AscentCurve.Get(); ok {
out.AscentCurve = &v
}
if v, ok := params.DescentCurve.Get(); ok {
out.DescentCurve = &v
}
if v, ok := params.Interpolate.Get(); ok {
out.Interpolate = &v
}
if v, ok := params.Format.Get(); ok {
s := string(v)
out.Format = &s
}
if v, ok := params.Dataset.Get(); ok {
out.Dataset = &v
}
return out
}

View file

@ -1,102 +0,0 @@
package errcodes
import (
"net/http"
"strings"
)
type ErrorCode struct {
StatusCode int
Message string
Details string
}
func New(statusCode int, message string, details ...string) *ErrorCode {
return &ErrorCode{
StatusCode: statusCode,
Message: message,
Details: strings.Join(details, " "),
}
}
func (e *ErrorCode) Error() string {
return e.Message
}
func IsErr(err error) bool {
_, ok := err.(*ErrorCode)
return ok
}
func AsErr(err error) (*ErrorCode, bool) {
if err == nil {
return nil, false
}
errcode, ok := err.(*ErrorCode)
return errcode, ok
}
func Join(errs ...error) error {
if len(errs) == 0 {
return nil
}
var messages []string
var details []string
for _, err := range errs {
if err == nil {
continue
}
if errcode, ok := AsErr(err); ok {
messages = append(messages, errcode.Message)
if errcode.Details != "" {
details = append(details, errcode.Details)
}
} else {
messages = append(messages, err.Error())
}
}
if len(messages) == 0 {
return nil
}
statusCode := http.StatusInternalServerError
if len(errs) > 0 {
if errcode, ok := AsErr(errs[0]); ok {
statusCode = errcode.StatusCode
}
}
return New(statusCode, strings.Join(messages, "; "), details...)
}
func Wrap(err error, message string) error {
if err == nil {
return nil
}
if errcode, ok := AsErr(err); ok {
return New(errcode.StatusCode, message, errcode.Message, errcode.Details)
}
return New(http.StatusInternalServerError, message, err.Error())
}
var (
ErrNoDataset = New(http.StatusNotFound, "no grib dataset found")
ErrOutOfBounds = New(http.StatusBadRequest, "requested time is out of bounds")
ErrConfig = New(http.StatusInternalServerError, "configuration error")
ErrConfigInvalidEnv = New(http.StatusInternalServerError, "invalid environment configuration")
ErrConfigMissingRequired = New(http.StatusInternalServerError, "missing required configuration")
ErrDownload = New(http.StatusInternalServerError, "download error")
ErrProcessing = New(http.StatusInternalServerError, "data processing error")
ErrNoCubeFilesFound = New(http.StatusNotFound, "no cube files found")
ErrNoValidCubeFilesFound = New(http.StatusNotFound, "no valid cube files found")
ErrLatestCubeFileIsTooOld = New(http.StatusNotFound, "latest cube file is too old")
ErrScheduler = New(http.StatusInternalServerError, "scheduler error")
ErrSchedulerInvalidJob = New(http.StatusBadRequest, "invalid job configuration")
ErrSchedulerTimeoutTooLong = New(http.StatusBadRequest, "job timeout too long", "timeout cannot exceed interval")
)

View file

@ -1,100 +0,0 @@
# GRIB Module
Этот модуль реализует функциональность для работы с GRIB-файлами, аналогичную tawhiri-downloader и tawhiri, но на Go.
## Основные возможности
- **Скачивание GRIB-файлов** с NOMADS (GFS прогнозы)
- **Сборка 5D-куба** (время, давление, широта, долгота, переменные u/v)
- **Эффективное хранение** с использованием mmap
- **Интерполяция** ветровых данных для произвольных координат и времени
- **Кэширование** результатов (in-memory)
- **Распределенные блокировки** для предотвращения дублирования загрузок
## Архитектура
### Основные компоненты
- **Downloader** - скачивает GRIB-файлы с NOMADS
- **Cube** - управляет 5D-массивом данных через mmap
- **Extractor** - выполняет интерполяцию данных
- **Cache** - кэширует результаты запросов
- **Service** - основной интерфейс для работы с модулем
### Структура данных
5D-куб содержит:
- **Время**: 17 временных срезов (0, 3, 6, ..., 48 часов)
- **Давление**: 34 уровня давления (1000, 975, 950, ..., 2 hPa)
- **Широта**: 361 точка (-90° до +90°)
- **Долгота**: 720 точек (0° до 359.5°)
- **Переменные**: u-ветер и v-ветер
## Использование
```go
// Создание сервиса
cfg := grib.ServiceConfig{
Dir: "/tmp/grib",
TTL: 24 * time.Hour,
CacheTTL: 1 * time.Hour,
Parallel: 4,
Client: &http.Client{Timeout: 30 * time.Second},
}
service, err := grib.New(cfg)
if err != nil {
log.Fatal(err)
}
defer service.Close()
// Обновление данных
err = service.Update(ctx)
// Извлечение ветровых данных
wind, err := service.Extract(ctx, lat, lon, alt, timestamp)
// wind[0] - u-компонента ветра
// wind[1] - v-компонента ветра
```
## Интерполяция
Модуль выполняет 16-точечную интерполяцию:
1. **Временная интерполяция** между двумя ближайшими срезами
2. **Интерполяция по давлению** между двумя ближайшими уровнями
3. **Билинейная интерполяция** по широте и долготе
## Кэширование
- **In-memory кэш**: быстрый доступ к недавно запрошенным данным
## Расписание обновлений
Рекомендуемая частота вызова `Update()`:
- **Каждые 6 часов** - для получения свежих GFS прогнозов
- **При запуске** - для загрузки начальных данных
- **По требованию** - при отсутствии данных для запрашиваемого времени
## Отличия от tawhiri
### Преимущества Go-реализации:
- **Высокая производительность** (mmap, конкурентные загрузки)
- **Эффективное использование памяти** (не загружает весь массив в RAM)
- **Горизонтальное масштабирование** (stateless, множество реплик)
- **Встроенное кэширование** (in-memory)
### Особенности:
- Использует `github.com/nilsmagnus/grib` вместо pygrib
- Реализует собственную логику интерполяции
## Конфигурация
### Переменные окружения:
- `PREDICTOR_GRIB_DATASET_URL` - URL источника данных (опционально)
### Параметры ServiceConfig:
- `Dir` - директория для хранения файлов
- `TTL` - время жизни данных (по умолчанию 24 часа)
- `CacheTTL` - время жизни кэша (по умолчанию 1 час)
- `Parallel` - количество параллельных загрузок
- `Client` - HTTP клиент для загрузок

View file

@ -1,36 +0,0 @@
package grib
import (
"sync"
"time"
)
type vec [2]float64
type item struct {
v vec
exp time.Time
}
type memCache struct {
ttl time.Duration
m sync.Map
}
func (c *memCache) get(k uint64) (vec, bool) {
if v, ok := c.m.Load(k); ok {
it := v.(item)
if time.Now().Before(it.exp) {
return it.v, true
}
c.m.Delete(k)
}
return vec{}, false
}
func (c *memCache) set(k uint64, v vec) {
c.m.Store(k, item{v, time.Now().Add(c.ttl)})
}

View file

@ -1,32 +0,0 @@
package grib
import (
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
env "github.com/caarlos0/env/v11"
)
type Config struct {
Dir string `env:"DIR" envDefault:"/tmp/grib"`
TTL time.Duration `env:"TTL" envDefault:"24h"`
CacheTTL time.Duration `env:"CACHE_TTL" envDefault:"1h"`
Parallel int `env:"PARALLEL" envDefault:"8"`
DatasetURL string `env:"DATASET_URL" envDefault:"https://nomads.ncep.noaa.gov/pub/data/nccf/com/gfs/prod"`
// S3 configuration
UseS3 bool `env:"USE_S3" envDefault:"true"`
S3Bucket string `env:"S3_BUCKET" envDefault:"noaa-gfs-bdp-pds"`
S3Region string `env:"S3_REGION" envDefault:"us-east-1"`
S3Timeout time.Duration `env:"S3_TIMEOUT" envDefault:"300s"`
}
func NewConfig() (*Config, error) {
cfg := &Config{}
if err := env.ParseWithOptions(cfg, env.Options{
PrefixTagName: "GSN_PREDICTOR_GRIB_",
}); err != nil {
return nil, errcodes.Wrap(err, "failed to parse GRIB config")
}
return cfg, nil
}

View file

@ -1,55 +0,0 @@
package grib
import (
"encoding/binary"
"math"
"os"
mmap "github.com/edsrzf/mmap-go"
)
type cube struct {
mm mmap.MMap
t, p, lat, lon int
bytesPerVar int64
file *os.File
}
func openCube(path string) (*cube, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
mm, err := mmap.Map(f, mmap.RDONLY, 0)
if err != nil {
f.Close()
return nil, err
}
const (
nT = 97 // 0-96 hours with step 1 hour
nP = 47 // 47 pressure levels matching tawhiri
nLat = 361
nLon = 720
)
return &cube{mm: mm, t: nT, p: nP, lat: nLat, lon: nLon, bytesPerVar: int64(nT * nP * nLat * nLon * 4), file: f}, nil
}
func (c *cube) val(varIdx, ti, pi, y, x int) float32 {
idx := (((ti*c.p+pi)*c.lat + y) * c.lon) + x
off := int64(varIdx)*c.bytesPerVar + int64(idx)*4
bits := binary.LittleEndian.Uint32(c.mm[off : off+4])
return math.Float32frombits(bits)
}
func (c *cube) Close() error {
if c.mm != nil {
c.mm.Unmap()
}
if c.file != nil {
return c.file.Close()
}
return nil
}

View file

@ -1,13 +0,0 @@
package grib
type dataset struct {
cube *cube
runUTC int64 // unix seconds
}
func (d *dataset) Close() error {
if d.cube != nil {
return d.cube.Close()
}
return nil
}

View file

@ -1,91 +0,0 @@
package grib
import (
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
"golang.org/x/sync/errgroup"
)
type Downloader struct {
Dir string
Parallel int
Client *http.Client
DatasetURL string
}
func (d *Downloader) fileURL(run string, hour int, step int) string {
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.0p50.f%03d", d.DatasetURL, run, hour, hour, step)
}
func (d *Downloader) fetch(ctx context.Context, url, dst string) (err error) {
// Check if final file already exists
if _, err := os.Stat(dst); err == nil {
return nil
}
tmp := dst + ".part"
// Remove old .part file if it exists (fixes race condition)
os.Remove(tmp)
f, err := os.Create(tmp)
if err != nil {
return err
}
// Cleanup .part file on any error (using named return value)
defer func() {
f.Close()
if err != nil {
os.Remove(tmp)
}
}()
req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
resp, err := d.Client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return errcodes.Wrap(errcodes.ErrDownload, "bad status: "+resp.Status)
}
if _, err := io.Copy(f, resp.Body); err != nil {
return err
}
// Close file before rename
if err := f.Close(); err != nil {
return err
}
// If rename fails, err will be set and defer will cleanup .part file
return os.Rename(tmp, dst)
}
func (d *Downloader) Run(ctx context.Context, run time.Time) error {
runStr := run.Format("20060102")
hour := run.Hour()
g, ctx := errgroup.WithContext(ctx)
sem := make(chan struct{}, d.Parallel)
for _, step := range steps {
step := step
sem <- struct{}{}
g.Go(func() error {
defer func() { <-sem }()
url := d.fileURL(runStr, hour, step)
dst := filepath.Join(d.Dir, fileName(run, step))
return d.fetch(ctx, url, dst)
})
}
return g.Wait()
}

View file

@ -1,54 +0,0 @@
package grib
import "math"
func lerp(a, b, t float64) float64 { return a + t*(b-a) }
// Interpolate 16point (time, p, lat, lon)
func (d *dataset) uv(lat, lon, alt float64, tHours float64) (float64, float64) {
if lon < 0 {
lon += 360
}
iy := (lat + 90) * 2
y0 := int(math.Floor(iy))
y1 := y0 + 1
wy := iy - float64(y0)
ix := lon * 2
x0 := int(math.Floor(ix)) % d.cube.lon
x1 := (x0 + 1) % d.cube.lon
wx := ix - float64(x0)
// For hourly data (step = 1 hour)
it0 := int(math.Floor(tHours))
wt := tHours - float64(it0)
p := pressureFromAlt(alt)
ip0 := 0
for ip0+1 < len(pressureLevels) && pressureLevels[ip0+1] > p {
ip0++
}
ip1 := ip0 + 1
wp := (pressureLevels[ip0] - p) / (pressureLevels[ip0] - pressureLevels[ip1])
fetch := func(ti, pi int) (float64, float64) {
u00 := d.cube.val(1, ti, pi, y0, x0)
u10 := d.cube.val(1, ti, pi, y0, x1)
u01 := d.cube.val(1, ti, pi, y1, x0)
u11 := d.cube.val(1, ti, pi, y1, x1)
v00 := d.cube.val(2, ti, pi, y0, x0)
v10 := d.cube.val(2, ti, pi, y0, x1)
v01 := d.cube.val(2, ti, pi, y1, x0)
v11 := d.cube.val(2, ti, pi, y1, x1)
uxy := (1-wy)*((1-wx)*float64(u00)+wx*float64(u10)) + wy*((1-wx)*float64(u01)+wx*float64(u11))
vxy := (1-wy)*((1-wx)*float64(v00)+wx*float64(v10)) + wy*((1-wx)*float64(v01)+wx*float64(v11))
return uxy, vxy
}
u0p0, v0p0 := fetch(it0, ip0)
u0p1, v0p1 := fetch(it0, ip1)
u1p0, v1p0 := fetch(it0+1, ip0)
u1p1, v1p1 := fetch(it0+1, ip1)
uLow := lerp(u0p0, u0p1, wp)
vLow := lerp(v0p0, v0p1, wp)
uHig := lerp(u1p0, u1p1, wp)
vHig := lerp(v1p0, v1p1, wp)
u := lerp(uLow, uHig, wt)
v := lerp(vLow, vHig, wt)
return u, v
}

View file

@ -1,321 +0,0 @@
package grib
import (
"context"
"encoding/binary"
"math"
"net/http"
"os"
"path/filepath"
"strings"
"sync/atomic"
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
"github.com/edsrzf/mmap-go"
"github.com/nilsmagnus/grib/griblib"
)
type Service interface {
Update(ctx context.Context) error
Extract(ctx context.Context, lat, lon, alt float64, ts time.Time) ([2]float64, error)
Close() error
GetStatus() (ready bool, lastUpdate time.Time, isFresh bool, errMsg string)
}
type service struct {
cfg *Config
cache memCache
data atomic.Pointer[dataset]
}
func New(cfg *Config) (Service, error) {
if cfg.TTL == 0 {
cfg.TTL = 24 * time.Hour
}
if err := os.MkdirAll(cfg.Dir, 0o755); err != nil {
return nil, err
}
s := &service{cfg: cfg, cache: memCache{ttl: cfg.CacheTTL}}
// Try to load existing dataset on startup
if err := s.loadExistingDataset(); err != nil {
// Log error but don't fail startup - dataset will be loaded on first Update()
// This allows the service to start even if no data is available yet
}
return s, nil
}
// loadExistingDataset tries to load the most recent available dataset
func (s *service) loadExistingDataset() error {
// Find the most recent cube file
pattern := filepath.Join(s.cfg.Dir, "*.cube")
matches, err := filepath.Glob(pattern)
if err != nil {
return err
}
if len(matches) == 0 {
return errcodes.ErrNoCubeFilesFound
}
// Sort by modification time (newest first)
var latestFile string
var latestTime time.Time
for _, match := range matches {
info, err := os.Stat(match)
if err != nil {
continue
}
if info.ModTime().After(latestTime) {
latestTime = info.ModTime()
latestFile = match
}
}
if latestFile == "" {
return errcodes.ErrNoValidCubeFilesFound
}
// Check if the file is fresh enough
if time.Since(latestTime) > s.cfg.TTL {
return errcodes.Wrap(errcodes.ErrLatestCubeFileIsTooOld, "latest cube file is too old")
}
// Load the dataset
c, err := openCube(latestFile)
if err != nil {
return err
}
// Extract run time from filename
base := filepath.Base(latestFile)
runStr := strings.TrimSuffix(base, ".cube")
run, err := time.Parse("20060102_15", runStr)
if err != nil {
c.Close()
return err
}
ds := &dataset{cube: c, runUTC: run.Unix()}
s.data.Store(ds)
return nil
}
// Update() downloads missing GRIBs, assembles cube into a single mmapfile.
func (s *service) Update(ctx context.Context) error {
// Check if we already have fresh data
if d := s.data.Load(); d != nil {
runTime := time.Unix(d.runUTC, 0)
if time.Since(runTime) < s.cfg.TTL {
// Data is still fresh, no need to update
return nil
}
}
// Check again after acquiring lock (double-checked locking pattern)
if d := s.data.Load(); d != nil {
runTime := time.Unix(d.runUTC, 0)
if time.Since(runTime) < s.cfg.TTL {
// Another instance already updated the data
return nil
}
}
run := nearestRun(time.Now().UTC().Add(-24 * time.Hour))
// Check if we already have this run
cubePath := filepath.Join(s.cfg.Dir, run.Format("20060102_15")) + ".cube"
if _, err := os.Stat(cubePath); err == nil {
// File exists, check if it's fresh
info, err := os.Stat(cubePath)
if err == nil && time.Since(info.ModTime()) < s.cfg.TTL {
// File is fresh, just load it
c, err := openCube(cubePath)
if err != nil {
return err
}
ds := &dataset{cube: c, runUTC: run.Unix()}
s.data.Store(ds)
s.cache = memCache{ttl: s.cfg.CacheTTL}
return nil
}
}
// Download new data using S3 or HTTP
var downloadErr error
if s.cfg.UseS3 {
s3dl, err := NewS3Downloader(s.cfg.Dir, s.cfg.Parallel, s.cfg.S3Bucket, s.cfg.S3Region)
if err != nil {
return errcodes.Wrap(err, "failed to create S3 downloader")
}
downloadErr = s3dl.Run(ctx, run)
} else {
dl := Downloader{
Dir: s.cfg.Dir,
Parallel: s.cfg.Parallel,
Client: http.DefaultClient,
DatasetURL: s.cfg.DatasetURL,
}
downloadErr = dl.Run(ctx, run)
}
if downloadErr != nil {
return downloadErr
}
// Assemble cube if it doesn't exist
if _, err := os.Stat(cubePath); err != nil {
if err := assembleCube(s.cfg.Dir, run, cubePath); err != nil {
return err
}
}
c, err := openCube(cubePath)
if err != nil {
return err
}
ds := &dataset{cube: c, runUTC: run.Unix()}
s.data.Store(ds)
s.cache = memCache{ttl: s.cfg.CacheTTL}
return nil
}
func assembleCube(dir string, run time.Time, cubePath string) error {
const sizePerVar = 97 * 47 * 361 * 720 * 4 // 97 time steps (0-96 hours), 47 pressure levels
total := int64(sizePerVar * 3) // 3 variables: gh, u, v
f, err := os.Create(cubePath)
if err != nil {
return err
}
if err := f.Truncate(total); err != nil {
return err
}
mm, err := mmap.MapRegion(f, int(total), mmap.RDWR, 0, 0)
if err != nil {
return err
}
defer mm.Unmap()
defer f.Close()
pIndex := make(map[int]int)
for i, p := range pressureLevels {
pIndex[int(math.Round(p))] = i
}
for ti, step := range steps {
fn := filepath.Join(dir, fileName(run, step))
file, err := os.Open(fn)
if err != nil {
return err
}
messages, err := griblib.ReadMessages(file)
file.Close() // Close immediately after reading
if err != nil {
return err
}
for _, m := range messages {
// Check if this is a wind component (u or v) or geopotential height
// ParameterCategory 2 = momentum, ParameterNumber 2 = u-wind, 3 = v-wind
// ParameterCategory 3 = mass, ParameterNumber 5 = geopotential height
if m.Section4.ProductDefinitionTemplateNumber != 0 {
continue
}
product := m.Section4.ProductDefinitionTemplate
var varIdx int
// Match tawhiri variable order: ['gh', 'u', 'v'] (indices 0, 1, 2)
if product.ParameterCategory == 2 {
switch product.ParameterNumber {
case 2: // u-wind
varIdx = 1
case 3: // v-wind
varIdx = 2
default:
continue
}
} else if product.ParameterCategory == 3 && product.ParameterNumber == 5 {
// geopotential height
varIdx = 0
} else {
continue
}
// Check if this is a pressure level (type 100)
if product.FirstSurface.Type != 100 {
continue
}
// Get pressure level in hPa
pressure := float64(product.FirstSurface.Value) / 100.0
pIdx, ok := pIndex[int(math.Round(pressure))]
if !ok {
continue
}
vals := m.Data()
// GRIB library returns scan north->south, west->east already in row-major order
raw := make([]byte, len(vals)*4)
for i, v := range vals {
binary.LittleEndian.PutUint32(raw[i*4:], math.Float32bits(float32(v)))
}
base := int64(varIdx*sizePerVar + (ti*47+pIdx)*361*720*4)
copy(mm[base:base+int64(len(raw))], raw)
}
}
return mm.Flush()
}
func (s *service) Extract(ctx context.Context, lat, lon, alt float64, ts time.Time) ([2]float64, error) {
var zero [2]float64
d := s.data.Load()
if d == nil {
return zero, errcodes.ErrNoDataset
}
if ts.Before(time.Unix(d.runUTC, 0)) || ts.After(time.Unix(d.runUTC, 0).Add(96*time.Hour)) {
return zero, errcodes.ErrOutOfBounds
}
// Try memory cache first
key := encodeKey(lat, lon, alt, ts)
if v, ok := s.cache.get(key); ok {
return [2]float64(v), nil
}
// Calculate result
td := ts.Sub(time.Unix(d.runUTC, 0)).Hours()
u, v := d.uv(lat, lon, alt, td)
out := [2]float64{u, v}
// Cache in memory
s.cache.set(key, vec(out))
return out, nil
}
func (s *service) Close() error {
if d := s.data.Load(); d != nil {
return d.Close()
}
return nil
}
func (s *service) GetStatus() (ready bool, lastUpdate time.Time, isFresh bool, errMsg string) {
d := s.data.Load()
if d == nil {
return false, time.Time{}, false, "no dataset loaded"
}
runTime := time.Unix(d.runUTC, 0)
fresh := time.Since(runTime) < s.cfg.TTL
if !fresh {
return false, runTime, false, "dataset is too old"
}
return true, runTime, true, ""
}

View file

@ -1,16 +0,0 @@
package grib
import "math"
// 47 pressure levels matching tawhiri configuration
var pressureLevels = []float64{
1000, 975, 950, 925, 900, 875, 850, 825, 800, 775,
750, 725, 700, 675, 650, 625, 600, 575, 550, 525,
500, 475, 450, 425, 400, 375, 350, 325, 300, 275,
250, 225, 200, 175, 150, 125, 100, 70, 50, 30,
20, 10, 7, 5, 3, 2, 1,
}
func pressureFromAlt(alt float64) float64 { // ICAO ISA
return 1013.25 * math.Pow(1-alt/44307.69396, 5.255877)
}

View file

@ -1,265 +0,0 @@
package grib
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/s3"
"golang.org/x/sync/errgroup"
)
// S3Downloader downloads GRIB files from AWS S3
type S3Downloader struct {
Dir string
Parallel int
Bucket string
Region string
Client *s3.Client
}
// NewS3Downloader creates a new S3 downloader with anonymous access
func NewS3Downloader(dir string, parallel int, bucket, region string) (*S3Downloader, error) {
// Create AWS config with anonymous credentials for public bucket
cfg, err := config.LoadDefaultConfig(context.Background(),
config.WithRegion(region),
config.WithCredentialsProvider(aws.AnonymousCredentials{}),
)
if err != nil {
return nil, errcodes.Wrap(err, "failed to load AWS config")
}
client := s3.NewFromConfig(cfg)
return &S3Downloader{
Dir: dir,
Parallel: parallel,
Bucket: bucket,
Region: region,
Client: client,
}, nil
}
// s3Key generates the S3 key for a GRIB file
// Path format: gfs.YYYYMMDD/HH/atmos/gfs.tHHz.pgrb2.0p50.fFFF
func (d *S3Downloader) s3Key(run string, hour int, step int) string {
return fmt.Sprintf("gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.0p50.f%03d", run, hour, hour, step)
}
// CheckFileExists checks if a file exists in S3 using HeadObject
func (d *S3Downloader) CheckFileExists(ctx context.Context, key string) (bool, int64, error) {
input := &s3.HeadObjectInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(key),
}
result, err := d.Client.HeadObject(ctx, input)
if err != nil {
// Check if error is NotFound
// AWS SDK v2 doesn't export specific error types, check error string
if isNotFoundError(err) {
return false, 0, nil
}
return false, 0, errcodes.Wrap(err, "failed to check file existence")
}
size := int64(0)
if result.ContentLength != nil {
size = *result.ContentLength
}
return true, size, nil
}
// isNotFoundError checks if error is a NotFound error
func isNotFoundError(err error) bool {
if err == nil {
return false
}
// AWS SDK v2 error handling
errStr := err.Error()
return contains(errStr, "NotFound") || contains(errStr, "404") || contains(errStr, "NoSuchKey")
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && findSubstring(s, substr))
}
func findSubstring(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// ListAvailableFiles lists all available files for a given run
func (d *S3Downloader) ListAvailableFiles(ctx context.Context, run string, hour int) ([]string, error) {
prefix := fmt.Sprintf("gfs.%s/%02d/atmos/", run, hour)
input := &s3.ListObjectsV2Input{
Bucket: aws.String(d.Bucket),
Prefix: aws.String(prefix),
}
var files []string
paginator := s3.NewListObjectsV2Paginator(d.Client, input)
for paginator.HasMorePages() {
page, err := paginator.NextPage(ctx)
if err != nil {
return nil, errcodes.Wrap(err, "failed to list S3 objects")
}
for _, obj := range page.Contents {
if obj.Key != nil {
files = append(files, *obj.Key)
}
}
}
return files, nil
}
// fetchFromS3 downloads a file from S3 to local disk with retry logic
func (d *S3Downloader) fetchFromS3(ctx context.Context, key, dst string) (err error) {
// Check if final file already exists
if _, err := os.Stat(dst); err == nil {
return nil
}
const maxRetries = 3
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
if attempt > 0 {
// Exponential backoff: 2s, 4s, 8s
waitTime := time.Duration(1<<uint(attempt)) * time.Second
time.Sleep(waitTime)
}
lastErr = d.fetchFromS3Once(ctx, key, dst)
if lastErr == nil {
return nil
}
}
return errcodes.Wrap(lastErr, fmt.Sprintf("failed after %d retries", maxRetries))
}
// fetchFromS3Once performs a single download attempt
func (d *S3Downloader) fetchFromS3Once(ctx context.Context, key, dst string) (err error) {
tmp := dst + ".part"
// Remove old .part file if it exists
os.Remove(tmp)
f, err := os.Create(tmp)
if err != nil {
return err
}
fileClosed := false
// Cleanup .part file on any error (using named return value)
defer func() {
if !fileClosed {
f.Close()
}
if err != nil {
os.Remove(tmp)
}
}()
// Check if file exists in S3
exists, size, checkErr := d.CheckFileExists(ctx, key)
if checkErr != nil {
return errcodes.Wrap(checkErr, "failed to check S3 file existence")
}
if !exists {
return errcodes.Wrap(errcodes.ErrDownload, fmt.Sprintf("file not found in S3: %s", key))
}
// Download from S3
input := &s3.GetObjectInput{
Bucket: aws.String(d.Bucket),
Key: aws.String(key),
}
result, err := d.Client.GetObject(ctx, input)
if err != nil {
return errcodes.Wrap(err, "failed to get S3 object")
}
defer result.Body.Close()
// Copy to local file
written, err := io.Copy(f, result.Body)
if err != nil {
return errcodes.Wrap(err, fmt.Sprintf("failed to write S3 object to file %s", dst))
}
// Verify size if available
if size > 0 && written != size {
return errcodes.Wrap(errcodes.ErrDownload, fmt.Sprintf("size mismatch: got %d bytes, expected %d", written, size))
}
// Close file before rename
if err := f.Close(); err != nil {
return err
}
fileClosed = true
// If rename fails, err will be set and defer will cleanup .part file
return os.Rename(tmp, dst)
}
// Run downloads all required GRIB files for a forecast run
func (d *S3Downloader) Run(ctx context.Context, run time.Time) error {
runStr := run.Format("20060102")
hour := run.Hour()
// First, list available files to verify they exist
availableFiles, err := d.ListAvailableFiles(ctx, runStr, hour)
if err != nil {
return errcodes.Wrap(err, "failed to list available files")
}
if len(availableFiles) == 0 {
return errcodes.Wrap(errcodes.ErrDownload, fmt.Sprintf("no files found for run %s/%02d", runStr, hour))
}
// Build a map of available files for quick lookup
availableMap := make(map[string]bool)
for _, file := range availableFiles {
availableMap[file] = true
}
g, ctx := errgroup.WithContext(ctx)
sem := make(chan struct{}, d.Parallel)
for _, step := range steps {
step := step
key := d.s3Key(runStr, hour, step)
// Check if file is available in S3
if !availableMap[key] {
// Log warning but don't fail - some forecast hours might not be available yet
continue
}
sem <- struct{}{}
g.Go(func() error {
defer func() { <-sem }()
dst := filepath.Join(d.Dir, fileName(run, step))
return d.fetchFromS3(ctx, key, dst)
})
}
return g.Wait()
}

View file

@ -1,34 +0,0 @@
package grib
import (
"fmt"
"hash/fnv"
"time"
)
// Generate steps from 0 to 96 with step 1 hour (97 steps total)
// GFS provides hourly data for 0-120 hours, we use first 96 hours
var steps = func() []int {
result := make([]int, 0, 97)
for i := 0; i <= 96; i++ {
result = append(result, i)
}
return result
}()
func nearestRun(t time.Time) time.Time {
h := t.UTC().Hour() - t.UTC().Hour()%6
return time.Date(t.Year(), t.Month(), t.Day(), h, 0, 0, 0, time.UTC)
}
func fileName(run time.Time, step int) string {
return fmt.Sprintf("gfs.t%02dz.pgrb2.0p50.f%03d", run.Hour(), step)
}
func encodeKey(a ...any) uint64 {
h := fnv.New64a()
for _, v := range a {
fmt.Fprint(h, v)
}
return h.Sum64()
}

View file

@ -1,23 +0,0 @@
package log
import (
"context"
"go.uber.org/zap"
)
type ctxLogKey struct{}
func ToCtx(ctx context.Context, lg *zap.Logger) context.Context {
return context.WithValue(ctx, ctxLogKey{}, lg)
}
func Ctx(ctx context.Context) *zap.Logger {
lg, ok := ctx.Value(ctxLogKey{}).(*zap.Logger)
if !ok || lg == nil {
zap.L().Error("no logger in context, using global")
return zap.L()
}
return lg
}

View file

@ -0,0 +1,153 @@
package prediction
import (
"fmt"
"predictor-refactored/internal/dataset"
)
// Exact port of the reference interpolation logic (interpolate.pyx).
// 4D interpolation: time, latitude, longitude, altitude (via geopotential height).
// lerp1 holds an index and interpolation weight for one axis.
type lerp1 struct {
index int
lerp float64
}
// lerp3 holds indices and a combined weight for the (hour, lat, lon) axes.
type lerp3 struct {
hour, lat, lng int
lerp float64
}
// RangeError indicates a coordinate is outside the dataset bounds.
type RangeError struct {
Variable string
Value float64
}
func (e *RangeError) Error() string {
return fmt.Sprintf("%s=%f out of range", e.Variable, e.Value)
}
// pick computes interpolation indices and weights for a single axis.
// left: axis start, step: axis spacing, n: number of points, value: query value.
// Returns two lerp1 values (lower and upper bracket).
func pick(left, step float64, n int, value float64, variableName string) ([2]lerp1, error) {
a := (value - left) / step
b := int(a) // truncation toward zero, same as Cython <long> cast
if b < 0 || b >= n-1 {
return [2]lerp1{}, &RangeError{Variable: variableName, Value: value}
}
l := a - float64(b)
return [2]lerp1{
{index: b, lerp: 1 - l},
{index: b + 1, lerp: l},
}, nil
}
// pick3 computes 8 trilinear interpolation weights for (hour, lat, lng).
func pick3(hour, lat, lng float64) ([8]lerp3, error) {
lhour, err := pick(0, 3, 65, hour, "hour")
if err != nil {
return [8]lerp3{}, err
}
llat, err := pick(-90, 0.5, 361, lat, "lat")
if err != nil {
return [8]lerp3{}, err
}
// Longitude wraps: tell pick the axis is one larger, then wrap index 720 → 0
llng, err := pick(0, 0.5, 720+1, lng, "lng")
if err != nil {
return [8]lerp3{}, err
}
if llng[1].index == 720 {
llng[1].index = 0
}
var out [8]lerp3
i := 0
for _, a := range lhour {
for _, b := range llat {
for _, c := range llng {
out[i] = lerp3{
hour: a.index,
lat: b.index,
lng: c.index,
lerp: a.lerp * b.lerp * c.lerp,
}
i++
}
}
}
return out, nil
}
// interp3 performs 8-point weighted interpolation at a given variable and pressure level.
func interp3(ds *dataset.File, lerps [8]lerp3, variable, level int) float64 {
var r float64
for i := 0; i < 8; i++ {
v := ds.Val(lerps[i].hour, level, variable, lerps[i].lat, lerps[i].lng)
r += float64(v) * lerps[i].lerp
}
return r
}
// search finds the largest pressure level index where interpolated geopotential
// height is less than the target altitude. Searches levels 0..45 (excludes topmost).
func search(ds *dataset.File, lerps [8]lerp3, target float64) int {
lower, upper := 0, 45
for lower < upper {
mid := (lower + upper + 1) / 2
test := interp3(ds, lerps, dataset.VarHeight, mid)
if target <= test {
upper = mid - 1
} else {
lower = mid
}
}
return lower
}
// interp4 performs altitude-interpolated wind lookup using two bracketing levels.
func interp4(ds *dataset.File, lerps [8]lerp3, altLerp lerp1, variable int) float64 {
lower := interp3(ds, lerps, variable, altLerp.index)
upper := interp3(ds, lerps, variable, altLerp.index+1)
return lower*altLerp.lerp + upper*(1-altLerp.lerp)
}
// GetWind returns interpolated (u, v) wind components for the given position.
// hour: fractional hours since dataset start.
// lat: latitude in degrees (-90 to +90).
// lng: longitude in degrees (0 to 360).
// alt: altitude in metres above sea level.
func GetWind(ds *dataset.File, warnings *Warnings, hour, lat, lng, alt float64) (u, v float64, err error) {
lerps, err := pick3(hour, lat, lng)
if err != nil {
return 0, 0, err
}
altidx := search(ds, lerps, alt)
lower := interp3(ds, lerps, dataset.VarHeight, altidx)
upper := interp3(ds, lerps, dataset.VarHeight, altidx+1)
var altLerp float64
if lower != upper {
altLerp = (upper - alt) / (upper - lower)
} else {
altLerp = 0.5
}
if altLerp < 0 {
warnings.AltitudeTooHigh.Add(1)
}
alt1 := lerp1{index: altidx, lerp: altLerp}
u = interp4(ds, lerps, alt1, dataset.VarWindU)
v = interp4(ds, lerps, alt1, dataset.VarWindV)
return u, v, nil
}

View file

@ -0,0 +1,188 @@
package prediction
import (
"math"
"time"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/elevation"
)
// Exact port of the reference flight models (models.py).
const (
pi180 = math.Pi / 180.0
_180pi = 180.0 / math.Pi
)
// --- Up/Down Models ---
// ConstantAscent returns a model with constant vertical velocity (m/s).
func ConstantAscent(ascentRate float64) Model {
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
return 0, 0, ascentRate
}
}
// DragDescent returns a descent-under-parachute model.
// seaLevelDescentRate is the descent rate at sea level (m/s, positive value).
// Uses the NASA atmosphere model for density at altitude.
func DragDescent(seaLevelDescentRate float64) Model {
dragCoefficient := seaLevelDescentRate * 1.1045
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
return 0, 0, -dragCoefficient / math.Sqrt(nasaDensity(alt))
}
}
// nasaDensity computes air density using the NASA atmosphere model.
// Reference: http://www.grc.nasa.gov/WWW/K-12/airplane/atmosmet.html
func nasaDensity(alt float64) float64 {
var temp, pressure float64
switch {
case alt > 25000:
temp = -131.21 + 0.00299*alt
pressure = 2.488 * math.Pow((temp+273.1)/216.6, -11.388)
case alt > 11000:
temp = -56.46
pressure = 22.65 * math.Exp(1.73-0.000157*alt)
default:
temp = 15.04 - 0.00649*alt
pressure = 101.29 * math.Pow((temp+273.1)/288.08, 5.256)
}
return pressure / (0.2869 * (temp + 273.1))
}
// --- Sideways Models ---
// WindVelocity returns a model that gives lateral movement at the wind velocity.
// ds is the wind dataset, dsEpoch is the dataset start time as UNIX timestamp.
func WindVelocity(ds *dataset.File, dsEpoch float64, warnings *Warnings) Model {
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
tHours := (t - dsEpoch) / 3600.0
u, v, err := GetWind(ds, warnings, tHours, lat, lng, alt)
if err != nil {
return 0, 0, 0
}
R := 6371009.0 + alt
dlat = _180pi * v / R
dlng = _180pi * u / (R * math.Cos(lat*pi180))
return dlat, dlng, 0
}
}
// --- Model Combinations ---
// LinearModel returns a model that sums all component models.
func LinearModel(models ...Model) Model {
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
for _, m := range models {
d1, d2, d3 := m(t, lat, lng, alt)
dlat += d1
dlng += d2
dalt += d3
}
return
}
}
// --- Termination Criteria ---
// BurstTermination returns a terminator that fires when altitude >= burstAltitude.
func BurstTermination(burstAltitude float64) Terminator {
return func(t, lat, lng, alt float64) bool {
return alt >= burstAltitude
}
}
// SeaLevelTermination fires when altitude <= 0.
func SeaLevelTermination(t, lat, lng, alt float64) bool {
return alt <= 0
}
// TimeTermination returns a terminator that fires when t > maxTime.
func TimeTermination(maxTime float64) Terminator {
return func(t, lat, lng, alt float64) bool {
return t > maxTime
}
}
// ElevationTermination returns a terminator that fires when alt < ground level.
// Uses ruaumoko-compatible elevation data. Longitude is normalised internally.
func ElevationTermination(elev *elevation.Dataset) Terminator {
return func(t, lat, lng, alt float64) bool {
return elev.Get(lat, lng) > alt
}
}
// --- Pre-Defined Profiles ---
// Stage pairs a model with its termination criterion.
type Stage struct {
Model Model
Terminator Terminator
}
// StandardProfile creates the chain for a standard high-altitude balloon flight:
// ascent at constant rate → burst → descent under parachute.
// If elev is non-nil, descent terminates at ground level; otherwise at sea level.
func StandardProfile(ascentRate, burstAltitude, descentRate float64,
ds *dataset.File, dsEpoch float64, warnings *Warnings,
elev *elevation.Dataset) []Stage {
wind := WindVelocity(ds, dsEpoch, warnings)
modelUp := LinearModel(ConstantAscent(ascentRate), wind)
termUp := BurstTermination(burstAltitude)
modelDown := LinearModel(DragDescent(descentRate), wind)
var termDown Terminator
if elev != nil {
termDown = ElevationTermination(elev)
} else {
termDown = Terminator(SeaLevelTermination)
}
return []Stage{
{Model: modelUp, Terminator: termUp},
{Model: modelDown, Terminator: termDown},
}
}
// FloatProfile creates the chain for a floating balloon flight:
// ascent to float altitude → float until stop time.
func FloatProfile(ascentRate, floatAltitude float64, stopTime time.Time,
ds *dataset.File, dsEpoch float64, warnings *Warnings) []Stage {
wind := WindVelocity(ds, dsEpoch, warnings)
modelUp := LinearModel(ConstantAscent(ascentRate), wind)
termUp := BurstTermination(floatAltitude)
modelFloat := wind
termFloat := TimeTermination(float64(stopTime.Unix()))
return []Stage{
{Model: modelUp, Terminator: termUp},
{Model: modelFloat, Terminator: termFloat},
}
}
// RunPrediction runs a prediction with the given profile stages.
// launchTime is a UNIX timestamp.
func RunPrediction(launchTime float64, lat, lng, alt float64, stages []Stage) []StageResult {
chain := make([]struct {
Model Model
Terminator Terminator
}, len(stages))
for i, s := range stages {
chain[i].Model = s.Model
chain[i].Terminator = s.Terminator
}
return Solve(launchTime, lat, lng, alt, chain)
}

View file

@ -0,0 +1,180 @@
package prediction
import "math"
// Exact port of the reference RK4 solver (solver.pyx).
// Integrates balloon state using RK4 with dt=60 seconds.
// Termination uses binary search refinement (tolerance 0.01).
// Vec holds the balloon state: latitude, longitude, altitude.
type Vec struct {
Lat float64
Lng float64
Alt float64
}
// Model is a function that returns (dlat/dt, dlng/dt, dalt/dt) given state.
// t is UNIX timestamp, lat/lng in degrees, alt in metres.
type Model func(t float64, lat, lng, alt float64) (dlat, dlng, dalt float64)
// Terminator returns true when integration should stop.
type Terminator func(t float64, lat, lng, alt float64) bool
// StageResult holds the trajectory points for one flight stage.
type StageResult struct {
Points []TrajectoryPoint
}
// TrajectoryPoint is a single point in a trajectory (used by solver).
type TrajectoryPoint struct {
T float64 // UNIX timestamp
Lat float64
Lng float64
Alt float64
}
// pymod returns a % b with Python semantics (always non-negative when b > 0).
func pymod(a, b float64) float64 {
r := math.Mod(a, b)
if r < 0 {
r += b
}
return r
}
// vecadd returns a + k*b, with lng wrapped to [0, 360).
func vecadd(a Vec, k float64, b Vec) Vec {
return Vec{
Lat: a.Lat + k*b.Lat,
Lng: pymod(a.Lng+k*b.Lng, 360.0),
Alt: a.Alt + k*b.Alt,
}
}
// scalarLerp returns (1-l)*a + l*b.
func scalarLerp(a, b, l float64) float64 {
return (1-l)*a + l*b
}
// lngLerp interpolates longitude handling the 0/360 wrap-around.
func lngLerp(a, b, l float64) float64 {
l2 := 1 - l
if a > b {
a, b = b, a
l, l2 = l2, l
}
// distance round one way: b - a
// distance around other: (a + 360) - b
if b-a < 180.0 {
return l2*a + l*b
}
return pymod(l2*(a+360)+l*b, 360.0)
}
// vecLerp returns (1-l)*a + l*b with proper longitude wrapping.
func vecLerp(a, b Vec, l float64) Vec {
return Vec{
Lat: scalarLerp(a.Lat, b.Lat, l),
Lng: lngLerp(a.Lng, b.Lng, l),
Alt: scalarLerp(a.Alt, b.Alt, l),
}
}
// rk4 integrates from initial conditions using RK4.
// dt=60.0 seconds, terminationTolerance=0.01.
func rk4(t float64, lat, lng, alt float64, model Model, terminator Terminator) []TrajectoryPoint {
const dt = 60.0
const terminationTolerance = 0.01
y := Vec{Lat: lat, Lng: lng, Alt: alt}
result := []TrajectoryPoint{{T: t, Lat: y.Lat, Lng: y.Lng, Alt: y.Alt}}
for {
// Evaluate model at 4 points (standard RK4)
k1lat, k1lng, k1alt := model(t, y.Lat, y.Lng, y.Alt)
k1 := Vec{Lat: k1lat, Lng: k1lng, Alt: k1alt}
mid1 := vecadd(y, dt/2, k1)
k2lat, k2lng, k2alt := model(t+dt/2, mid1.Lat, mid1.Lng, mid1.Alt)
k2 := Vec{Lat: k2lat, Lng: k2lng, Alt: k2alt}
mid2 := vecadd(y, dt/2, k2)
k3lat, k3lng, k3alt := model(t+dt/2, mid2.Lat, mid2.Lng, mid2.Alt)
k3 := Vec{Lat: k3lat, Lng: k3lng, Alt: k3alt}
end := vecadd(y, dt, k3)
k4lat, k4lng, k4alt := model(t+dt, end.Lat, end.Lng, end.Alt)
k4 := Vec{Lat: k4lat, Lng: k4lng, Alt: k4alt}
// y2 = y + dt/6*k1 + dt/3*k2 + dt/3*k3 + dt/6*k4
y2 := y
y2 = vecadd(y2, dt/6, k1)
y2 = vecadd(y2, dt/3, k2)
y2 = vecadd(y2, dt/3, k3)
y2 = vecadd(y2, dt/6, k4)
t2 := t + dt
if terminator(t2, y2.Lat, y2.Lng, y2.Alt) {
// Binary search to refine the termination point.
// Find l in [0, 1] such that (t3, y3) = lerp((t, y), (t2, y2), l)
// is near where the terminator fires.
left := 0.0
right := 1.0
var t3 float64
var y3 Vec
t3 = t2
y3 = y2
for right-left > terminationTolerance {
mid := (left + right) / 2
t3 = scalarLerp(t, t2, mid)
y3 = vecLerp(y, y2, mid)
if terminator(t3, y3.Lat, y3.Lng, y3.Alt) {
right = mid
} else {
left = mid
}
}
result = append(result, TrajectoryPoint{T: t3, Lat: y3.Lat, Lng: y3.Lng, Alt: y3.Alt})
break
}
// Update current state
t = t2
y = y2
result = append(result, TrajectoryPoint{T: t, Lat: y.Lat, Lng: y.Lng, Alt: y.Alt})
}
return result
}
// Solve runs through a chain of (model, terminator) stages.
// Returns one StageResult per stage.
func Solve(t, lat, lng, alt float64, chain []struct {
Model Model
Terminator Terminator
}) []StageResult {
var results []StageResult
for _, stage := range chain {
points := rk4(t, lat, lng, alt, stage.Model, stage.Terminator)
results = append(results, StageResult{Points: points})
// Next stage starts where this one ended
if len(points) > 0 {
last := points[len(points)-1]
t = last.T
lat = last.Lat
lng = last.Lng
alt = last.Alt
}
}
return results
}

View file

@ -0,0 +1,21 @@
package prediction
import "sync/atomic"
// Warnings tracks warning conditions during a prediction run.
type Warnings struct {
AltitudeTooHigh atomic.Int64
}
// ToMap returns warnings as a map suitable for JSON serialization.
// Only includes warnings that have fired.
func (w *Warnings) ToMap() map[string]any {
result := make(map[string]any)
if n := w.AltitudeTooHigh.Load(); n > 0 {
result["altitude_too_high"] = map[string]any{
"count": n,
"description": "The altitude went too high, above the max forecast wind. Wind data will be unreliable",
}
}
return result
}

View file

@ -1,12 +0,0 @@
package service
import (
"context"
"time"
)
type Grib interface {
Update(ctx context.Context) error
Extract(ctx context.Context, lat, lon, alt float64, ts time.Time) ([2]float64, error)
Close() error
}

View file

@ -1,516 +0,0 @@
package service
import (
"context"
"encoding/base64"
"encoding/json"
"math"
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/ds"
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
"git.intra.yksa.space/gsn/predictor/internal/pkg/log"
"go.uber.org/zap"
)
var ErrInvalidParameters = errcodes.New(400, "missing required prediction parameters")
// Stage represents a prediction stage (ascent, descent, float)
type Stage struct {
Name string
Results []ds.PredicitonResult
StartTime time.Time
EndTime time.Time
}
// CustomCurve represents a custom ascent/descent curve
type CustomCurve struct {
Altitude []float64 `json:"altitude"`
Time []float64 `json:"time"` // seconds from start
}
func (s *Service) PerformPrediction(ctx context.Context, params ds.PredictionParameters) ([]ds.PredicitonResult, error) {
// Validate required parameters
if params.LaunchLatitude == nil || params.LaunchLongitude == nil || params.LaunchAltitude == nil || params.LaunchDatetime == nil {
return nil, ErrInvalidParameters
}
// Get default values
profile := "standard_profile"
if params.Profile != nil {
profile = *params.Profile
}
ascentRate := 5.0
if params.AscentRate != nil {
ascentRate = *params.AscentRate
}
burstAltitude := 30000.0
if params.BurstAltitude != nil {
burstAltitude = *params.BurstAltitude
}
descentRate := 5.0
if params.DescentRate != nil {
descentRate = *params.DescentRate
}
floatAltitude := 0.0
if params.FloatAltitude != nil {
floatAltitude = *params.FloatAltitude
}
// Parse custom curves if provided
var ascentCurve, descentCurve *CustomCurve
if params.AscentCurve != nil && *params.AscentCurve != "" {
if curve, err := parseCustomCurve(*params.AscentCurve); err == nil {
ascentCurve = curve
}
}
if params.DescentCurve != nil && *params.DescentCurve != "" {
if curve, err := parseCustomCurve(*params.DescentCurve); err == nil {
descentCurve = curve
}
}
log.Ctx(ctx).Info("Starting prediction",
zap.String("profile", profile),
zap.Float64("lat", *params.LaunchLatitude),
zap.Float64("lon", *params.LaunchLongitude),
zap.Float64("alt", *params.LaunchAltitude),
zap.Time("time", *params.LaunchDatetime),
)
var allResults []ds.PredicitonResult
switch profile {
case "standard_profile":
allResults = s.standardProfile(ctx, params, ascentRate, burstAltitude, descentRate, ascentCurve, descentCurve)
case "float_profile":
allResults = s.floatProfile(ctx, params, ascentRate, burstAltitude, floatAltitude, descentRate, ascentCurve, descentCurve)
case "reverse_profile":
allResults = s.reverseProfile(ctx, params, ascentRate, burstAltitude, descentRate, ascentCurve, descentCurve)
case "custom_profile":
allResults = s.customProfile(ctx, params, ascentCurve, descentCurve)
default:
return nil, errcodes.New(400, "unsupported profile: "+profile)
}
log.Ctx(ctx).Info("Prediction complete", zap.Int("total_steps", len(allResults)))
return allResults, nil
}
func (s *Service) standardProfile(ctx context.Context, params ds.PredictionParameters, ascentRate, burstAltitude, descentRate float64, ascentCurve, descentCurve *CustomCurve) []ds.PredicitonResult {
var results []ds.PredicitonResult
// Stage 1: Ascent
ascentResults := s.simulateAscent(ctx, params, ascentRate, burstAltitude, ascentCurve)
results = append(results, ascentResults...)
if len(ascentResults) > 0 {
// Get final position from ascent
lastResult := ascentResults[len(ascentResults)-1]
// Stage 2: Descent
descentParams := ds.PredictionParameters{
LaunchLatitude: lastResult.Latitude,
LaunchLongitude: lastResult.Longitude,
LaunchAltitude: lastResult.Altitude,
LaunchDatetime: lastResult.Timestamp,
}
descentResults := s.simulateDescent(ctx, descentParams, descentRate, 0, descentCurve)
results = append(results, descentResults...)
}
return results
}
func (s *Service) floatProfile(ctx context.Context, params ds.PredictionParameters, ascentRate, burstAltitude, floatAltitude, descentRate float64, ascentCurve, descentCurve *CustomCurve) []ds.PredicitonResult {
var results []ds.PredicitonResult
// Stage 1: Ascent to float altitude
ascentResults := s.simulateAscent(ctx, params, ascentRate, floatAltitude, ascentCurve)
results = append(results, ascentResults...)
if len(ascentResults) > 0 {
// Stage 2: Float (simulate for some time)
lastResult := ascentResults[len(ascentResults)-1]
floatResults := s.simulateFloat(ctx, lastResult, 30*time.Minute) // Float for 30 minutes
results = append(results, floatResults...)
if len(floatResults) > 0 {
// Stage 3: Descent
finalFloat := floatResults[len(floatResults)-1]
descentParams := ds.PredictionParameters{
LaunchLatitude: finalFloat.Latitude,
LaunchLongitude: finalFloat.Longitude,
LaunchAltitude: finalFloat.Altitude,
LaunchDatetime: finalFloat.Timestamp,
}
descentResults := s.simulateDescent(ctx, descentParams, descentRate, 0, descentCurve)
results = append(results, descentResults...)
}
}
return results
}
func (s *Service) reverseProfile(ctx context.Context, params ds.PredictionParameters, ascentRate, burstAltitude, descentRate float64, ascentCurve, descentCurve *CustomCurve) []ds.PredicitonResult {
var results []ds.PredicitonResult
// Stage 1: Ascent
ascentResults := s.simulateAscent(ctx, params, ascentRate, burstAltitude, ascentCurve)
results = append(results, ascentResults...)
if len(ascentResults) > 0 {
// Stage 2: Descent to float altitude
lastResult := ascentResults[len(ascentResults)-1]
descentParams := ds.PredictionParameters{
LaunchLatitude: lastResult.Latitude,
LaunchLongitude: lastResult.Longitude,
LaunchAltitude: lastResult.Altitude,
LaunchDatetime: lastResult.Timestamp,
}
// Descent to float altitude (if specified)
floatAlt := 0.0
if params.FloatAltitude != nil {
floatAlt = *params.FloatAltitude
}
descentResults := s.simulateDescent(ctx, descentParams, descentRate, floatAlt, descentCurve)
results = append(results, descentResults...)
if floatAlt > 0 && len(descentResults) > 0 {
// Stage 3: Float
finalDescent := descentResults[len(descentResults)-1]
floatResults := s.simulateFloat(ctx, finalDescent, 30*time.Minute)
results = append(results, floatResults...)
}
}
return results
}
func (s *Service) customProfile(ctx context.Context, params ds.PredictionParameters, ascentCurve, descentCurve *CustomCurve) []ds.PredicitonResult {
var results []ds.PredicitonResult
if ascentCurve != nil {
ascentResults := s.simulateCustomAscent(ctx, params, ascentCurve)
results = append(results, ascentResults...)
}
if descentCurve != nil && len(results) > 0 {
lastResult := results[len(results)-1]
descentParams := ds.PredictionParameters{
LaunchLatitude: lastResult.Latitude,
LaunchLongitude: lastResult.Longitude,
LaunchAltitude: lastResult.Altitude,
LaunchDatetime: lastResult.Timestamp,
}
descentResults := s.simulateCustomDescent(ctx, descentParams, descentCurve)
results = append(results, descentResults...)
}
return results
}
func rk4Step(lat, lon, alt float64, t time.Time, dt float64, windFunc func(lat, lon, alt float64, t time.Time) (float64, float64), altRate float64) (float64, float64, float64) {
// Helper for RK4 integration step
toRad := math.Pi / 180.0
toDeg := 180.0 / math.Pi
R := func(alt float64) float64 { return 6371009.0 + alt }
f := func(lat, lon, alt float64, t time.Time) (float64, float64, float64) {
windU, windV := windFunc(lat, lon, alt, t)
Rnow := R(alt)
dlat := toDeg * windV / Rnow
dlon := toDeg * windU / (Rnow * math.Cos(lat*toRad))
return dlat, dlon, altRate
}
k1_lat, k1_lon, k1_alt := f(lat, lon, alt, t)
k2_lat, k2_lon, k2_alt := f(lat+0.5*k1_lat*dt, lon+0.5*k1_lon*dt, alt+0.5*k1_alt*dt, t.Add(time.Duration(0.5*dt)*time.Second))
k3_lat, k3_lon, k3_alt := f(lat+0.5*k2_lat*dt, lon+0.5*k2_lon*dt, alt+0.5*k2_alt*dt, t.Add(time.Duration(0.5*dt)*time.Second))
k4_lat, k4_lon, k4_alt := f(lat+k3_lat*dt, lon+k3_lon*dt, alt+k3_alt*dt, t.Add(time.Duration(dt)*time.Second))
latNew := lat + (dt/6.0)*(k1_lat+2*k2_lat+2*k3_lat+k4_lat)
lonNew := lon + (dt/6.0)*(k1_lon+2*k2_lon+2*k3_lon+k4_lon)
altNew := alt + (dt/6.0)*(k1_alt+2*k2_alt+2*k3_alt+k4_alt)
return latNew, lonNew, altNew
}
func (s *Service) simulateAscent(ctx context.Context, params ds.PredictionParameters, ascentRate, targetAltitude float64, customCurve *CustomCurve) []ds.PredicitonResult {
const dt = 10.0 // simulation step in seconds
const outputInterval = 60.0 // output every 60 seconds
lat := *params.LaunchLatitude
lon := *params.LaunchLongitude
alt := *params.LaunchAltitude
timeCur := *params.LaunchDatetime
results := make([]ds.PredicitonResult, 0, 1000)
latCopy := lat
lonCopy := lon
altCopy := alt
timeCopy := timeCur
wind := [2]float64{0, 0}
windU := wind[0]
windV := wind[1]
results = append(results, ds.PredicitonResult{
Latitude: &latCopy,
Longitude: &lonCopy,
Altitude: &altCopy,
Timestamp: &timeCopy,
WindU: &windU,
WindV: &windV,
})
nextOutputTime := timeCur.Add(time.Duration(outputInterval) * time.Second)
windFunc := func(lat, lon, alt float64, t time.Time) (float64, float64) {
w, err := s.ExtractWind(ctx, lat, lon, alt, t)
if err != nil {
log.Ctx(ctx).Warn("Wind extraction failed during ascent", zap.Error(err))
return 0, 0
}
return w[0], w[1]
}
for alt < targetAltitude {
altRate := ascentRate
if customCurve != nil {
altRate = s.getCustomAltitudeRate(customCurve, alt, ascentRate)
}
latNew, lonNew, altNew := rk4Step(lat, lon, alt, timeCur, dt, windFunc, altRate)
timeCur = timeCur.Add(time.Duration(dt) * time.Second)
lat = latNew
lon = lonNew
alt = altNew
if alt >= targetAltitude {
break
}
if !timeCur.Before(nextOutputTime) {
wU, wV := windFunc(lat, lon, alt, timeCur)
latCopy := lat
lonCopy := lon
altCopy := alt
timeCopy := timeCur
windU := wU
windV := wV
results = append(results, ds.PredicitonResult{
Latitude: &latCopy,
Longitude: &lonCopy,
Altitude: &altCopy,
Timestamp: &timeCopy,
WindU: &windU,
WindV: &windV,
})
nextOutputTime = nextOutputTime.Add(time.Duration(outputInterval) * time.Second)
}
}
return results
}
func (s *Service) simulateDescent(ctx context.Context, params ds.PredictionParameters, descentRate, targetAltitude float64, customCurve *CustomCurve) []ds.PredicitonResult {
const dt = 10.0 // simulation step in seconds
const outputInterval = 60.0 // output every 60 seconds
lat := *params.LaunchLatitude
lon := *params.LaunchLongitude
alt := *params.LaunchAltitude
timeCur := *params.LaunchDatetime
results := make([]ds.PredicitonResult, 0, 1000)
latCopy := lat
lonCopy := lon
altCopy := alt
timeCopy := timeCur
wind := [2]float64{0, 0}
windU := wind[0]
windV := wind[1]
results = append(results, ds.PredicitonResult{
Latitude: &latCopy,
Longitude: &lonCopy,
Altitude: &altCopy,
Timestamp: &timeCopy,
WindU: &windU,
WindV: &windV,
})
nextOutputTime := timeCur.Add(time.Duration(outputInterval) * time.Second)
windFunc := func(lat, lon, alt float64, t time.Time) (float64, float64) {
w, err := s.ExtractWind(ctx, lat, lon, alt, t)
if err != nil {
log.Ctx(ctx).Warn("Wind extraction failed during descent", zap.Error(err))
return 0, 0
}
return w[0], w[1]
}
for alt > targetAltitude {
altRate := -descentRate
if customCurve != nil {
altRate = -s.getCustomAltitudeRate(customCurve, alt, descentRate)
}
latNew, lonNew, altNew := rk4Step(lat, lon, alt, timeCur, dt, windFunc, altRate)
timeCur = timeCur.Add(time.Duration(dt) * time.Second)
lat = latNew
lon = lonNew
alt = altNew
if alt <= targetAltitude {
break
}
if !timeCur.Before(nextOutputTime) {
wU, wV := windFunc(lat, lon, alt, timeCur)
latCopy := lat
lonCopy := lon
altCopy := alt
timeCopy := timeCur
windU := wU
windV := wV
results = append(results, ds.PredicitonResult{
Latitude: &latCopy,
Longitude: &lonCopy,
Altitude: &altCopy,
Timestamp: &timeCopy,
WindU: &windU,
WindV: &windV,
})
nextOutputTime = nextOutputTime.Add(time.Duration(outputInterval) * time.Second)
}
}
return results
}
func (s *Service) simulateFloat(ctx context.Context, startResult ds.PredicitonResult, duration time.Duration) []ds.PredicitonResult {
const dt = 10.0 // simulation step in seconds
const outputInterval = 60.0 // output every 60 seconds
lat := *startResult.Latitude
lon := *startResult.Longitude
alt := *startResult.Altitude
timeCur := *startResult.Timestamp
endTime := timeCur.Add(duration)
results := make([]ds.PredicitonResult, 0, 1000)
// Always include the initial float point
latCopy := lat
lonCopy := lon
altCopy := alt
timeCopy := timeCur
wind := [2]float64{0, 0}
windU := wind[0]
windV := wind[1]
results = append(results, ds.PredicitonResult{
Latitude: &latCopy,
Longitude: &lonCopy,
Altitude: &altCopy,
Timestamp: &timeCopy,
WindU: &windU,
WindV: &windV,
})
var nextOutputTime = timeCur.Add(time.Duration(outputInterval) * time.Second)
for timeCur.Before(endTime) {
wind, err := s.ExtractWind(ctx, lat, lon, alt, timeCur)
if err != nil {
log.Ctx(ctx).Warn("Wind extraction failed during float", zap.Error(err))
break
}
latDot := (wind[1] / 111320.0)
lonDot := (wind[0] / (40075000.0 * math.Cos(lat*math.Pi/180) / 360.0))
lat += latDot * dt
lon += lonDot * dt
// alt remains constant during float
timeCur = timeCur.Add(time.Duration(dt) * time.Second)
if !timeCur.Before(nextOutputTime) {
latCopy := lat
lonCopy := lon
altCopy := alt
timeCopy := timeCur
windU := wind[0]
windV := wind[1]
results = append(results, ds.PredicitonResult{
Latitude: &latCopy,
Longitude: &lonCopy,
Altitude: &altCopy,
Timestamp: &timeCopy,
WindU: &windU,
WindV: &windV,
})
nextOutputTime = nextOutputTime.Add(time.Duration(outputInterval) * time.Second)
}
}
return results
}
func (s *Service) simulateCustomAscent(ctx context.Context, params ds.PredictionParameters, curve *CustomCurve) []ds.PredicitonResult {
// Implementation for custom ascent curve
// This would interpolate the altitude rate from the custom curve
return s.simulateAscent(ctx, params, 5.0, 30000.0, curve)
}
func (s *Service) simulateCustomDescent(ctx context.Context, params ds.PredictionParameters, curve *CustomCurve) []ds.PredicitonResult {
// Implementation for custom descent curve
// This would interpolate the altitude rate from the custom curve
return s.simulateDescent(ctx, params, 5.0, 0.0, curve)
}
func (s *Service) getCustomAltitudeRate(curve *CustomCurve, currentAltitude, defaultRate float64) float64 {
if curve == nil || len(curve.Altitude) < 2 {
return defaultRate
}
// Find the two points in the curve that bracket the current altitude
for i := 0; i < len(curve.Altitude)-1; i++ {
if curve.Altitude[i] <= currentAltitude && currentAltitude <= curve.Altitude[i+1] {
// Linear interpolation
alt1, alt2 := curve.Altitude[i], curve.Altitude[i+1]
time1, time2 := curve.Time[i], curve.Time[i+1]
if alt2 == alt1 {
return defaultRate
}
// Calculate rate (change in altitude per second)
if time2 > time1 {
return (alt2 - alt1) / (time2 - time1)
}
return defaultRate
}
}
return defaultRate
}
func parseCustomCurve(base64Data string) (*CustomCurve, error) {
data, err := base64.StdEncoding.DecodeString(base64Data)
if err != nil {
return nil, err
}
var curve CustomCurve
if err := json.Unmarshal(data, &curve); err != nil {
return nil, err
}
return &curve, nil
}

View file

@ -2,59 +2,244 @@ package service
import (
"context"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/log"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/downloader"
"predictor-refactored/internal/elevation"
"go.uber.org/zap"
)
// Service orchestrates the dataset lifecycle and provides prediction capabilities.
type Service struct {
grib Grib
mu sync.RWMutex
ds *dataset.File
elev *elevation.Dataset
cfg *downloader.Config
dl *downloader.Downloader
log *zap.Logger
updating sync.Mutex // prevents concurrent downloads
}
func New(gribService Grib) (*Service, error) {
svc := &Service{
grib: gribService,
// New creates a new Service.
func New(cfg *downloader.Config, log *zap.Logger) *Service {
return &Service{
cfg: cfg,
dl: downloader.NewDownloader(cfg, log),
log: log,
}
}
// LoadElevation loads the ruaumoko-compatible elevation dataset from path.
// If the file doesn't exist, elevation termination is disabled (falls back to sea level).
func (s *Service) LoadElevation(path string) {
ds, err := elevation.Open(path)
if err != nil {
s.log.Warn("elevation dataset not available, using sea-level termination",
zap.String("path", path), zap.Error(err))
return
}
s.elev = ds
s.log.Info("elevation dataset loaded", zap.String("path", path))
}
// Elevation returns the elevation dataset (may be nil).
func (s *Service) Elevation() *elevation.Dataset {
return s.elev
}
// Ready returns true if the service has a loaded dataset.
func (s *Service) Ready() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.ds != nil
}
// DatasetTime returns the forecast time of the currently loaded dataset.
func (s *Service) DatasetTime() (time.Time, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
if s.ds == nil {
return time.Time{}, false
}
return s.ds.DSTime, true
}
// Dataset returns the current dataset for reading.
func (s *Service) Dataset() *dataset.File {
s.mu.RLock()
defer s.mu.RUnlock()
return s.ds
}
// Update checks for and downloads new forecast data if needed.
func (s *Service) Update(ctx context.Context) error {
if !s.updating.TryLock() {
s.log.Info("update already in progress, skipping")
return nil
}
defer s.updating.Unlock()
// Check if current dataset is still fresh
if dsTime, ok := s.DatasetTime(); ok {
if time.Since(dsTime) < s.cfg.DatasetTTL {
s.log.Info("dataset still fresh, skipping update",
zap.Time("dataset_time", dsTime),
zap.Duration("age", time.Since(dsTime)))
return nil
}
}
return svc, nil
}
// Try loading an existing dataset from disk first
if err := s.loadExistingDataset(); err == nil {
return nil
}
// UpdateWeatherData updates weather forecast data using the configured grib service
func (s *Service) UpdateWeatherData(ctx context.Context) error {
return s.grib.Update(ctx)
}
// Find latest available model run
run, err := s.dl.FindLatestRun(ctx)
if err != nil {
return err
}
// ExtractWind extracts wind data for given coordinates and time
func (s *Service) ExtractWind(ctx context.Context, lat, lon, alt float64, ts time.Time) ([2]float64, error) {
return s.grib.Extract(ctx, lat, lon, alt, ts)
}
// Download and assemble
path, err := s.dl.Download(ctx, run)
if err != nil {
return err
}
// Update updates the GRIB data (implements updater.GribService)
func (s *Service) Update(ctx context.Context) error {
return s.UpdateWeatherData(ctx)
}
// Open the new dataset
ds, err := dataset.Open(path, run)
if err != nil {
return err
}
// Start starts the service
func (s *Service) Start() {
log.Ctx(context.Background()).Info("service started")
}
// Swap in the new dataset
s.setDataset(ds)
s.log.Info("dataset loaded", zap.Time("run", run), zap.String("path", path))
// Stop stops the service
func (s *Service) Stop() {
log.Ctx(context.Background()).Info("service stopped")
}
// Clean old datasets
s.cleanOldDatasets(path)
// Close closes the service and releases resources
func (s *Service) Close() error {
s.Stop()
return nil
}
func (s *Service) GetGribStatus(ctx context.Context) (ready bool, lastUpdate time.Time, isFresh bool, errMsg string) {
if gribStatus, ok := s.grib.(interface {
GetStatus() (ready bool, lastUpdate time.Time, isFresh bool, errMsg string)
}); ok {
return gribStatus.GetStatus()
// loadExistingDataset tries to find and load an existing dataset from the data directory.
func (s *Service) loadExistingDataset() error {
entries, err := os.ReadDir(s.cfg.DataDir)
if err != nil {
return err
}
return false, time.Time{}, false, "grib service does not implement GetStatus"
// Collect valid dataset files (name is YYYYMMDDHH, no extension, correct size)
type candidate struct {
name string
path string
run time.Time
}
var candidates []candidate
for _, e := range entries {
if e.IsDir() || strings.Contains(e.Name(), ".") {
continue
}
if len(e.Name()) != 10 {
continue
}
run, err := time.Parse("2006010215", e.Name())
if err != nil {
continue
}
path := filepath.Join(s.cfg.DataDir, e.Name())
info, err := os.Stat(path)
if err != nil || info.Size() != dataset.DatasetSize {
continue
}
if time.Since(run) > s.cfg.DatasetTTL {
continue
}
candidates = append(candidates, candidate{name: e.Name(), path: path, run: run})
}
if len(candidates) == 0 {
return os.ErrNotExist
}
// Pick the newest
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].run.After(candidates[j].run)
})
best := candidates[0]
ds, err := dataset.Open(best.path, best.run)
if err != nil {
return err
}
s.setDataset(ds)
s.log.Info("loaded existing dataset",
zap.Time("run", best.run),
zap.String("path", best.path))
return nil
}
// setDataset swaps the current dataset with a new one, closing the old one.
func (s *Service) setDataset(ds *dataset.File) {
s.mu.Lock()
old := s.ds
s.ds = ds
s.mu.Unlock()
if old != nil {
if err := old.Close(); err != nil {
s.log.Error("failed to close old dataset", zap.Error(err))
}
}
}
// cleanOldDatasets removes dataset files other than the one at keepPath.
func (s *Service) cleanOldDatasets(keepPath string) {
entries, err := os.ReadDir(s.cfg.DataDir)
if err != nil {
return
}
for _, e := range entries {
if e.IsDir() {
continue
}
path := filepath.Join(s.cfg.DataDir, e.Name())
if path == keepPath {
continue
}
// Remove old datasets and temp files
if len(e.Name()) == 10 || strings.HasSuffix(e.Name(), ".downloading") {
if err := os.Remove(path); err != nil {
s.log.Warn("failed to remove old file", zap.String("path", path), zap.Error(err))
} else {
s.log.Info("removed old dataset", zap.String("path", path))
}
}
}
}
// Close releases all resources.
func (s *Service) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.ds != nil {
err := s.ds.Close()
s.ds = nil
return err
}
return nil
}

View file

@ -3,42 +3,28 @@ package middleware
import (
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
"git.intra.yksa.space/gsn/predictor/internal/pkg/log"
"github.com/ogen-go/ogen/middleware"
"go.uber.org/zap"
)
func Logging() middleware.Middleware {
// Logging returns an ogen middleware that logs request duration.
func Logging(log *zap.Logger) middleware.Middleware {
return func(req middleware.Request, next func(req middleware.Request) (middleware.Response, error)) (middleware.Response, error) {
lg := log.Ctx(req.Context).With(
zap.String("operationId", req.OperationID),
)
lg.Info("started request")
req.Context = log.ToCtx(req.Context, lg)
lg := log.With(zap.String("operation", req.OperationID))
start := time.Now()
resp, err := next(req)
dur := time.Since(start).Microseconds()
dur := time.Since(start)
if err != nil {
if errcode, ok := err.(*errcodes.ErrorCode); ok {
lg.Error("request error",
zap.Int("status_code", errcode.StatusCode),
zap.String("message", errcode.Message),
zap.String("details", errcode.Details),
)
} else {
lg.Error("request internal error",
zap.Error(err),
)
}
lg.Error("request failed",
zap.Duration("duration", dur),
zap.Error(err))
} else {
lg.Info("request completed",
zap.Duration("duration", dur))
}
lg.Info("done request", zap.Float64("duration_ms", float64(dur)/float64(1000)))
return resp, err
}
}

View file

@ -1,24 +0,0 @@
package rest
import (
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
env "github.com/caarlos0/env/v11"
)
type Config struct {
Host string `env:"HOST" envDefault:"0.0.0.0"`
Port int `env:"PORT" envDefault:"8080"`
ReadTimeout string `env:"READ_TIMEOUT" envDefault:"30s"`
WriteTimeout string `env:"WRITE_TIMEOUT" envDefault:"30s"`
IdleTimeout string `env:"IDLE_TIMEOUT" envDefault:"60s"`
}
func NewConfig() (*Config, error) {
cfg := &Config{}
if err := env.ParseWithOptions(cfg, env.Options{
PrefixTagName: "GSN_PREDICTOR_REST_",
}); err != nil {
return nil, errcodes.Wrap(err, "failed to parse REST config")
}
return cfg, nil
}

View file

@ -1,14 +1,16 @@
package handler
import (
"context"
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/ds"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/elevation"
)
// Service defines the interface the handler needs from the service layer.
type Service interface {
UpdateWeatherData(ctx context.Context) error
ExtractWind(ctx context.Context, lat, lon, alt float64, ts time.Time) ([2]float64, error)
PerformPrediction(ctx context.Context, params ds.PredictionParameters) ([]ds.PredicitonResult, error)
Ready() bool
DatasetTime() (time.Time, bool)
Dataset() *dataset.File
Elevation() *elevation.Dataset
}

View file

@ -5,190 +5,212 @@ import (
"net/http"
"time"
"git.intra.yksa.space/gsn/predictor/internal/pkg/ds"
"git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes"
api "git.intra.yksa.space/gsn/predictor/pkg/rest"
"predictor-refactored/internal/prediction"
api "predictor-refactored/pkg/rest"
"go.uber.org/zap"
)
var (
_ api.Handler = (*Handler)(nil)
)
var _ api.Handler = (*Handler)(nil)
// Handler implements the ogen-generated api.Handler interface.
type Handler struct {
svc Service
log *zap.Logger
}
func New(svc Service) *Handler {
return &Handler{
svc: svc,
}
// New creates a new Handler.
func New(svc Service, log *zap.Logger) *Handler {
return &Handler{svc: svc, log: log}
}
func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredictionParams) (*api.PredictionResult, error) {
internalParams := ds.ConvertFlatPredictionParams(params)
if internalParams == nil {
return nil, errcodes.New(http.StatusBadRequest, "invalid or missing parameters")
}
results, err := h.svc.PerformPrediction(ctx, *internalParams)
if err != nil {
return nil, err
}
if len(results) == 0 {
return nil, errcodes.New(http.StatusInternalServerError, "no prediction results")
// PerformPrediction implements the prediction endpoint.
func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredictionParams) (*api.PredictionResponse, error) {
if !h.svc.Ready() {
return nil, newError(http.StatusServiceUnavailable, "no dataset loaded, service is starting up")
}
// Group results into stages (ascent and descent)
stages := h.groupResultsIntoStages(results)
ds := h.svc.Dataset()
if ds == nil {
return nil, newError(http.StatusServiceUnavailable, "dataset unavailable")
}
// Map to OpenAPI schema
var predictionItems []api.PredictionResultPredictionItem
dsEpoch := float64(ds.DSTime.Unix())
for _, stage := range stages {
var trajectory []api.PredictionResultPredictionItemTrajectoryItem
// Parse parameters with defaults
profile := "standard_profile"
if p, ok := params.Profile.Get(); ok {
profile = string(p)
}
for _, result := range stage.Results {
traj := api.PredictionResultPredictionItemTrajectoryItem{
Datetime: *result.Timestamp,
Latitude: *result.Latitude,
Longitude: *result.Longitude,
Altitude: *result.Altitude,
ascentRate := 5.0
if v, ok := params.AscentRate.Get(); ok {
ascentRate = v
}
burstAltitude := 28000.0
if v, ok := params.BurstAltitude.Get(); ok {
burstAltitude = v
}
descentRate := 5.0
if v, ok := params.DescentRate.Get(); ok {
descentRate = v
}
launchAlt := 0.0
if v, ok := params.LaunchAltitude.Get(); ok {
launchAlt = v
}
// Normalize longitude to [0, 360)
lng := params.LaunchLongitude
if lng < 0 {
lng += 360.0
}
launchTime := float64(params.LaunchDatetime.Unix())
warnings := &prediction.Warnings{}
// Build profile chain
elev := h.svc.Elevation()
var stages []prediction.Stage
switch profile {
case "standard_profile":
stages = prediction.StandardProfile(
ascentRate, burstAltitude, descentRate,
ds, dsEpoch, warnings, elev)
case "float_profile":
floatAlt := 25000.0
if v, ok := params.FloatAltitude.Get(); ok {
floatAlt = v
}
stopTime := params.LaunchDatetime.Add(24 * time.Hour)
if v, ok := params.StopDatetime.Get(); ok {
stopTime = v
}
stages = prediction.FloatProfile(
ascentRate, floatAlt, stopTime,
ds, dsEpoch, warnings)
default:
return nil, newError(http.StatusBadRequest, "unknown profile: "+profile)
}
// Run prediction
startTime := time.Now().UTC()
results := prediction.RunPrediction(launchTime, params.LaunchLatitude, lng, launchAlt, stages)
completeTime := time.Now().UTC()
// Build response
stageNames := []string{"ascent", "descent"}
if profile == "float_profile" {
stageNames = []string{"ascent", "float"}
}
var predItems []api.PredictionResponsePredictionItem
for i, sr := range results {
stageName := "ascent"
if i < len(stageNames) {
stageName = stageNames[i]
}
var stageEnum api.PredictionResponsePredictionItemStage
switch stageName {
case "ascent":
stageEnum = api.PredictionResponsePredictionItemStageAscent
case "descent":
stageEnum = api.PredictionResponsePredictionItemStageDescent
case "float":
stageEnum = api.PredictionResponsePredictionItemStageFloat
}
var traj []api.PredictionResponsePredictionItemTrajectoryItem
for _, pt := range sr.Points {
ptLng := pt.Lng
if ptLng > 180 {
ptLng -= 360
}
trajectory = append(trajectory, traj)
}
item := api.PredictionResultPredictionItem{
Stage: stage.Stage,
Trajectory: trajectory,
}
predictionItems = append(predictionItems, item)
}
metadata := api.PredictionResultMetadata{
StartDatetime: *results[0].Timestamp,
CompleteDatetime: *results[len(results)-1].Timestamp,
}
resp := &api.PredictionResult{
Metadata: metadata,
Prediction: predictionItems,
}
return resp, nil
}
// StageResult represents a stage with its results
type StageResult struct {
Stage api.PredictionResultPredictionItemStage
Results []ds.PredicitonResult
}
// groupResultsIntoStages groups the prediction results into ascent and descent stages
func (h *Handler) groupResultsIntoStages(results []ds.PredicitonResult) []StageResult {
if len(results) == 0 {
return nil
}
var stages []StageResult
var currentStage []ds.PredicitonResult
var currentStageType api.PredictionResultPredictionItemStage
// Determine if we're in ascent or descent based on altitude changes
prevAlt := *results[0].Altitude
currentStage = append(currentStage, results[0])
currentStageType = api.PredictionResultPredictionItemStageAscent
for i := 1; i < len(results); i++ {
result := results[i]
currentAlt := *result.Altitude
// Determine if we're still in the same stage
var stageType api.PredictionResultPredictionItemStage
if currentAlt > prevAlt {
stageType = api.PredictionResultPredictionItemStageAscent
} else if currentAlt < prevAlt {
stageType = api.PredictionResultPredictionItemStageDescent
} else {
// Same altitude - continue with current stage
stageType = currentStageType
}
// If stage type changed, finalize current stage and start new one
if stageType != currentStageType && len(currentStage) > 0 {
stages = append(stages, StageResult{
Stage: currentStageType,
Results: currentStage,
traj = append(traj, api.PredictionResponsePredictionItemTrajectoryItem{
Datetime: time.Unix(int64(pt.T), 0).UTC(),
Latitude: pt.Lat,
Longitude: ptLng,
Altitude: pt.Alt,
})
currentStage = nil
currentStageType = stageType
}
currentStage = append(currentStage, result)
prevAlt = currentAlt
}
// Add the final stage
if len(currentStage) > 0 {
stages = append(stages, StageResult{
Stage: currentStageType,
Results: currentStage,
predItems = append(predItems, api.PredictionResponsePredictionItem{
Stage: stageEnum,
Trajectory: traj,
})
}
return stages
}
func (h *Handler) NewError(ctx context.Context, err error) *api.ErrorStatusCode {
if errcode, ok := err.(*errcodes.ErrorCode); ok {
resp := api.Error{
Message: errcode.Message,
}
if errcode.Details != "" {
resp.Details = api.NewOptString(errcode.Details)
}
return &api.ErrorStatusCode{
StatusCode: errcode.StatusCode,
Response: resp,
}
resp := &api.PredictionResponse{
Prediction: predItems,
Metadata: api.PredictionResponseMetadata{
StartDatetime: startTime,
CompleteDatetime: completeTime,
},
}
// Echo request
resp.Request = api.NewOptPredictionResponseRequest(api.PredictionResponseRequest{
Dataset: api.NewOptString(ds.DSTime.Format("2006-01-02T15:04:05Z")),
LaunchLatitude: api.NewOptFloat64(params.LaunchLatitude),
LaunchLongitude: api.NewOptFloat64(params.LaunchLongitude),
LaunchDatetime: api.NewOptString(params.LaunchDatetime.Format(time.RFC3339)),
LaunchAltitude: params.LaunchAltitude,
})
// Warnings
warnMap := warnings.ToMap()
if len(warnMap) > 0 {
resp.Warnings = api.NewOptPredictionResponseWarnings(api.PredictionResponseWarnings{})
}
h.log.Info("prediction complete",
zap.String("profile", profile),
zap.Int("stages", len(results)),
zap.Duration("elapsed", completeTime.Sub(startTime)))
return resp, nil
}
// ReadinessCheck implements the health check endpoint.
func (h *Handler) ReadinessCheck(ctx context.Context) (*api.ReadinessResponse, error) {
resp := &api.ReadinessResponse{}
if h.svc.Ready() {
resp.Status = api.ReadinessResponseStatusOk
if dsTime, ok := h.svc.DatasetTime(); ok {
resp.DatasetTime = api.NewOptDateTime(dsTime)
}
} else {
resp.Status = api.ReadinessResponseStatusNotReady
resp.ErrorMessage = api.NewOptString("no dataset loaded")
}
return resp, nil
}
// NewError creates an ErrorStatusCode from an error returned by a handler.
func (h *Handler) NewError(ctx context.Context, err error) *api.ErrorStatusCode {
if statusErr, ok := err.(*api.ErrorStatusCode); ok {
return statusErr
}
h.log.Error("unhandled error", zap.Error(err))
return newError(http.StatusInternalServerError, err.Error())
}
func newError(status int, description string) *api.ErrorStatusCode {
return &api.ErrorStatusCode{
StatusCode: http.StatusInternalServerError,
StatusCode: status,
Response: api.Error{
Message: "undefined internal error",
Details: api.NewOptString(err.Error()),
Error: api.ErrorError{
Type: http.StatusText(status),
Description: description,
},
},
}
}
func (h *Handler) ReadinessCheck(ctx context.Context) (*api.ReadinessResponse, error) {
status := api.ReadinessResponseStatusNotReady
var lastUpdate time.Time
var isFresh bool
var errMsg string
if s, ok := h.svc.(interface {
GetGribStatus(ctx context.Context) (ready bool, lastUpdate time.Time, isFresh bool, errMsg string)
}); ok {
ready, lu, fresh, em := s.GetGribStatus(ctx)
lastUpdate = lu
isFresh = fresh
errMsg = em
if ready {
status = api.ReadinessResponseStatusOk
} else if em != "" {
status = api.ReadinessResponseStatusError
}
} else {
errMsg = "service does not implement GetGribStatus"
status = api.ReadinessResponseStatusError
}
resp := &api.ReadinessResponse{
Status: status,
IsFresh: api.NewOptBool(isFresh),
LastUpdate: api.NewOptDateTime(lastUpdate),
ErrorMessage: api.NewOptString(errMsg),
}
return resp, nil
}

View file

@ -5,43 +5,71 @@ import (
"fmt"
"net/http"
"git.intra.yksa.space/gsn/predictor/internal/pkg/log"
"git.intra.yksa.space/gsn/predictor/internal/transport/middleware"
handler "git.intra.yksa.space/gsn/predictor/internal/transport/rest/handler"
api "git.intra.yksa.space/gsn/predictor/pkg/rest"
"github.com/rs/cors"
"predictor-refactored/internal/transport/middleware"
"predictor-refactored/internal/transport/rest/handler"
api "predictor-refactored/pkg/rest"
"go.uber.org/zap"
)
// Transport wraps the ogen HTTP server.
type Transport struct {
cfg *Config
srv *api.Server
handler *handler.Handler
port int
log *zap.Logger
}
func New(handler *handler.Handler, cfg *Config) (*Transport, error) {
// New creates a new REST transport.
func New(h *handler.Handler, port int, log *zap.Logger) (*Transport, error) {
srv, err := api.NewServer(
handler,
api.WithMiddleware(middleware.Logging()),
h,
api.WithMiddleware(middleware.Logging(log)),
)
if err != nil {
return nil, err
return nil, fmt.Errorf("create ogen server: %w", err)
}
return &Transport{
srv: srv,
cfg: cfg,
handler: handler,
handler: h,
port: port,
log: log,
}, nil
}
func (t *Transport) Run() {
log.Ctx(context.Background()).Info("started")
// Run starts the HTTP server. Blocks until the server stops.
func (t *Transport) Run() error {
mux := http.NewServeMux()
mux.Handle("/", t.srv)
cors.AllowAll().Handler(mux)
if err := http.ListenAndServe(fmt.Sprintf(":%d", t.cfg.Port), t.srv); err != nil {
panic(err)
httpSrv := &http.Server{
Addr: fmt.Sprintf(":%d", t.port),
Handler: corsMiddleware(mux),
}
t.log.Info("starting HTTP server", zap.Int("port", t.port))
return httpSrv.ListenAndServe()
}
// Shutdown gracefully stops the HTTP server.
func (t *Transport) Shutdown(ctx context.Context) error {
// The ogen server doesn't have a shutdown method;
// shutdown is handled by the http.Server in main.go
return nil
}
func corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}