forked from gsn/predictor
feat: refactor
This commit is contained in:
parent
82ef1cb3b8
commit
51bbf3c579
44 changed files with 8589 additions and 0 deletions
158
internal/dataset/dataset.go
Normal file
158
internal/dataset/dataset.go
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
package dataset
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Dataset shape constants.
|
||||
// Shape: (65, 47, 3, 361, 720) = (hour, pressure_level, variable, latitude, longitude)
|
||||
// This matches the reference predictor exactly.
|
||||
const (
|
||||
NumHours = 65 // 0, 3, 6, ..., 192
|
||||
NumLevels = 47 // pressure levels
|
||||
NumVariables = 3 // height, wind_u, wind_v
|
||||
NumLatitudes = 361 // -90.0 to +90.0 in 0.5 degree steps
|
||||
NumLongitudes = 720 // 0.0 to 359.5 in 0.5 degree steps
|
||||
|
||||
HourStep = 3 // hours between forecast time steps
|
||||
MaxHour = 192 // maximum forecast hour
|
||||
Resolution = 0.5 // grid resolution in degrees
|
||||
LatStart = -90.0 // first latitude in the dataset
|
||||
LonStart = 0.0 // first longitude in the dataset
|
||||
|
||||
// Variable indices within the dataset.
|
||||
VarHeight = 0
|
||||
VarWindU = 1
|
||||
VarWindV = 2
|
||||
|
||||
ElementSize = 4 // float32 = 4 bytes
|
||||
)
|
||||
|
||||
// DatasetSize is the total size of the dataset file in bytes.
|
||||
// 65 * 47 * 3 * 361 * 720 * 4 = 9,528,667,200
|
||||
const DatasetSize int64 = int64(NumHours) * int64(NumLevels) * int64(NumVariables) *
|
||||
int64(NumLatitudes) * int64(NumLongitudes) * int64(ElementSize)
|
||||
|
||||
// LevelSet identifies which GRIB file set a pressure level belongs to.
|
||||
type LevelSet int
|
||||
|
||||
const (
|
||||
LevelSetA LevelSet = iota // pgrb2 (primary)
|
||||
LevelSetB // pgrb2b (secondary)
|
||||
)
|
||||
|
||||
// Pressures contains the 47 pressure levels in hPa, sorted descending.
|
||||
// Index 0 = 1000 hPa (near surface), Index 46 = 1 hPa (high atmosphere).
|
||||
var Pressures = [NumLevels]int{
|
||||
1000, 975, 950, 925, 900, 875, 850, 825, 800, 775,
|
||||
750, 725, 700, 675, 650, 625, 600, 575, 550, 525,
|
||||
500, 475, 450, 425, 400, 375, 350, 325, 300, 275,
|
||||
250, 225, 200, 175, 150, 125, 100, 70, 50, 30,
|
||||
20, 10, 7, 5, 3, 2, 1,
|
||||
}
|
||||
|
||||
// pressureIndex maps pressure in hPa to its index in the Pressures array.
|
||||
var pressureIndex map[int]int
|
||||
|
||||
// pressureLevelSet maps pressure in hPa to its GRIB file set.
|
||||
var pressureLevelSet map[int]LevelSet
|
||||
|
||||
func init() {
|
||||
pressureIndex = make(map[int]int, NumLevels)
|
||||
for i, p := range Pressures {
|
||||
pressureIndex[p] = i
|
||||
}
|
||||
|
||||
pressureLevelSet = make(map[int]LevelSet, NumLevels)
|
||||
for _, p := range PressuresPgrb2 {
|
||||
pressureLevelSet[p] = LevelSetA
|
||||
}
|
||||
for _, p := range PressuresPgrb2b {
|
||||
pressureLevelSet[p] = LevelSetB
|
||||
}
|
||||
}
|
||||
|
||||
// PressuresPgrb2 contains levels found in the primary pgrb2 file (26 levels).
|
||||
var PressuresPgrb2 = []int{
|
||||
10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 350, 400,
|
||||
450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 925,
|
||||
950, 975, 1000,
|
||||
}
|
||||
|
||||
// PressuresPgrb2b contains levels found in the secondary pgrb2b file (21 levels).
|
||||
var PressuresPgrb2b = []int{
|
||||
1, 2, 3, 5, 7, 125, 175, 225, 275, 325, 375, 425,
|
||||
475, 525, 575, 625, 675, 725, 775, 825, 875,
|
||||
}
|
||||
|
||||
// PressureIndex returns the dataset index for a given pressure level in hPa.
|
||||
// Returns -1 if the level is not found.
|
||||
func PressureIndex(hPa int) int {
|
||||
idx, ok := pressureIndex[hPa]
|
||||
if !ok {
|
||||
return -1
|
||||
}
|
||||
return idx
|
||||
}
|
||||
|
||||
// PressureLevelSet returns which GRIB file set a pressure level belongs to.
|
||||
func PressureLevelSet(hPa int) (LevelSet, bool) {
|
||||
ls, ok := pressureLevelSet[hPa]
|
||||
return ls, ok
|
||||
}
|
||||
|
||||
// HourIndex returns the dataset time index for a forecast hour.
|
||||
// Returns -1 if the hour is invalid (not a multiple of HourStep or out of range).
|
||||
func HourIndex(hour int) int {
|
||||
if hour < 0 || hour > MaxHour || hour%HourStep != 0 {
|
||||
return -1
|
||||
}
|
||||
return hour / HourStep
|
||||
}
|
||||
|
||||
// Hours returns all forecast hours as a slice: [0, 3, 6, ..., 192].
|
||||
func Hours() []int {
|
||||
out := make([]int, 0, NumHours)
|
||||
for h := 0; h <= MaxHour; h += HourStep {
|
||||
out = append(out, h)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// S3 URL configuration for NOAA GFS data.
|
||||
const S3BaseURL = "https://noaa-gfs-bdp-pds.s3.amazonaws.com"
|
||||
|
||||
// GribURL returns the S3 URL for a primary (pgrb2) GRIB file.
|
||||
func GribURL(date string, runHour, forecastStep int) string {
|
||||
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.0p50.f%03d",
|
||||
S3BaseURL, date, runHour, runHour, forecastStep)
|
||||
}
|
||||
|
||||
// GribURLB returns the S3 URL for a secondary (pgrb2b) GRIB file.
|
||||
func GribURLB(date string, runHour, forecastStep int) string {
|
||||
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2b.0p50.f%03d",
|
||||
S3BaseURL, date, runHour, runHour, forecastStep)
|
||||
}
|
||||
|
||||
// GribFileName returns the local filename for a primary GRIB file.
|
||||
func GribFileName(runHour, forecastStep int) string {
|
||||
return fmt.Sprintf("gfs.t%02dz.pgrb2.0p50.f%03d", runHour, forecastStep)
|
||||
}
|
||||
|
||||
// GribFileNameB returns the local filename for a secondary GRIB file.
|
||||
func GribFileNameB(runHour, forecastStep int) string {
|
||||
return fmt.Sprintf("gfs.t%02dz.pgrb2b.0p50.f%03d", runHour, forecastStep)
|
||||
}
|
||||
|
||||
// VariableIndex returns the dataset variable index for a GRIB parameter.
|
||||
// Returns -1 if the parameter is not recognized.
|
||||
func VariableIndex(parameterCategory, parameterNumber int) int {
|
||||
switch {
|
||||
case parameterCategory == 3 && parameterNumber == 5:
|
||||
return VarHeight // Geopotential Height
|
||||
case parameterCategory == 2 && parameterNumber == 2:
|
||||
return VarWindU // U-component of wind
|
||||
case parameterCategory == 2 && parameterNumber == 3:
|
||||
return VarWindV // V-component of wind
|
||||
default:
|
||||
return -1
|
||||
}
|
||||
}
|
||||
152
internal/dataset/dataset_test.go
Normal file
152
internal/dataset/dataset_test.go
Normal file
|
|
@ -0,0 +1,152 @@
|
|||
package dataset
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDatasetShape(t *testing.T) {
|
||||
if NumHours != 65 {
|
||||
t.Errorf("NumHours = %d, want 65", NumHours)
|
||||
}
|
||||
if NumLevels != 47 {
|
||||
t.Errorf("NumLevels = %d, want 47", NumLevels)
|
||||
}
|
||||
if NumVariables != 3 {
|
||||
t.Errorf("NumVariables = %d, want 3", NumVariables)
|
||||
}
|
||||
if NumLatitudes != 361 {
|
||||
t.Errorf("NumLatitudes = %d, want 361", NumLatitudes)
|
||||
}
|
||||
if NumLongitudes != 720 {
|
||||
t.Errorf("NumLongitudes = %d, want 720", NumLongitudes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatasetSize(t *testing.T) {
|
||||
// 65 * 47 * 3 * 361 * 720 * 4 = 9,528,667,200
|
||||
want := int64(9_528_667_200)
|
||||
if DatasetSize != want {
|
||||
t.Errorf("DatasetSize = %d, want %d", DatasetSize, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPressureLevels(t *testing.T) {
|
||||
if len(Pressures) != 47 {
|
||||
t.Fatalf("len(Pressures) = %d, want 47", len(Pressures))
|
||||
}
|
||||
|
||||
// First should be 1000 (highest pressure, near surface)
|
||||
if Pressures[0] != 1000 {
|
||||
t.Errorf("Pressures[0] = %d, want 1000", Pressures[0])
|
||||
}
|
||||
// Last should be 1 (lowest pressure, high atmosphere)
|
||||
if Pressures[46] != 1 {
|
||||
t.Errorf("Pressures[46] = %d, want 1", Pressures[46])
|
||||
}
|
||||
|
||||
// Should be sorted descending
|
||||
for i := 1; i < len(Pressures); i++ {
|
||||
if Pressures[i] >= Pressures[i-1] {
|
||||
t.Errorf("Pressures not descending at [%d]: %d >= %d", i, Pressures[i], Pressures[i-1])
|
||||
}
|
||||
}
|
||||
|
||||
// Total levels: 26 from pgrb2 + 21 from pgrb2b = 47
|
||||
if len(PressuresPgrb2) != 26 {
|
||||
t.Errorf("len(PressuresPgrb2) = %d, want 26", len(PressuresPgrb2))
|
||||
}
|
||||
if len(PressuresPgrb2b) != 21 {
|
||||
t.Errorf("len(PressuresPgrb2b) = %d, want 21", len(PressuresPgrb2b))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPressureIndex(t *testing.T) {
|
||||
if PressureIndex(1000) != 0 {
|
||||
t.Errorf("PressureIndex(1000) = %d, want 0", PressureIndex(1000))
|
||||
}
|
||||
if PressureIndex(1) != 46 {
|
||||
t.Errorf("PressureIndex(1) = %d, want 46", PressureIndex(1))
|
||||
}
|
||||
if PressureIndex(500) != 20 {
|
||||
t.Errorf("PressureIndex(500) = %d, want 20", PressureIndex(500))
|
||||
}
|
||||
if PressureIndex(9999) != -1 {
|
||||
t.Errorf("PressureIndex(9999) = %d, want -1", PressureIndex(9999))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPressureLevelSet(t *testing.T) {
|
||||
// 1000 mb should be in pgrb2 (A)
|
||||
ls, ok := PressureLevelSet(1000)
|
||||
if !ok || ls != LevelSetA {
|
||||
t.Errorf("PressureLevelSet(1000) = %v, %v; want A, true", ls, ok)
|
||||
}
|
||||
|
||||
// 125 mb should be in pgrb2b (B)
|
||||
ls, ok = PressureLevelSet(125)
|
||||
if !ok || ls != LevelSetB {
|
||||
t.Errorf("PressureLevelSet(125) = %v, %v; want B, true", ls, ok)
|
||||
}
|
||||
|
||||
// 1, 2, 3, 5, 7 should be in pgrb2b (B)
|
||||
for _, p := range []int{1, 2, 3, 5, 7} {
|
||||
ls, ok := PressureLevelSet(p)
|
||||
if !ok || ls != LevelSetB {
|
||||
t.Errorf("PressureLevelSet(%d) = %v, %v; want B, true", p, ls, ok)
|
||||
}
|
||||
}
|
||||
|
||||
// Every pressure level should have a level set assignment
|
||||
for _, p := range Pressures {
|
||||
_, ok := PressureLevelSet(p)
|
||||
if !ok {
|
||||
t.Errorf("PressureLevelSet(%d) not found", p)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestHourIndex(t *testing.T) {
|
||||
if HourIndex(0) != 0 {
|
||||
t.Errorf("HourIndex(0) = %d, want 0", HourIndex(0))
|
||||
}
|
||||
if HourIndex(3) != 1 {
|
||||
t.Errorf("HourIndex(3) = %d, want 1", HourIndex(3))
|
||||
}
|
||||
if HourIndex(192) != 64 {
|
||||
t.Errorf("HourIndex(192) = %d, want 64", HourIndex(192))
|
||||
}
|
||||
if HourIndex(1) != -1 {
|
||||
t.Errorf("HourIndex(1) = %d, want -1 (not multiple of 3)", HourIndex(1))
|
||||
}
|
||||
if HourIndex(195) != -1 {
|
||||
t.Errorf("HourIndex(195) = %d, want -1 (out of range)", HourIndex(195))
|
||||
}
|
||||
}
|
||||
|
||||
func TestHours(t *testing.T) {
|
||||
hours := Hours()
|
||||
if len(hours) != NumHours {
|
||||
t.Fatalf("len(Hours()) = %d, want %d", len(hours), NumHours)
|
||||
}
|
||||
if hours[0] != 0 {
|
||||
t.Errorf("Hours()[0] = %d, want 0", hours[0])
|
||||
}
|
||||
if hours[len(hours)-1] != MaxHour {
|
||||
t.Errorf("Hours()[last] = %d, want %d", hours[len(hours)-1], MaxHour)
|
||||
}
|
||||
}
|
||||
|
||||
func TestVariableIndex(t *testing.T) {
|
||||
if VariableIndex(3, 5) != VarHeight {
|
||||
t.Errorf("HGT: got %d, want %d", VariableIndex(3, 5), VarHeight)
|
||||
}
|
||||
if VariableIndex(2, 2) != VarWindU {
|
||||
t.Errorf("UGRD: got %d, want %d", VariableIndex(2, 2), VarWindU)
|
||||
}
|
||||
if VariableIndex(2, 3) != VarWindV {
|
||||
t.Errorf("VGRD: got %d, want %d", VariableIndex(2, 3), VarWindV)
|
||||
}
|
||||
if VariableIndex(0, 0) != -1 {
|
||||
t.Errorf("unknown: got %d, want -1", VariableIndex(0, 0))
|
||||
}
|
||||
}
|
||||
140
internal/dataset/file.go
Normal file
140
internal/dataset/file.go
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
package dataset
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
mmap "github.com/edsrzf/mmap-go"
|
||||
)
|
||||
|
||||
// File represents an mmap-backed wind dataset file.
|
||||
type File struct {
|
||||
mm mmap.MMap
|
||||
file *os.File
|
||||
writable bool
|
||||
DSTime time.Time // forecast run time (UTC)
|
||||
}
|
||||
|
||||
// Open opens an existing dataset file for reading.
|
||||
func Open(path string, dsTime time.Time) (*File, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open dataset: %w", err)
|
||||
}
|
||||
|
||||
info, err := f.Stat()
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return nil, fmt.Errorf("stat dataset: %w", err)
|
||||
}
|
||||
if info.Size() != DatasetSize {
|
||||
f.Close()
|
||||
return nil, fmt.Errorf("dataset should be %d bytes (was %d)", DatasetSize, info.Size())
|
||||
}
|
||||
|
||||
mm, err := mmap.Map(f, mmap.RDONLY, 0)
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return nil, fmt.Errorf("mmap dataset: %w", err)
|
||||
}
|
||||
|
||||
return &File{mm: mm, file: f, writable: false, DSTime: dsTime}, nil
|
||||
}
|
||||
|
||||
// Create creates a new dataset file for writing.
|
||||
// The file is truncated to the correct size and mmap'd read-write.
|
||||
func Create(path string) (*File, error) {
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create dataset: %w", err)
|
||||
}
|
||||
|
||||
if err := f.Truncate(DatasetSize); err != nil {
|
||||
f.Close()
|
||||
return nil, fmt.Errorf("truncate dataset: %w", err)
|
||||
}
|
||||
|
||||
mm, err := mmap.MapRegion(f, int(DatasetSize), mmap.RDWR, 0, 0)
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return nil, fmt.Errorf("mmap dataset: %w", err)
|
||||
}
|
||||
|
||||
return &File{mm: mm, file: f, writable: true}, nil
|
||||
}
|
||||
|
||||
// offset computes the byte offset for element [hour][level][variable][lat][lon].
|
||||
// Row-major C-order indexing matching the reference implementation:
|
||||
// shape = (65, 47, 3, 361, 720)
|
||||
func offset(hour, level, variable, lat, lon int) int64 {
|
||||
idx := int64(hour)
|
||||
idx = idx*int64(NumLevels) + int64(level)
|
||||
idx = idx*int64(NumVariables) + int64(variable)
|
||||
idx = idx*int64(NumLatitudes) + int64(lat)
|
||||
idx = idx*int64(NumLongitudes) + int64(lon)
|
||||
return idx * int64(ElementSize)
|
||||
}
|
||||
|
||||
// Val reads a float32 value from the dataset at [hour][level][variable][lat][lon].
|
||||
func (d *File) Val(hour, level, variable, lat, lon int) float32 {
|
||||
off := offset(hour, level, variable, lat, lon)
|
||||
bits := binary.LittleEndian.Uint32(d.mm[off : off+4])
|
||||
return math.Float32frombits(bits)
|
||||
}
|
||||
|
||||
// SetVal writes a float32 value to the dataset at [hour][level][variable][lat][lon].
|
||||
// Only valid on writable (created) datasets.
|
||||
func (d *File) SetVal(hour, level, variable, lat, lon int, val float32) {
|
||||
off := offset(hour, level, variable, lat, lon)
|
||||
binary.LittleEndian.PutUint32(d.mm[off:off+4], math.Float32bits(val))
|
||||
}
|
||||
|
||||
// BlitGribData copies decoded GRIB grid data into the dataset at the given position.
|
||||
// gribData is 361*720 float64 values in GRIB scan order (north-to-south, west-to-east).
|
||||
// This function flips the latitude so that dataset index 0 = -90 (south) and 360 = +90 (north).
|
||||
func (d *File) BlitGribData(hourIdx, levelIdx, varIdx int, gribData []float64) error {
|
||||
expected := NumLatitudes * NumLongitudes
|
||||
if len(gribData) != expected {
|
||||
return fmt.Errorf("grib data has %d values, expected %d", len(gribData), expected)
|
||||
}
|
||||
|
||||
for lat := 0; lat < NumLatitudes; lat++ {
|
||||
for lon := 0; lon < NumLongitudes; lon++ {
|
||||
// GRIB scans north-to-south: row 0 = 90N, row 360 = 90S
|
||||
// Dataset stores south-to-north: index 0 = -90 (90S), index 360 = +90 (90N)
|
||||
gribIdx := (360-lat)*NumLongitudes + lon
|
||||
val := float32(gribData[gribIdx])
|
||||
d.SetVal(hourIdx, levelIdx, varIdx, lat, lon, val)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flush flushes the mmap to disk.
|
||||
func (d *File) Flush() error {
|
||||
if d.mm != nil {
|
||||
return d.mm.Flush()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close unmaps and closes the dataset file.
|
||||
func (d *File) Close() error {
|
||||
if d.mm != nil {
|
||||
if err := d.mm.Unmap(); err != nil {
|
||||
d.file.Close()
|
||||
return fmt.Errorf("unmap: %w", err)
|
||||
}
|
||||
d.mm = nil
|
||||
}
|
||||
if d.file != nil {
|
||||
err := d.file.Close()
|
||||
d.file = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
58
internal/downloader/config.go
Normal file
58
internal/downloader/config.go
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
package downloader
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config holds downloader configuration, loaded from environment variables.
|
||||
type Config struct {
|
||||
// DataDir is the directory for storing dataset files and temporary GRIB data.
|
||||
DataDir string
|
||||
|
||||
// Parallel is the maximum number of concurrent GRIB downloads.
|
||||
Parallel int
|
||||
|
||||
// UpdateInterval is how often the scheduler checks for new forecast data.
|
||||
UpdateInterval time.Duration
|
||||
|
||||
// DatasetTTL is how long a dataset is considered fresh before a new one is needed.
|
||||
DatasetTTL time.Duration
|
||||
}
|
||||
|
||||
// DefaultConfig returns the default configuration.
|
||||
func DefaultConfig() *Config {
|
||||
return &Config{
|
||||
DataDir: "/tmp/predictor-data",
|
||||
Parallel: 8,
|
||||
UpdateInterval: 6 * time.Hour,
|
||||
DatasetTTL: 48 * time.Hour,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadConfig loads configuration from environment variables, falling back to defaults.
|
||||
func LoadConfig() *Config {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
if v := os.Getenv("PREDICTOR_DATA_DIR"); v != "" {
|
||||
cfg.DataDir = v
|
||||
}
|
||||
if v := os.Getenv("PREDICTOR_DOWNLOAD_PARALLEL"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||
cfg.Parallel = n
|
||||
}
|
||||
}
|
||||
if v := os.Getenv("PREDICTOR_UPDATE_INTERVAL"); v != "" {
|
||||
if d, err := time.ParseDuration(v); err == nil {
|
||||
cfg.UpdateInterval = d
|
||||
}
|
||||
}
|
||||
if v := os.Getenv("PREDICTOR_DATASET_TTL"); v != "" {
|
||||
if d, err := time.ParseDuration(v); err == nil {
|
||||
cfg.DatasetTTL = d
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
380
internal/downloader/downloader.go
Normal file
380
internal/downloader/downloader.go
Normal file
|
|
@ -0,0 +1,380 @@
|
|||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"predictor-refactored/internal/dataset"
|
||||
|
||||
"github.com/nilsmagnus/grib/griblib"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
)
|
||||
|
||||
// Downloader handles fetching GFS forecast data from S3 and assembling dataset files.
|
||||
type Downloader struct {
|
||||
cfg *Config
|
||||
client *http.Client
|
||||
log *zap.Logger
|
||||
}
|
||||
|
||||
// NewDownloader creates a new Downloader.
|
||||
func NewDownloader(cfg *Config, log *zap.Logger) *Downloader {
|
||||
return &Downloader{
|
||||
cfg: cfg,
|
||||
client: &http.Client{
|
||||
Timeout: 2 * time.Minute,
|
||||
},
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
// neededVariables is the set of GRIB variable names we need.
|
||||
var neededVariables = map[string]bool{
|
||||
"HGT": true,
|
||||
"UGRD": true,
|
||||
"VGRD": true,
|
||||
}
|
||||
|
||||
// FindLatestRun finds the most recent available GFS model run on S3.
|
||||
// It checks the last forecast step of each run to confirm availability.
|
||||
func (d *Downloader) FindLatestRun(ctx context.Context) (time.Time, error) {
|
||||
now := time.Now().UTC()
|
||||
hour := now.Hour() - (now.Hour() % 6)
|
||||
current := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC)
|
||||
|
||||
for i := 0; i < 8; i++ {
|
||||
date := current.Format("20060102")
|
||||
url := dataset.GribURL(date, current.Hour(), dataset.MaxHour) + ".idx"
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
|
||||
if err != nil {
|
||||
current = current.Add(-6 * time.Hour)
|
||||
continue
|
||||
}
|
||||
|
||||
resp, err := d.client.Do(req)
|
||||
if err == nil {
|
||||
resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
d.log.Info("found latest model run",
|
||||
zap.Time("run", current),
|
||||
zap.String("verified_url", url))
|
||||
return current, nil
|
||||
}
|
||||
}
|
||||
|
||||
current = current.Add(-6 * time.Hour)
|
||||
}
|
||||
|
||||
return time.Time{}, fmt.Errorf("no recent GFS forecast found (checked 8 runs)")
|
||||
}
|
||||
|
||||
// Download downloads a complete forecast and assembles a dataset file.
|
||||
// Returns the path to the completed dataset file.
|
||||
func (d *Downloader) Download(ctx context.Context, run time.Time) (string, error) {
|
||||
date := run.Format("20060102")
|
||||
runHour := run.Hour()
|
||||
|
||||
finalPath := filepath.Join(d.cfg.DataDir, run.Format("2006010215"))
|
||||
tempPath := finalPath + ".downloading"
|
||||
|
||||
// Check if final dataset already exists
|
||||
if info, err := os.Stat(finalPath); err == nil && info.Size() == dataset.DatasetSize {
|
||||
d.log.Info("dataset already exists", zap.String("path", finalPath))
|
||||
return finalPath, nil
|
||||
}
|
||||
|
||||
d.log.Info("starting dataset download",
|
||||
zap.Time("run", run),
|
||||
zap.String("temp_path", tempPath))
|
||||
|
||||
// Create the dataset file
|
||||
ds, err := dataset.Create(tempPath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create dataset: %w", err)
|
||||
}
|
||||
defer ds.Close()
|
||||
|
||||
steps := dataset.Hours()
|
||||
totalSteps := len(steps) * 2 // pgrb2 + pgrb2b per step
|
||||
completed := 0
|
||||
|
||||
// Process each forecast step with bounded concurrency
|
||||
g, ctx := errgroup.WithContext(ctx)
|
||||
sem := make(chan struct{}, d.cfg.Parallel)
|
||||
|
||||
for _, step := range steps {
|
||||
step := step
|
||||
hourIdx := dataset.HourIndex(step)
|
||||
if hourIdx < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Download pgrb2 (level set A)
|
||||
sem <- struct{}{}
|
||||
g.Go(func() error {
|
||||
defer func() { <-sem }()
|
||||
url := dataset.GribURL(date, runHour, step)
|
||||
err := d.DownloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetA)
|
||||
if err != nil {
|
||||
return fmt.Errorf("step %d pgrb2: %w", step, err)
|
||||
}
|
||||
completed++
|
||||
d.log.Debug("step complete",
|
||||
zap.Int("step", step),
|
||||
zap.String("set", "pgrb2"),
|
||||
zap.Int("progress", completed),
|
||||
zap.Int("total", totalSteps))
|
||||
return nil
|
||||
})
|
||||
|
||||
// Download pgrb2b (level set B)
|
||||
sem <- struct{}{}
|
||||
g.Go(func() error {
|
||||
defer func() { <-sem }()
|
||||
url := dataset.GribURLB(date, runHour, step)
|
||||
err := d.DownloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetB)
|
||||
if err != nil {
|
||||
return fmt.Errorf("step %d pgrb2b: %w", step, err)
|
||||
}
|
||||
completed++
|
||||
d.log.Debug("step complete",
|
||||
zap.Int("step", step),
|
||||
zap.String("set", "pgrb2b"),
|
||||
zap.Int("progress", completed),
|
||||
zap.Int("total", totalSteps))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
os.Remove(tempPath)
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Flush to disk
|
||||
if err := ds.Flush(); err != nil {
|
||||
os.Remove(tempPath)
|
||||
return "", fmt.Errorf("flush dataset: %w", err)
|
||||
}
|
||||
|
||||
// Close before rename
|
||||
ds.Close()
|
||||
|
||||
// Atomic rename
|
||||
if err := os.Rename(tempPath, finalPath); err != nil {
|
||||
os.Remove(tempPath)
|
||||
return "", fmt.Errorf("rename dataset: %w", err)
|
||||
}
|
||||
|
||||
d.log.Info("dataset download complete", zap.String("path", finalPath))
|
||||
return finalPath, nil
|
||||
}
|
||||
|
||||
// DownloadAndBlit downloads needed GRIB fields from a URL and writes them into the dataset.
|
||||
func (d *Downloader) DownloadAndBlit(ctx context.Context, ds *dataset.File, baseURL string, hourIdx int, levelSet dataset.LevelSet) error {
|
||||
// 1. Download .idx
|
||||
idxURL := baseURL + ".idx"
|
||||
idxBody, err := d.httpGet(ctx, idxURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("download idx: %w", err)
|
||||
}
|
||||
|
||||
// 2. Parse and filter
|
||||
entries := ParseIdx(idxBody)
|
||||
filtered := FilterIdx(entries, neededVariables)
|
||||
|
||||
// Further filter to only levels in this level set
|
||||
var relevant []IdxEntry
|
||||
for _, e := range filtered {
|
||||
ls, ok := dataset.PressureLevelSet(e.LevelMB)
|
||||
if ok && ls == levelSet {
|
||||
relevant = append(relevant, e)
|
||||
}
|
||||
}
|
||||
|
||||
if len(relevant) == 0 {
|
||||
d.log.Warn("no relevant entries found in idx",
|
||||
zap.String("url", idxURL),
|
||||
zap.Int("total_entries", len(entries)),
|
||||
zap.Int("filtered", len(filtered)))
|
||||
return nil
|
||||
}
|
||||
|
||||
// 3. Download byte ranges and write to temp file
|
||||
ranges := EntriesToRanges(relevant)
|
||||
tmpFile, err := d.downloadRangesToTempFile(ctx, baseURL, ranges)
|
||||
if err != nil {
|
||||
return fmt.Errorf("download ranges: %w", err)
|
||||
}
|
||||
defer os.Remove(tmpFile)
|
||||
|
||||
// 4. Read GRIB messages from temp file
|
||||
f, err := os.Open(tmpFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open temp grib: %w", err)
|
||||
}
|
||||
|
||||
messages, err := griblib.ReadMessages(f)
|
||||
f.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("read grib messages: %w", err)
|
||||
}
|
||||
|
||||
// 5. Decode and blit each message into the dataset
|
||||
for _, msg := range messages {
|
||||
if msg.Section4.ProductDefinitionTemplateNumber != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
product := msg.Section4.ProductDefinitionTemplate
|
||||
|
||||
varIdx := dataset.VariableIndex(int(product.ParameterCategory), int(product.ParameterNumber))
|
||||
if varIdx < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if product.FirstSurface.Type != 100 { // isobaric surface
|
||||
continue
|
||||
}
|
||||
|
||||
pressurePa := float64(product.FirstSurface.Value)
|
||||
pressureMB := int(math.Round(pressurePa / 100.0))
|
||||
levelIdx := dataset.PressureIndex(pressureMB)
|
||||
if levelIdx < 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
data := msg.Data()
|
||||
if err := ds.BlitGribData(hourIdx, levelIdx, varIdx, data); err != nil {
|
||||
d.log.Warn("blit failed",
|
||||
zap.Int("var", varIdx),
|
||||
zap.Int("level_mb", pressureMB),
|
||||
zap.Error(err))
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// downloadRangesToTempFile downloads multiple byte ranges from a URL,
|
||||
// concatenating them into a single temp file (valid concatenated GRIB messages).
|
||||
func (d *Downloader) downloadRangesToTempFile(ctx context.Context, baseURL string, ranges []ByteRange) (string, error) {
|
||||
tmpFile, err := os.CreateTemp(d.cfg.DataDir, "grib-*.tmp")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
|
||||
for _, r := range ranges {
|
||||
data, err := d.httpGetRange(ctx, baseURL, r.Start, r.End)
|
||||
if err != nil {
|
||||
tmpFile.Close()
|
||||
os.Remove(tmpPath)
|
||||
return "", fmt.Errorf("download range %d-%d: %w", r.Start, r.End, err)
|
||||
}
|
||||
if _, err := tmpFile.Write(data); err != nil {
|
||||
tmpFile.Close()
|
||||
os.Remove(tmpPath)
|
||||
return "", fmt.Errorf("write temp: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
os.Remove(tmpPath)
|
||||
return "", err
|
||||
}
|
||||
|
||||
return tmpPath, nil
|
||||
}
|
||||
|
||||
// httpGet downloads a URL and returns the body bytes.
|
||||
func (d *Downloader) httpGet(ctx context.Context, url string) ([]byte, error) {
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < 3; attempt++ {
|
||||
if attempt > 0 {
|
||||
select {
|
||||
case <-time.After(time.Duration(attempt*2) * time.Second):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := d.client.Do(req)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url)
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
|
||||
}
|
||||
|
||||
// httpGetRange downloads a byte range from a URL.
|
||||
func (d *Downloader) httpGetRange(ctx context.Context, url string, start, end int64) ([]byte, error) {
|
||||
var lastErr error
|
||||
for attempt := 0; attempt < 3; attempt++ {
|
||||
if attempt > 0 {
|
||||
select {
|
||||
case <-time.After(time.Duration(attempt*2) * time.Second):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
|
||||
|
||||
resp, err := d.client.Do(req)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK {
|
||||
lastErr = fmt.Errorf("HTTP %d for range %d-%d of %s", resp.StatusCode, start, end, url)
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
|
||||
}
|
||||
157
internal/downloader/idx.go
Normal file
157
internal/downloader/idx.go
Normal file
|
|
@ -0,0 +1,157 @@
|
|||
package downloader
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// IdxEntry represents a single parsed line from a GRIB .idx file.
|
||||
// Example line: "15:1207405:d=2024010100:HGT:1000 mb:0 hour fcst:"
|
||||
type IdxEntry struct {
|
||||
Index int
|
||||
Offset int64
|
||||
Variable string // "HGT", "UGRD", "VGRD", etc.
|
||||
LevelMB int // pressure level in mb (0 if not a pressure level)
|
||||
Hour int // forecast hour
|
||||
EndOffset int64 // byte after this message (from next entry's offset, or -1 if last)
|
||||
}
|
||||
|
||||
// Length returns the byte length of this GRIB message, or -1 if unknown.
|
||||
func (e *IdxEntry) Length() int64 {
|
||||
if e.EndOffset <= 0 {
|
||||
return -1
|
||||
}
|
||||
return e.EndOffset - e.Offset
|
||||
}
|
||||
|
||||
// ParseIdx parses a .idx file body and returns all entries.
|
||||
// Lines that can't be parsed are silently skipped.
|
||||
func ParseIdx(body []byte) []IdxEntry {
|
||||
lines := strings.Split(string(body), "\n")
|
||||
var entries []IdxEntry
|
||||
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.Split(line, ":")
|
||||
if len(parts) < 7 {
|
||||
continue
|
||||
}
|
||||
|
||||
idx, err := strconv.Atoi(parts[0])
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
offset, err := strconv.ParseInt(parts[1], 10, 64)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
variable := parts[3]
|
||||
levelStr := parts[4]
|
||||
hourStr := parts[5]
|
||||
|
||||
levelMB := parseLevelMB(levelStr)
|
||||
hour := parseHour(hourStr)
|
||||
|
||||
entries = append(entries, IdxEntry{
|
||||
Index: idx,
|
||||
Offset: offset,
|
||||
Variable: variable,
|
||||
LevelMB: levelMB,
|
||||
Hour: hour,
|
||||
EndOffset: -1, // filled in below
|
||||
})
|
||||
}
|
||||
|
||||
// Fill in EndOffset from the next entry's Offset.
|
||||
for i := 0; i < len(entries)-1; i++ {
|
||||
entries[i].EndOffset = entries[i+1].Offset
|
||||
}
|
||||
|
||||
return entries
|
||||
}
|
||||
|
||||
// FilterIdx returns entries matching the given variables at pressure levels.
|
||||
// Only entries with a recognized pressure level (levelMB > 0) are returned.
|
||||
func FilterIdx(entries []IdxEntry, variables map[string]bool) []IdxEntry {
|
||||
var filtered []IdxEntry
|
||||
for _, e := range entries {
|
||||
if !variables[e.Variable] {
|
||||
continue
|
||||
}
|
||||
if e.LevelMB <= 0 {
|
||||
continue
|
||||
}
|
||||
// Must have a known length (not the last entry) or be handled specially
|
||||
if e.Length() <= 0 {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, e)
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// parseLevelMB parses a level string like "1000 mb" and returns the pressure in mb.
|
||||
// Returns 0 if not a pressure level.
|
||||
func parseLevelMB(s string) int {
|
||||
s = strings.TrimSpace(s)
|
||||
if !strings.HasSuffix(s, " mb") {
|
||||
return 0
|
||||
}
|
||||
numStr := strings.TrimSuffix(s, " mb")
|
||||
n, err := strconv.Atoi(numStr)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// parseHour parses a forecast hour string like "0 hour fcst" or "anl".
|
||||
// Returns -1 if it can't be parsed.
|
||||
func parseHour(s string) int {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "anl" {
|
||||
return 0
|
||||
}
|
||||
s = strings.TrimSuffix(s, " hour fcst")
|
||||
n, err := strconv.Atoi(s)
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
return n
|
||||
}
|
||||
|
||||
// GroupByRange groups idx entries into byte ranges suitable for HTTP Range downloads.
|
||||
// Each range covers one contiguous GRIB message.
|
||||
type ByteRange struct {
|
||||
Start int64
|
||||
End int64 // inclusive
|
||||
Entry IdxEntry
|
||||
}
|
||||
|
||||
// EntriesToRanges converts filtered idx entries to byte ranges.
|
||||
func EntriesToRanges(entries []IdxEntry) []ByteRange {
|
||||
ranges := make([]ByteRange, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
if e.Length() <= 0 {
|
||||
continue
|
||||
}
|
||||
ranges = append(ranges, ByteRange{
|
||||
Start: e.Offset,
|
||||
End: e.EndOffset - 1, // inclusive
|
||||
Entry: e,
|
||||
})
|
||||
}
|
||||
return ranges
|
||||
}
|
||||
|
||||
// FormatRange returns an HTTP Range header value for a byte range.
|
||||
func (r ByteRange) FormatRange() string {
|
||||
return fmt.Sprintf("bytes=%d-%d", r.Start, r.End)
|
||||
}
|
||||
110
internal/downloader/idx_test.go
Normal file
110
internal/downloader/idx_test.go
Normal file
|
|
@ -0,0 +1,110 @@
|
|||
package downloader
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
const sampleIdx = `1:0:d=2024010100:HGT:1000 mb:0 hour fcst:
|
||||
2:289012:d=2024010100:HGT:975 mb:0 hour fcst:
|
||||
3:541876:d=2024010100:TMP:1000 mb:0 hour fcst:
|
||||
4:789012:d=2024010100:UGRD:1000 mb:0 hour fcst:
|
||||
5:1045678:d=2024010100:VGRD:1000 mb:0 hour fcst:
|
||||
6:1298765:d=2024010100:UGRD:975 mb:0 hour fcst:
|
||||
7:1567890:d=2024010100:UGRD:2 m above ground:0 hour fcst:
|
||||
8:1812345:d=2024010100:VGRD:975 mb:0 hour fcst:
|
||||
9:2098765:d=2024010100:HGT:500 mb:3 hour fcst:
|
||||
`
|
||||
|
||||
func TestParseIdx(t *testing.T) {
|
||||
entries := ParseIdx([]byte(sampleIdx))
|
||||
if len(entries) != 9 {
|
||||
t.Fatalf("expected 9 entries, got %d", len(entries))
|
||||
}
|
||||
|
||||
// Check first entry
|
||||
e := entries[0]
|
||||
if e.Index != 1 || e.Offset != 0 || e.Variable != "HGT" || e.LevelMB != 1000 || e.Hour != 0 {
|
||||
t.Errorf("entry 0: got %+v", e)
|
||||
}
|
||||
if e.EndOffset != 289012 {
|
||||
t.Errorf("entry 0 EndOffset: got %d, want 289012", e.EndOffset)
|
||||
}
|
||||
|
||||
// Check "2 m above ground" is not a pressure level
|
||||
e = entries[6] // UGRD at "2 m above ground"
|
||||
if e.LevelMB != 0 {
|
||||
t.Errorf("non-pressure level should have LevelMB=0, got %d", e.LevelMB)
|
||||
}
|
||||
|
||||
// Last entry should have EndOffset = -1
|
||||
last := entries[len(entries)-1]
|
||||
if last.EndOffset != -1 {
|
||||
t.Errorf("last entry EndOffset: got %d, want -1", last.EndOffset)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilterIdx(t *testing.T) {
|
||||
entries := ParseIdx([]byte(sampleIdx))
|
||||
filtered := FilterIdx(entries, neededVariables)
|
||||
|
||||
// Should include HGT/UGRD/VGRD at pressure levels, exclude TMP and "above ground"
|
||||
// And exclude last entry (no EndOffset)
|
||||
for _, e := range filtered {
|
||||
if !neededVariables[e.Variable] {
|
||||
t.Errorf("unexpected variable %s", e.Variable)
|
||||
}
|
||||
if e.LevelMB <= 0 {
|
||||
t.Errorf("non-pressure level included: %+v", e)
|
||||
}
|
||||
if e.Length() <= 0 {
|
||||
t.Errorf("entry with unknown length included: %+v", e)
|
||||
}
|
||||
}
|
||||
|
||||
// Count expected: HGT@1000, HGT@975, UGRD@1000, VGRD@1000, UGRD@975, VGRD@975 = 6
|
||||
// But HGT@500 at 3hr fcst is the last entry (no EndOffset), so excluded
|
||||
if len(filtered) != 6 {
|
||||
t.Errorf("expected 6 filtered entries, got %d", len(filtered))
|
||||
for _, e := range filtered {
|
||||
t.Logf(" %s %d mb (offset %d, len %d)", e.Variable, e.LevelMB, e.Offset, e.Length())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseLevelMB(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want int
|
||||
}{
|
||||
{"1000 mb", 1000},
|
||||
{"975 mb", 975},
|
||||
{"1 mb", 1},
|
||||
{"2 m above ground", 0},
|
||||
{"surface", 0},
|
||||
{"tropopause", 0},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := parseLevelMB(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("parseLevelMB(%q) = %d, want %d", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseHour(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
want int
|
||||
}{
|
||||
{"0 hour fcst", 0},
|
||||
{"3 hour fcst", 3},
|
||||
{"192 hour fcst", 192},
|
||||
{"anl", 0},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := parseHour(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("parseHour(%q) = %d, want %d", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
113
internal/elevation/elevation.go
Normal file
113
internal/elevation/elevation.go
Normal file
|
|
@ -0,0 +1,113 @@
|
|||
package elevation
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"math"
|
||||
"os"
|
||||
|
||||
mmap "github.com/edsrzf/mmap-go"
|
||||
)
|
||||
|
||||
// Dataset provides global elevation lookup, compatible with ruaumoko.
|
||||
// Binary format: int16 little-endian elevation values in metres, row-major (lat, lon).
|
||||
// Latitude axis: -90 to +90 (south to north), Longitude axis: 0 to 360 (wraps).
|
||||
// Resolution: 120 cells per degree (30 arc-seconds).
|
||||
const (
|
||||
CellsPerDegree = 120
|
||||
NumLats = 180*CellsPerDegree + 1 // 21601
|
||||
NumLons = 360 * CellsPerDegree // 43200
|
||||
DataSize = NumLats * NumLons * 2 // 1,866,326,400 bytes (~1.74 GiB)
|
||||
)
|
||||
|
||||
// Dataset is a memory-mapped global elevation grid.
|
||||
type Dataset struct {
|
||||
mm mmap.MMap
|
||||
file *os.File
|
||||
}
|
||||
|
||||
// Open opens an existing elevation dataset file.
|
||||
func Open(path string) (*Dataset, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open elevation: %w", err)
|
||||
}
|
||||
|
||||
info, err := f.Stat()
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return nil, fmt.Errorf("stat elevation: %w", err)
|
||||
}
|
||||
if info.Size() != DataSize {
|
||||
f.Close()
|
||||
return nil, fmt.Errorf("elevation dataset should be %d bytes (was %d)", DataSize, info.Size())
|
||||
}
|
||||
|
||||
mm, err := mmap.Map(f, mmap.RDONLY, 0)
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return nil, fmt.Errorf("mmap elevation: %w", err)
|
||||
}
|
||||
|
||||
return &Dataset{mm: mm, file: f}, nil
|
||||
}
|
||||
|
||||
// getCell reads the int16 elevation at grid indices (latIdx, lngIdx).
|
||||
func (d *Dataset) getCell(latIdx, lngIdx int) int16 {
|
||||
// Clamp latitude
|
||||
if latIdx < 0 {
|
||||
latIdx = 0
|
||||
}
|
||||
if latIdx >= NumLats {
|
||||
latIdx = NumLats - 1
|
||||
}
|
||||
// Wrap longitude
|
||||
lngIdx = lngIdx % NumLons
|
||||
if lngIdx < 0 {
|
||||
lngIdx += NumLons
|
||||
}
|
||||
|
||||
off := (latIdx*NumLons + lngIdx) * 2
|
||||
return int16(binary.LittleEndian.Uint16(d.mm[off : off+2]))
|
||||
}
|
||||
|
||||
// Get returns the interpolated elevation in metres at the given coordinates.
|
||||
// lat: -90 to +90, lng: 0 to 360 (or -180 to 180, will be normalised).
|
||||
func (d *Dataset) Get(lat, lng float64) float64 {
|
||||
// Normalise longitude to [0, 360)
|
||||
if lng < 0 {
|
||||
lng += 360
|
||||
}
|
||||
|
||||
// Convert to cell coordinates
|
||||
latCell := (lat + 90.0) * CellsPerDegree
|
||||
lngCell := lng * CellsPerDegree
|
||||
|
||||
lat0 := int(math.Floor(latCell))
|
||||
lng0 := int(math.Floor(lngCell))
|
||||
latFrac := latCell - float64(lat0)
|
||||
lngFrac := lngCell - float64(lng0)
|
||||
|
||||
// Bilinear interpolation
|
||||
v00 := float64(d.getCell(lat0, lng0))
|
||||
v10 := float64(d.getCell(lat0+1, lng0))
|
||||
v01 := float64(d.getCell(lat0, lng0+1))
|
||||
v11 := float64(d.getCell(lat0+1, lng0+1))
|
||||
|
||||
return (1-latFrac)*((1-lngFrac)*v00+lngFrac*v01) +
|
||||
latFrac*((1-lngFrac)*v10+lngFrac*v11)
|
||||
}
|
||||
|
||||
// Close unmaps and closes the dataset.
|
||||
func (d *Dataset) Close() error {
|
||||
if d.mm != nil {
|
||||
d.mm.Unmap()
|
||||
d.mm = nil
|
||||
}
|
||||
if d.file != nil {
|
||||
err := d.file.Close()
|
||||
d.file = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
153
internal/prediction/interpolate.go
Normal file
153
internal/prediction/interpolate.go
Normal file
|
|
@ -0,0 +1,153 @@
|
|||
package prediction
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"predictor-refactored/internal/dataset"
|
||||
)
|
||||
|
||||
// Exact port of the reference interpolation logic (interpolate.pyx).
|
||||
// 4D interpolation: time, latitude, longitude, altitude (via geopotential height).
|
||||
|
||||
// lerp1 holds an index and interpolation weight for one axis.
|
||||
type lerp1 struct {
|
||||
index int
|
||||
lerp float64
|
||||
}
|
||||
|
||||
// lerp3 holds indices and a combined weight for the (hour, lat, lon) axes.
|
||||
type lerp3 struct {
|
||||
hour, lat, lng int
|
||||
lerp float64
|
||||
}
|
||||
|
||||
// RangeError indicates a coordinate is outside the dataset bounds.
|
||||
type RangeError struct {
|
||||
Variable string
|
||||
Value float64
|
||||
}
|
||||
|
||||
func (e *RangeError) Error() string {
|
||||
return fmt.Sprintf("%s=%f out of range", e.Variable, e.Value)
|
||||
}
|
||||
|
||||
// pick computes interpolation indices and weights for a single axis.
|
||||
// left: axis start, step: axis spacing, n: number of points, value: query value.
|
||||
// Returns two lerp1 values (lower and upper bracket).
|
||||
func pick(left, step float64, n int, value float64, variableName string) ([2]lerp1, error) {
|
||||
a := (value - left) / step
|
||||
b := int(a) // truncation toward zero, same as Cython <long> cast
|
||||
if b < 0 || b >= n-1 {
|
||||
return [2]lerp1{}, &RangeError{Variable: variableName, Value: value}
|
||||
}
|
||||
l := a - float64(b)
|
||||
return [2]lerp1{
|
||||
{index: b, lerp: 1 - l},
|
||||
{index: b + 1, lerp: l},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// pick3 computes 8 trilinear interpolation weights for (hour, lat, lng).
|
||||
func pick3(hour, lat, lng float64) ([8]lerp3, error) {
|
||||
lhour, err := pick(0, 3, 65, hour, "hour")
|
||||
if err != nil {
|
||||
return [8]lerp3{}, err
|
||||
}
|
||||
llat, err := pick(-90, 0.5, 361, lat, "lat")
|
||||
if err != nil {
|
||||
return [8]lerp3{}, err
|
||||
}
|
||||
// Longitude wraps: tell pick the axis is one larger, then wrap index 720 → 0
|
||||
llng, err := pick(0, 0.5, 720+1, lng, "lng")
|
||||
if err != nil {
|
||||
return [8]lerp3{}, err
|
||||
}
|
||||
if llng[1].index == 720 {
|
||||
llng[1].index = 0
|
||||
}
|
||||
|
||||
var out [8]lerp3
|
||||
i := 0
|
||||
for _, a := range lhour {
|
||||
for _, b := range llat {
|
||||
for _, c := range llng {
|
||||
out[i] = lerp3{
|
||||
hour: a.index,
|
||||
lat: b.index,
|
||||
lng: c.index,
|
||||
lerp: a.lerp * b.lerp * c.lerp,
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// interp3 performs 8-point weighted interpolation at a given variable and pressure level.
|
||||
func interp3(ds *dataset.File, lerps [8]lerp3, variable, level int) float64 {
|
||||
var r float64
|
||||
for i := 0; i < 8; i++ {
|
||||
v := ds.Val(lerps[i].hour, level, variable, lerps[i].lat, lerps[i].lng)
|
||||
r += float64(v) * lerps[i].lerp
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// search finds the largest pressure level index where interpolated geopotential
|
||||
// height is less than the target altitude. Searches levels 0..45 (excludes topmost).
|
||||
func search(ds *dataset.File, lerps [8]lerp3, target float64) int {
|
||||
lower, upper := 0, 45
|
||||
|
||||
for lower < upper {
|
||||
mid := (lower + upper + 1) / 2
|
||||
test := interp3(ds, lerps, dataset.VarHeight, mid)
|
||||
if target <= test {
|
||||
upper = mid - 1
|
||||
} else {
|
||||
lower = mid
|
||||
}
|
||||
}
|
||||
|
||||
return lower
|
||||
}
|
||||
|
||||
// interp4 performs altitude-interpolated wind lookup using two bracketing levels.
|
||||
func interp4(ds *dataset.File, lerps [8]lerp3, altLerp lerp1, variable int) float64 {
|
||||
lower := interp3(ds, lerps, variable, altLerp.index)
|
||||
upper := interp3(ds, lerps, variable, altLerp.index+1)
|
||||
return lower*altLerp.lerp + upper*(1-altLerp.lerp)
|
||||
}
|
||||
|
||||
// GetWind returns interpolated (u, v) wind components for the given position.
|
||||
// hour: fractional hours since dataset start.
|
||||
// lat: latitude in degrees (-90 to +90).
|
||||
// lng: longitude in degrees (0 to 360).
|
||||
// alt: altitude in metres above sea level.
|
||||
func GetWind(ds *dataset.File, warnings *Warnings, hour, lat, lng, alt float64) (u, v float64, err error) {
|
||||
lerps, err := pick3(hour, lat, lng)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
altidx := search(ds, lerps, alt)
|
||||
lower := interp3(ds, lerps, dataset.VarHeight, altidx)
|
||||
upper := interp3(ds, lerps, dataset.VarHeight, altidx+1)
|
||||
|
||||
var altLerp float64
|
||||
if lower != upper {
|
||||
altLerp = (upper - alt) / (upper - lower)
|
||||
} else {
|
||||
altLerp = 0.5
|
||||
}
|
||||
|
||||
if altLerp < 0 {
|
||||
warnings.AltitudeTooHigh.Add(1)
|
||||
}
|
||||
|
||||
alt1 := lerp1{index: altidx, lerp: altLerp}
|
||||
u = interp4(ds, lerps, alt1, dataset.VarWindU)
|
||||
v = interp4(ds, lerps, alt1, dataset.VarWindV)
|
||||
|
||||
return u, v, nil
|
||||
}
|
||||
188
internal/prediction/models.go
Normal file
188
internal/prediction/models.go
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
package prediction
|
||||
|
||||
import (
|
||||
"math"
|
||||
"time"
|
||||
|
||||
"predictor-refactored/internal/dataset"
|
||||
"predictor-refactored/internal/elevation"
|
||||
)
|
||||
|
||||
// Exact port of the reference flight models (models.py).
|
||||
|
||||
const (
|
||||
pi180 = math.Pi / 180.0
|
||||
_180pi = 180.0 / math.Pi
|
||||
)
|
||||
|
||||
// --- Up/Down Models ---
|
||||
|
||||
// ConstantAscent returns a model with constant vertical velocity (m/s).
|
||||
func ConstantAscent(ascentRate float64) Model {
|
||||
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
|
||||
return 0, 0, ascentRate
|
||||
}
|
||||
}
|
||||
|
||||
// DragDescent returns a descent-under-parachute model.
|
||||
// seaLevelDescentRate is the descent rate at sea level (m/s, positive value).
|
||||
// Uses the NASA atmosphere model for density at altitude.
|
||||
func DragDescent(seaLevelDescentRate float64) Model {
|
||||
dragCoefficient := seaLevelDescentRate * 1.1045
|
||||
|
||||
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
|
||||
return 0, 0, -dragCoefficient / math.Sqrt(nasaDensity(alt))
|
||||
}
|
||||
}
|
||||
|
||||
// nasaDensity computes air density using the NASA atmosphere model.
|
||||
// Reference: http://www.grc.nasa.gov/WWW/K-12/airplane/atmosmet.html
|
||||
func nasaDensity(alt float64) float64 {
|
||||
var temp, pressure float64
|
||||
|
||||
switch {
|
||||
case alt > 25000:
|
||||
temp = -131.21 + 0.00299*alt
|
||||
pressure = 2.488 * math.Pow((temp+273.1)/216.6, -11.388)
|
||||
case alt > 11000:
|
||||
temp = -56.46
|
||||
pressure = 22.65 * math.Exp(1.73-0.000157*alt)
|
||||
default:
|
||||
temp = 15.04 - 0.00649*alt
|
||||
pressure = 101.29 * math.Pow((temp+273.1)/288.08, 5.256)
|
||||
}
|
||||
|
||||
return pressure / (0.2869 * (temp + 273.1))
|
||||
}
|
||||
|
||||
// --- Sideways Models ---
|
||||
|
||||
// WindVelocity returns a model that gives lateral movement at the wind velocity.
|
||||
// ds is the wind dataset, dsEpoch is the dataset start time as UNIX timestamp.
|
||||
func WindVelocity(ds *dataset.File, dsEpoch float64, warnings *Warnings) Model {
|
||||
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
|
||||
tHours := (t - dsEpoch) / 3600.0
|
||||
u, v, err := GetWind(ds, warnings, tHours, lat, lng, alt)
|
||||
if err != nil {
|
||||
return 0, 0, 0
|
||||
}
|
||||
|
||||
R := 6371009.0 + alt
|
||||
dlat = _180pi * v / R
|
||||
dlng = _180pi * u / (R * math.Cos(lat*pi180))
|
||||
return dlat, dlng, 0
|
||||
}
|
||||
}
|
||||
|
||||
// --- Model Combinations ---
|
||||
|
||||
// LinearModel returns a model that sums all component models.
|
||||
func LinearModel(models ...Model) Model {
|
||||
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
|
||||
for _, m := range models {
|
||||
d1, d2, d3 := m(t, lat, lng, alt)
|
||||
dlat += d1
|
||||
dlng += d2
|
||||
dalt += d3
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// --- Termination Criteria ---
|
||||
|
||||
// BurstTermination returns a terminator that fires when altitude >= burstAltitude.
|
||||
func BurstTermination(burstAltitude float64) Terminator {
|
||||
return func(t, lat, lng, alt float64) bool {
|
||||
return alt >= burstAltitude
|
||||
}
|
||||
}
|
||||
|
||||
// SeaLevelTermination fires when altitude <= 0.
|
||||
func SeaLevelTermination(t, lat, lng, alt float64) bool {
|
||||
return alt <= 0
|
||||
}
|
||||
|
||||
// TimeTermination returns a terminator that fires when t > maxTime.
|
||||
func TimeTermination(maxTime float64) Terminator {
|
||||
return func(t, lat, lng, alt float64) bool {
|
||||
return t > maxTime
|
||||
}
|
||||
}
|
||||
|
||||
// ElevationTermination returns a terminator that fires when alt < ground level.
|
||||
// Uses ruaumoko-compatible elevation data. Longitude is normalised internally.
|
||||
func ElevationTermination(elev *elevation.Dataset) Terminator {
|
||||
return func(t, lat, lng, alt float64) bool {
|
||||
return elev.Get(lat, lng) > alt
|
||||
}
|
||||
}
|
||||
|
||||
// --- Pre-Defined Profiles ---
|
||||
|
||||
// Stage pairs a model with its termination criterion.
|
||||
type Stage struct {
|
||||
Model Model
|
||||
Terminator Terminator
|
||||
}
|
||||
|
||||
// StandardProfile creates the chain for a standard high-altitude balloon flight:
|
||||
// ascent at constant rate → burst → descent under parachute.
|
||||
// If elev is non-nil, descent terminates at ground level; otherwise at sea level.
|
||||
func StandardProfile(ascentRate, burstAltitude, descentRate float64,
|
||||
ds *dataset.File, dsEpoch float64, warnings *Warnings,
|
||||
elev *elevation.Dataset) []Stage {
|
||||
|
||||
wind := WindVelocity(ds, dsEpoch, warnings)
|
||||
|
||||
modelUp := LinearModel(ConstantAscent(ascentRate), wind)
|
||||
termUp := BurstTermination(burstAltitude)
|
||||
|
||||
modelDown := LinearModel(DragDescent(descentRate), wind)
|
||||
var termDown Terminator
|
||||
if elev != nil {
|
||||
termDown = ElevationTermination(elev)
|
||||
} else {
|
||||
termDown = Terminator(SeaLevelTermination)
|
||||
}
|
||||
|
||||
return []Stage{
|
||||
{Model: modelUp, Terminator: termUp},
|
||||
{Model: modelDown, Terminator: termDown},
|
||||
}
|
||||
}
|
||||
|
||||
// FloatProfile creates the chain for a floating balloon flight:
|
||||
// ascent to float altitude → float until stop time.
|
||||
func FloatProfile(ascentRate, floatAltitude float64, stopTime time.Time,
|
||||
ds *dataset.File, dsEpoch float64, warnings *Warnings) []Stage {
|
||||
|
||||
wind := WindVelocity(ds, dsEpoch, warnings)
|
||||
|
||||
modelUp := LinearModel(ConstantAscent(ascentRate), wind)
|
||||
termUp := BurstTermination(floatAltitude)
|
||||
|
||||
modelFloat := wind
|
||||
termFloat := TimeTermination(float64(stopTime.Unix()))
|
||||
|
||||
return []Stage{
|
||||
{Model: modelUp, Terminator: termUp},
|
||||
{Model: modelFloat, Terminator: termFloat},
|
||||
}
|
||||
}
|
||||
|
||||
// RunPrediction runs a prediction with the given profile stages.
|
||||
// launchTime is a UNIX timestamp.
|
||||
func RunPrediction(launchTime float64, lat, lng, alt float64, stages []Stage) []StageResult {
|
||||
chain := make([]struct {
|
||||
Model Model
|
||||
Terminator Terminator
|
||||
}, len(stages))
|
||||
|
||||
for i, s := range stages {
|
||||
chain[i].Model = s.Model
|
||||
chain[i].Terminator = s.Terminator
|
||||
}
|
||||
|
||||
return Solve(launchTime, lat, lng, alt, chain)
|
||||
}
|
||||
180
internal/prediction/solver.go
Normal file
180
internal/prediction/solver.go
Normal file
|
|
@ -0,0 +1,180 @@
|
|||
package prediction
|
||||
|
||||
import "math"
|
||||
|
||||
// Exact port of the reference RK4 solver (solver.pyx).
|
||||
// Integrates balloon state using RK4 with dt=60 seconds.
|
||||
// Termination uses binary search refinement (tolerance 0.01).
|
||||
|
||||
// Vec holds the balloon state: latitude, longitude, altitude.
|
||||
type Vec struct {
|
||||
Lat float64
|
||||
Lng float64
|
||||
Alt float64
|
||||
}
|
||||
|
||||
// Model is a function that returns (dlat/dt, dlng/dt, dalt/dt) given state.
|
||||
// t is UNIX timestamp, lat/lng in degrees, alt in metres.
|
||||
type Model func(t float64, lat, lng, alt float64) (dlat, dlng, dalt float64)
|
||||
|
||||
// Terminator returns true when integration should stop.
|
||||
type Terminator func(t float64, lat, lng, alt float64) bool
|
||||
|
||||
// StageResult holds the trajectory points for one flight stage.
|
||||
type StageResult struct {
|
||||
Points []TrajectoryPoint
|
||||
}
|
||||
|
||||
// TrajectoryPoint is a single point in a trajectory (used by solver).
|
||||
type TrajectoryPoint struct {
|
||||
T float64 // UNIX timestamp
|
||||
Lat float64
|
||||
Lng float64
|
||||
Alt float64
|
||||
}
|
||||
|
||||
// pymod returns a % b with Python semantics (always non-negative when b > 0).
|
||||
func pymod(a, b float64) float64 {
|
||||
r := math.Mod(a, b)
|
||||
if r < 0 {
|
||||
r += b
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// vecadd returns a + k*b, with lng wrapped to [0, 360).
|
||||
func vecadd(a Vec, k float64, b Vec) Vec {
|
||||
return Vec{
|
||||
Lat: a.Lat + k*b.Lat,
|
||||
Lng: pymod(a.Lng+k*b.Lng, 360.0),
|
||||
Alt: a.Alt + k*b.Alt,
|
||||
}
|
||||
}
|
||||
|
||||
// scalarLerp returns (1-l)*a + l*b.
|
||||
func scalarLerp(a, b, l float64) float64 {
|
||||
return (1-l)*a + l*b
|
||||
}
|
||||
|
||||
// lngLerp interpolates longitude handling the 0/360 wrap-around.
|
||||
func lngLerp(a, b, l float64) float64 {
|
||||
l2 := 1 - l
|
||||
|
||||
if a > b {
|
||||
a, b = b, a
|
||||
l, l2 = l2, l
|
||||
}
|
||||
|
||||
// distance round one way: b - a
|
||||
// distance around other: (a + 360) - b
|
||||
if b-a < 180.0 {
|
||||
return l2*a + l*b
|
||||
}
|
||||
return pymod(l2*(a+360)+l*b, 360.0)
|
||||
}
|
||||
|
||||
// vecLerp returns (1-l)*a + l*b with proper longitude wrapping.
|
||||
func vecLerp(a, b Vec, l float64) Vec {
|
||||
return Vec{
|
||||
Lat: scalarLerp(a.Lat, b.Lat, l),
|
||||
Lng: lngLerp(a.Lng, b.Lng, l),
|
||||
Alt: scalarLerp(a.Alt, b.Alt, l),
|
||||
}
|
||||
}
|
||||
|
||||
// rk4 integrates from initial conditions using RK4.
|
||||
// dt=60.0 seconds, terminationTolerance=0.01.
|
||||
func rk4(t float64, lat, lng, alt float64, model Model, terminator Terminator) []TrajectoryPoint {
|
||||
const dt = 60.0
|
||||
const terminationTolerance = 0.01
|
||||
|
||||
y := Vec{Lat: lat, Lng: lng, Alt: alt}
|
||||
result := []TrajectoryPoint{{T: t, Lat: y.Lat, Lng: y.Lng, Alt: y.Alt}}
|
||||
|
||||
for {
|
||||
// Evaluate model at 4 points (standard RK4)
|
||||
k1lat, k1lng, k1alt := model(t, y.Lat, y.Lng, y.Alt)
|
||||
k1 := Vec{Lat: k1lat, Lng: k1lng, Alt: k1alt}
|
||||
|
||||
mid1 := vecadd(y, dt/2, k1)
|
||||
k2lat, k2lng, k2alt := model(t+dt/2, mid1.Lat, mid1.Lng, mid1.Alt)
|
||||
k2 := Vec{Lat: k2lat, Lng: k2lng, Alt: k2alt}
|
||||
|
||||
mid2 := vecadd(y, dt/2, k2)
|
||||
k3lat, k3lng, k3alt := model(t+dt/2, mid2.Lat, mid2.Lng, mid2.Alt)
|
||||
k3 := Vec{Lat: k3lat, Lng: k3lng, Alt: k3alt}
|
||||
|
||||
end := vecadd(y, dt, k3)
|
||||
k4lat, k4lng, k4alt := model(t+dt, end.Lat, end.Lng, end.Alt)
|
||||
k4 := Vec{Lat: k4lat, Lng: k4lng, Alt: k4alt}
|
||||
|
||||
// y2 = y + dt/6*k1 + dt/3*k2 + dt/3*k3 + dt/6*k4
|
||||
y2 := y
|
||||
y2 = vecadd(y2, dt/6, k1)
|
||||
y2 = vecadd(y2, dt/3, k2)
|
||||
y2 = vecadd(y2, dt/3, k3)
|
||||
y2 = vecadd(y2, dt/6, k4)
|
||||
|
||||
t2 := t + dt
|
||||
|
||||
if terminator(t2, y2.Lat, y2.Lng, y2.Alt) {
|
||||
// Binary search to refine the termination point.
|
||||
// Find l in [0, 1] such that (t3, y3) = lerp((t, y), (t2, y2), l)
|
||||
// is near where the terminator fires.
|
||||
left := 0.0
|
||||
right := 1.0
|
||||
|
||||
var t3 float64
|
||||
var y3 Vec
|
||||
t3 = t2
|
||||
y3 = y2
|
||||
|
||||
for right-left > terminationTolerance {
|
||||
mid := (left + right) / 2
|
||||
t3 = scalarLerp(t, t2, mid)
|
||||
y3 = vecLerp(y, y2, mid)
|
||||
|
||||
if terminator(t3, y3.Lat, y3.Lng, y3.Alt) {
|
||||
right = mid
|
||||
} else {
|
||||
left = mid
|
||||
}
|
||||
}
|
||||
|
||||
result = append(result, TrajectoryPoint{T: t3, Lat: y3.Lat, Lng: y3.Lng, Alt: y3.Alt})
|
||||
break
|
||||
}
|
||||
|
||||
// Update current state
|
||||
t = t2
|
||||
y = y2
|
||||
result = append(result, TrajectoryPoint{T: t, Lat: y.Lat, Lng: y.Lng, Alt: y.Alt})
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
// Solve runs through a chain of (model, terminator) stages.
|
||||
// Returns one StageResult per stage.
|
||||
func Solve(t, lat, lng, alt float64, chain []struct {
|
||||
Model Model
|
||||
Terminator Terminator
|
||||
}) []StageResult {
|
||||
var results []StageResult
|
||||
|
||||
for _, stage := range chain {
|
||||
points := rk4(t, lat, lng, alt, stage.Model, stage.Terminator)
|
||||
results = append(results, StageResult{Points: points})
|
||||
|
||||
// Next stage starts where this one ended
|
||||
if len(points) > 0 {
|
||||
last := points[len(points)-1]
|
||||
t = last.T
|
||||
lat = last.Lat
|
||||
lng = last.Lng
|
||||
alt = last.Alt
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
21
internal/prediction/warnings.go
Normal file
21
internal/prediction/warnings.go
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
package prediction
|
||||
|
||||
import "sync/atomic"
|
||||
|
||||
// Warnings tracks warning conditions during a prediction run.
|
||||
type Warnings struct {
|
||||
AltitudeTooHigh atomic.Int64
|
||||
}
|
||||
|
||||
// ToMap returns warnings as a map suitable for JSON serialization.
|
||||
// Only includes warnings that have fired.
|
||||
func (w *Warnings) ToMap() map[string]any {
|
||||
result := make(map[string]any)
|
||||
if n := w.AltitudeTooHigh.Load(); n > 0 {
|
||||
result["altitude_too_high"] = map[string]any{
|
||||
"count": n,
|
||||
"description": "The altitude went too high, above the max forecast wind. Wind data will be unreliable",
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
245
internal/service/service.go
Normal file
245
internal/service/service.go
Normal file
|
|
@ -0,0 +1,245 @@
|
|||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"predictor-refactored/internal/dataset"
|
||||
"predictor-refactored/internal/downloader"
|
||||
"predictor-refactored/internal/elevation"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Service orchestrates the dataset lifecycle and provides prediction capabilities.
|
||||
type Service struct {
|
||||
mu sync.RWMutex
|
||||
ds *dataset.File
|
||||
elev *elevation.Dataset
|
||||
cfg *downloader.Config
|
||||
dl *downloader.Downloader
|
||||
log *zap.Logger
|
||||
updating sync.Mutex // prevents concurrent downloads
|
||||
}
|
||||
|
||||
// New creates a new Service.
|
||||
func New(cfg *downloader.Config, log *zap.Logger) *Service {
|
||||
return &Service{
|
||||
cfg: cfg,
|
||||
dl: downloader.NewDownloader(cfg, log),
|
||||
log: log,
|
||||
}
|
||||
}
|
||||
|
||||
// LoadElevation loads the ruaumoko-compatible elevation dataset from path.
|
||||
// If the file doesn't exist, elevation termination is disabled (falls back to sea level).
|
||||
func (s *Service) LoadElevation(path string) {
|
||||
ds, err := elevation.Open(path)
|
||||
if err != nil {
|
||||
s.log.Warn("elevation dataset not available, using sea-level termination",
|
||||
zap.String("path", path), zap.Error(err))
|
||||
return
|
||||
}
|
||||
s.elev = ds
|
||||
s.log.Info("elevation dataset loaded", zap.String("path", path))
|
||||
}
|
||||
|
||||
// Elevation returns the elevation dataset (may be nil).
|
||||
func (s *Service) Elevation() *elevation.Dataset {
|
||||
return s.elev
|
||||
}
|
||||
|
||||
// Ready returns true if the service has a loaded dataset.
|
||||
func (s *Service) Ready() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.ds != nil
|
||||
}
|
||||
|
||||
// DatasetTime returns the forecast time of the currently loaded dataset.
|
||||
func (s *Service) DatasetTime() (time.Time, bool) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
if s.ds == nil {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return s.ds.DSTime, true
|
||||
}
|
||||
|
||||
// Dataset returns the current dataset for reading.
|
||||
func (s *Service) Dataset() *dataset.File {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.ds
|
||||
}
|
||||
|
||||
// Update checks for and downloads new forecast data if needed.
|
||||
func (s *Service) Update(ctx context.Context) error {
|
||||
if !s.updating.TryLock() {
|
||||
s.log.Info("update already in progress, skipping")
|
||||
return nil
|
||||
}
|
||||
defer s.updating.Unlock()
|
||||
|
||||
// Check if current dataset is still fresh
|
||||
if dsTime, ok := s.DatasetTime(); ok {
|
||||
if time.Since(dsTime) < s.cfg.DatasetTTL {
|
||||
s.log.Info("dataset still fresh, skipping update",
|
||||
zap.Time("dataset_time", dsTime),
|
||||
zap.Duration("age", time.Since(dsTime)))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Try loading an existing dataset from disk first
|
||||
if err := s.loadExistingDataset(); err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Find latest available model run
|
||||
run, err := s.dl.FindLatestRun(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Download and assemble
|
||||
path, err := s.dl.Download(ctx, run)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Open the new dataset
|
||||
ds, err := dataset.Open(path, run)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Swap in the new dataset
|
||||
s.setDataset(ds)
|
||||
s.log.Info("dataset loaded", zap.Time("run", run), zap.String("path", path))
|
||||
|
||||
// Clean old datasets
|
||||
s.cleanOldDatasets(path)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadExistingDataset tries to find and load an existing dataset from the data directory.
|
||||
func (s *Service) loadExistingDataset() error {
|
||||
entries, err := os.ReadDir(s.cfg.DataDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Collect valid dataset files (name is YYYYMMDDHH, no extension, correct size)
|
||||
type candidate struct {
|
||||
name string
|
||||
path string
|
||||
run time.Time
|
||||
}
|
||||
var candidates []candidate
|
||||
|
||||
for _, e := range entries {
|
||||
if e.IsDir() || strings.Contains(e.Name(), ".") {
|
||||
continue
|
||||
}
|
||||
if len(e.Name()) != 10 {
|
||||
continue
|
||||
}
|
||||
|
||||
run, err := time.Parse("2006010215", e.Name())
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
path := filepath.Join(s.cfg.DataDir, e.Name())
|
||||
info, err := os.Stat(path)
|
||||
if err != nil || info.Size() != dataset.DatasetSize {
|
||||
continue
|
||||
}
|
||||
|
||||
if time.Since(run) > s.cfg.DatasetTTL {
|
||||
continue
|
||||
}
|
||||
|
||||
candidates = append(candidates, candidate{name: e.Name(), path: path, run: run})
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return os.ErrNotExist
|
||||
}
|
||||
|
||||
// Pick the newest
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
return candidates[i].run.After(candidates[j].run)
|
||||
})
|
||||
|
||||
best := candidates[0]
|
||||
ds, err := dataset.Open(best.path, best.run)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.setDataset(ds)
|
||||
s.log.Info("loaded existing dataset",
|
||||
zap.Time("run", best.run),
|
||||
zap.String("path", best.path))
|
||||
return nil
|
||||
}
|
||||
|
||||
// setDataset swaps the current dataset with a new one, closing the old one.
|
||||
func (s *Service) setDataset(ds *dataset.File) {
|
||||
s.mu.Lock()
|
||||
old := s.ds
|
||||
s.ds = ds
|
||||
s.mu.Unlock()
|
||||
|
||||
if old != nil {
|
||||
if err := old.Close(); err != nil {
|
||||
s.log.Error("failed to close old dataset", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// cleanOldDatasets removes dataset files other than the one at keepPath.
|
||||
func (s *Service) cleanOldDatasets(keepPath string) {
|
||||
entries, err := os.ReadDir(s.cfg.DataDir)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(s.cfg.DataDir, e.Name())
|
||||
if path == keepPath {
|
||||
continue
|
||||
}
|
||||
// Remove old datasets and temp files
|
||||
if len(e.Name()) == 10 || strings.HasSuffix(e.Name(), ".downloading") {
|
||||
if err := os.Remove(path); err != nil {
|
||||
s.log.Warn("failed to remove old file", zap.String("path", path), zap.Error(err))
|
||||
} else {
|
||||
s.log.Info("removed old dataset", zap.String("path", path))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close releases all resources.
|
||||
func (s *Service) Close() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.ds != nil {
|
||||
err := s.ds.Close()
|
||||
s.ds = nil
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
30
internal/transport/middleware/log.go
Normal file
30
internal/transport/middleware/log.go
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
package middleware
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/ogen-go/ogen/middleware"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Logging returns an ogen middleware that logs request duration.
|
||||
func Logging(log *zap.Logger) middleware.Middleware {
|
||||
return func(req middleware.Request, next func(req middleware.Request) (middleware.Response, error)) (middleware.Response, error) {
|
||||
lg := log.With(zap.String("operation", req.OperationID))
|
||||
|
||||
start := time.Now()
|
||||
resp, err := next(req)
|
||||
dur := time.Since(start)
|
||||
|
||||
if err != nil {
|
||||
lg.Error("request failed",
|
||||
zap.Duration("duration", dur),
|
||||
zap.Error(err))
|
||||
} else {
|
||||
lg.Info("request completed",
|
||||
zap.Duration("duration", dur))
|
||||
}
|
||||
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
16
internal/transport/rest/handler/deps.go
Normal file
16
internal/transport/rest/handler/deps.go
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"predictor-refactored/internal/dataset"
|
||||
"predictor-refactored/internal/elevation"
|
||||
)
|
||||
|
||||
// Service defines the interface the handler needs from the service layer.
|
||||
type Service interface {
|
||||
Ready() bool
|
||||
DatasetTime() (time.Time, bool)
|
||||
Dataset() *dataset.File
|
||||
Elevation() *elevation.Dataset
|
||||
}
|
||||
216
internal/transport/rest/handler/handler.go
Normal file
216
internal/transport/rest/handler/handler.go
Normal file
|
|
@ -0,0 +1,216 @@
|
|||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"predictor-refactored/internal/prediction"
|
||||
api "predictor-refactored/pkg/rest"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
var _ api.Handler = (*Handler)(nil)
|
||||
|
||||
// Handler implements the ogen-generated api.Handler interface.
|
||||
type Handler struct {
|
||||
svc Service
|
||||
log *zap.Logger
|
||||
}
|
||||
|
||||
// New creates a new Handler.
|
||||
func New(svc Service, log *zap.Logger) *Handler {
|
||||
return &Handler{svc: svc, log: log}
|
||||
}
|
||||
|
||||
// PerformPrediction implements the prediction endpoint.
|
||||
func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredictionParams) (*api.PredictionResponse, error) {
|
||||
if !h.svc.Ready() {
|
||||
return nil, newError(http.StatusServiceUnavailable, "no dataset loaded, service is starting up")
|
||||
}
|
||||
|
||||
ds := h.svc.Dataset()
|
||||
if ds == nil {
|
||||
return nil, newError(http.StatusServiceUnavailable, "dataset unavailable")
|
||||
}
|
||||
|
||||
dsEpoch := float64(ds.DSTime.Unix())
|
||||
|
||||
// Parse parameters with defaults
|
||||
profile := "standard_profile"
|
||||
if p, ok := params.Profile.Get(); ok {
|
||||
profile = string(p)
|
||||
}
|
||||
|
||||
ascentRate := 5.0
|
||||
if v, ok := params.AscentRate.Get(); ok {
|
||||
ascentRate = v
|
||||
}
|
||||
|
||||
burstAltitude := 28000.0
|
||||
if v, ok := params.BurstAltitude.Get(); ok {
|
||||
burstAltitude = v
|
||||
}
|
||||
|
||||
descentRate := 5.0
|
||||
if v, ok := params.DescentRate.Get(); ok {
|
||||
descentRate = v
|
||||
}
|
||||
|
||||
launchAlt := 0.0
|
||||
if v, ok := params.LaunchAltitude.Get(); ok {
|
||||
launchAlt = v
|
||||
}
|
||||
|
||||
// Normalize longitude to [0, 360)
|
||||
lng := params.LaunchLongitude
|
||||
if lng < 0 {
|
||||
lng += 360.0
|
||||
}
|
||||
|
||||
launchTime := float64(params.LaunchDatetime.Unix())
|
||||
|
||||
warnings := &prediction.Warnings{}
|
||||
|
||||
// Build profile chain
|
||||
elev := h.svc.Elevation()
|
||||
var stages []prediction.Stage
|
||||
switch profile {
|
||||
case "standard_profile":
|
||||
stages = prediction.StandardProfile(
|
||||
ascentRate, burstAltitude, descentRate,
|
||||
ds, dsEpoch, warnings, elev)
|
||||
case "float_profile":
|
||||
floatAlt := 25000.0
|
||||
if v, ok := params.FloatAltitude.Get(); ok {
|
||||
floatAlt = v
|
||||
}
|
||||
stopTime := params.LaunchDatetime.Add(24 * time.Hour)
|
||||
if v, ok := params.StopDatetime.Get(); ok {
|
||||
stopTime = v
|
||||
}
|
||||
stages = prediction.FloatProfile(
|
||||
ascentRate, floatAlt, stopTime,
|
||||
ds, dsEpoch, warnings)
|
||||
default:
|
||||
return nil, newError(http.StatusBadRequest, "unknown profile: "+profile)
|
||||
}
|
||||
|
||||
// Run prediction
|
||||
startTime := time.Now().UTC()
|
||||
results := prediction.RunPrediction(launchTime, params.LaunchLatitude, lng, launchAlt, stages)
|
||||
completeTime := time.Now().UTC()
|
||||
|
||||
// Build response
|
||||
stageNames := []string{"ascent", "descent"}
|
||||
if profile == "float_profile" {
|
||||
stageNames = []string{"ascent", "float"}
|
||||
}
|
||||
|
||||
var predItems []api.PredictionResponsePredictionItem
|
||||
for i, sr := range results {
|
||||
stageName := "ascent"
|
||||
if i < len(stageNames) {
|
||||
stageName = stageNames[i]
|
||||
}
|
||||
|
||||
var stageEnum api.PredictionResponsePredictionItemStage
|
||||
switch stageName {
|
||||
case "ascent":
|
||||
stageEnum = api.PredictionResponsePredictionItemStageAscent
|
||||
case "descent":
|
||||
stageEnum = api.PredictionResponsePredictionItemStageDescent
|
||||
case "float":
|
||||
stageEnum = api.PredictionResponsePredictionItemStageFloat
|
||||
}
|
||||
|
||||
var traj []api.PredictionResponsePredictionItemTrajectoryItem
|
||||
for _, pt := range sr.Points {
|
||||
ptLng := pt.Lng
|
||||
if ptLng > 180 {
|
||||
ptLng -= 360
|
||||
}
|
||||
traj = append(traj, api.PredictionResponsePredictionItemTrajectoryItem{
|
||||
Datetime: time.Unix(int64(pt.T), 0).UTC(),
|
||||
Latitude: pt.Lat,
|
||||
Longitude: ptLng,
|
||||
Altitude: pt.Alt,
|
||||
})
|
||||
}
|
||||
|
||||
predItems = append(predItems, api.PredictionResponsePredictionItem{
|
||||
Stage: stageEnum,
|
||||
Trajectory: traj,
|
||||
})
|
||||
}
|
||||
|
||||
resp := &api.PredictionResponse{
|
||||
Prediction: predItems,
|
||||
Metadata: api.PredictionResponseMetadata{
|
||||
StartDatetime: startTime,
|
||||
CompleteDatetime: completeTime,
|
||||
},
|
||||
}
|
||||
|
||||
// Echo request
|
||||
resp.Request = api.NewOptPredictionResponseRequest(api.PredictionResponseRequest{
|
||||
Dataset: api.NewOptString(ds.DSTime.Format("2006-01-02T15:04:05Z")),
|
||||
LaunchLatitude: api.NewOptFloat64(params.LaunchLatitude),
|
||||
LaunchLongitude: api.NewOptFloat64(params.LaunchLongitude),
|
||||
LaunchDatetime: api.NewOptString(params.LaunchDatetime.Format(time.RFC3339)),
|
||||
LaunchAltitude: params.LaunchAltitude,
|
||||
})
|
||||
|
||||
// Warnings
|
||||
warnMap := warnings.ToMap()
|
||||
if len(warnMap) > 0 {
|
||||
resp.Warnings = api.NewOptPredictionResponseWarnings(api.PredictionResponseWarnings{})
|
||||
}
|
||||
|
||||
h.log.Info("prediction complete",
|
||||
zap.String("profile", profile),
|
||||
zap.Int("stages", len(results)),
|
||||
zap.Duration("elapsed", completeTime.Sub(startTime)))
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ReadinessCheck implements the health check endpoint.
|
||||
func (h *Handler) ReadinessCheck(ctx context.Context) (*api.ReadinessResponse, error) {
|
||||
resp := &api.ReadinessResponse{}
|
||||
|
||||
if h.svc.Ready() {
|
||||
resp.Status = api.ReadinessResponseStatusOk
|
||||
if dsTime, ok := h.svc.DatasetTime(); ok {
|
||||
resp.DatasetTime = api.NewOptDateTime(dsTime)
|
||||
}
|
||||
} else {
|
||||
resp.Status = api.ReadinessResponseStatusNotReady
|
||||
resp.ErrorMessage = api.NewOptString("no dataset loaded")
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// NewError creates an ErrorStatusCode from an error returned by a handler.
|
||||
func (h *Handler) NewError(ctx context.Context, err error) *api.ErrorStatusCode {
|
||||
if statusErr, ok := err.(*api.ErrorStatusCode); ok {
|
||||
return statusErr
|
||||
}
|
||||
|
||||
h.log.Error("unhandled error", zap.Error(err))
|
||||
return newError(http.StatusInternalServerError, err.Error())
|
||||
}
|
||||
|
||||
func newError(status int, description string) *api.ErrorStatusCode {
|
||||
return &api.ErrorStatusCode{
|
||||
StatusCode: status,
|
||||
Response: api.Error{
|
||||
Error: api.ErrorError{
|
||||
Type: http.StatusText(status),
|
||||
Description: description,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
75
internal/transport/rest/transport.go
Normal file
75
internal/transport/rest/transport.go
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
package rest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"predictor-refactored/internal/transport/middleware"
|
||||
"predictor-refactored/internal/transport/rest/handler"
|
||||
api "predictor-refactored/pkg/rest"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// Transport wraps the ogen HTTP server.
|
||||
type Transport struct {
|
||||
srv *api.Server
|
||||
handler *handler.Handler
|
||||
port int
|
||||
log *zap.Logger
|
||||
}
|
||||
|
||||
// New creates a new REST transport.
|
||||
func New(h *handler.Handler, port int, log *zap.Logger) (*Transport, error) {
|
||||
srv, err := api.NewServer(
|
||||
h,
|
||||
api.WithMiddleware(middleware.Logging(log)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create ogen server: %w", err)
|
||||
}
|
||||
|
||||
return &Transport{
|
||||
srv: srv,
|
||||
handler: h,
|
||||
port: port,
|
||||
log: log,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Run starts the HTTP server. Blocks until the server stops.
|
||||
func (t *Transport) Run() error {
|
||||
mux := http.NewServeMux()
|
||||
mux.Handle("/", t.srv)
|
||||
|
||||
httpSrv := &http.Server{
|
||||
Addr: fmt.Sprintf(":%d", t.port),
|
||||
Handler: corsMiddleware(mux),
|
||||
}
|
||||
|
||||
t.log.Info("starting HTTP server", zap.Int("port", t.port))
|
||||
return httpSrv.ListenAndServe()
|
||||
}
|
||||
|
||||
// Shutdown gracefully stops the HTTP server.
|
||||
func (t *Transport) Shutdown(ctx context.Context) error {
|
||||
// The ogen server doesn't have a shutdown method;
|
||||
// shutdown is handled by the http.Server in main.go
|
||||
return nil
|
||||
}
|
||||
|
||||
func corsMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
|
||||
|
||||
if r.Method == http.MethodOptions {
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue