feat: refactor

This commit is contained in:
Anatoly Antonov 2026-03-28 03:07:13 +09:00
parent 82ef1cb3b8
commit 51bbf3c579
44 changed files with 8589 additions and 0 deletions

5
.gitignore vendored Normal file
View file

@ -0,0 +1,5 @@
predictor
*.exe
*.test
*.out
/tmp/

40
Makefile Normal file
View file

@ -0,0 +1,40 @@
.PHONY: build run test fmt lint clean generate-ogen help
# Build the application
build:
go build -o predictor ./cmd/api
# Run locally
run:
go run ./cmd/api
# Run tests
test:
go test ./...
# Format code
fmt:
go fmt ./...
# Lint code
lint:
golangci-lint run
# Generate ogen API code from swagger spec
generate-ogen:
go run github.com/ogen-go/ogen/cmd/ogen@latest --target pkg/rest --package rest --clean api/rest/predictor.swagger.yml
# Clean build artifacts
clean:
rm -f predictor
# Show help
help:
@echo "Available commands:"
@echo " build - Build binary"
@echo " run - Run locally"
@echo " test - Run tests"
@echo " fmt - Format code"
@echo " lint - Lint code"
@echo " generate-ogen - Generate API code from swagger spec"
@echo " clean - Remove build artifacts"

261
README.md Normal file
View file

@ -0,0 +1,261 @@
# Balloon Trajectory Predictor
High-altitude balloon trajectory prediction service. Predicts ascent, burst, and descent trajectories using GFS wind forecast data from NOAA.
The prediction algorithms are an exact port of [Tawhiri](https://github.com/cuspaceflight/tawhiri) (Cambridge University Spaceflight) to Go, verified to produce identical results.
## Quick Start
```bash
# Build
make build
# Run (downloads ~9 GB of GFS data on first start, takes 30-60 min)
PREDICTOR_DATA_DIR=/tmp/predictor-data go run ./cmd/api
# Check readiness
curl http://localhost:8080/ready
# Run a prediction
curl 'http://localhost:8080/api/v1/prediction?launch_latitude=52.2&launch_longitude=0.1&launch_datetime=2026-03-28T12:00:00Z&launch_altitude=0&ascent_rate=5&burst_altitude=30000&descent_rate=5'
```
## Configuration
All configuration is via environment variables.
| Variable | Default | Description |
|---|---|---|
| `PREDICTOR_PORT` | `8080` | HTTP server port |
| `PREDICTOR_DATA_DIR` | `/tmp/predictor-data` | Directory for wind datasets and temp files |
| `PREDICTOR_DOWNLOAD_PARALLEL` | `8` | Max concurrent GRIB download goroutines |
| `PREDICTOR_UPDATE_INTERVAL` | `6h` | How often to check for new forecasts |
| `PREDICTOR_DATASET_TTL` | `48h` | Max age before a dataset is considered stale |
| `PREDICTOR_ELEVATION_DATASET` | `/srv/ruaumoko-dataset` | Path to elevation dataset (optional) |
## API
### `GET /api/v1/prediction`
Run a balloon trajectory prediction.
**Parameters** (query string):
| Parameter | Required | Description |
|---|---|---|
| `launch_latitude` | yes | Launch latitude in degrees (-90 to 90) |
| `launch_longitude` | yes | Launch longitude in degrees (-180 to 180 or 0 to 360) |
| `launch_datetime` | yes | Launch time in RFC 3339 format |
| `launch_altitude` | no | Launch altitude in metres ASL (default: 0) |
| `profile` | no | `standard_profile` (default) or `float_profile` |
| `ascent_rate` | no | Ascent rate in m/s (default: 5) |
| `burst_altitude` | no | Burst altitude in metres (default: 28000) |
| `descent_rate` | no | Sea-level descent rate in m/s (default: 5) |
| `float_altitude` | no | Float altitude in metres (float_profile only) |
| `stop_datetime` | no | Float end time (float_profile only, default: +24h) |
**Response** (Tawhiri-compatible):
```json
{
"prediction": [
{
"stage": "ascent",
"trajectory": [
{"datetime": "2026-03-28T12:00:00Z", "latitude": 52.2, "longitude": 0.1, "altitude": 0},
...
]
},
{
"stage": "descent",
"trajectory": [...]
}
],
"metadata": {
"start_datetime": "...",
"complete_datetime": "..."
},
"request": {
"dataset": "2026-03-28T06:00:00Z",
"launch_latitude": 52.2,
...
}
}
```
### `GET /ready`
Health check. Returns `{"status": "ok"}` when a dataset is loaded.
## Elevation Dataset
Without elevation data, descent terminates at sea level (altitude <= 0). With elevation data, descent terminates at ground level, matching Tawhiri's behaviour.
### Building the elevation dataset
The elevation dataset uses ETOPO 2022 at 30 arc-second resolution, converted to a ruaumoko-compatible binary format (21601 x 43200 grid of int16 little-endian elevation values in metres).
**Requirements**: Python 3, xarray, netcdf4, numpy.
```bash
pip install xarray netcdf4 numpy
# Downloads ~1.1 GB from NOAA, produces ~1.74 GB binary file
python3 scripts/build_elevation.py /tmp/predictor-data/ruaumoko-dataset
```
To skip the download if you already have the ETOPO NetCDF file:
```bash
ETOPO_NC_PATH=/path/to/ETOPO_2022_v1_30s_N90W180_surface.nc \
python3 scripts/build_elevation.py /tmp/predictor-data/ruaumoko-dataset
```
The ETOPO 2022 NetCDF can be manually downloaded from:
https://www.ncei.noaa.gov/products/etopo-global-relief-model
### Using the elevation dataset
```bash
PREDICTOR_ELEVATION_DATASET=/tmp/predictor-data/ruaumoko-dataset go run ./cmd/api
```
If the file doesn't exist or can't be read, the service starts normally with a warning and falls back to sea-level termination.
## Architecture
```
cmd/api/main.go Entry point, config, scheduler, HTTP server
internal/
dataset/
dataset.go Shape constants, pressure levels, S3 URLs
file.go mmap-backed dataset file (read/write/blit)
downloader/
downloader.go S3 partial GRIB download (idx + range requests)
idx.go NOAA .idx file parser
config.go Environment-based configuration
elevation/
elevation.go Ruaumoko-compatible elevation dataset (mmap int16)
prediction/
interpolate.go 4D wind interpolation (time, lat, lon, altitude)
solver.go RK4 integrator with binary search termination
models.go Ascent, descent, wind models; flight profiles
warnings.go Prediction warning counters
service/
service.go Dataset lifecycle, concurrent-safe access
transport/
middleware/log.go Request logging middleware
rest/
handler/handler.go ogen API handler implementation
handler/deps.go Service interface
transport.go ogen HTTP server, CORS
api/rest/predictor.swagger.yml OpenAPI 3.0 spec
pkg/rest/ Generated ogen code (17 files)
scripts/
build_elevation.py ETOPO 2022 to ruaumoko converter
```
## Wind Dataset
The service downloads GFS 0.5-degree forecast data from NOAA S3:
| Property | Value |
|---|---|
| Source | `noaa-gfs-bdp-pds.s3.amazonaws.com` |
| Resolution | 0.5 degrees |
| Grid | 361 lat x 720 lon |
| Time steps | 65 (every 3 hours, 0-192h) |
| Pressure levels | 47 (1000 to 1 hPa) |
| Variables | Geopotential height, U-wind, V-wind |
| Dataset size | 9,528,667,200 bytes (~8.87 GiB) |
| Update cadence | Every 6 hours (GFS runs at 00, 06, 12, 18 UTC) |
Data is downloaded using HTTP Range requests against `.idx` index files, fetching only the needed GRIB messages (HGT, UGRD, VGRD at 47 pressure levels). Full download takes 30-60 minutes depending on bandwidth.
The dataset is stored as a memory-mapped flat binary file of float32 values in C-order with shape `(65, 47, 3, 361, 720)`.
## Prediction Algorithms
All algorithms are exact ports of the reference implementations in Tawhiri. The following sections describe the key components.
### Interpolation (`internal/prediction/interpolate.go`)
4D wind interpolation from the dataset grid to arbitrary coordinates.
1. **Trilinear weights** (`pick3`): compute 8 interpolation weights for the (hour, lat, lon) cube corners.
2. **Altitude search** (`search`): binary search on interpolated geopotential height to find the two pressure levels bracketing the target altitude.
3. **Wind extraction** (`interp4`): 8-point weighted sum at each bracket level, then linear interpolation between levels.
Reference: `tawhiri/interpolate.pyx`
### Solver (`internal/prediction/solver.go`)
4th-order Runge-Kutta integrator with dt = 60 seconds.
- State vector: (latitude, longitude, altitude) in degrees and metres.
- Time: UNIX timestamp in seconds.
- Longitude is kept in [0, 360) via Python-style modulo after each `vecadd`.
- When a terminator fires, binary search refinement (tolerance 0.01) finds the precise termination point between the last good step and the first terminated step.
- Longitude interpolation (`lngLerp`) handles the 0/360 wrap-around.
Reference: `tawhiri/solver.pyx`
### Models (`internal/prediction/models.go`)
- **Constant ascent**: vertical velocity = ascent_rate m/s.
- **Drag descent**: NASA atmosphere density model with drag coefficient = sea_level_rate * 1.1045. Descent rate increases with altitude due to thinner air.
- **Wind velocity**: u, v components from interpolation converted to degrees/second: `dlat = (180/pi) * v / (R)`, `dlng = (180/pi) * u / (R * cos(lat))` where R = 6371009 + altitude.
- **Linear model**: sum of component models (e.g., wind + ascent).
- **Elevation termination**: `ground_elevation > altitude` using ruaumoko dataset.
Reference: `tawhiri/models.py`
### Profiles
- **standard_profile**: ascent (constant rate + wind) until burst altitude, then descent (drag + wind) until ground level.
- **float_profile**: ascent to float altitude, then drift at constant altitude until stop time.
## Verification
The predictor has been verified against the reference Tawhiri implementation:
| Test | Result |
|---|---|
| Dataset (step 0): 36.6M float32 values vs Python/cfgrib | 0 mismatches, max diff = 0.0 |
| Prediction burst point vs public Tawhiri API | Identical (lat, lon, alt all match) |
| Prediction landing point vs public Tawhiri API | Identical lat/lon, 5m altitude diff (different elevation datasets) |
| Descent point count | Identical (46 points) |
| Ascent point count | Identical (101 points) |
## Development
```bash
# Regenerate ogen API code after modifying the swagger spec
make generate-ogen
# Run tests
make test
# Format
make fmt
```
### Comparison tools
```bash
# Compare single dataset step against Python/cfgrib reference
go run ./cmd/compare_step0 <run_YYYYMMDDHH> <output_path>
# Run prediction and compare against public Tawhiri API
go run ./cmd/compare_prediction
```
## References
- [Tawhiri](https://github.com/cuspaceflight/tawhiri) — Reference Python/Cython predictor (Cambridge University Spaceflight)
- [tawhiri-downloader](https://github.com/cuspaceflight/tawhiri-downloader) — OCaml dataset downloader
- [ruaumoko](https://github.com/cuspaceflight/ruaumoko) — Global elevation dataset
- [NOAA GFS](https://www.ncei.noaa.gov/products/weather-climate-models/global-forecast) — Global Forecast System
- [NOAA GFS on S3](https://noaa-gfs-bdp-pds.s3.amazonaws.com/index.html) — Public S3 bucket
- [ETOPO 2022](https://www.ncei.noaa.gov/products/etopo-global-relief-model) — Global relief model for elevation data
- [SondeHub Tawhiri API](https://api.v2.sondehub.org/tawhiri) — Public Tawhiri instance for comparison

View file

@ -0,0 +1,199 @@
openapi: 3.0.4
info:
title: Predictor API
version: 0.0.1
paths:
/api/v1/prediction:
get:
tags:
- Prediction
summary: Perform prediction
operationId: performPrediction
parameters:
- in: query
name: launch_latitude
required: true
schema:
type: number
- in: query
name: launch_longitude
required: true
schema:
type: number
- in: query
name: launch_datetime
required: true
schema:
type: string
format: date-time
- in: query
name: launch_altitude
schema:
type: number
- in: query
name: profile
schema:
type: string
enum: [standard_profile, float_profile]
default: standard_profile
- in: query
name: ascent_rate
schema:
type: number
- in: query
name: burst_altitude
schema:
type: number
- in: query
name: descent_rate
schema:
type: number
- in: query
name: float_altitude
schema:
type: number
- in: query
name: stop_datetime
schema:
type: string
format: date-time
- in: query
name: dataset
schema:
type: string
format: date-time
responses:
"200":
description: Prediction response
content:
application/json:
schema:
$ref: '#/components/schemas/PredictionResponse'
default:
description: Error
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
/ready:
get:
tags:
- Health
summary: Readiness check
operationId: readinessCheck
responses:
"200":
description: Readiness status
content:
application/json:
schema:
$ref: '#/components/schemas/ReadinessResponse'
default:
description: Error
content:
application/json:
schema:
$ref: '#/components/schemas/Error'
components:
schemas:
Error:
type: object
required:
- error
properties:
error:
type: object
required:
- type
- description
properties:
type:
type: string
description:
type: string
PredictionResponse:
type: object
required:
- prediction
- metadata
properties:
request:
type: object
properties:
dataset:
type: string
launch_latitude:
type: number
launch_longitude:
type: number
launch_datetime:
type: string
launch_altitude:
type: number
profile:
type: string
ascent_rate:
type: number
burst_altitude:
type: number
descent_rate:
type: number
prediction:
type: array
items:
type: object
required:
- stage
- trajectory
properties:
stage:
type: string
enum: ["ascent", "descent", "float"]
trajectory:
type: array
items:
type: object
required:
- datetime
- latitude
- longitude
- altitude
properties:
datetime:
type: string
format: date-time
latitude:
type: number
longitude:
type: number
altitude:
type: number
metadata:
type: object
required:
- start_datetime
- complete_datetime
properties:
start_datetime:
type: string
format: date-time
complete_datetime:
type: string
format: date-time
warnings:
type: object
additionalProperties: true
ReadinessResponse:
type: object
required:
- status
properties:
status:
type: string
enum: [ok, not_ready, error]
dataset_time:
type: string
format: date-time
error_message:
type: string

98
cmd/api/main.go Normal file
View file

@ -0,0 +1,98 @@
package main
import (
"context"
"fmt"
"os"
"os/signal"
"syscall"
"time"
"predictor-refactored/internal/downloader"
"predictor-refactored/internal/service"
"predictor-refactored/internal/transport/rest"
"predictor-refactored/internal/transport/rest/handler"
"github.com/go-co-op/gocron"
"go.uber.org/zap"
)
func main() {
log, err := zap.NewProduction()
if err != nil {
panic(err)
}
defer log.Sync()
cfg := downloader.LoadConfig()
log.Info("configuration loaded",
zap.String("data_dir", cfg.DataDir),
zap.Int("parallel", cfg.Parallel),
zap.Duration("update_interval", cfg.UpdateInterval),
zap.Duration("dataset_ttl", cfg.DatasetTTL))
if err := os.MkdirAll(cfg.DataDir, 0o755); err != nil {
log.Fatal("failed to create data directory", zap.Error(err))
}
svc := service.New(cfg, log)
defer svc.Close()
// Load elevation dataset (optional — falls back to sea-level termination)
elevPath := "/srv/ruaumoko-dataset"
if v := os.Getenv("PREDICTOR_ELEVATION_DATASET"); v != "" {
elevPath = v
}
svc.LoadElevation(elevPath)
// Initial dataset load (async so the server starts immediately)
go func() {
log.Info("performing initial dataset update...")
if err := svc.Update(context.Background()); err != nil {
log.Error("initial dataset update failed", zap.Error(err))
} else {
log.Info("initial dataset update complete")
}
}()
// Scheduler for periodic dataset updates
scheduler := gocron.NewScheduler(time.UTC)
scheduler.Every(cfg.UpdateInterval).Do(func() {
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute)
defer cancel()
log.Info("scheduled dataset update starting")
if err := svc.Update(ctx); err != nil {
log.Error("scheduled dataset update failed", zap.Error(err))
} else {
log.Info("scheduled dataset update complete")
}
})
scheduler.StartAsync()
defer scheduler.Stop()
// HTTP transport (ogen)
port := 8080
if p := os.Getenv("PREDICTOR_PORT"); p != "" {
fmt.Sscanf(p, "%d", &port)
}
h := handler.New(svc, log)
transport, err := rest.New(h, port, log)
if err != nil {
log.Fatal("failed to create transport", zap.Error(err))
}
go func() {
if err := transport.Run(); err != nil {
log.Fatal("HTTP server error", zap.Error(err))
}
}()
log.Info("service started")
// Graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
sig := <-sigChan
log.Info("received shutdown signal", zap.String("signal", sig.String()))
}

View file

@ -0,0 +1,195 @@
package main
import (
"context"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"os"
"time"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/downloader"
"predictor-refactored/internal/prediction"
"go.uber.org/zap"
)
// Downloads a few forecast steps and runs a prediction, then compares
// against the public Tawhiri API.
func main() {
log, _ := zap.NewDevelopment()
cfg := &downloader.Config{
DataDir: os.TempDir(),
Parallel: 4,
}
dl := downloader.NewDownloader(cfg, log)
ctx := context.Background()
// Find latest run
run, err := dl.FindLatestRun(ctx)
if err != nil {
fmt.Fprintf(os.Stderr, "FindLatestRun: %v\n", err)
os.Exit(1)
}
fmt.Printf("Using run: %s\n", run.Format("2006010215"))
// Create dataset and download first 10 steps (0-27 hours, enough for a prediction)
dsPath := fmt.Sprintf("/tmp/pred_test_%s.bin", run.Format("2006010215"))
defer os.Remove(dsPath)
ds, err := dataset.Create(dsPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Create: %v\n", err)
os.Exit(1)
}
date := run.Format("20060102")
runHour := run.Hour()
stepsToDownload := []int{0, 3, 6, 9, 12, 15, 18, 21, 24, 27}
fmt.Printf("Downloading %d steps...\n", len(stepsToDownload))
for _, step := range stepsToDownload {
hourIdx := dataset.HourIndex(step)
fmt.Printf(" step %d (hour idx %d)...\n", step, hourIdx)
urlA := dataset.GribURL(date, runHour, step)
if err := dl.DownloadAndBlit(ctx, ds, urlA, hourIdx, dataset.LevelSetA); err != nil {
fmt.Fprintf(os.Stderr, " pgrb2 step %d: %v\n", step, err)
os.Exit(1)
}
urlB := dataset.GribURLB(date, runHour, step)
if err := dl.DownloadAndBlit(ctx, ds, urlB, hourIdx, dataset.LevelSetB); err != nil {
fmt.Fprintf(os.Stderr, " pgrb2b step %d: %v\n", step, err)
os.Exit(1)
}
}
ds.Flush()
fmt.Println("Download complete")
// Set dataset time
ds.DSTime = run
// Run our prediction
launchLat := 52.2135
launchLon := 0.0964 // already in [0, 360)
launchAlt := 0.0
ascentRate := 5.0
burstAlt := 30000.0
descentRate := 5.0
// Launch 3 hours into the forecast
launchTime := run.Add(3 * time.Hour)
launchTimestamp := float64(launchTime.Unix())
dsEpoch := float64(run.Unix())
warnings := &prediction.Warnings{}
stages := prediction.StandardProfile(ascentRate, burstAlt, descentRate, ds, dsEpoch, warnings, nil)
results := prediction.RunPrediction(launchTimestamp, launchLat, launchLon, launchAlt, stages)
fmt.Printf("\n=== Our prediction ===\n")
for i, sr := range results {
stage := "ascent"
if i == 1 {
stage = "descent"
}
first := sr.Points[0]
last := sr.Points[len(sr.Points)-1]
fmt.Printf(" %s: %d points, start=(%.4f, %.4f, %.0fm) end=(%.4f, %.4f, %.0fm)\n",
stage, len(sr.Points),
first.Lat, first.Lng, first.Alt,
last.Lat, last.Lng, last.Alt)
}
// Get landing point
var ourLandLat, ourLandLon float64
if len(results) >= 2 {
last := results[1].Points[len(results[1].Points)-1]
ourLandLat = last.Lat
ourLandLon = last.Lng
if ourLandLon > 180 {
ourLandLon -= 360
}
}
fmt.Printf(" Landing: lat=%.4f, lon=%.4f\n", ourLandLat, ourLandLon)
// Compare against public Tawhiri API
fmt.Printf("\n=== Tawhiri API comparison ===\n")
tawhiriLandLat, tawhiriLandLon, err := queryTawhiri(launchLat, launchLon, launchAlt, launchTime, ascentRate, burstAlt, descentRate)
if err != nil {
fmt.Printf(" Tawhiri API error: %v\n", err)
fmt.Println(" (Cannot compare — Tawhiri may use a different dataset)")
ds.Close()
return
}
fmt.Printf(" Tawhiri landing: lat=%.4f, lon=%.4f\n", tawhiriLandLat, tawhiriLandLon)
dist := haversine(ourLandLat, ourLandLon, tawhiriLandLat, tawhiriLandLon)
fmt.Printf(" Distance between landing points: %.2f km\n", dist/1000)
if dist < 1000 {
fmt.Println(" CLOSE MATCH (< 1 km)")
} else if dist < 50000 {
fmt.Printf(" MODERATE DIFFERENCE (%.1f km) — likely different datasets\n", dist/1000)
} else {
fmt.Printf(" LARGE DIFFERENCE (%.1f km) — possible bug\n", dist/1000)
}
ds.Close()
}
func queryTawhiri(lat, lon, alt float64, launchTime time.Time, ascentRate, burstAlt, descentRate float64) (landLat, landLon float64, err error) {
url := fmt.Sprintf(
"https://api.v2.sondehub.org/tawhiri?launch_latitude=%.4f&launch_longitude=%.4f&launch_altitude=%.0f&launch_datetime=%s&ascent_rate=%.1f&burst_altitude=%.0f&descent_rate=%.1f",
lat, lon, alt, launchTime.Format(time.RFC3339), ascentRate, burstAlt, descentRate)
resp, err := http.Get(url)
if err != nil {
return 0, 0, err
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
return 0, 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body))
}
var result struct {
Prediction []struct {
Stage string `json:"stage"`
Trajectory []struct {
Latitude float64 `json:"latitude"`
Longitude float64 `json:"longitude"`
Altitude float64 `json:"altitude"`
} `json:"trajectory"`
} `json:"prediction"`
}
if err := json.Unmarshal(body, &result); err != nil {
return 0, 0, err
}
for _, stage := range result.Prediction {
if stage.Stage == "descent" && len(stage.Trajectory) > 0 {
last := stage.Trajectory[len(stage.Trajectory)-1]
return last.Latitude, last.Longitude, nil
}
}
return 0, 0, fmt.Errorf("no descent stage found")
}
func haversine(lat1, lon1, lat2, lon2 float64) float64 {
const R = 6371000.0
phi1 := lat1 * math.Pi / 180
phi2 := lat2 * math.Pi / 180
dphi := (lat2 - lat1) * math.Pi / 180
dlam := (lon2 - lon1) * math.Pi / 180
a := math.Sin(dphi/2)*math.Sin(dphi/2) + math.Cos(phi1)*math.Cos(phi2)*math.Sin(dlam/2)*math.Sin(dlam/2)
return R * 2 * math.Atan2(math.Sqrt(a), math.Sqrt(1-a))
}

104
cmd/compare_step0/main.go Normal file
View file

@ -0,0 +1,104 @@
package main
import (
"context"
"fmt"
"os"
"time"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/downloader"
"go.uber.org/zap"
)
// Downloads step 0 of a given run and writes a minimal dataset for comparison.
// Usage: go run ./cmd/compare_step0 <run_YYYYMMDDHH> <output_path>
func main() {
if len(os.Args) < 3 {
fmt.Fprintf(os.Stderr, "Usage: %s <run_YYYYMMDDHH> <output_path>\n", os.Args[0])
os.Exit(1)
}
runStr := os.Args[1]
outPath := os.Args[2]
run, err := time.Parse("2006010215", runStr)
if err != nil {
fmt.Fprintf(os.Stderr, "Invalid run time %q: %v\n", runStr, err)
os.Exit(1)
}
log, _ := zap.NewDevelopment()
// Create a full-size dataset (we only fill step 0)
fmt.Printf("Creating dataset at %s (%d bytes)...\n", outPath, dataset.DatasetSize)
ds, err := dataset.Create(outPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Create dataset: %v\n", err)
os.Exit(1)
}
defer ds.Close()
cfg := &downloader.Config{
DataDir: os.TempDir(),
Parallel: 4,
}
dl := downloader.NewDownloader(cfg, log)
ctx := context.Background()
date := run.Format("20060102")
runHour := run.Hour()
// Download and blit step 0 from pgrb2
fmt.Println("Downloading pgrb2 step 0...")
urlA := dataset.GribURL(date, runHour, 0)
if err := dl.DownloadAndBlit(ctx, ds, urlA, 0, dataset.LevelSetA); err != nil {
fmt.Fprintf(os.Stderr, "pgrb2: %v\n", err)
os.Exit(1)
}
fmt.Println(" done")
// Download and blit step 0 from pgrb2b
fmt.Println("Downloading pgrb2b step 0...")
urlB := dataset.GribURLB(date, runHour, 0)
if err := dl.DownloadAndBlit(ctx, ds, urlB, 0, dataset.LevelSetB); err != nil {
fmt.Fprintf(os.Stderr, "pgrb2b: %v\n", err)
os.Exit(1)
}
fmt.Println(" done")
if err := ds.Flush(); err != nil {
fmt.Fprintf(os.Stderr, "Flush: %v\n", err)
os.Exit(1)
}
// Spot-check: print same values as the Python script
fmt.Println("\n=== Go dataset values (spot check) ===")
type testPoint struct {
varName string
varIdx int
levelIdx int
lat, lon int
}
points := []testPoint{
{"HGT", 0, 0, 0, 0}, // HGT @ 1000mb, lat=-90, lon=0
{"HGT", 0, 0, 180, 0}, // HGT @ 1000mb, lat=0, lon=0
{"HGT", 0, 0, 360, 0}, // HGT @ 1000mb, lat=+90, lon=0
{"HGT", 0, 20, 180, 360}, // HGT @ 500mb, lat=0, lon=180
{"UGRD", 1, 0, 180, 0}, // UGRD @ 1000mb, lat=0, lon=0
{"VGRD", 2, 0, 180, 0}, // VGRD @ 1000mb, lat=0, lon=0
{"UGRD", 1, 20, 284, 0}, // UGRD @ 500mb, lat=52N, lon=0
}
for _, p := range points {
val := ds.Val(0, p.levelIdx, p.varIdx, p.lat, p.lon)
actualLat := -90.0 + float64(p.lat)*0.5
actualLon := float64(p.lon) * 0.5
fmt.Printf(" %-4s %4dmb lat=%+7.1f lon=%6.1f: %12.4f\n",
p.varName, dataset.Pressures[p.levelIdx], actualLat, actualLon, val)
}
fmt.Printf("\nDataset written to %s\n", outPath)
}

41
go.mod Normal file
View file

@ -0,0 +1,41 @@
module predictor-refactored
go 1.25.0
require (
github.com/edsrzf/mmap-go v1.2.0
github.com/go-co-op/gocron v1.37.0
github.com/go-faster/errors v0.7.1
github.com/go-faster/jx v1.2.0
github.com/nilsmagnus/grib v1.2.8
github.com/ogen-go/ogen v1.20.2
go.opentelemetry.io/otel v1.42.0
go.opentelemetry.io/otel/metric v1.42.0
go.opentelemetry.io/otel/trace v1.42.0
go.uber.org/zap v1.27.1
golang.org/x/sync v0.20.0
)
require (
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/fatih/color v1.19.0 // indirect
github.com/ghodss/yaml v1.0.0 // indirect
github.com/go-faster/yaml v0.4.6 // indirect
github.com/go-logr/logr v1.4.3 // indirect
github.com/go-logr/stdr v1.2.2 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/mattn/go-colorable v0.1.14 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect
github.com/segmentio/asm v1.2.1 // indirect
github.com/shopspring/decimal v1.4.0 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.uber.org/atomic v1.9.0 // indirect
go.uber.org/multierr v1.11.0 // indirect
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 // indirect
golang.org/x/net v0.52.0 // indirect
golang.org/x/sys v0.42.0 // indirect
golang.org/x/text v0.35.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
)

108
go.sum Normal file
View file

@ -0,0 +1,108 @@
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/edsrzf/mmap-go v1.2.0 h1:hXLYlkbaPzt1SaQk+anYwKSRNhufIDCchSPkUD6dD84=
github.com/edsrzf/mmap-go v1.2.0/go.mod h1:19H/e8pUPLicwkyNgOykDXkJ9F0MHE+Z52B8EIth78Q=
github.com/fatih/color v1.19.0 h1:Zp3PiM21/9Ld6FzSKyL5c/BULoe/ONr9KlbYVOfG8+w=
github.com/fatih/color v1.19.0/go.mod h1:zNk67I0ZUT1bEGsSGyCZYZNrHuTkJJB+r6Q9VuMi0LE=
github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/go-co-op/gocron v1.37.0 h1:ZYDJGtQ4OMhTLKOKMIch+/CY70Brbb1dGdooLEhh7b0=
github.com/go-co-op/gocron v1.37.0/go.mod h1:3L/n6BkO7ABj+TrfSVXLRzsP26zmikL4ISkLQ0O8iNY=
github.com/go-faster/errors v0.7.1 h1:MkJTnDoEdi9pDabt1dpWf7AA8/BaSYZqibYyhZ20AYg=
github.com/go-faster/errors v0.7.1/go.mod h1:5ySTjWFiphBs07IKuiL69nxdfd5+fzh1u7FPGZP2quo=
github.com/go-faster/jx v1.2.0 h1:T2YHJPrFaYu21fJtUxC9GzmluKu8rVIFDwwGBKTDseI=
github.com/go-faster/jx v1.2.0/go.mod h1:UWLOVDmMG597a5tBFPLIWJdUxz5/2emOpfsj9Neg0PE=
github.com/go-faster/yaml v0.4.6 h1:lOK/EhI04gCpPgPhgt0bChS6bvw7G3WwI8xxVe0sw9I=
github.com/go-faster/yaml v0.4.6/go.mod h1:390dRIvV4zbnO7qC9FGo6YYutc+wyyUSHBgbXL52eXk=
github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE=
github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/nilsmagnus/grib v1.2.8 h1:H7ch/1/agaCqM3MC8hW1Ft+EJ+q2XB757uml/IfPvp4=
github.com/nilsmagnus/grib v1.2.8/go.mod h1:XHm+5zuoOk0NSIWaGmA3JaAxI4i50YvD1L1vz+aqPOQ=
github.com/ogen-go/ogen v1.20.2 h1:mEZGPST7ZeX84AkqRlFawDLwcwuzcLO5PtYpAXLT1YE=
github.com/ogen-go/ogen v1.20.2/go.mod h1:sJ1pJVp4S1RcSZlYIiMLo0QSMSt2pls4zfrc+hNKnzk=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs=
github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/segmentio/asm v1.2.1 h1:DTNbBqs57ioxAD4PrArqftgypG4/qNpXoJx8TVXxPR0=
github.com/segmentio/asm v1.2.1/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs=
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho=
go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc=
go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4=
go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI=
go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY=
go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc=
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto=
go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE=
go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0=
go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y=
go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc=
go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY=
golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc=
golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0=
golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw=
golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4=
golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo=
golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw=
golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8=
golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

158
internal/dataset/dataset.go Normal file
View file

@ -0,0 +1,158 @@
package dataset
import "fmt"
// Dataset shape constants.
// Shape: (65, 47, 3, 361, 720) = (hour, pressure_level, variable, latitude, longitude)
// This matches the reference predictor exactly.
const (
NumHours = 65 // 0, 3, 6, ..., 192
NumLevels = 47 // pressure levels
NumVariables = 3 // height, wind_u, wind_v
NumLatitudes = 361 // -90.0 to +90.0 in 0.5 degree steps
NumLongitudes = 720 // 0.0 to 359.5 in 0.5 degree steps
HourStep = 3 // hours between forecast time steps
MaxHour = 192 // maximum forecast hour
Resolution = 0.5 // grid resolution in degrees
LatStart = -90.0 // first latitude in the dataset
LonStart = 0.0 // first longitude in the dataset
// Variable indices within the dataset.
VarHeight = 0
VarWindU = 1
VarWindV = 2
ElementSize = 4 // float32 = 4 bytes
)
// DatasetSize is the total size of the dataset file in bytes.
// 65 * 47 * 3 * 361 * 720 * 4 = 9,528,667,200
const DatasetSize int64 = int64(NumHours) * int64(NumLevels) * int64(NumVariables) *
int64(NumLatitudes) * int64(NumLongitudes) * int64(ElementSize)
// LevelSet identifies which GRIB file set a pressure level belongs to.
type LevelSet int
const (
LevelSetA LevelSet = iota // pgrb2 (primary)
LevelSetB // pgrb2b (secondary)
)
// Pressures contains the 47 pressure levels in hPa, sorted descending.
// Index 0 = 1000 hPa (near surface), Index 46 = 1 hPa (high atmosphere).
var Pressures = [NumLevels]int{
1000, 975, 950, 925, 900, 875, 850, 825, 800, 775,
750, 725, 700, 675, 650, 625, 600, 575, 550, 525,
500, 475, 450, 425, 400, 375, 350, 325, 300, 275,
250, 225, 200, 175, 150, 125, 100, 70, 50, 30,
20, 10, 7, 5, 3, 2, 1,
}
// pressureIndex maps pressure in hPa to its index in the Pressures array.
var pressureIndex map[int]int
// pressureLevelSet maps pressure in hPa to its GRIB file set.
var pressureLevelSet map[int]LevelSet
func init() {
pressureIndex = make(map[int]int, NumLevels)
for i, p := range Pressures {
pressureIndex[p] = i
}
pressureLevelSet = make(map[int]LevelSet, NumLevels)
for _, p := range PressuresPgrb2 {
pressureLevelSet[p] = LevelSetA
}
for _, p := range PressuresPgrb2b {
pressureLevelSet[p] = LevelSetB
}
}
// PressuresPgrb2 contains levels found in the primary pgrb2 file (26 levels).
var PressuresPgrb2 = []int{
10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 350, 400,
450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 925,
950, 975, 1000,
}
// PressuresPgrb2b contains levels found in the secondary pgrb2b file (21 levels).
var PressuresPgrb2b = []int{
1, 2, 3, 5, 7, 125, 175, 225, 275, 325, 375, 425,
475, 525, 575, 625, 675, 725, 775, 825, 875,
}
// PressureIndex returns the dataset index for a given pressure level in hPa.
// Returns -1 if the level is not found.
func PressureIndex(hPa int) int {
idx, ok := pressureIndex[hPa]
if !ok {
return -1
}
return idx
}
// PressureLevelSet returns which GRIB file set a pressure level belongs to.
func PressureLevelSet(hPa int) (LevelSet, bool) {
ls, ok := pressureLevelSet[hPa]
return ls, ok
}
// HourIndex returns the dataset time index for a forecast hour.
// Returns -1 if the hour is invalid (not a multiple of HourStep or out of range).
func HourIndex(hour int) int {
if hour < 0 || hour > MaxHour || hour%HourStep != 0 {
return -1
}
return hour / HourStep
}
// Hours returns all forecast hours as a slice: [0, 3, 6, ..., 192].
func Hours() []int {
out := make([]int, 0, NumHours)
for h := 0; h <= MaxHour; h += HourStep {
out = append(out, h)
}
return out
}
// S3 URL configuration for NOAA GFS data.
const S3BaseURL = "https://noaa-gfs-bdp-pds.s3.amazonaws.com"
// GribURL returns the S3 URL for a primary (pgrb2) GRIB file.
func GribURL(date string, runHour, forecastStep int) string {
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.0p50.f%03d",
S3BaseURL, date, runHour, runHour, forecastStep)
}
// GribURLB returns the S3 URL for a secondary (pgrb2b) GRIB file.
func GribURLB(date string, runHour, forecastStep int) string {
return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2b.0p50.f%03d",
S3BaseURL, date, runHour, runHour, forecastStep)
}
// GribFileName returns the local filename for a primary GRIB file.
func GribFileName(runHour, forecastStep int) string {
return fmt.Sprintf("gfs.t%02dz.pgrb2.0p50.f%03d", runHour, forecastStep)
}
// GribFileNameB returns the local filename for a secondary GRIB file.
func GribFileNameB(runHour, forecastStep int) string {
return fmt.Sprintf("gfs.t%02dz.pgrb2b.0p50.f%03d", runHour, forecastStep)
}
// VariableIndex returns the dataset variable index for a GRIB parameter.
// Returns -1 if the parameter is not recognized.
func VariableIndex(parameterCategory, parameterNumber int) int {
switch {
case parameterCategory == 3 && parameterNumber == 5:
return VarHeight // Geopotential Height
case parameterCategory == 2 && parameterNumber == 2:
return VarWindU // U-component of wind
case parameterCategory == 2 && parameterNumber == 3:
return VarWindV // V-component of wind
default:
return -1
}
}

View file

@ -0,0 +1,152 @@
package dataset
import (
"testing"
)
func TestDatasetShape(t *testing.T) {
if NumHours != 65 {
t.Errorf("NumHours = %d, want 65", NumHours)
}
if NumLevels != 47 {
t.Errorf("NumLevels = %d, want 47", NumLevels)
}
if NumVariables != 3 {
t.Errorf("NumVariables = %d, want 3", NumVariables)
}
if NumLatitudes != 361 {
t.Errorf("NumLatitudes = %d, want 361", NumLatitudes)
}
if NumLongitudes != 720 {
t.Errorf("NumLongitudes = %d, want 720", NumLongitudes)
}
}
func TestDatasetSize(t *testing.T) {
// 65 * 47 * 3 * 361 * 720 * 4 = 9,528,667,200
want := int64(9_528_667_200)
if DatasetSize != want {
t.Errorf("DatasetSize = %d, want %d", DatasetSize, want)
}
}
func TestPressureLevels(t *testing.T) {
if len(Pressures) != 47 {
t.Fatalf("len(Pressures) = %d, want 47", len(Pressures))
}
// First should be 1000 (highest pressure, near surface)
if Pressures[0] != 1000 {
t.Errorf("Pressures[0] = %d, want 1000", Pressures[0])
}
// Last should be 1 (lowest pressure, high atmosphere)
if Pressures[46] != 1 {
t.Errorf("Pressures[46] = %d, want 1", Pressures[46])
}
// Should be sorted descending
for i := 1; i < len(Pressures); i++ {
if Pressures[i] >= Pressures[i-1] {
t.Errorf("Pressures not descending at [%d]: %d >= %d", i, Pressures[i], Pressures[i-1])
}
}
// Total levels: 26 from pgrb2 + 21 from pgrb2b = 47
if len(PressuresPgrb2) != 26 {
t.Errorf("len(PressuresPgrb2) = %d, want 26", len(PressuresPgrb2))
}
if len(PressuresPgrb2b) != 21 {
t.Errorf("len(PressuresPgrb2b) = %d, want 21", len(PressuresPgrb2b))
}
}
func TestPressureIndex(t *testing.T) {
if PressureIndex(1000) != 0 {
t.Errorf("PressureIndex(1000) = %d, want 0", PressureIndex(1000))
}
if PressureIndex(1) != 46 {
t.Errorf("PressureIndex(1) = %d, want 46", PressureIndex(1))
}
if PressureIndex(500) != 20 {
t.Errorf("PressureIndex(500) = %d, want 20", PressureIndex(500))
}
if PressureIndex(9999) != -1 {
t.Errorf("PressureIndex(9999) = %d, want -1", PressureIndex(9999))
}
}
func TestPressureLevelSet(t *testing.T) {
// 1000 mb should be in pgrb2 (A)
ls, ok := PressureLevelSet(1000)
if !ok || ls != LevelSetA {
t.Errorf("PressureLevelSet(1000) = %v, %v; want A, true", ls, ok)
}
// 125 mb should be in pgrb2b (B)
ls, ok = PressureLevelSet(125)
if !ok || ls != LevelSetB {
t.Errorf("PressureLevelSet(125) = %v, %v; want B, true", ls, ok)
}
// 1, 2, 3, 5, 7 should be in pgrb2b (B)
for _, p := range []int{1, 2, 3, 5, 7} {
ls, ok := PressureLevelSet(p)
if !ok || ls != LevelSetB {
t.Errorf("PressureLevelSet(%d) = %v, %v; want B, true", p, ls, ok)
}
}
// Every pressure level should have a level set assignment
for _, p := range Pressures {
_, ok := PressureLevelSet(p)
if !ok {
t.Errorf("PressureLevelSet(%d) not found", p)
}
}
}
func TestHourIndex(t *testing.T) {
if HourIndex(0) != 0 {
t.Errorf("HourIndex(0) = %d, want 0", HourIndex(0))
}
if HourIndex(3) != 1 {
t.Errorf("HourIndex(3) = %d, want 1", HourIndex(3))
}
if HourIndex(192) != 64 {
t.Errorf("HourIndex(192) = %d, want 64", HourIndex(192))
}
if HourIndex(1) != -1 {
t.Errorf("HourIndex(1) = %d, want -1 (not multiple of 3)", HourIndex(1))
}
if HourIndex(195) != -1 {
t.Errorf("HourIndex(195) = %d, want -1 (out of range)", HourIndex(195))
}
}
func TestHours(t *testing.T) {
hours := Hours()
if len(hours) != NumHours {
t.Fatalf("len(Hours()) = %d, want %d", len(hours), NumHours)
}
if hours[0] != 0 {
t.Errorf("Hours()[0] = %d, want 0", hours[0])
}
if hours[len(hours)-1] != MaxHour {
t.Errorf("Hours()[last] = %d, want %d", hours[len(hours)-1], MaxHour)
}
}
func TestVariableIndex(t *testing.T) {
if VariableIndex(3, 5) != VarHeight {
t.Errorf("HGT: got %d, want %d", VariableIndex(3, 5), VarHeight)
}
if VariableIndex(2, 2) != VarWindU {
t.Errorf("UGRD: got %d, want %d", VariableIndex(2, 2), VarWindU)
}
if VariableIndex(2, 3) != VarWindV {
t.Errorf("VGRD: got %d, want %d", VariableIndex(2, 3), VarWindV)
}
if VariableIndex(0, 0) != -1 {
t.Errorf("unknown: got %d, want -1", VariableIndex(0, 0))
}
}

140
internal/dataset/file.go Normal file
View file

@ -0,0 +1,140 @@
package dataset
import (
"encoding/binary"
"fmt"
"math"
"os"
"time"
mmap "github.com/edsrzf/mmap-go"
)
// File represents an mmap-backed wind dataset file.
type File struct {
mm mmap.MMap
file *os.File
writable bool
DSTime time.Time // forecast run time (UTC)
}
// Open opens an existing dataset file for reading.
func Open(path string, dsTime time.Time) (*File, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open dataset: %w", err)
}
info, err := f.Stat()
if err != nil {
f.Close()
return nil, fmt.Errorf("stat dataset: %w", err)
}
if info.Size() != DatasetSize {
f.Close()
return nil, fmt.Errorf("dataset should be %d bytes (was %d)", DatasetSize, info.Size())
}
mm, err := mmap.Map(f, mmap.RDONLY, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap dataset: %w", err)
}
return &File{mm: mm, file: f, writable: false, DSTime: dsTime}, nil
}
// Create creates a new dataset file for writing.
// The file is truncated to the correct size and mmap'd read-write.
func Create(path string) (*File, error) {
f, err := os.Create(path)
if err != nil {
return nil, fmt.Errorf("create dataset: %w", err)
}
if err := f.Truncate(DatasetSize); err != nil {
f.Close()
return nil, fmt.Errorf("truncate dataset: %w", err)
}
mm, err := mmap.MapRegion(f, int(DatasetSize), mmap.RDWR, 0, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap dataset: %w", err)
}
return &File{mm: mm, file: f, writable: true}, nil
}
// offset computes the byte offset for element [hour][level][variable][lat][lon].
// Row-major C-order indexing matching the reference implementation:
// shape = (65, 47, 3, 361, 720)
func offset(hour, level, variable, lat, lon int) int64 {
idx := int64(hour)
idx = idx*int64(NumLevels) + int64(level)
idx = idx*int64(NumVariables) + int64(variable)
idx = idx*int64(NumLatitudes) + int64(lat)
idx = idx*int64(NumLongitudes) + int64(lon)
return idx * int64(ElementSize)
}
// Val reads a float32 value from the dataset at [hour][level][variable][lat][lon].
func (d *File) Val(hour, level, variable, lat, lon int) float32 {
off := offset(hour, level, variable, lat, lon)
bits := binary.LittleEndian.Uint32(d.mm[off : off+4])
return math.Float32frombits(bits)
}
// SetVal writes a float32 value to the dataset at [hour][level][variable][lat][lon].
// Only valid on writable (created) datasets.
func (d *File) SetVal(hour, level, variable, lat, lon int, val float32) {
off := offset(hour, level, variable, lat, lon)
binary.LittleEndian.PutUint32(d.mm[off:off+4], math.Float32bits(val))
}
// BlitGribData copies decoded GRIB grid data into the dataset at the given position.
// gribData is 361*720 float64 values in GRIB scan order (north-to-south, west-to-east).
// This function flips the latitude so that dataset index 0 = -90 (south) and 360 = +90 (north).
func (d *File) BlitGribData(hourIdx, levelIdx, varIdx int, gribData []float64) error {
expected := NumLatitudes * NumLongitudes
if len(gribData) != expected {
return fmt.Errorf("grib data has %d values, expected %d", len(gribData), expected)
}
for lat := 0; lat < NumLatitudes; lat++ {
for lon := 0; lon < NumLongitudes; lon++ {
// GRIB scans north-to-south: row 0 = 90N, row 360 = 90S
// Dataset stores south-to-north: index 0 = -90 (90S), index 360 = +90 (90N)
gribIdx := (360-lat)*NumLongitudes + lon
val := float32(gribData[gribIdx])
d.SetVal(hourIdx, levelIdx, varIdx, lat, lon, val)
}
}
return nil
}
// Flush flushes the mmap to disk.
func (d *File) Flush() error {
if d.mm != nil {
return d.mm.Flush()
}
return nil
}
// Close unmaps and closes the dataset file.
func (d *File) Close() error {
if d.mm != nil {
if err := d.mm.Unmap(); err != nil {
d.file.Close()
return fmt.Errorf("unmap: %w", err)
}
d.mm = nil
}
if d.file != nil {
err := d.file.Close()
d.file = nil
return err
}
return nil
}

View file

@ -0,0 +1,58 @@
package downloader
import (
"os"
"strconv"
"time"
)
// Config holds downloader configuration, loaded from environment variables.
type Config struct {
// DataDir is the directory for storing dataset files and temporary GRIB data.
DataDir string
// Parallel is the maximum number of concurrent GRIB downloads.
Parallel int
// UpdateInterval is how often the scheduler checks for new forecast data.
UpdateInterval time.Duration
// DatasetTTL is how long a dataset is considered fresh before a new one is needed.
DatasetTTL time.Duration
}
// DefaultConfig returns the default configuration.
func DefaultConfig() *Config {
return &Config{
DataDir: "/tmp/predictor-data",
Parallel: 8,
UpdateInterval: 6 * time.Hour,
DatasetTTL: 48 * time.Hour,
}
}
// LoadConfig loads configuration from environment variables, falling back to defaults.
func LoadConfig() *Config {
cfg := DefaultConfig()
if v := os.Getenv("PREDICTOR_DATA_DIR"); v != "" {
cfg.DataDir = v
}
if v := os.Getenv("PREDICTOR_DOWNLOAD_PARALLEL"); v != "" {
if n, err := strconv.Atoi(v); err == nil && n > 0 {
cfg.Parallel = n
}
}
if v := os.Getenv("PREDICTOR_UPDATE_INTERVAL"); v != "" {
if d, err := time.ParseDuration(v); err == nil {
cfg.UpdateInterval = d
}
}
if v := os.Getenv("PREDICTOR_DATASET_TTL"); v != "" {
if d, err := time.ParseDuration(v); err == nil {
cfg.DatasetTTL = d
}
}
return cfg
}

View file

@ -0,0 +1,380 @@
package downloader
import (
"context"
"fmt"
"io"
"math"
"net/http"
"os"
"path/filepath"
"time"
"predictor-refactored/internal/dataset"
"github.com/nilsmagnus/grib/griblib"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
)
// Downloader handles fetching GFS forecast data from S3 and assembling dataset files.
type Downloader struct {
cfg *Config
client *http.Client
log *zap.Logger
}
// NewDownloader creates a new Downloader.
func NewDownloader(cfg *Config, log *zap.Logger) *Downloader {
return &Downloader{
cfg: cfg,
client: &http.Client{
Timeout: 2 * time.Minute,
},
log: log,
}
}
// neededVariables is the set of GRIB variable names we need.
var neededVariables = map[string]bool{
"HGT": true,
"UGRD": true,
"VGRD": true,
}
// FindLatestRun finds the most recent available GFS model run on S3.
// It checks the last forecast step of each run to confirm availability.
func (d *Downloader) FindLatestRun(ctx context.Context) (time.Time, error) {
now := time.Now().UTC()
hour := now.Hour() - (now.Hour() % 6)
current := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC)
for i := 0; i < 8; i++ {
date := current.Format("20060102")
url := dataset.GribURL(date, current.Hour(), dataset.MaxHour) + ".idx"
req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil)
if err != nil {
current = current.Add(-6 * time.Hour)
continue
}
resp, err := d.client.Do(req)
if err == nil {
resp.Body.Close()
if resp.StatusCode == http.StatusOK {
d.log.Info("found latest model run",
zap.Time("run", current),
zap.String("verified_url", url))
return current, nil
}
}
current = current.Add(-6 * time.Hour)
}
return time.Time{}, fmt.Errorf("no recent GFS forecast found (checked 8 runs)")
}
// Download downloads a complete forecast and assembles a dataset file.
// Returns the path to the completed dataset file.
func (d *Downloader) Download(ctx context.Context, run time.Time) (string, error) {
date := run.Format("20060102")
runHour := run.Hour()
finalPath := filepath.Join(d.cfg.DataDir, run.Format("2006010215"))
tempPath := finalPath + ".downloading"
// Check if final dataset already exists
if info, err := os.Stat(finalPath); err == nil && info.Size() == dataset.DatasetSize {
d.log.Info("dataset already exists", zap.String("path", finalPath))
return finalPath, nil
}
d.log.Info("starting dataset download",
zap.Time("run", run),
zap.String("temp_path", tempPath))
// Create the dataset file
ds, err := dataset.Create(tempPath)
if err != nil {
return "", fmt.Errorf("create dataset: %w", err)
}
defer ds.Close()
steps := dataset.Hours()
totalSteps := len(steps) * 2 // pgrb2 + pgrb2b per step
completed := 0
// Process each forecast step with bounded concurrency
g, ctx := errgroup.WithContext(ctx)
sem := make(chan struct{}, d.cfg.Parallel)
for _, step := range steps {
step := step
hourIdx := dataset.HourIndex(step)
if hourIdx < 0 {
continue
}
// Download pgrb2 (level set A)
sem <- struct{}{}
g.Go(func() error {
defer func() { <-sem }()
url := dataset.GribURL(date, runHour, step)
err := d.DownloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetA)
if err != nil {
return fmt.Errorf("step %d pgrb2: %w", step, err)
}
completed++
d.log.Debug("step complete",
zap.Int("step", step),
zap.String("set", "pgrb2"),
zap.Int("progress", completed),
zap.Int("total", totalSteps))
return nil
})
// Download pgrb2b (level set B)
sem <- struct{}{}
g.Go(func() error {
defer func() { <-sem }()
url := dataset.GribURLB(date, runHour, step)
err := d.DownloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetB)
if err != nil {
return fmt.Errorf("step %d pgrb2b: %w", step, err)
}
completed++
d.log.Debug("step complete",
zap.Int("step", step),
zap.String("set", "pgrb2b"),
zap.Int("progress", completed),
zap.Int("total", totalSteps))
return nil
})
}
if err := g.Wait(); err != nil {
os.Remove(tempPath)
return "", err
}
// Flush to disk
if err := ds.Flush(); err != nil {
os.Remove(tempPath)
return "", fmt.Errorf("flush dataset: %w", err)
}
// Close before rename
ds.Close()
// Atomic rename
if err := os.Rename(tempPath, finalPath); err != nil {
os.Remove(tempPath)
return "", fmt.Errorf("rename dataset: %w", err)
}
d.log.Info("dataset download complete", zap.String("path", finalPath))
return finalPath, nil
}
// DownloadAndBlit downloads needed GRIB fields from a URL and writes them into the dataset.
func (d *Downloader) DownloadAndBlit(ctx context.Context, ds *dataset.File, baseURL string, hourIdx int, levelSet dataset.LevelSet) error {
// 1. Download .idx
idxURL := baseURL + ".idx"
idxBody, err := d.httpGet(ctx, idxURL)
if err != nil {
return fmt.Errorf("download idx: %w", err)
}
// 2. Parse and filter
entries := ParseIdx(idxBody)
filtered := FilterIdx(entries, neededVariables)
// Further filter to only levels in this level set
var relevant []IdxEntry
for _, e := range filtered {
ls, ok := dataset.PressureLevelSet(e.LevelMB)
if ok && ls == levelSet {
relevant = append(relevant, e)
}
}
if len(relevant) == 0 {
d.log.Warn("no relevant entries found in idx",
zap.String("url", idxURL),
zap.Int("total_entries", len(entries)),
zap.Int("filtered", len(filtered)))
return nil
}
// 3. Download byte ranges and write to temp file
ranges := EntriesToRanges(relevant)
tmpFile, err := d.downloadRangesToTempFile(ctx, baseURL, ranges)
if err != nil {
return fmt.Errorf("download ranges: %w", err)
}
defer os.Remove(tmpFile)
// 4. Read GRIB messages from temp file
f, err := os.Open(tmpFile)
if err != nil {
return fmt.Errorf("open temp grib: %w", err)
}
messages, err := griblib.ReadMessages(f)
f.Close()
if err != nil {
return fmt.Errorf("read grib messages: %w", err)
}
// 5. Decode and blit each message into the dataset
for _, msg := range messages {
if msg.Section4.ProductDefinitionTemplateNumber != 0 {
continue
}
product := msg.Section4.ProductDefinitionTemplate
varIdx := dataset.VariableIndex(int(product.ParameterCategory), int(product.ParameterNumber))
if varIdx < 0 {
continue
}
if product.FirstSurface.Type != 100 { // isobaric surface
continue
}
pressurePa := float64(product.FirstSurface.Value)
pressureMB := int(math.Round(pressurePa / 100.0))
levelIdx := dataset.PressureIndex(pressureMB)
if levelIdx < 0 {
continue
}
data := msg.Data()
if err := ds.BlitGribData(hourIdx, levelIdx, varIdx, data); err != nil {
d.log.Warn("blit failed",
zap.Int("var", varIdx),
zap.Int("level_mb", pressureMB),
zap.Error(err))
continue
}
}
return nil
}
// downloadRangesToTempFile downloads multiple byte ranges from a URL,
// concatenating them into a single temp file (valid concatenated GRIB messages).
func (d *Downloader) downloadRangesToTempFile(ctx context.Context, baseURL string, ranges []ByteRange) (string, error) {
tmpFile, err := os.CreateTemp(d.cfg.DataDir, "grib-*.tmp")
if err != nil {
return "", fmt.Errorf("create temp file: %w", err)
}
tmpPath := tmpFile.Name()
for _, r := range ranges {
data, err := d.httpGetRange(ctx, baseURL, r.Start, r.End)
if err != nil {
tmpFile.Close()
os.Remove(tmpPath)
return "", fmt.Errorf("download range %d-%d: %w", r.Start, r.End, err)
}
if _, err := tmpFile.Write(data); err != nil {
tmpFile.Close()
os.Remove(tmpPath)
return "", fmt.Errorf("write temp: %w", err)
}
}
if err := tmpFile.Close(); err != nil {
os.Remove(tmpPath)
return "", err
}
return tmpPath, nil
}
// httpGet downloads a URL and returns the body bytes.
func (d *Downloader) httpGet(ctx context.Context, url string) ([]byte, error) {
var lastErr error
for attempt := 0; attempt < 3; attempt++ {
if attempt > 0 {
select {
case <-time.After(time.Duration(attempt*2) * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := d.client.Do(req)
if err != nil {
lastErr = err
continue
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url)
continue
}
if err != nil {
lastErr = err
continue
}
return body, nil
}
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
}
// httpGetRange downloads a byte range from a URL.
func (d *Downloader) httpGetRange(ctx context.Context, url string, start, end int64) ([]byte, error) {
var lastErr error
for attempt := 0; attempt < 3; attempt++ {
if attempt > 0 {
select {
case <-time.After(time.Duration(attempt*2) * time.Second):
case <-ctx.Done():
return nil, ctx.Err()
}
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end))
resp, err := d.client.Do(req)
if err != nil {
lastErr = err
continue
}
body, err := io.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK {
lastErr = fmt.Errorf("HTTP %d for range %d-%d of %s", resp.StatusCode, start, end, url)
continue
}
if err != nil {
lastErr = err
continue
}
return body, nil
}
return nil, fmt.Errorf("after 3 attempts: %w", lastErr)
}

157
internal/downloader/idx.go Normal file
View file

@ -0,0 +1,157 @@
package downloader
import (
"fmt"
"strconv"
"strings"
)
// IdxEntry represents a single parsed line from a GRIB .idx file.
// Example line: "15:1207405:d=2024010100:HGT:1000 mb:0 hour fcst:"
type IdxEntry struct {
Index int
Offset int64
Variable string // "HGT", "UGRD", "VGRD", etc.
LevelMB int // pressure level in mb (0 if not a pressure level)
Hour int // forecast hour
EndOffset int64 // byte after this message (from next entry's offset, or -1 if last)
}
// Length returns the byte length of this GRIB message, or -1 if unknown.
func (e *IdxEntry) Length() int64 {
if e.EndOffset <= 0 {
return -1
}
return e.EndOffset - e.Offset
}
// ParseIdx parses a .idx file body and returns all entries.
// Lines that can't be parsed are silently skipped.
func ParseIdx(body []byte) []IdxEntry {
lines := strings.Split(string(body), "\n")
var entries []IdxEntry
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
parts := strings.Split(line, ":")
if len(parts) < 7 {
continue
}
idx, err := strconv.Atoi(parts[0])
if err != nil {
continue
}
offset, err := strconv.ParseInt(parts[1], 10, 64)
if err != nil {
continue
}
variable := parts[3]
levelStr := parts[4]
hourStr := parts[5]
levelMB := parseLevelMB(levelStr)
hour := parseHour(hourStr)
entries = append(entries, IdxEntry{
Index: idx,
Offset: offset,
Variable: variable,
LevelMB: levelMB,
Hour: hour,
EndOffset: -1, // filled in below
})
}
// Fill in EndOffset from the next entry's Offset.
for i := 0; i < len(entries)-1; i++ {
entries[i].EndOffset = entries[i+1].Offset
}
return entries
}
// FilterIdx returns entries matching the given variables at pressure levels.
// Only entries with a recognized pressure level (levelMB > 0) are returned.
func FilterIdx(entries []IdxEntry, variables map[string]bool) []IdxEntry {
var filtered []IdxEntry
for _, e := range entries {
if !variables[e.Variable] {
continue
}
if e.LevelMB <= 0 {
continue
}
// Must have a known length (not the last entry) or be handled specially
if e.Length() <= 0 {
continue
}
filtered = append(filtered, e)
}
return filtered
}
// parseLevelMB parses a level string like "1000 mb" and returns the pressure in mb.
// Returns 0 if not a pressure level.
func parseLevelMB(s string) int {
s = strings.TrimSpace(s)
if !strings.HasSuffix(s, " mb") {
return 0
}
numStr := strings.TrimSuffix(s, " mb")
n, err := strconv.Atoi(numStr)
if err != nil {
return 0
}
return n
}
// parseHour parses a forecast hour string like "0 hour fcst" or "anl".
// Returns -1 if it can't be parsed.
func parseHour(s string) int {
s = strings.TrimSpace(s)
if s == "anl" {
return 0
}
s = strings.TrimSuffix(s, " hour fcst")
n, err := strconv.Atoi(s)
if err != nil {
return -1
}
return n
}
// GroupByRange groups idx entries into byte ranges suitable for HTTP Range downloads.
// Each range covers one contiguous GRIB message.
type ByteRange struct {
Start int64
End int64 // inclusive
Entry IdxEntry
}
// EntriesToRanges converts filtered idx entries to byte ranges.
func EntriesToRanges(entries []IdxEntry) []ByteRange {
ranges := make([]ByteRange, 0, len(entries))
for _, e := range entries {
if e.Length() <= 0 {
continue
}
ranges = append(ranges, ByteRange{
Start: e.Offset,
End: e.EndOffset - 1, // inclusive
Entry: e,
})
}
return ranges
}
// FormatRange returns an HTTP Range header value for a byte range.
func (r ByteRange) FormatRange() string {
return fmt.Sprintf("bytes=%d-%d", r.Start, r.End)
}

View file

@ -0,0 +1,110 @@
package downloader
import (
"testing"
)
const sampleIdx = `1:0:d=2024010100:HGT:1000 mb:0 hour fcst:
2:289012:d=2024010100:HGT:975 mb:0 hour fcst:
3:541876:d=2024010100:TMP:1000 mb:0 hour fcst:
4:789012:d=2024010100:UGRD:1000 mb:0 hour fcst:
5:1045678:d=2024010100:VGRD:1000 mb:0 hour fcst:
6:1298765:d=2024010100:UGRD:975 mb:0 hour fcst:
7:1567890:d=2024010100:UGRD:2 m above ground:0 hour fcst:
8:1812345:d=2024010100:VGRD:975 mb:0 hour fcst:
9:2098765:d=2024010100:HGT:500 mb:3 hour fcst:
`
func TestParseIdx(t *testing.T) {
entries := ParseIdx([]byte(sampleIdx))
if len(entries) != 9 {
t.Fatalf("expected 9 entries, got %d", len(entries))
}
// Check first entry
e := entries[0]
if e.Index != 1 || e.Offset != 0 || e.Variable != "HGT" || e.LevelMB != 1000 || e.Hour != 0 {
t.Errorf("entry 0: got %+v", e)
}
if e.EndOffset != 289012 {
t.Errorf("entry 0 EndOffset: got %d, want 289012", e.EndOffset)
}
// Check "2 m above ground" is not a pressure level
e = entries[6] // UGRD at "2 m above ground"
if e.LevelMB != 0 {
t.Errorf("non-pressure level should have LevelMB=0, got %d", e.LevelMB)
}
// Last entry should have EndOffset = -1
last := entries[len(entries)-1]
if last.EndOffset != -1 {
t.Errorf("last entry EndOffset: got %d, want -1", last.EndOffset)
}
}
func TestFilterIdx(t *testing.T) {
entries := ParseIdx([]byte(sampleIdx))
filtered := FilterIdx(entries, neededVariables)
// Should include HGT/UGRD/VGRD at pressure levels, exclude TMP and "above ground"
// And exclude last entry (no EndOffset)
for _, e := range filtered {
if !neededVariables[e.Variable] {
t.Errorf("unexpected variable %s", e.Variable)
}
if e.LevelMB <= 0 {
t.Errorf("non-pressure level included: %+v", e)
}
if e.Length() <= 0 {
t.Errorf("entry with unknown length included: %+v", e)
}
}
// Count expected: HGT@1000, HGT@975, UGRD@1000, VGRD@1000, UGRD@975, VGRD@975 = 6
// But HGT@500 at 3hr fcst is the last entry (no EndOffset), so excluded
if len(filtered) != 6 {
t.Errorf("expected 6 filtered entries, got %d", len(filtered))
for _, e := range filtered {
t.Logf(" %s %d mb (offset %d, len %d)", e.Variable, e.LevelMB, e.Offset, e.Length())
}
}
}
func TestParseLevelMB(t *testing.T) {
tests := []struct {
input string
want int
}{
{"1000 mb", 1000},
{"975 mb", 975},
{"1 mb", 1},
{"2 m above ground", 0},
{"surface", 0},
{"tropopause", 0},
}
for _, tt := range tests {
got := parseLevelMB(tt.input)
if got != tt.want {
t.Errorf("parseLevelMB(%q) = %d, want %d", tt.input, got, tt.want)
}
}
}
func TestParseHour(t *testing.T) {
tests := []struct {
input string
want int
}{
{"0 hour fcst", 0},
{"3 hour fcst", 3},
{"192 hour fcst", 192},
{"anl", 0},
}
for _, tt := range tests {
got := parseHour(tt.input)
if got != tt.want {
t.Errorf("parseHour(%q) = %d, want %d", tt.input, got, tt.want)
}
}
}

View file

@ -0,0 +1,113 @@
package elevation
import (
"encoding/binary"
"fmt"
"math"
"os"
mmap "github.com/edsrzf/mmap-go"
)
// Dataset provides global elevation lookup, compatible with ruaumoko.
// Binary format: int16 little-endian elevation values in metres, row-major (lat, lon).
// Latitude axis: -90 to +90 (south to north), Longitude axis: 0 to 360 (wraps).
// Resolution: 120 cells per degree (30 arc-seconds).
const (
CellsPerDegree = 120
NumLats = 180*CellsPerDegree + 1 // 21601
NumLons = 360 * CellsPerDegree // 43200
DataSize = NumLats * NumLons * 2 // 1,866,326,400 bytes (~1.74 GiB)
)
// Dataset is a memory-mapped global elevation grid.
type Dataset struct {
mm mmap.MMap
file *os.File
}
// Open opens an existing elevation dataset file.
func Open(path string) (*Dataset, error) {
f, err := os.Open(path)
if err != nil {
return nil, fmt.Errorf("open elevation: %w", err)
}
info, err := f.Stat()
if err != nil {
f.Close()
return nil, fmt.Errorf("stat elevation: %w", err)
}
if info.Size() != DataSize {
f.Close()
return nil, fmt.Errorf("elevation dataset should be %d bytes (was %d)", DataSize, info.Size())
}
mm, err := mmap.Map(f, mmap.RDONLY, 0)
if err != nil {
f.Close()
return nil, fmt.Errorf("mmap elevation: %w", err)
}
return &Dataset{mm: mm, file: f}, nil
}
// getCell reads the int16 elevation at grid indices (latIdx, lngIdx).
func (d *Dataset) getCell(latIdx, lngIdx int) int16 {
// Clamp latitude
if latIdx < 0 {
latIdx = 0
}
if latIdx >= NumLats {
latIdx = NumLats - 1
}
// Wrap longitude
lngIdx = lngIdx % NumLons
if lngIdx < 0 {
lngIdx += NumLons
}
off := (latIdx*NumLons + lngIdx) * 2
return int16(binary.LittleEndian.Uint16(d.mm[off : off+2]))
}
// Get returns the interpolated elevation in metres at the given coordinates.
// lat: -90 to +90, lng: 0 to 360 (or -180 to 180, will be normalised).
func (d *Dataset) Get(lat, lng float64) float64 {
// Normalise longitude to [0, 360)
if lng < 0 {
lng += 360
}
// Convert to cell coordinates
latCell := (lat + 90.0) * CellsPerDegree
lngCell := lng * CellsPerDegree
lat0 := int(math.Floor(latCell))
lng0 := int(math.Floor(lngCell))
latFrac := latCell - float64(lat0)
lngFrac := lngCell - float64(lng0)
// Bilinear interpolation
v00 := float64(d.getCell(lat0, lng0))
v10 := float64(d.getCell(lat0+1, lng0))
v01 := float64(d.getCell(lat0, lng0+1))
v11 := float64(d.getCell(lat0+1, lng0+1))
return (1-latFrac)*((1-lngFrac)*v00+lngFrac*v01) +
latFrac*((1-lngFrac)*v10+lngFrac*v11)
}
// Close unmaps and closes the dataset.
func (d *Dataset) Close() error {
if d.mm != nil {
d.mm.Unmap()
d.mm = nil
}
if d.file != nil {
err := d.file.Close()
d.file = nil
return err
}
return nil
}

View file

@ -0,0 +1,153 @@
package prediction
import (
"fmt"
"predictor-refactored/internal/dataset"
)
// Exact port of the reference interpolation logic (interpolate.pyx).
// 4D interpolation: time, latitude, longitude, altitude (via geopotential height).
// lerp1 holds an index and interpolation weight for one axis.
type lerp1 struct {
index int
lerp float64
}
// lerp3 holds indices and a combined weight for the (hour, lat, lon) axes.
type lerp3 struct {
hour, lat, lng int
lerp float64
}
// RangeError indicates a coordinate is outside the dataset bounds.
type RangeError struct {
Variable string
Value float64
}
func (e *RangeError) Error() string {
return fmt.Sprintf("%s=%f out of range", e.Variable, e.Value)
}
// pick computes interpolation indices and weights for a single axis.
// left: axis start, step: axis spacing, n: number of points, value: query value.
// Returns two lerp1 values (lower and upper bracket).
func pick(left, step float64, n int, value float64, variableName string) ([2]lerp1, error) {
a := (value - left) / step
b := int(a) // truncation toward zero, same as Cython <long> cast
if b < 0 || b >= n-1 {
return [2]lerp1{}, &RangeError{Variable: variableName, Value: value}
}
l := a - float64(b)
return [2]lerp1{
{index: b, lerp: 1 - l},
{index: b + 1, lerp: l},
}, nil
}
// pick3 computes 8 trilinear interpolation weights for (hour, lat, lng).
func pick3(hour, lat, lng float64) ([8]lerp3, error) {
lhour, err := pick(0, 3, 65, hour, "hour")
if err != nil {
return [8]lerp3{}, err
}
llat, err := pick(-90, 0.5, 361, lat, "lat")
if err != nil {
return [8]lerp3{}, err
}
// Longitude wraps: tell pick the axis is one larger, then wrap index 720 → 0
llng, err := pick(0, 0.5, 720+1, lng, "lng")
if err != nil {
return [8]lerp3{}, err
}
if llng[1].index == 720 {
llng[1].index = 0
}
var out [8]lerp3
i := 0
for _, a := range lhour {
for _, b := range llat {
for _, c := range llng {
out[i] = lerp3{
hour: a.index,
lat: b.index,
lng: c.index,
lerp: a.lerp * b.lerp * c.lerp,
}
i++
}
}
}
return out, nil
}
// interp3 performs 8-point weighted interpolation at a given variable and pressure level.
func interp3(ds *dataset.File, lerps [8]lerp3, variable, level int) float64 {
var r float64
for i := 0; i < 8; i++ {
v := ds.Val(lerps[i].hour, level, variable, lerps[i].lat, lerps[i].lng)
r += float64(v) * lerps[i].lerp
}
return r
}
// search finds the largest pressure level index where interpolated geopotential
// height is less than the target altitude. Searches levels 0..45 (excludes topmost).
func search(ds *dataset.File, lerps [8]lerp3, target float64) int {
lower, upper := 0, 45
for lower < upper {
mid := (lower + upper + 1) / 2
test := interp3(ds, lerps, dataset.VarHeight, mid)
if target <= test {
upper = mid - 1
} else {
lower = mid
}
}
return lower
}
// interp4 performs altitude-interpolated wind lookup using two bracketing levels.
func interp4(ds *dataset.File, lerps [8]lerp3, altLerp lerp1, variable int) float64 {
lower := interp3(ds, lerps, variable, altLerp.index)
upper := interp3(ds, lerps, variable, altLerp.index+1)
return lower*altLerp.lerp + upper*(1-altLerp.lerp)
}
// GetWind returns interpolated (u, v) wind components for the given position.
// hour: fractional hours since dataset start.
// lat: latitude in degrees (-90 to +90).
// lng: longitude in degrees (0 to 360).
// alt: altitude in metres above sea level.
func GetWind(ds *dataset.File, warnings *Warnings, hour, lat, lng, alt float64) (u, v float64, err error) {
lerps, err := pick3(hour, lat, lng)
if err != nil {
return 0, 0, err
}
altidx := search(ds, lerps, alt)
lower := interp3(ds, lerps, dataset.VarHeight, altidx)
upper := interp3(ds, lerps, dataset.VarHeight, altidx+1)
var altLerp float64
if lower != upper {
altLerp = (upper - alt) / (upper - lower)
} else {
altLerp = 0.5
}
if altLerp < 0 {
warnings.AltitudeTooHigh.Add(1)
}
alt1 := lerp1{index: altidx, lerp: altLerp}
u = interp4(ds, lerps, alt1, dataset.VarWindU)
v = interp4(ds, lerps, alt1, dataset.VarWindV)
return u, v, nil
}

View file

@ -0,0 +1,188 @@
package prediction
import (
"math"
"time"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/elevation"
)
// Exact port of the reference flight models (models.py).
const (
pi180 = math.Pi / 180.0
_180pi = 180.0 / math.Pi
)
// --- Up/Down Models ---
// ConstantAscent returns a model with constant vertical velocity (m/s).
func ConstantAscent(ascentRate float64) Model {
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
return 0, 0, ascentRate
}
}
// DragDescent returns a descent-under-parachute model.
// seaLevelDescentRate is the descent rate at sea level (m/s, positive value).
// Uses the NASA atmosphere model for density at altitude.
func DragDescent(seaLevelDescentRate float64) Model {
dragCoefficient := seaLevelDescentRate * 1.1045
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
return 0, 0, -dragCoefficient / math.Sqrt(nasaDensity(alt))
}
}
// nasaDensity computes air density using the NASA atmosphere model.
// Reference: http://www.grc.nasa.gov/WWW/K-12/airplane/atmosmet.html
func nasaDensity(alt float64) float64 {
var temp, pressure float64
switch {
case alt > 25000:
temp = -131.21 + 0.00299*alt
pressure = 2.488 * math.Pow((temp+273.1)/216.6, -11.388)
case alt > 11000:
temp = -56.46
pressure = 22.65 * math.Exp(1.73-0.000157*alt)
default:
temp = 15.04 - 0.00649*alt
pressure = 101.29 * math.Pow((temp+273.1)/288.08, 5.256)
}
return pressure / (0.2869 * (temp + 273.1))
}
// --- Sideways Models ---
// WindVelocity returns a model that gives lateral movement at the wind velocity.
// ds is the wind dataset, dsEpoch is the dataset start time as UNIX timestamp.
func WindVelocity(ds *dataset.File, dsEpoch float64, warnings *Warnings) Model {
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
tHours := (t - dsEpoch) / 3600.0
u, v, err := GetWind(ds, warnings, tHours, lat, lng, alt)
if err != nil {
return 0, 0, 0
}
R := 6371009.0 + alt
dlat = _180pi * v / R
dlng = _180pi * u / (R * math.Cos(lat*pi180))
return dlat, dlng, 0
}
}
// --- Model Combinations ---
// LinearModel returns a model that sums all component models.
func LinearModel(models ...Model) Model {
return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) {
for _, m := range models {
d1, d2, d3 := m(t, lat, lng, alt)
dlat += d1
dlng += d2
dalt += d3
}
return
}
}
// --- Termination Criteria ---
// BurstTermination returns a terminator that fires when altitude >= burstAltitude.
func BurstTermination(burstAltitude float64) Terminator {
return func(t, lat, lng, alt float64) bool {
return alt >= burstAltitude
}
}
// SeaLevelTermination fires when altitude <= 0.
func SeaLevelTermination(t, lat, lng, alt float64) bool {
return alt <= 0
}
// TimeTermination returns a terminator that fires when t > maxTime.
func TimeTermination(maxTime float64) Terminator {
return func(t, lat, lng, alt float64) bool {
return t > maxTime
}
}
// ElevationTermination returns a terminator that fires when alt < ground level.
// Uses ruaumoko-compatible elevation data. Longitude is normalised internally.
func ElevationTermination(elev *elevation.Dataset) Terminator {
return func(t, lat, lng, alt float64) bool {
return elev.Get(lat, lng) > alt
}
}
// --- Pre-Defined Profiles ---
// Stage pairs a model with its termination criterion.
type Stage struct {
Model Model
Terminator Terminator
}
// StandardProfile creates the chain for a standard high-altitude balloon flight:
// ascent at constant rate → burst → descent under parachute.
// If elev is non-nil, descent terminates at ground level; otherwise at sea level.
func StandardProfile(ascentRate, burstAltitude, descentRate float64,
ds *dataset.File, dsEpoch float64, warnings *Warnings,
elev *elevation.Dataset) []Stage {
wind := WindVelocity(ds, dsEpoch, warnings)
modelUp := LinearModel(ConstantAscent(ascentRate), wind)
termUp := BurstTermination(burstAltitude)
modelDown := LinearModel(DragDescent(descentRate), wind)
var termDown Terminator
if elev != nil {
termDown = ElevationTermination(elev)
} else {
termDown = Terminator(SeaLevelTermination)
}
return []Stage{
{Model: modelUp, Terminator: termUp},
{Model: modelDown, Terminator: termDown},
}
}
// FloatProfile creates the chain for a floating balloon flight:
// ascent to float altitude → float until stop time.
func FloatProfile(ascentRate, floatAltitude float64, stopTime time.Time,
ds *dataset.File, dsEpoch float64, warnings *Warnings) []Stage {
wind := WindVelocity(ds, dsEpoch, warnings)
modelUp := LinearModel(ConstantAscent(ascentRate), wind)
termUp := BurstTermination(floatAltitude)
modelFloat := wind
termFloat := TimeTermination(float64(stopTime.Unix()))
return []Stage{
{Model: modelUp, Terminator: termUp},
{Model: modelFloat, Terminator: termFloat},
}
}
// RunPrediction runs a prediction with the given profile stages.
// launchTime is a UNIX timestamp.
func RunPrediction(launchTime float64, lat, lng, alt float64, stages []Stage) []StageResult {
chain := make([]struct {
Model Model
Terminator Terminator
}, len(stages))
for i, s := range stages {
chain[i].Model = s.Model
chain[i].Terminator = s.Terminator
}
return Solve(launchTime, lat, lng, alt, chain)
}

View file

@ -0,0 +1,180 @@
package prediction
import "math"
// Exact port of the reference RK4 solver (solver.pyx).
// Integrates balloon state using RK4 with dt=60 seconds.
// Termination uses binary search refinement (tolerance 0.01).
// Vec holds the balloon state: latitude, longitude, altitude.
type Vec struct {
Lat float64
Lng float64
Alt float64
}
// Model is a function that returns (dlat/dt, dlng/dt, dalt/dt) given state.
// t is UNIX timestamp, lat/lng in degrees, alt in metres.
type Model func(t float64, lat, lng, alt float64) (dlat, dlng, dalt float64)
// Terminator returns true when integration should stop.
type Terminator func(t float64, lat, lng, alt float64) bool
// StageResult holds the trajectory points for one flight stage.
type StageResult struct {
Points []TrajectoryPoint
}
// TrajectoryPoint is a single point in a trajectory (used by solver).
type TrajectoryPoint struct {
T float64 // UNIX timestamp
Lat float64
Lng float64
Alt float64
}
// pymod returns a % b with Python semantics (always non-negative when b > 0).
func pymod(a, b float64) float64 {
r := math.Mod(a, b)
if r < 0 {
r += b
}
return r
}
// vecadd returns a + k*b, with lng wrapped to [0, 360).
func vecadd(a Vec, k float64, b Vec) Vec {
return Vec{
Lat: a.Lat + k*b.Lat,
Lng: pymod(a.Lng+k*b.Lng, 360.0),
Alt: a.Alt + k*b.Alt,
}
}
// scalarLerp returns (1-l)*a + l*b.
func scalarLerp(a, b, l float64) float64 {
return (1-l)*a + l*b
}
// lngLerp interpolates longitude handling the 0/360 wrap-around.
func lngLerp(a, b, l float64) float64 {
l2 := 1 - l
if a > b {
a, b = b, a
l, l2 = l2, l
}
// distance round one way: b - a
// distance around other: (a + 360) - b
if b-a < 180.0 {
return l2*a + l*b
}
return pymod(l2*(a+360)+l*b, 360.0)
}
// vecLerp returns (1-l)*a + l*b with proper longitude wrapping.
func vecLerp(a, b Vec, l float64) Vec {
return Vec{
Lat: scalarLerp(a.Lat, b.Lat, l),
Lng: lngLerp(a.Lng, b.Lng, l),
Alt: scalarLerp(a.Alt, b.Alt, l),
}
}
// rk4 integrates from initial conditions using RK4.
// dt=60.0 seconds, terminationTolerance=0.01.
func rk4(t float64, lat, lng, alt float64, model Model, terminator Terminator) []TrajectoryPoint {
const dt = 60.0
const terminationTolerance = 0.01
y := Vec{Lat: lat, Lng: lng, Alt: alt}
result := []TrajectoryPoint{{T: t, Lat: y.Lat, Lng: y.Lng, Alt: y.Alt}}
for {
// Evaluate model at 4 points (standard RK4)
k1lat, k1lng, k1alt := model(t, y.Lat, y.Lng, y.Alt)
k1 := Vec{Lat: k1lat, Lng: k1lng, Alt: k1alt}
mid1 := vecadd(y, dt/2, k1)
k2lat, k2lng, k2alt := model(t+dt/2, mid1.Lat, mid1.Lng, mid1.Alt)
k2 := Vec{Lat: k2lat, Lng: k2lng, Alt: k2alt}
mid2 := vecadd(y, dt/2, k2)
k3lat, k3lng, k3alt := model(t+dt/2, mid2.Lat, mid2.Lng, mid2.Alt)
k3 := Vec{Lat: k3lat, Lng: k3lng, Alt: k3alt}
end := vecadd(y, dt, k3)
k4lat, k4lng, k4alt := model(t+dt, end.Lat, end.Lng, end.Alt)
k4 := Vec{Lat: k4lat, Lng: k4lng, Alt: k4alt}
// y2 = y + dt/6*k1 + dt/3*k2 + dt/3*k3 + dt/6*k4
y2 := y
y2 = vecadd(y2, dt/6, k1)
y2 = vecadd(y2, dt/3, k2)
y2 = vecadd(y2, dt/3, k3)
y2 = vecadd(y2, dt/6, k4)
t2 := t + dt
if terminator(t2, y2.Lat, y2.Lng, y2.Alt) {
// Binary search to refine the termination point.
// Find l in [0, 1] such that (t3, y3) = lerp((t, y), (t2, y2), l)
// is near where the terminator fires.
left := 0.0
right := 1.0
var t3 float64
var y3 Vec
t3 = t2
y3 = y2
for right-left > terminationTolerance {
mid := (left + right) / 2
t3 = scalarLerp(t, t2, mid)
y3 = vecLerp(y, y2, mid)
if terminator(t3, y3.Lat, y3.Lng, y3.Alt) {
right = mid
} else {
left = mid
}
}
result = append(result, TrajectoryPoint{T: t3, Lat: y3.Lat, Lng: y3.Lng, Alt: y3.Alt})
break
}
// Update current state
t = t2
y = y2
result = append(result, TrajectoryPoint{T: t, Lat: y.Lat, Lng: y.Lng, Alt: y.Alt})
}
return result
}
// Solve runs through a chain of (model, terminator) stages.
// Returns one StageResult per stage.
func Solve(t, lat, lng, alt float64, chain []struct {
Model Model
Terminator Terminator
}) []StageResult {
var results []StageResult
for _, stage := range chain {
points := rk4(t, lat, lng, alt, stage.Model, stage.Terminator)
results = append(results, StageResult{Points: points})
// Next stage starts where this one ended
if len(points) > 0 {
last := points[len(points)-1]
t = last.T
lat = last.Lat
lng = last.Lng
alt = last.Alt
}
}
return results
}

View file

@ -0,0 +1,21 @@
package prediction
import "sync/atomic"
// Warnings tracks warning conditions during a prediction run.
type Warnings struct {
AltitudeTooHigh atomic.Int64
}
// ToMap returns warnings as a map suitable for JSON serialization.
// Only includes warnings that have fired.
func (w *Warnings) ToMap() map[string]any {
result := make(map[string]any)
if n := w.AltitudeTooHigh.Load(); n > 0 {
result["altitude_too_high"] = map[string]any{
"count": n,
"description": "The altitude went too high, above the max forecast wind. Wind data will be unreliable",
}
}
return result
}

245
internal/service/service.go Normal file
View file

@ -0,0 +1,245 @@
package service
import (
"context"
"os"
"path/filepath"
"sort"
"strings"
"sync"
"time"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/downloader"
"predictor-refactored/internal/elevation"
"go.uber.org/zap"
)
// Service orchestrates the dataset lifecycle and provides prediction capabilities.
type Service struct {
mu sync.RWMutex
ds *dataset.File
elev *elevation.Dataset
cfg *downloader.Config
dl *downloader.Downloader
log *zap.Logger
updating sync.Mutex // prevents concurrent downloads
}
// New creates a new Service.
func New(cfg *downloader.Config, log *zap.Logger) *Service {
return &Service{
cfg: cfg,
dl: downloader.NewDownloader(cfg, log),
log: log,
}
}
// LoadElevation loads the ruaumoko-compatible elevation dataset from path.
// If the file doesn't exist, elevation termination is disabled (falls back to sea level).
func (s *Service) LoadElevation(path string) {
ds, err := elevation.Open(path)
if err != nil {
s.log.Warn("elevation dataset not available, using sea-level termination",
zap.String("path", path), zap.Error(err))
return
}
s.elev = ds
s.log.Info("elevation dataset loaded", zap.String("path", path))
}
// Elevation returns the elevation dataset (may be nil).
func (s *Service) Elevation() *elevation.Dataset {
return s.elev
}
// Ready returns true if the service has a loaded dataset.
func (s *Service) Ready() bool {
s.mu.RLock()
defer s.mu.RUnlock()
return s.ds != nil
}
// DatasetTime returns the forecast time of the currently loaded dataset.
func (s *Service) DatasetTime() (time.Time, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
if s.ds == nil {
return time.Time{}, false
}
return s.ds.DSTime, true
}
// Dataset returns the current dataset for reading.
func (s *Service) Dataset() *dataset.File {
s.mu.RLock()
defer s.mu.RUnlock()
return s.ds
}
// Update checks for and downloads new forecast data if needed.
func (s *Service) Update(ctx context.Context) error {
if !s.updating.TryLock() {
s.log.Info("update already in progress, skipping")
return nil
}
defer s.updating.Unlock()
// Check if current dataset is still fresh
if dsTime, ok := s.DatasetTime(); ok {
if time.Since(dsTime) < s.cfg.DatasetTTL {
s.log.Info("dataset still fresh, skipping update",
zap.Time("dataset_time", dsTime),
zap.Duration("age", time.Since(dsTime)))
return nil
}
}
// Try loading an existing dataset from disk first
if err := s.loadExistingDataset(); err == nil {
return nil
}
// Find latest available model run
run, err := s.dl.FindLatestRun(ctx)
if err != nil {
return err
}
// Download and assemble
path, err := s.dl.Download(ctx, run)
if err != nil {
return err
}
// Open the new dataset
ds, err := dataset.Open(path, run)
if err != nil {
return err
}
// Swap in the new dataset
s.setDataset(ds)
s.log.Info("dataset loaded", zap.Time("run", run), zap.String("path", path))
// Clean old datasets
s.cleanOldDatasets(path)
return nil
}
// loadExistingDataset tries to find and load an existing dataset from the data directory.
func (s *Service) loadExistingDataset() error {
entries, err := os.ReadDir(s.cfg.DataDir)
if err != nil {
return err
}
// Collect valid dataset files (name is YYYYMMDDHH, no extension, correct size)
type candidate struct {
name string
path string
run time.Time
}
var candidates []candidate
for _, e := range entries {
if e.IsDir() || strings.Contains(e.Name(), ".") {
continue
}
if len(e.Name()) != 10 {
continue
}
run, err := time.Parse("2006010215", e.Name())
if err != nil {
continue
}
path := filepath.Join(s.cfg.DataDir, e.Name())
info, err := os.Stat(path)
if err != nil || info.Size() != dataset.DatasetSize {
continue
}
if time.Since(run) > s.cfg.DatasetTTL {
continue
}
candidates = append(candidates, candidate{name: e.Name(), path: path, run: run})
}
if len(candidates) == 0 {
return os.ErrNotExist
}
// Pick the newest
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].run.After(candidates[j].run)
})
best := candidates[0]
ds, err := dataset.Open(best.path, best.run)
if err != nil {
return err
}
s.setDataset(ds)
s.log.Info("loaded existing dataset",
zap.Time("run", best.run),
zap.String("path", best.path))
return nil
}
// setDataset swaps the current dataset with a new one, closing the old one.
func (s *Service) setDataset(ds *dataset.File) {
s.mu.Lock()
old := s.ds
s.ds = ds
s.mu.Unlock()
if old != nil {
if err := old.Close(); err != nil {
s.log.Error("failed to close old dataset", zap.Error(err))
}
}
}
// cleanOldDatasets removes dataset files other than the one at keepPath.
func (s *Service) cleanOldDatasets(keepPath string) {
entries, err := os.ReadDir(s.cfg.DataDir)
if err != nil {
return
}
for _, e := range entries {
if e.IsDir() {
continue
}
path := filepath.Join(s.cfg.DataDir, e.Name())
if path == keepPath {
continue
}
// Remove old datasets and temp files
if len(e.Name()) == 10 || strings.HasSuffix(e.Name(), ".downloading") {
if err := os.Remove(path); err != nil {
s.log.Warn("failed to remove old file", zap.String("path", path), zap.Error(err))
} else {
s.log.Info("removed old dataset", zap.String("path", path))
}
}
}
}
// Close releases all resources.
func (s *Service) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.ds != nil {
err := s.ds.Close()
s.ds = nil
return err
}
return nil
}

View file

@ -0,0 +1,30 @@
package middleware
import (
"time"
"github.com/ogen-go/ogen/middleware"
"go.uber.org/zap"
)
// Logging returns an ogen middleware that logs request duration.
func Logging(log *zap.Logger) middleware.Middleware {
return func(req middleware.Request, next func(req middleware.Request) (middleware.Response, error)) (middleware.Response, error) {
lg := log.With(zap.String("operation", req.OperationID))
start := time.Now()
resp, err := next(req)
dur := time.Since(start)
if err != nil {
lg.Error("request failed",
zap.Duration("duration", dur),
zap.Error(err))
} else {
lg.Info("request completed",
zap.Duration("duration", dur))
}
return resp, err
}
}

View file

@ -0,0 +1,16 @@
package handler
import (
"time"
"predictor-refactored/internal/dataset"
"predictor-refactored/internal/elevation"
)
// Service defines the interface the handler needs from the service layer.
type Service interface {
Ready() bool
DatasetTime() (time.Time, bool)
Dataset() *dataset.File
Elevation() *elevation.Dataset
}

View file

@ -0,0 +1,216 @@
package handler
import (
"context"
"net/http"
"time"
"predictor-refactored/internal/prediction"
api "predictor-refactored/pkg/rest"
"go.uber.org/zap"
)
var _ api.Handler = (*Handler)(nil)
// Handler implements the ogen-generated api.Handler interface.
type Handler struct {
svc Service
log *zap.Logger
}
// New creates a new Handler.
func New(svc Service, log *zap.Logger) *Handler {
return &Handler{svc: svc, log: log}
}
// PerformPrediction implements the prediction endpoint.
func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredictionParams) (*api.PredictionResponse, error) {
if !h.svc.Ready() {
return nil, newError(http.StatusServiceUnavailable, "no dataset loaded, service is starting up")
}
ds := h.svc.Dataset()
if ds == nil {
return nil, newError(http.StatusServiceUnavailable, "dataset unavailable")
}
dsEpoch := float64(ds.DSTime.Unix())
// Parse parameters with defaults
profile := "standard_profile"
if p, ok := params.Profile.Get(); ok {
profile = string(p)
}
ascentRate := 5.0
if v, ok := params.AscentRate.Get(); ok {
ascentRate = v
}
burstAltitude := 28000.0
if v, ok := params.BurstAltitude.Get(); ok {
burstAltitude = v
}
descentRate := 5.0
if v, ok := params.DescentRate.Get(); ok {
descentRate = v
}
launchAlt := 0.0
if v, ok := params.LaunchAltitude.Get(); ok {
launchAlt = v
}
// Normalize longitude to [0, 360)
lng := params.LaunchLongitude
if lng < 0 {
lng += 360.0
}
launchTime := float64(params.LaunchDatetime.Unix())
warnings := &prediction.Warnings{}
// Build profile chain
elev := h.svc.Elevation()
var stages []prediction.Stage
switch profile {
case "standard_profile":
stages = prediction.StandardProfile(
ascentRate, burstAltitude, descentRate,
ds, dsEpoch, warnings, elev)
case "float_profile":
floatAlt := 25000.0
if v, ok := params.FloatAltitude.Get(); ok {
floatAlt = v
}
stopTime := params.LaunchDatetime.Add(24 * time.Hour)
if v, ok := params.StopDatetime.Get(); ok {
stopTime = v
}
stages = prediction.FloatProfile(
ascentRate, floatAlt, stopTime,
ds, dsEpoch, warnings)
default:
return nil, newError(http.StatusBadRequest, "unknown profile: "+profile)
}
// Run prediction
startTime := time.Now().UTC()
results := prediction.RunPrediction(launchTime, params.LaunchLatitude, lng, launchAlt, stages)
completeTime := time.Now().UTC()
// Build response
stageNames := []string{"ascent", "descent"}
if profile == "float_profile" {
stageNames = []string{"ascent", "float"}
}
var predItems []api.PredictionResponsePredictionItem
for i, sr := range results {
stageName := "ascent"
if i < len(stageNames) {
stageName = stageNames[i]
}
var stageEnum api.PredictionResponsePredictionItemStage
switch stageName {
case "ascent":
stageEnum = api.PredictionResponsePredictionItemStageAscent
case "descent":
stageEnum = api.PredictionResponsePredictionItemStageDescent
case "float":
stageEnum = api.PredictionResponsePredictionItemStageFloat
}
var traj []api.PredictionResponsePredictionItemTrajectoryItem
for _, pt := range sr.Points {
ptLng := pt.Lng
if ptLng > 180 {
ptLng -= 360
}
traj = append(traj, api.PredictionResponsePredictionItemTrajectoryItem{
Datetime: time.Unix(int64(pt.T), 0).UTC(),
Latitude: pt.Lat,
Longitude: ptLng,
Altitude: pt.Alt,
})
}
predItems = append(predItems, api.PredictionResponsePredictionItem{
Stage: stageEnum,
Trajectory: traj,
})
}
resp := &api.PredictionResponse{
Prediction: predItems,
Metadata: api.PredictionResponseMetadata{
StartDatetime: startTime,
CompleteDatetime: completeTime,
},
}
// Echo request
resp.Request = api.NewOptPredictionResponseRequest(api.PredictionResponseRequest{
Dataset: api.NewOptString(ds.DSTime.Format("2006-01-02T15:04:05Z")),
LaunchLatitude: api.NewOptFloat64(params.LaunchLatitude),
LaunchLongitude: api.NewOptFloat64(params.LaunchLongitude),
LaunchDatetime: api.NewOptString(params.LaunchDatetime.Format(time.RFC3339)),
LaunchAltitude: params.LaunchAltitude,
})
// Warnings
warnMap := warnings.ToMap()
if len(warnMap) > 0 {
resp.Warnings = api.NewOptPredictionResponseWarnings(api.PredictionResponseWarnings{})
}
h.log.Info("prediction complete",
zap.String("profile", profile),
zap.Int("stages", len(results)),
zap.Duration("elapsed", completeTime.Sub(startTime)))
return resp, nil
}
// ReadinessCheck implements the health check endpoint.
func (h *Handler) ReadinessCheck(ctx context.Context) (*api.ReadinessResponse, error) {
resp := &api.ReadinessResponse{}
if h.svc.Ready() {
resp.Status = api.ReadinessResponseStatusOk
if dsTime, ok := h.svc.DatasetTime(); ok {
resp.DatasetTime = api.NewOptDateTime(dsTime)
}
} else {
resp.Status = api.ReadinessResponseStatusNotReady
resp.ErrorMessage = api.NewOptString("no dataset loaded")
}
return resp, nil
}
// NewError creates an ErrorStatusCode from an error returned by a handler.
func (h *Handler) NewError(ctx context.Context, err error) *api.ErrorStatusCode {
if statusErr, ok := err.(*api.ErrorStatusCode); ok {
return statusErr
}
h.log.Error("unhandled error", zap.Error(err))
return newError(http.StatusInternalServerError, err.Error())
}
func newError(status int, description string) *api.ErrorStatusCode {
return &api.ErrorStatusCode{
StatusCode: status,
Response: api.Error{
Error: api.ErrorError{
Type: http.StatusText(status),
Description: description,
},
},
}
}

View file

@ -0,0 +1,75 @@
package rest
import (
"context"
"fmt"
"net/http"
"predictor-refactored/internal/transport/middleware"
"predictor-refactored/internal/transport/rest/handler"
api "predictor-refactored/pkg/rest"
"go.uber.org/zap"
)
// Transport wraps the ogen HTTP server.
type Transport struct {
srv *api.Server
handler *handler.Handler
port int
log *zap.Logger
}
// New creates a new REST transport.
func New(h *handler.Handler, port int, log *zap.Logger) (*Transport, error) {
srv, err := api.NewServer(
h,
api.WithMiddleware(middleware.Logging(log)),
)
if err != nil {
return nil, fmt.Errorf("create ogen server: %w", err)
}
return &Transport{
srv: srv,
handler: h,
port: port,
log: log,
}, nil
}
// Run starts the HTTP server. Blocks until the server stops.
func (t *Transport) Run() error {
mux := http.NewServeMux()
mux.Handle("/", t.srv)
httpSrv := &http.Server{
Addr: fmt.Sprintf(":%d", t.port),
Handler: corsMiddleware(mux),
}
t.log.Info("starting HTTP server", zap.Int("port", t.port))
return httpSrv.ListenAndServe()
}
// Shutdown gracefully stops the HTTP server.
func (t *Transport) Shutdown(ctx context.Context) error {
// The ogen server doesn't have a shutdown method;
// shutdown is handled by the http.Server in main.go
return nil
}
func corsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Access-Control-Allow-Origin", "*")
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
if r.Method == http.MethodOptions {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}

318
pkg/rest/oas_cfg_gen.go Normal file
View file

@ -0,0 +1,318 @@
// Code generated by ogen, DO NOT EDIT.
package rest
import (
"net/http"
"strings"
ht "github.com/ogen-go/ogen/http"
"github.com/ogen-go/ogen/middleware"
"github.com/ogen-go/ogen/ogenerrors"
"github.com/ogen-go/ogen/otelogen"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/metric"
"go.opentelemetry.io/otel/trace"
)
var (
// Allocate option closure once.
clientSpanKind = trace.WithSpanKind(trace.SpanKindClient)
// Allocate option closure once.
serverSpanKind = trace.WithSpanKind(trace.SpanKindServer)
)
type (
optionFunc[C any] func(*C)
otelOptionFunc func(*otelConfig)
)
type otelConfig struct {
TracerProvider trace.TracerProvider
Tracer trace.Tracer
MeterProvider metric.MeterProvider
Meter metric.Meter
Attributes []attribute.KeyValue
}
func (cfg *otelConfig) initOTEL() {
if cfg.TracerProvider == nil {
cfg.TracerProvider = otel.GetTracerProvider()
}
if cfg.MeterProvider == nil {
cfg.MeterProvider = otel.GetMeterProvider()
}
cfg.Tracer = cfg.TracerProvider.Tracer(otelogen.Name,
trace.WithInstrumentationVersion(otelogen.SemVersion()),
)
cfg.Meter = cfg.MeterProvider.Meter(otelogen.Name,
metric.WithInstrumentationVersion(otelogen.SemVersion()),
)
}
// ErrorHandler is error handler.
type ErrorHandler = ogenerrors.ErrorHandler
type serverConfig struct {
otelConfig
NotFound http.HandlerFunc
MethodNotAllowed func(w http.ResponseWriter, r *http.Request, allowed string)
ErrorHandler ErrorHandler
Prefix string
Middleware Middleware
MaxMultipartMemory int64
}
// ServerOption is server config option.
type ServerOption interface {
applyServer(*serverConfig)
}
var _ ServerOption = (optionFunc[serverConfig])(nil)
func (o optionFunc[C]) applyServer(c *C) {
o(c)
}
var _ ServerOption = (otelOptionFunc)(nil)
func (o otelOptionFunc) applyServer(c *serverConfig) {
o(&c.otelConfig)
}
func newServerConfig(opts ...ServerOption) serverConfig {
cfg := serverConfig{
NotFound: http.NotFound,
MethodNotAllowed: nil,
ErrorHandler: ogenerrors.DefaultErrorHandler,
Middleware: nil,
MaxMultipartMemory: 32 << 20, // 32 MB
}
for _, opt := range opts {
opt.applyServer(&cfg)
}
cfg.initOTEL()
return cfg
}
type baseServer struct {
cfg serverConfig
requests metric.Int64Counter
errors metric.Int64Counter
duration metric.Float64Histogram
}
func (s baseServer) notFound(w http.ResponseWriter, r *http.Request) {
s.cfg.NotFound(w, r)
}
type notAllowedParams struct {
allowedMethods string
allowedHeaders map[string]string
acceptPost string
acceptPatch string
}
func (s baseServer) notAllowed(w http.ResponseWriter, r *http.Request, params notAllowedParams) {
h := w.Header()
isOptions := r.Method == "OPTIONS"
if isOptions {
h.Set("Access-Control-Allow-Methods", params.allowedMethods)
if params.allowedHeaders != nil {
m := r.Header.Get("Access-Control-Request-Method")
if m != "" {
allowedHeaders, ok := params.allowedHeaders[strings.ToUpper(m)]
if ok {
h.Set("Access-Control-Allow-Headers", allowedHeaders)
}
}
}
if params.acceptPost != "" {
h.Set("Accept-Post", params.acceptPost)
}
if params.acceptPatch != "" {
h.Set("Accept-Patch", params.acceptPatch)
}
}
if s.cfg.MethodNotAllowed != nil {
s.cfg.MethodNotAllowed(w, r, params.allowedMethods)
return
}
status := http.StatusNoContent
if !isOptions {
h.Set("Allow", params.allowedMethods)
status = http.StatusMethodNotAllowed
}
w.WriteHeader(status)
}
func (cfg serverConfig) baseServer() (s baseServer, err error) {
s = baseServer{cfg: cfg}
if s.requests, err = otelogen.ServerRequestCountCounter(s.cfg.Meter); err != nil {
return s, err
}
if s.errors, err = otelogen.ServerErrorsCountCounter(s.cfg.Meter); err != nil {
return s, err
}
if s.duration, err = otelogen.ServerDurationHistogram(s.cfg.Meter); err != nil {
return s, err
}
return s, nil
}
type clientConfig struct {
otelConfig
Client ht.Client
}
// ClientOption is client config option.
type ClientOption interface {
applyClient(*clientConfig)
}
var _ ClientOption = (optionFunc[clientConfig])(nil)
func (o optionFunc[C]) applyClient(c *C) {
o(c)
}
var _ ClientOption = (otelOptionFunc)(nil)
func (o otelOptionFunc) applyClient(c *clientConfig) {
o(&c.otelConfig)
}
func newClientConfig(opts ...ClientOption) clientConfig {
cfg := clientConfig{
Client: http.DefaultClient,
}
for _, opt := range opts {
opt.applyClient(&cfg)
}
cfg.initOTEL()
return cfg
}
type baseClient struct {
cfg clientConfig
requests metric.Int64Counter
errors metric.Int64Counter
duration metric.Float64Histogram
}
func (cfg clientConfig) baseClient() (c baseClient, err error) {
c = baseClient{cfg: cfg}
if c.requests, err = otelogen.ClientRequestCountCounter(c.cfg.Meter); err != nil {
return c, err
}
if c.errors, err = otelogen.ClientErrorsCountCounter(c.cfg.Meter); err != nil {
return c, err
}
if c.duration, err = otelogen.ClientDurationHistogram(c.cfg.Meter); err != nil {
return c, err
}
return c, nil
}
// Option is config option.
type Option interface {
ServerOption
ClientOption
}
// WithTracerProvider specifies a tracer provider to use for creating a tracer.
//
// If none is specified, the global provider is used.
func WithTracerProvider(provider trace.TracerProvider) Option {
return otelOptionFunc(func(cfg *otelConfig) {
if provider != nil {
cfg.TracerProvider = provider
}
})
}
// WithMeterProvider specifies a meter provider to use for creating a meter.
//
// If none is specified, the otel.GetMeterProvider() is used.
func WithMeterProvider(provider metric.MeterProvider) Option {
return otelOptionFunc(func(cfg *otelConfig) {
if provider != nil {
cfg.MeterProvider = provider
}
})
}
// WithAttributes specifies default otel attributes.
func WithAttributes(attributes ...attribute.KeyValue) Option {
return otelOptionFunc(func(cfg *otelConfig) {
cfg.Attributes = attributes
})
}
// WithClient specifies http client to use.
func WithClient(client ht.Client) ClientOption {
return optionFunc[clientConfig](func(cfg *clientConfig) {
if client != nil {
cfg.Client = client
}
})
}
// WithNotFound specifies Not Found handler to use.
func WithNotFound(notFound http.HandlerFunc) ServerOption {
return optionFunc[serverConfig](func(cfg *serverConfig) {
if notFound != nil {
cfg.NotFound = notFound
}
})
}
// WithMethodNotAllowed specifies Method Not Allowed handler to use.
func WithMethodNotAllowed(methodNotAllowed func(w http.ResponseWriter, r *http.Request, allowed string)) ServerOption {
return optionFunc[serverConfig](func(cfg *serverConfig) {
if methodNotAllowed != nil {
cfg.MethodNotAllowed = methodNotAllowed
}
})
}
// WithErrorHandler specifies error handler to use.
func WithErrorHandler(h ErrorHandler) ServerOption {
return optionFunc[serverConfig](func(cfg *serverConfig) {
if h != nil {
cfg.ErrorHandler = h
}
})
}
// WithPathPrefix specifies server path prefix.
func WithPathPrefix(prefix string) ServerOption {
return optionFunc[serverConfig](func(cfg *serverConfig) {
cfg.Prefix = prefix
})
}
// WithMiddleware specifies middlewares to use.
func WithMiddleware(m ...Middleware) ServerOption {
return optionFunc[serverConfig](func(cfg *serverConfig) {
switch len(m) {
case 0:
cfg.Middleware = nil
case 1:
cfg.Middleware = m[0]
default:
cfg.Middleware = middleware.ChainMiddlewares(m...)
}
})
}
// WithMaxMultipartMemory specifies limit of memory for storing file parts.
// File parts which can't be stored in memory will be stored on disk in temporary files.
func WithMaxMultipartMemory(max int64) ServerOption {
return optionFunc[serverConfig](func(cfg *serverConfig) {
if max > 0 {
cfg.MaxMultipartMemory = max
}
})
}

411
pkg/rest/oas_client_gen.go Normal file
View file

@ -0,0 +1,411 @@
// Code generated by ogen, DO NOT EDIT.
package rest
import (
"context"
"net/url"
"strings"
"time"
"github.com/go-faster/errors"
"github.com/ogen-go/ogen/conv"
ht "github.com/ogen-go/ogen/http"
"github.com/ogen-go/ogen/otelogen"
"github.com/ogen-go/ogen/uri"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/metric"
semconv "go.opentelemetry.io/otel/semconv/v1.39.0"
"go.opentelemetry.io/otel/trace"
)
func trimTrailingSlashes(u *url.URL) {
u.Path = strings.TrimRight(u.Path, "/")
u.RawPath = strings.TrimRight(u.RawPath, "/")
}
// Invoker invokes operations described by OpenAPI v3 specification.
type Invoker interface {
// PerformPrediction invokes performPrediction operation.
//
// Perform prediction.
//
// GET /api/v1/prediction
PerformPrediction(ctx context.Context, params PerformPredictionParams) (*PredictionResponse, error)
// ReadinessCheck invokes readinessCheck operation.
//
// Readiness check.
//
// GET /ready
ReadinessCheck(ctx context.Context) (*ReadinessResponse, error)
}
// Client implements OAS client.
type Client struct {
serverURL *url.URL
baseClient
}
// NewClient initializes new Client defined by OAS.
func NewClient(serverURL string, opts ...ClientOption) (*Client, error) {
u, err := url.Parse(serverURL)
if err != nil {
return nil, err
}
trimTrailingSlashes(u)
c, err := newClientConfig(opts...).baseClient()
if err != nil {
return nil, err
}
return &Client{
serverURL: u,
baseClient: c,
}, nil
}
type serverURLKey struct{}
// WithServerURL sets context key to override server URL.
func WithServerURL(ctx context.Context, u *url.URL) context.Context {
return context.WithValue(ctx, serverURLKey{}, u)
}
func (c *Client) requestURL(ctx context.Context) *url.URL {
u, ok := ctx.Value(serverURLKey{}).(*url.URL)
if !ok {
return c.serverURL
}
return u
}
// PerformPrediction invokes performPrediction operation.
//
// Perform prediction.
//
// GET /api/v1/prediction
func (c *Client) PerformPrediction(ctx context.Context, params PerformPredictionParams) (*PredictionResponse, error) {
res, err := c.sendPerformPrediction(ctx, params)
return res, err
}
func (c *Client) sendPerformPrediction(ctx context.Context, params PerformPredictionParams) (res *PredictionResponse, err error) {
otelAttrs := []attribute.KeyValue{
otelogen.OperationID("performPrediction"),
semconv.HTTPRequestMethodKey.String("GET"),
semconv.URLTemplateKey.String("/api/v1/prediction"),
}
otelAttrs = append(otelAttrs, c.cfg.Attributes...)
// Run stopwatch.
startTime := time.Now()
defer func() {
// Use floating point division here for higher precision (instead of Millisecond method).
elapsedDuration := time.Since(startTime)
c.duration.Record(ctx, float64(elapsedDuration)/float64(time.Millisecond), metric.WithAttributes(otelAttrs...))
}()
// Increment request counter.
c.requests.Add(ctx, 1, metric.WithAttributes(otelAttrs...))
// Start a span for this request.
ctx, span := c.cfg.Tracer.Start(ctx, PerformPredictionOperation,
trace.WithAttributes(otelAttrs...),
clientSpanKind,
)
// Track stage for error reporting.
var stage string
defer func() {
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, stage)
c.errors.Add(ctx, 1, metric.WithAttributes(otelAttrs...))
}
span.End()
}()
stage = "BuildURL"
u := uri.Clone(c.requestURL(ctx))
var pathParts [1]string
pathParts[0] = "/api/v1/prediction"
uri.AddPathParts(u, pathParts[:]...)
stage = "EncodeQueryParams"
q := uri.NewQueryEncoder()
{
// Encode "launch_latitude" parameter.
cfg := uri.QueryParameterEncodingConfig{
Name: "launch_latitude",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.EncodeParam(cfg, func(e uri.Encoder) error {
return e.EncodeValue(conv.Float64ToString(params.LaunchLatitude))
}); err != nil {
return res, errors.Wrap(err, "encode query")
}
}
{
// Encode "launch_longitude" parameter.
cfg := uri.QueryParameterEncodingConfig{
Name: "launch_longitude",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.EncodeParam(cfg, func(e uri.Encoder) error {
return e.EncodeValue(conv.Float64ToString(params.LaunchLongitude))
}); err != nil {
return res, errors.Wrap(err, "encode query")
}
}
{
// Encode "launch_datetime" parameter.
cfg := uri.QueryParameterEncodingConfig{
Name: "launch_datetime",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.EncodeParam(cfg, func(e uri.Encoder) error {
return e.EncodeValue(conv.DateTimeToString(params.LaunchDatetime))
}); err != nil {
return res, errors.Wrap(err, "encode query")
}
}
{
// Encode "launch_altitude" parameter.
cfg := uri.QueryParameterEncodingConfig{
Name: "launch_altitude",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.EncodeParam(cfg, func(e uri.Encoder) error {
if val, ok := params.LaunchAltitude.Get(); ok {
return e.EncodeValue(conv.Float64ToString(val))
}
return nil
}); err != nil {
return res, errors.Wrap(err, "encode query")
}
}
{
// Encode "profile" parameter.
cfg := uri.QueryParameterEncodingConfig{
Name: "profile",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.EncodeParam(cfg, func(e uri.Encoder) error {
if val, ok := params.Profile.Get(); ok {
return e.EncodeValue(conv.StringToString(string(val)))
}
return nil
}); err != nil {
return res, errors.Wrap(err, "encode query")
}
}
{
// Encode "ascent_rate" parameter.
cfg := uri.QueryParameterEncodingConfig{
Name: "ascent_rate",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.EncodeParam(cfg, func(e uri.Encoder) error {
if val, ok := params.AscentRate.Get(); ok {
return e.EncodeValue(conv.Float64ToString(val))
}
return nil
}); err != nil {
return res, errors.Wrap(err, "encode query")
}
}
{
// Encode "burst_altitude" parameter.
cfg := uri.QueryParameterEncodingConfig{
Name: "burst_altitude",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.EncodeParam(cfg, func(e uri.Encoder) error {
if val, ok := params.BurstAltitude.Get(); ok {
return e.EncodeValue(conv.Float64ToString(val))
}
return nil
}); err != nil {
return res, errors.Wrap(err, "encode query")
}
}
{
// Encode "descent_rate" parameter.
cfg := uri.QueryParameterEncodingConfig{
Name: "descent_rate",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.EncodeParam(cfg, func(e uri.Encoder) error {
if val, ok := params.DescentRate.Get(); ok {
return e.EncodeValue(conv.Float64ToString(val))
}
return nil
}); err != nil {
return res, errors.Wrap(err, "encode query")
}
}
{
// Encode "float_altitude" parameter.
cfg := uri.QueryParameterEncodingConfig{
Name: "float_altitude",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.EncodeParam(cfg, func(e uri.Encoder) error {
if val, ok := params.FloatAltitude.Get(); ok {
return e.EncodeValue(conv.Float64ToString(val))
}
return nil
}); err != nil {
return res, errors.Wrap(err, "encode query")
}
}
{
// Encode "stop_datetime" parameter.
cfg := uri.QueryParameterEncodingConfig{
Name: "stop_datetime",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.EncodeParam(cfg, func(e uri.Encoder) error {
if val, ok := params.StopDatetime.Get(); ok {
return e.EncodeValue(conv.DateTimeToString(val))
}
return nil
}); err != nil {
return res, errors.Wrap(err, "encode query")
}
}
{
// Encode "dataset" parameter.
cfg := uri.QueryParameterEncodingConfig{
Name: "dataset",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.EncodeParam(cfg, func(e uri.Encoder) error {
if val, ok := params.Dataset.Get(); ok {
return e.EncodeValue(conv.DateTimeToString(val))
}
return nil
}); err != nil {
return res, errors.Wrap(err, "encode query")
}
}
u.RawQuery = q.Values().Encode()
stage = "EncodeRequest"
r, err := ht.NewRequest(ctx, "GET", u)
if err != nil {
return res, errors.Wrap(err, "create request")
}
stage = "SendRequest"
resp, err := c.cfg.Client.Do(r)
if err != nil {
return res, errors.Wrap(err, "do request")
}
body := resp.Body
defer body.Close()
stage = "DecodeResponse"
result, err := decodePerformPredictionResponse(resp)
if err != nil {
return res, errors.Wrap(err, "decode response")
}
return result, nil
}
// ReadinessCheck invokes readinessCheck operation.
//
// Readiness check.
//
// GET /ready
func (c *Client) ReadinessCheck(ctx context.Context) (*ReadinessResponse, error) {
res, err := c.sendReadinessCheck(ctx)
return res, err
}
func (c *Client) sendReadinessCheck(ctx context.Context) (res *ReadinessResponse, err error) {
otelAttrs := []attribute.KeyValue{
otelogen.OperationID("readinessCheck"),
semconv.HTTPRequestMethodKey.String("GET"),
semconv.URLTemplateKey.String("/ready"),
}
otelAttrs = append(otelAttrs, c.cfg.Attributes...)
// Run stopwatch.
startTime := time.Now()
defer func() {
// Use floating point division here for higher precision (instead of Millisecond method).
elapsedDuration := time.Since(startTime)
c.duration.Record(ctx, float64(elapsedDuration)/float64(time.Millisecond), metric.WithAttributes(otelAttrs...))
}()
// Increment request counter.
c.requests.Add(ctx, 1, metric.WithAttributes(otelAttrs...))
// Start a span for this request.
ctx, span := c.cfg.Tracer.Start(ctx, ReadinessCheckOperation,
trace.WithAttributes(otelAttrs...),
clientSpanKind,
)
// Track stage for error reporting.
var stage string
defer func() {
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, stage)
c.errors.Add(ctx, 1, metric.WithAttributes(otelAttrs...))
}
span.End()
}()
stage = "BuildURL"
u := uri.Clone(c.requestURL(ctx))
var pathParts [1]string
pathParts[0] = "/ready"
uri.AddPathParts(u, pathParts[:]...)
stage = "EncodeRequest"
r, err := ht.NewRequest(ctx, "GET", u)
if err != nil {
return res, errors.Wrap(err, "create request")
}
stage = "SendRequest"
resp, err := c.cfg.Client.Do(r)
if err != nil {
return res, errors.Wrap(err, "do request")
}
body := resp.Body
defer body.Close()
stage = "DecodeResponse"
result, err := decodeReadinessCheckResponse(resp)
if err != nil {
return res, errors.Wrap(err, "decode response")
}
return result, nil
}

View file

@ -0,0 +1,363 @@
// Code generated by ogen, DO NOT EDIT.
package rest
import (
"context"
"net/http"
"time"
"github.com/go-faster/errors"
ht "github.com/ogen-go/ogen/http"
"github.com/ogen-go/ogen/middleware"
"github.com/ogen-go/ogen/ogenerrors"
"github.com/ogen-go/ogen/otelogen"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/metric"
semconv "go.opentelemetry.io/otel/semconv/v1.39.0"
"go.opentelemetry.io/otel/trace"
)
type codeRecorder struct {
http.ResponseWriter
status int
}
func (c *codeRecorder) WriteHeader(status int) {
c.status = status
c.ResponseWriter.WriteHeader(status)
}
func (c *codeRecorder) Unwrap() http.ResponseWriter {
return c.ResponseWriter
}
// handlePerformPredictionRequest handles performPrediction operation.
//
// Perform prediction.
//
// GET /api/v1/prediction
func (s *Server) handlePerformPredictionRequest(args [0]string, argsEscaped bool, w http.ResponseWriter, r *http.Request) {
statusWriter := &codeRecorder{ResponseWriter: w}
w = statusWriter
otelAttrs := []attribute.KeyValue{
otelogen.OperationID("performPrediction"),
semconv.HTTPRequestMethodKey.String("GET"),
semconv.HTTPRouteKey.String("/api/v1/prediction"),
}
// Add attributes from config.
otelAttrs = append(otelAttrs, s.cfg.Attributes...)
// Start a span for this request.
ctx, span := s.cfg.Tracer.Start(r.Context(), PerformPredictionOperation,
trace.WithAttributes(otelAttrs...),
serverSpanKind,
)
defer span.End()
// Add Labeler to context.
labeler := &Labeler{attrs: otelAttrs}
ctx = contextWithLabeler(ctx, labeler)
// Run stopwatch.
startTime := time.Now()
defer func() {
elapsedDuration := time.Since(startTime)
attrSet := labeler.AttributeSet()
attrs := attrSet.ToSlice()
code := statusWriter.status
if code != 0 {
codeAttr := semconv.HTTPResponseStatusCode(code)
attrs = append(attrs, codeAttr)
span.SetAttributes(codeAttr)
}
attrOpt := metric.WithAttributes(attrs...)
// Increment request counter.
s.requests.Add(ctx, 1, attrOpt)
// Use floating point division here for higher precision (instead of Millisecond method).
s.duration.Record(ctx, float64(elapsedDuration)/float64(time.Millisecond), attrOpt)
}()
var (
recordError = func(stage string, err error) {
span.RecordError(err)
// https://opentelemetry.io/docs/specs/semconv/http/http-spans/#status
// Span Status MUST be left unset if HTTP status code was in the 1xx, 2xx or 3xx ranges,
// unless there was another error (e.g., network error receiving the response body; or 3xx codes with
// max redirects exceeded), in which case status MUST be set to Error.
code := statusWriter.status
if code < 100 || code >= 500 {
span.SetStatus(codes.Error, stage)
}
attrSet := labeler.AttributeSet()
attrs := attrSet.ToSlice()
if code != 0 {
attrs = append(attrs, semconv.HTTPResponseStatusCode(code))
}
s.errors.Add(ctx, 1, metric.WithAttributes(attrs...))
}
err error
opErrContext = ogenerrors.OperationContext{
Name: PerformPredictionOperation,
ID: "performPrediction",
}
)
params, err := decodePerformPredictionParams(args, argsEscaped, r)
if err != nil {
err = &ogenerrors.DecodeParamsError{
OperationContext: opErrContext,
Err: err,
}
defer recordError("DecodeParams", err)
s.cfg.ErrorHandler(ctx, w, r, err)
return
}
var rawBody []byte
var response *PredictionResponse
if m := s.cfg.Middleware; m != nil {
mreq := middleware.Request{
Context: ctx,
OperationName: PerformPredictionOperation,
OperationSummary: "Perform prediction",
OperationID: "performPrediction",
Body: nil,
RawBody: rawBody,
Params: middleware.Parameters{
{
Name: "launch_latitude",
In: "query",
}: params.LaunchLatitude,
{
Name: "launch_longitude",
In: "query",
}: params.LaunchLongitude,
{
Name: "launch_datetime",
In: "query",
}: params.LaunchDatetime,
{
Name: "launch_altitude",
In: "query",
}: params.LaunchAltitude,
{
Name: "profile",
In: "query",
}: params.Profile,
{
Name: "ascent_rate",
In: "query",
}: params.AscentRate,
{
Name: "burst_altitude",
In: "query",
}: params.BurstAltitude,
{
Name: "descent_rate",
In: "query",
}: params.DescentRate,
{
Name: "float_altitude",
In: "query",
}: params.FloatAltitude,
{
Name: "stop_datetime",
In: "query",
}: params.StopDatetime,
{
Name: "dataset",
In: "query",
}: params.Dataset,
},
Raw: r,
}
type (
Request = struct{}
Params = PerformPredictionParams
Response = *PredictionResponse
)
response, err = middleware.HookMiddleware[
Request,
Params,
Response,
](
m,
mreq,
unpackPerformPredictionParams,
func(ctx context.Context, request Request, params Params) (response Response, err error) {
response, err = s.h.PerformPrediction(ctx, params)
return response, err
},
)
} else {
response, err = s.h.PerformPrediction(ctx, params)
}
if err != nil {
if errRes, ok := errors.Into[*ErrorStatusCode](err); ok {
if err := encodeErrorResponse(errRes, w, span); err != nil {
defer recordError("Internal", err)
}
return
}
if errors.Is(err, ht.ErrNotImplemented) {
s.cfg.ErrorHandler(ctx, w, r, err)
return
}
if err := encodeErrorResponse(s.h.NewError(ctx, err), w, span); err != nil {
defer recordError("Internal", err)
}
return
}
if err := encodePerformPredictionResponse(response, w, span); err != nil {
defer recordError("EncodeResponse", err)
if !errors.Is(err, ht.ErrInternalServerErrorResponse) {
s.cfg.ErrorHandler(ctx, w, r, err)
}
return
}
}
// handleReadinessCheckRequest handles readinessCheck operation.
//
// Readiness check.
//
// GET /ready
func (s *Server) handleReadinessCheckRequest(args [0]string, argsEscaped bool, w http.ResponseWriter, r *http.Request) {
statusWriter := &codeRecorder{ResponseWriter: w}
w = statusWriter
otelAttrs := []attribute.KeyValue{
otelogen.OperationID("readinessCheck"),
semconv.HTTPRequestMethodKey.String("GET"),
semconv.HTTPRouteKey.String("/ready"),
}
// Add attributes from config.
otelAttrs = append(otelAttrs, s.cfg.Attributes...)
// Start a span for this request.
ctx, span := s.cfg.Tracer.Start(r.Context(), ReadinessCheckOperation,
trace.WithAttributes(otelAttrs...),
serverSpanKind,
)
defer span.End()
// Add Labeler to context.
labeler := &Labeler{attrs: otelAttrs}
ctx = contextWithLabeler(ctx, labeler)
// Run stopwatch.
startTime := time.Now()
defer func() {
elapsedDuration := time.Since(startTime)
attrSet := labeler.AttributeSet()
attrs := attrSet.ToSlice()
code := statusWriter.status
if code != 0 {
codeAttr := semconv.HTTPResponseStatusCode(code)
attrs = append(attrs, codeAttr)
span.SetAttributes(codeAttr)
}
attrOpt := metric.WithAttributes(attrs...)
// Increment request counter.
s.requests.Add(ctx, 1, attrOpt)
// Use floating point division here for higher precision (instead of Millisecond method).
s.duration.Record(ctx, float64(elapsedDuration)/float64(time.Millisecond), attrOpt)
}()
var (
recordError = func(stage string, err error) {
span.RecordError(err)
// https://opentelemetry.io/docs/specs/semconv/http/http-spans/#status
// Span Status MUST be left unset if HTTP status code was in the 1xx, 2xx or 3xx ranges,
// unless there was another error (e.g., network error receiving the response body; or 3xx codes with
// max redirects exceeded), in which case status MUST be set to Error.
code := statusWriter.status
if code < 100 || code >= 500 {
span.SetStatus(codes.Error, stage)
}
attrSet := labeler.AttributeSet()
attrs := attrSet.ToSlice()
if code != 0 {
attrs = append(attrs, semconv.HTTPResponseStatusCode(code))
}
s.errors.Add(ctx, 1, metric.WithAttributes(attrs...))
}
err error
)
var rawBody []byte
var response *ReadinessResponse
if m := s.cfg.Middleware; m != nil {
mreq := middleware.Request{
Context: ctx,
OperationName: ReadinessCheckOperation,
OperationSummary: "Readiness check",
OperationID: "readinessCheck",
Body: nil,
RawBody: rawBody,
Params: middleware.Parameters{},
Raw: r,
}
type (
Request = struct{}
Params = struct{}
Response = *ReadinessResponse
)
response, err = middleware.HookMiddleware[
Request,
Params,
Response,
](
m,
mreq,
nil,
func(ctx context.Context, request Request, params Params) (response Response, err error) {
response, err = s.h.ReadinessCheck(ctx)
return response, err
},
)
} else {
response, err = s.h.ReadinessCheck(ctx)
}
if err != nil {
if errRes, ok := errors.Into[*ErrorStatusCode](err); ok {
if err := encodeErrorResponse(errRes, w, span); err != nil {
defer recordError("Internal", err)
}
return
}
if errors.Is(err, ht.ErrNotImplemented) {
s.cfg.ErrorHandler(ctx, w, r, err)
return
}
if err := encodeErrorResponse(s.h.NewError(ctx, err), w, span); err != nil {
defer recordError("Internal", err)
}
return
}
if err := encodeReadinessCheckResponse(response, w, span); err != nil {
defer recordError("EncodeResponse", err)
if !errors.Is(err, ht.ErrInternalServerErrorResponse) {
s.cfg.ErrorHandler(ctx, w, r, err)
}
return
}
}

1398
pkg/rest/oas_json_gen.go Normal file

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,42 @@
// Code generated by ogen, DO NOT EDIT.
package rest
import (
"context"
"go.opentelemetry.io/otel/attribute"
)
// Labeler is used to allow adding custom attributes to the server request metrics.
type Labeler struct {
attrs []attribute.KeyValue
}
// Add attributes to the Labeler.
func (l *Labeler) Add(attrs ...attribute.KeyValue) {
l.attrs = append(l.attrs, attrs...)
}
// AttributeSet returns the attributes added to the Labeler as an attribute.Set.
func (l *Labeler) AttributeSet() attribute.Set {
return attribute.NewSet(l.attrs...)
}
type labelerContextKey struct{}
// LabelerFromContext retrieves the Labeler from the provided context, if present.
//
// If no Labeler was found in the provided context a new, empty Labeler is returned and the second
// return value is false. In this case it is safe to use the Labeler but any attributes added to
// it will not be used.
func LabelerFromContext(ctx context.Context) (*Labeler, bool) {
if l, ok := ctx.Value(labelerContextKey{}).(*Labeler); ok {
return l, true
}
return &Labeler{}, false
}
func contextWithLabeler(ctx context.Context, l *Labeler) context.Context {
return context.WithValue(ctx, labelerContextKey{}, l)
}

View file

@ -0,0 +1,10 @@
// Code generated by ogen, DO NOT EDIT.
package rest
import (
"github.com/ogen-go/ogen/middleware"
)
// Middleware is middleware type.
type Middleware = middleware.Middleware

View file

@ -0,0 +1,11 @@
// Code generated by ogen, DO NOT EDIT.
package rest
// OperationName is the ogen operation name
type OperationName = string
const (
PerformPredictionOperation OperationName = "PerformPrediction"
ReadinessCheckOperation OperationName = "ReadinessCheck"
)

View file

@ -0,0 +1,679 @@
// Code generated by ogen, DO NOT EDIT.
package rest
import (
"net/http"
"time"
"github.com/go-faster/errors"
"github.com/ogen-go/ogen/conv"
"github.com/ogen-go/ogen/middleware"
"github.com/ogen-go/ogen/ogenerrors"
"github.com/ogen-go/ogen/uri"
"github.com/ogen-go/ogen/validate"
)
// PerformPredictionParams is parameters of performPrediction operation.
type PerformPredictionParams struct {
LaunchLatitude float64
LaunchLongitude float64
LaunchDatetime time.Time
LaunchAltitude OptFloat64 `json:",omitempty,omitzero"`
Profile OptPerformPredictionProfile `json:",omitempty,omitzero"`
AscentRate OptFloat64 `json:",omitempty,omitzero"`
BurstAltitude OptFloat64 `json:",omitempty,omitzero"`
DescentRate OptFloat64 `json:",omitempty,omitzero"`
FloatAltitude OptFloat64 `json:",omitempty,omitzero"`
StopDatetime OptDateTime `json:",omitempty,omitzero"`
Dataset OptDateTime `json:",omitempty,omitzero"`
}
func unpackPerformPredictionParams(packed middleware.Parameters) (params PerformPredictionParams) {
{
key := middleware.ParameterKey{
Name: "launch_latitude",
In: "query",
}
params.LaunchLatitude = packed[key].(float64)
}
{
key := middleware.ParameterKey{
Name: "launch_longitude",
In: "query",
}
params.LaunchLongitude = packed[key].(float64)
}
{
key := middleware.ParameterKey{
Name: "launch_datetime",
In: "query",
}
params.LaunchDatetime = packed[key].(time.Time)
}
{
key := middleware.ParameterKey{
Name: "launch_altitude",
In: "query",
}
if v, ok := packed[key]; ok {
params.LaunchAltitude = v.(OptFloat64)
}
}
{
key := middleware.ParameterKey{
Name: "profile",
In: "query",
}
if v, ok := packed[key]; ok {
params.Profile = v.(OptPerformPredictionProfile)
}
}
{
key := middleware.ParameterKey{
Name: "ascent_rate",
In: "query",
}
if v, ok := packed[key]; ok {
params.AscentRate = v.(OptFloat64)
}
}
{
key := middleware.ParameterKey{
Name: "burst_altitude",
In: "query",
}
if v, ok := packed[key]; ok {
params.BurstAltitude = v.(OptFloat64)
}
}
{
key := middleware.ParameterKey{
Name: "descent_rate",
In: "query",
}
if v, ok := packed[key]; ok {
params.DescentRate = v.(OptFloat64)
}
}
{
key := middleware.ParameterKey{
Name: "float_altitude",
In: "query",
}
if v, ok := packed[key]; ok {
params.FloatAltitude = v.(OptFloat64)
}
}
{
key := middleware.ParameterKey{
Name: "stop_datetime",
In: "query",
}
if v, ok := packed[key]; ok {
params.StopDatetime = v.(OptDateTime)
}
}
{
key := middleware.ParameterKey{
Name: "dataset",
In: "query",
}
if v, ok := packed[key]; ok {
params.Dataset = v.(OptDateTime)
}
}
return params
}
func decodePerformPredictionParams(args [0]string, argsEscaped bool, r *http.Request) (params PerformPredictionParams, _ error) {
q := uri.NewQueryDecoder(r.URL.Query())
// Decode query: launch_latitude.
if err := func() error {
cfg := uri.QueryParameterDecodingConfig{
Name: "launch_latitude",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.HasParam(cfg); err == nil {
if err := q.DecodeParam(cfg, func(d uri.Decoder) error {
val, err := d.DecodeValue()
if err != nil {
return err
}
c, err := conv.ToFloat64(val)
if err != nil {
return err
}
params.LaunchLatitude = c
return nil
}); err != nil {
return err
}
if err := func() error {
if err := (validate.Float{}).Validate(float64(params.LaunchLatitude)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
} else {
return err
}
return nil
}(); err != nil {
return params, &ogenerrors.DecodeParamError{
Name: "launch_latitude",
In: "query",
Err: err,
}
}
// Decode query: launch_longitude.
if err := func() error {
cfg := uri.QueryParameterDecodingConfig{
Name: "launch_longitude",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.HasParam(cfg); err == nil {
if err := q.DecodeParam(cfg, func(d uri.Decoder) error {
val, err := d.DecodeValue()
if err != nil {
return err
}
c, err := conv.ToFloat64(val)
if err != nil {
return err
}
params.LaunchLongitude = c
return nil
}); err != nil {
return err
}
if err := func() error {
if err := (validate.Float{}).Validate(float64(params.LaunchLongitude)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
} else {
return err
}
return nil
}(); err != nil {
return params, &ogenerrors.DecodeParamError{
Name: "launch_longitude",
In: "query",
Err: err,
}
}
// Decode query: launch_datetime.
if err := func() error {
cfg := uri.QueryParameterDecodingConfig{
Name: "launch_datetime",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.HasParam(cfg); err == nil {
if err := q.DecodeParam(cfg, func(d uri.Decoder) error {
val, err := d.DecodeValue()
if err != nil {
return err
}
c, err := conv.ToDateTime(val)
if err != nil {
return err
}
params.LaunchDatetime = c
return nil
}); err != nil {
return err
}
} else {
return err
}
return nil
}(); err != nil {
return params, &ogenerrors.DecodeParamError{
Name: "launch_datetime",
In: "query",
Err: err,
}
}
// Decode query: launch_altitude.
if err := func() error {
cfg := uri.QueryParameterDecodingConfig{
Name: "launch_altitude",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.HasParam(cfg); err == nil {
if err := q.DecodeParam(cfg, func(d uri.Decoder) error {
var paramsDotLaunchAltitudeVal float64
if err := func() error {
val, err := d.DecodeValue()
if err != nil {
return err
}
c, err := conv.ToFloat64(val)
if err != nil {
return err
}
paramsDotLaunchAltitudeVal = c
return nil
}(); err != nil {
return err
}
params.LaunchAltitude.SetTo(paramsDotLaunchAltitudeVal)
return nil
}); err != nil {
return err
}
if err := func() error {
if value, ok := params.LaunchAltitude.Get(); ok {
if err := func() error {
if err := (validate.Float{}).Validate(float64(value)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
return params, &ogenerrors.DecodeParamError{
Name: "launch_altitude",
In: "query",
Err: err,
}
}
// Set default value for query: profile.
{
val := PerformPredictionProfile("standard_profile")
params.Profile.SetTo(val)
}
// Decode query: profile.
if err := func() error {
cfg := uri.QueryParameterDecodingConfig{
Name: "profile",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.HasParam(cfg); err == nil {
if err := q.DecodeParam(cfg, func(d uri.Decoder) error {
var paramsDotProfileVal PerformPredictionProfile
if err := func() error {
val, err := d.DecodeValue()
if err != nil {
return err
}
c, err := conv.ToString(val)
if err != nil {
return err
}
paramsDotProfileVal = PerformPredictionProfile(c)
return nil
}(); err != nil {
return err
}
params.Profile.SetTo(paramsDotProfileVal)
return nil
}); err != nil {
return err
}
if err := func() error {
if value, ok := params.Profile.Get(); ok {
if err := func() error {
if err := value.Validate(); err != nil {
return err
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
return params, &ogenerrors.DecodeParamError{
Name: "profile",
In: "query",
Err: err,
}
}
// Decode query: ascent_rate.
if err := func() error {
cfg := uri.QueryParameterDecodingConfig{
Name: "ascent_rate",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.HasParam(cfg); err == nil {
if err := q.DecodeParam(cfg, func(d uri.Decoder) error {
var paramsDotAscentRateVal float64
if err := func() error {
val, err := d.DecodeValue()
if err != nil {
return err
}
c, err := conv.ToFloat64(val)
if err != nil {
return err
}
paramsDotAscentRateVal = c
return nil
}(); err != nil {
return err
}
params.AscentRate.SetTo(paramsDotAscentRateVal)
return nil
}); err != nil {
return err
}
if err := func() error {
if value, ok := params.AscentRate.Get(); ok {
if err := func() error {
if err := (validate.Float{}).Validate(float64(value)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
return params, &ogenerrors.DecodeParamError{
Name: "ascent_rate",
In: "query",
Err: err,
}
}
// Decode query: burst_altitude.
if err := func() error {
cfg := uri.QueryParameterDecodingConfig{
Name: "burst_altitude",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.HasParam(cfg); err == nil {
if err := q.DecodeParam(cfg, func(d uri.Decoder) error {
var paramsDotBurstAltitudeVal float64
if err := func() error {
val, err := d.DecodeValue()
if err != nil {
return err
}
c, err := conv.ToFloat64(val)
if err != nil {
return err
}
paramsDotBurstAltitudeVal = c
return nil
}(); err != nil {
return err
}
params.BurstAltitude.SetTo(paramsDotBurstAltitudeVal)
return nil
}); err != nil {
return err
}
if err := func() error {
if value, ok := params.BurstAltitude.Get(); ok {
if err := func() error {
if err := (validate.Float{}).Validate(float64(value)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
return params, &ogenerrors.DecodeParamError{
Name: "burst_altitude",
In: "query",
Err: err,
}
}
// Decode query: descent_rate.
if err := func() error {
cfg := uri.QueryParameterDecodingConfig{
Name: "descent_rate",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.HasParam(cfg); err == nil {
if err := q.DecodeParam(cfg, func(d uri.Decoder) error {
var paramsDotDescentRateVal float64
if err := func() error {
val, err := d.DecodeValue()
if err != nil {
return err
}
c, err := conv.ToFloat64(val)
if err != nil {
return err
}
paramsDotDescentRateVal = c
return nil
}(); err != nil {
return err
}
params.DescentRate.SetTo(paramsDotDescentRateVal)
return nil
}); err != nil {
return err
}
if err := func() error {
if value, ok := params.DescentRate.Get(); ok {
if err := func() error {
if err := (validate.Float{}).Validate(float64(value)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
return params, &ogenerrors.DecodeParamError{
Name: "descent_rate",
In: "query",
Err: err,
}
}
// Decode query: float_altitude.
if err := func() error {
cfg := uri.QueryParameterDecodingConfig{
Name: "float_altitude",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.HasParam(cfg); err == nil {
if err := q.DecodeParam(cfg, func(d uri.Decoder) error {
var paramsDotFloatAltitudeVal float64
if err := func() error {
val, err := d.DecodeValue()
if err != nil {
return err
}
c, err := conv.ToFloat64(val)
if err != nil {
return err
}
paramsDotFloatAltitudeVal = c
return nil
}(); err != nil {
return err
}
params.FloatAltitude.SetTo(paramsDotFloatAltitudeVal)
return nil
}); err != nil {
return err
}
if err := func() error {
if value, ok := params.FloatAltitude.Get(); ok {
if err := func() error {
if err := (validate.Float{}).Validate(float64(value)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
return params, &ogenerrors.DecodeParamError{
Name: "float_altitude",
In: "query",
Err: err,
}
}
// Decode query: stop_datetime.
if err := func() error {
cfg := uri.QueryParameterDecodingConfig{
Name: "stop_datetime",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.HasParam(cfg); err == nil {
if err := q.DecodeParam(cfg, func(d uri.Decoder) error {
var paramsDotStopDatetimeVal time.Time
if err := func() error {
val, err := d.DecodeValue()
if err != nil {
return err
}
c, err := conv.ToDateTime(val)
if err != nil {
return err
}
paramsDotStopDatetimeVal = c
return nil
}(); err != nil {
return err
}
params.StopDatetime.SetTo(paramsDotStopDatetimeVal)
return nil
}); err != nil {
return err
}
}
return nil
}(); err != nil {
return params, &ogenerrors.DecodeParamError{
Name: "stop_datetime",
In: "query",
Err: err,
}
}
// Decode query: dataset.
if err := func() error {
cfg := uri.QueryParameterDecodingConfig{
Name: "dataset",
Style: uri.QueryStyleForm,
Explode: true,
}
if err := q.HasParam(cfg); err == nil {
if err := q.DecodeParam(cfg, func(d uri.Decoder) error {
var paramsDotDatasetVal time.Time
if err := func() error {
val, err := d.DecodeValue()
if err != nil {
return err
}
c, err := conv.ToDateTime(val)
if err != nil {
return err
}
paramsDotDatasetVal = c
return nil
}(); err != nil {
return err
}
params.Dataset.SetTo(paramsDotDatasetVal)
return nil
}); err != nil {
return err
}
}
return nil
}(); err != nil {
return params, &ogenerrors.DecodeParamError{
Name: "dataset",
In: "query",
Err: err,
}
}
return params, nil
}

View file

@ -0,0 +1,3 @@
// Code generated by ogen, DO NOT EDIT.
package rest

View file

@ -0,0 +1,3 @@
// Code generated by ogen, DO NOT EDIT.
package rest

View file

@ -0,0 +1,198 @@
// Code generated by ogen, DO NOT EDIT.
package rest
import (
"io"
"mime"
"net/http"
"github.com/go-faster/errors"
"github.com/go-faster/jx"
"github.com/ogen-go/ogen/ogenerrors"
"github.com/ogen-go/ogen/validate"
)
func decodePerformPredictionResponse(resp *http.Response) (res *PredictionResponse, _ error) {
switch resp.StatusCode {
case 200:
// Code 200.
ct, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
if err != nil {
return res, errors.Wrap(err, "parse media type")
}
switch {
case ct == "application/json":
buf, err := io.ReadAll(resp.Body)
if err != nil {
return res, err
}
d := jx.DecodeBytes(buf)
var response PredictionResponse
if err := func() error {
if err := response.Decode(d); err != nil {
return err
}
if err := d.Skip(); err != io.EOF {
return errors.New("unexpected trailing data")
}
return nil
}(); err != nil {
err = &ogenerrors.DecodeBodyError{
ContentType: ct,
Body: buf,
Err: err,
}
return res, err
}
// Validate response.
if err := func() error {
if err := response.Validate(); err != nil {
return err
}
return nil
}(); err != nil {
return res, errors.Wrap(err, "validate")
}
return &response, nil
default:
return res, validate.InvalidContentType(ct)
}
}
// Convenient error response.
defRes, err := func() (res *ErrorStatusCode, err error) {
ct, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
if err != nil {
return res, errors.Wrap(err, "parse media type")
}
switch {
case ct == "application/json":
buf, err := io.ReadAll(resp.Body)
if err != nil {
return res, err
}
d := jx.DecodeBytes(buf)
var response Error
if err := func() error {
if err := response.Decode(d); err != nil {
return err
}
if err := d.Skip(); err != io.EOF {
return errors.New("unexpected trailing data")
}
return nil
}(); err != nil {
err = &ogenerrors.DecodeBodyError{
ContentType: ct,
Body: buf,
Err: err,
}
return res, err
}
return &ErrorStatusCode{
StatusCode: resp.StatusCode,
Response: response,
}, nil
default:
return res, validate.InvalidContentType(ct)
}
}()
if err != nil {
return res, errors.Wrapf(err, "default (code %d)", resp.StatusCode)
}
return res, errors.Wrap(defRes, "error")
}
func decodeReadinessCheckResponse(resp *http.Response) (res *ReadinessResponse, _ error) {
switch resp.StatusCode {
case 200:
// Code 200.
ct, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
if err != nil {
return res, errors.Wrap(err, "parse media type")
}
switch {
case ct == "application/json":
buf, err := io.ReadAll(resp.Body)
if err != nil {
return res, err
}
d := jx.DecodeBytes(buf)
var response ReadinessResponse
if err := func() error {
if err := response.Decode(d); err != nil {
return err
}
if err := d.Skip(); err != io.EOF {
return errors.New("unexpected trailing data")
}
return nil
}(); err != nil {
err = &ogenerrors.DecodeBodyError{
ContentType: ct,
Body: buf,
Err: err,
}
return res, err
}
// Validate response.
if err := func() error {
if err := response.Validate(); err != nil {
return err
}
return nil
}(); err != nil {
return res, errors.Wrap(err, "validate")
}
return &response, nil
default:
return res, validate.InvalidContentType(ct)
}
}
// Convenient error response.
defRes, err := func() (res *ErrorStatusCode, err error) {
ct, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type"))
if err != nil {
return res, errors.Wrap(err, "parse media type")
}
switch {
case ct == "application/json":
buf, err := io.ReadAll(resp.Body)
if err != nil {
return res, err
}
d := jx.DecodeBytes(buf)
var response Error
if err := func() error {
if err := response.Decode(d); err != nil {
return err
}
if err := d.Skip(); err != io.EOF {
return errors.New("unexpected trailing data")
}
return nil
}(); err != nil {
err = &ogenerrors.DecodeBodyError{
ContentType: ct,
Body: buf,
Err: err,
}
return res, err
}
return &ErrorStatusCode{
StatusCode: resp.StatusCode,
Response: response,
}, nil
default:
return res, validate.InvalidContentType(ct)
}
}()
if err != nil {
return res, errors.Wrapf(err, "default (code %d)", resp.StatusCode)
}
return res, errors.Wrap(defRes, "error")
}

View file

@ -0,0 +1,68 @@
// Code generated by ogen, DO NOT EDIT.
package rest
import (
"net/http"
"github.com/go-faster/errors"
"github.com/go-faster/jx"
ht "github.com/ogen-go/ogen/http"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)
func encodePerformPredictionResponse(response *PredictionResponse, w http.ResponseWriter, span trace.Span) error {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(200)
span.SetStatus(codes.Ok, http.StatusText(200))
e := new(jx.Encoder)
response.Encode(e)
if _, err := e.WriteTo(w); err != nil {
return errors.Wrap(err, "write")
}
return nil
}
func encodeReadinessCheckResponse(response *ReadinessResponse, w http.ResponseWriter, span trace.Span) error {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
w.WriteHeader(200)
span.SetStatus(codes.Ok, http.StatusText(200))
e := new(jx.Encoder)
response.Encode(e)
if _, err := e.WriteTo(w); err != nil {
return errors.Wrap(err, "write")
}
return nil
}
func encodeErrorResponse(response *ErrorStatusCode, w http.ResponseWriter, span trace.Span) error {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
code := response.StatusCode
if code == 0 {
// Set default status code.
code = http.StatusOK
}
w.WriteHeader(code)
if st := http.StatusText(code); code >= http.StatusBadRequest {
span.SetStatus(codes.Error, st)
} else {
span.SetStatus(codes.Ok, st)
}
e := new(jx.Encoder)
response.Response.Encode(e)
if _, err := e.WriteTo(w); err != nil {
return errors.Wrap(err, "write")
}
if code >= http.StatusInternalServerError {
return errors.Wrapf(ht.ErrInternalServerErrorResponse, "code: %d, message: %s", code, http.StatusText(code))
}
return nil
}

268
pkg/rest/oas_router_gen.go Normal file
View file

@ -0,0 +1,268 @@
// Code generated by ogen, DO NOT EDIT.
package rest
import (
"net/http"
"net/url"
"strings"
"github.com/ogen-go/ogen/uri"
)
func (s *Server) cutPrefix(path string) (string, bool) {
prefix := s.cfg.Prefix
if prefix == "" {
return path, true
}
if !strings.HasPrefix(path, prefix) {
// Prefix doesn't match.
return "", false
}
// Cut prefix from the path.
return strings.TrimPrefix(path, prefix), true
}
// ServeHTTP serves http request as defined by OpenAPI v3 specification,
// calling handler that matches the path or returning not found error.
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
elem := r.URL.Path
elemIsEscaped := false
if rawPath := r.URL.RawPath; rawPath != "" {
if normalized, ok := uri.NormalizeEscapedPath(rawPath); ok {
elem = normalized
elemIsEscaped = strings.ContainsRune(elem, '%')
}
}
elem, ok := s.cutPrefix(elem)
if !ok || len(elem) == 0 {
s.notFound(w, r)
return
}
// Static code generated router with unwrapped path search.
switch {
default:
if len(elem) == 0 {
break
}
switch elem[0] {
case '/': // Prefix: "/"
if l := len("/"); len(elem) >= l && elem[0:l] == "/" {
elem = elem[l:]
} else {
break
}
if len(elem) == 0 {
break
}
switch elem[0] {
case 'a': // Prefix: "api/v1/prediction"
if l := len("api/v1/prediction"); len(elem) >= l && elem[0:l] == "api/v1/prediction" {
elem = elem[l:]
} else {
break
}
if len(elem) == 0 {
// Leaf node.
switch r.Method {
case "GET":
s.handlePerformPredictionRequest([0]string{}, elemIsEscaped, w, r)
default:
s.notAllowed(w, r, notAllowedParams{
allowedMethods: "GET",
allowedHeaders: nil,
acceptPost: "",
acceptPatch: "",
})
}
return
}
case 'r': // Prefix: "ready"
if l := len("ready"); len(elem) >= l && elem[0:l] == "ready" {
elem = elem[l:]
} else {
break
}
if len(elem) == 0 {
// Leaf node.
switch r.Method {
case "GET":
s.handleReadinessCheckRequest([0]string{}, elemIsEscaped, w, r)
default:
s.notAllowed(w, r, notAllowedParams{
allowedMethods: "GET",
allowedHeaders: nil,
acceptPost: "",
acceptPatch: "",
})
}
return
}
}
}
}
s.notFound(w, r)
}
// Route is route object.
type Route struct {
name string
summary string
operationID string
operationGroup string
pathPattern string
count int
args [0]string
}
// Name returns ogen operation name.
//
// It is guaranteed to be unique and not empty.
func (r Route) Name() string {
return r.name
}
// Summary returns OpenAPI summary.
func (r Route) Summary() string {
return r.summary
}
// OperationID returns OpenAPI operationId.
func (r Route) OperationID() string {
return r.operationID
}
// OperationGroup returns the x-ogen-operation-group value.
func (r Route) OperationGroup() string {
return r.operationGroup
}
// PathPattern returns OpenAPI path.
func (r Route) PathPattern() string {
return r.pathPattern
}
// Args returns parsed arguments.
func (r Route) Args() []string {
return r.args[:r.count]
}
// FindRoute finds Route for given method and path.
//
// Note: this method does not unescape path or handle reserved characters in path properly. Use FindPath instead.
func (s *Server) FindRoute(method, path string) (Route, bool) {
return s.FindPath(method, &url.URL{Path: path})
}
// FindPath finds Route for given method and URL.
func (s *Server) FindPath(method string, u *url.URL) (r Route, _ bool) {
var (
elem = u.Path
args = r.args
)
if rawPath := u.RawPath; rawPath != "" {
if normalized, ok := uri.NormalizeEscapedPath(rawPath); ok {
elem = normalized
}
defer func() {
for i, arg := range r.args[:r.count] {
if unescaped, err := url.PathUnescape(arg); err == nil {
r.args[i] = unescaped
}
}
}()
}
elem, ok := s.cutPrefix(elem)
if !ok {
return r, false
}
// Static code generated router with unwrapped path search.
switch {
default:
if len(elem) == 0 {
break
}
switch elem[0] {
case '/': // Prefix: "/"
if l := len("/"); len(elem) >= l && elem[0:l] == "/" {
elem = elem[l:]
} else {
break
}
if len(elem) == 0 {
break
}
switch elem[0] {
case 'a': // Prefix: "api/v1/prediction"
if l := len("api/v1/prediction"); len(elem) >= l && elem[0:l] == "api/v1/prediction" {
elem = elem[l:]
} else {
break
}
if len(elem) == 0 {
// Leaf node.
switch method {
case "GET":
r.name = PerformPredictionOperation
r.summary = "Perform prediction"
r.operationID = "performPrediction"
r.operationGroup = ""
r.pathPattern = "/api/v1/prediction"
r.args = args
r.count = 0
return r, true
default:
return
}
}
case 'r': // Prefix: "ready"
if l := len("ready"); len(elem) >= l && elem[0:l] == "ready" {
elem = elem[l:]
} else {
break
}
if len(elem) == 0 {
// Leaf node.
switch method {
case "GET":
r.name = ReadinessCheckOperation
r.summary = "Readiness check"
r.operationID = "readinessCheck"
r.operationGroup = ""
r.pathPattern = "/ready"
r.args = args
r.count = 0
return r, true
default:
return
}
}
}
}
}
return r, false
}

789
pkg/rest/oas_schemas_gen.go Normal file
View file

@ -0,0 +1,789 @@
// Code generated by ogen, DO NOT EDIT.
package rest
import (
"fmt"
"time"
"github.com/go-faster/errors"
"github.com/go-faster/jx"
)
func (s *ErrorStatusCode) Error() string {
return fmt.Sprintf("code %d: %+v", s.StatusCode, s.Response)
}
// Ref: #/components/schemas/Error
type Error struct {
Error ErrorError `json:"error"`
}
// GetError returns the value of Error.
func (s *Error) GetError() ErrorError {
return s.Error
}
// SetError sets the value of Error.
func (s *Error) SetError(val ErrorError) {
s.Error = val
}
type ErrorError struct {
Type string `json:"type"`
Description string `json:"description"`
}
// GetType returns the value of Type.
func (s *ErrorError) GetType() string {
return s.Type
}
// GetDescription returns the value of Description.
func (s *ErrorError) GetDescription() string {
return s.Description
}
// SetType sets the value of Type.
func (s *ErrorError) SetType(val string) {
s.Type = val
}
// SetDescription sets the value of Description.
func (s *ErrorError) SetDescription(val string) {
s.Description = val
}
// ErrorStatusCode wraps Error with StatusCode.
type ErrorStatusCode struct {
StatusCode int
Response Error
}
// GetStatusCode returns the value of StatusCode.
func (s *ErrorStatusCode) GetStatusCode() int {
return s.StatusCode
}
// GetResponse returns the value of Response.
func (s *ErrorStatusCode) GetResponse() Error {
return s.Response
}
// SetStatusCode sets the value of StatusCode.
func (s *ErrorStatusCode) SetStatusCode(val int) {
s.StatusCode = val
}
// SetResponse sets the value of Response.
func (s *ErrorStatusCode) SetResponse(val Error) {
s.Response = val
}
// NewOptDateTime returns new OptDateTime with value set to v.
func NewOptDateTime(v time.Time) OptDateTime {
return OptDateTime{
Value: v,
Set: true,
}
}
// OptDateTime is optional time.Time.
type OptDateTime struct {
Value time.Time
Set bool
}
// IsSet returns true if OptDateTime was set.
func (o OptDateTime) IsSet() bool { return o.Set }
// Reset unsets value.
func (o *OptDateTime) Reset() {
var v time.Time
o.Value = v
o.Set = false
}
// SetTo sets value to v.
func (o *OptDateTime) SetTo(v time.Time) {
o.Set = true
o.Value = v
}
// Get returns value and boolean that denotes whether value was set.
func (o OptDateTime) Get() (v time.Time, ok bool) {
if !o.Set {
return v, false
}
return o.Value, true
}
// Or returns value if set, or given parameter if does not.
func (o OptDateTime) Or(d time.Time) time.Time {
if v, ok := o.Get(); ok {
return v
}
return d
}
// NewOptFloat64 returns new OptFloat64 with value set to v.
func NewOptFloat64(v float64) OptFloat64 {
return OptFloat64{
Value: v,
Set: true,
}
}
// OptFloat64 is optional float64.
type OptFloat64 struct {
Value float64
Set bool
}
// IsSet returns true if OptFloat64 was set.
func (o OptFloat64) IsSet() bool { return o.Set }
// Reset unsets value.
func (o *OptFloat64) Reset() {
var v float64
o.Value = v
o.Set = false
}
// SetTo sets value to v.
func (o *OptFloat64) SetTo(v float64) {
o.Set = true
o.Value = v
}
// Get returns value and boolean that denotes whether value was set.
func (o OptFloat64) Get() (v float64, ok bool) {
if !o.Set {
return v, false
}
return o.Value, true
}
// Or returns value if set, or given parameter if does not.
func (o OptFloat64) Or(d float64) float64 {
if v, ok := o.Get(); ok {
return v
}
return d
}
// NewOptPerformPredictionProfile returns new OptPerformPredictionProfile with value set to v.
func NewOptPerformPredictionProfile(v PerformPredictionProfile) OptPerformPredictionProfile {
return OptPerformPredictionProfile{
Value: v,
Set: true,
}
}
// OptPerformPredictionProfile is optional PerformPredictionProfile.
type OptPerformPredictionProfile struct {
Value PerformPredictionProfile
Set bool
}
// IsSet returns true if OptPerformPredictionProfile was set.
func (o OptPerformPredictionProfile) IsSet() bool { return o.Set }
// Reset unsets value.
func (o *OptPerformPredictionProfile) Reset() {
var v PerformPredictionProfile
o.Value = v
o.Set = false
}
// SetTo sets value to v.
func (o *OptPerformPredictionProfile) SetTo(v PerformPredictionProfile) {
o.Set = true
o.Value = v
}
// Get returns value and boolean that denotes whether value was set.
func (o OptPerformPredictionProfile) Get() (v PerformPredictionProfile, ok bool) {
if !o.Set {
return v, false
}
return o.Value, true
}
// Or returns value if set, or given parameter if does not.
func (o OptPerformPredictionProfile) Or(d PerformPredictionProfile) PerformPredictionProfile {
if v, ok := o.Get(); ok {
return v
}
return d
}
// NewOptPredictionResponseRequest returns new OptPredictionResponseRequest with value set to v.
func NewOptPredictionResponseRequest(v PredictionResponseRequest) OptPredictionResponseRequest {
return OptPredictionResponseRequest{
Value: v,
Set: true,
}
}
// OptPredictionResponseRequest is optional PredictionResponseRequest.
type OptPredictionResponseRequest struct {
Value PredictionResponseRequest
Set bool
}
// IsSet returns true if OptPredictionResponseRequest was set.
func (o OptPredictionResponseRequest) IsSet() bool { return o.Set }
// Reset unsets value.
func (o *OptPredictionResponseRequest) Reset() {
var v PredictionResponseRequest
o.Value = v
o.Set = false
}
// SetTo sets value to v.
func (o *OptPredictionResponseRequest) SetTo(v PredictionResponseRequest) {
o.Set = true
o.Value = v
}
// Get returns value and boolean that denotes whether value was set.
func (o OptPredictionResponseRequest) Get() (v PredictionResponseRequest, ok bool) {
if !o.Set {
return v, false
}
return o.Value, true
}
// Or returns value if set, or given parameter if does not.
func (o OptPredictionResponseRequest) Or(d PredictionResponseRequest) PredictionResponseRequest {
if v, ok := o.Get(); ok {
return v
}
return d
}
// NewOptPredictionResponseWarnings returns new OptPredictionResponseWarnings with value set to v.
func NewOptPredictionResponseWarnings(v PredictionResponseWarnings) OptPredictionResponseWarnings {
return OptPredictionResponseWarnings{
Value: v,
Set: true,
}
}
// OptPredictionResponseWarnings is optional PredictionResponseWarnings.
type OptPredictionResponseWarnings struct {
Value PredictionResponseWarnings
Set bool
}
// IsSet returns true if OptPredictionResponseWarnings was set.
func (o OptPredictionResponseWarnings) IsSet() bool { return o.Set }
// Reset unsets value.
func (o *OptPredictionResponseWarnings) Reset() {
var v PredictionResponseWarnings
o.Value = v
o.Set = false
}
// SetTo sets value to v.
func (o *OptPredictionResponseWarnings) SetTo(v PredictionResponseWarnings) {
o.Set = true
o.Value = v
}
// Get returns value and boolean that denotes whether value was set.
func (o OptPredictionResponseWarnings) Get() (v PredictionResponseWarnings, ok bool) {
if !o.Set {
return v, false
}
return o.Value, true
}
// Or returns value if set, or given parameter if does not.
func (o OptPredictionResponseWarnings) Or(d PredictionResponseWarnings) PredictionResponseWarnings {
if v, ok := o.Get(); ok {
return v
}
return d
}
// NewOptString returns new OptString with value set to v.
func NewOptString(v string) OptString {
return OptString{
Value: v,
Set: true,
}
}
// OptString is optional string.
type OptString struct {
Value string
Set bool
}
// IsSet returns true if OptString was set.
func (o OptString) IsSet() bool { return o.Set }
// Reset unsets value.
func (o *OptString) Reset() {
var v string
o.Value = v
o.Set = false
}
// SetTo sets value to v.
func (o *OptString) SetTo(v string) {
o.Set = true
o.Value = v
}
// Get returns value and boolean that denotes whether value was set.
func (o OptString) Get() (v string, ok bool) {
if !o.Set {
return v, false
}
return o.Value, true
}
// Or returns value if set, or given parameter if does not.
func (o OptString) Or(d string) string {
if v, ok := o.Get(); ok {
return v
}
return d
}
type PerformPredictionProfile string
const (
PerformPredictionProfileStandardProfile PerformPredictionProfile = "standard_profile"
PerformPredictionProfileFloatProfile PerformPredictionProfile = "float_profile"
)
// AllValues returns all PerformPredictionProfile values.
func (PerformPredictionProfile) AllValues() []PerformPredictionProfile {
return []PerformPredictionProfile{
PerformPredictionProfileStandardProfile,
PerformPredictionProfileFloatProfile,
}
}
// MarshalText implements encoding.TextMarshaler.
func (s PerformPredictionProfile) MarshalText() ([]byte, error) {
switch s {
case PerformPredictionProfileStandardProfile:
return []byte(s), nil
case PerformPredictionProfileFloatProfile:
return []byte(s), nil
default:
return nil, errors.Errorf("invalid value: %q", s)
}
}
// UnmarshalText implements encoding.TextUnmarshaler.
func (s *PerformPredictionProfile) UnmarshalText(data []byte) error {
switch PerformPredictionProfile(data) {
case PerformPredictionProfileStandardProfile:
*s = PerformPredictionProfileStandardProfile
return nil
case PerformPredictionProfileFloatProfile:
*s = PerformPredictionProfileFloatProfile
return nil
default:
return errors.Errorf("invalid value: %q", data)
}
}
// Ref: #/components/schemas/PredictionResponse
type PredictionResponse struct {
Request OptPredictionResponseRequest `json:"request"`
Prediction []PredictionResponsePredictionItem `json:"prediction"`
Metadata PredictionResponseMetadata `json:"metadata"`
Warnings OptPredictionResponseWarnings `json:"warnings"`
}
// GetRequest returns the value of Request.
func (s *PredictionResponse) GetRequest() OptPredictionResponseRequest {
return s.Request
}
// GetPrediction returns the value of Prediction.
func (s *PredictionResponse) GetPrediction() []PredictionResponsePredictionItem {
return s.Prediction
}
// GetMetadata returns the value of Metadata.
func (s *PredictionResponse) GetMetadata() PredictionResponseMetadata {
return s.Metadata
}
// GetWarnings returns the value of Warnings.
func (s *PredictionResponse) GetWarnings() OptPredictionResponseWarnings {
return s.Warnings
}
// SetRequest sets the value of Request.
func (s *PredictionResponse) SetRequest(val OptPredictionResponseRequest) {
s.Request = val
}
// SetPrediction sets the value of Prediction.
func (s *PredictionResponse) SetPrediction(val []PredictionResponsePredictionItem) {
s.Prediction = val
}
// SetMetadata sets the value of Metadata.
func (s *PredictionResponse) SetMetadata(val PredictionResponseMetadata) {
s.Metadata = val
}
// SetWarnings sets the value of Warnings.
func (s *PredictionResponse) SetWarnings(val OptPredictionResponseWarnings) {
s.Warnings = val
}
type PredictionResponseMetadata struct {
StartDatetime time.Time `json:"start_datetime"`
CompleteDatetime time.Time `json:"complete_datetime"`
}
// GetStartDatetime returns the value of StartDatetime.
func (s *PredictionResponseMetadata) GetStartDatetime() time.Time {
return s.StartDatetime
}
// GetCompleteDatetime returns the value of CompleteDatetime.
func (s *PredictionResponseMetadata) GetCompleteDatetime() time.Time {
return s.CompleteDatetime
}
// SetStartDatetime sets the value of StartDatetime.
func (s *PredictionResponseMetadata) SetStartDatetime(val time.Time) {
s.StartDatetime = val
}
// SetCompleteDatetime sets the value of CompleteDatetime.
func (s *PredictionResponseMetadata) SetCompleteDatetime(val time.Time) {
s.CompleteDatetime = val
}
type PredictionResponsePredictionItem struct {
Stage PredictionResponsePredictionItemStage `json:"stage"`
Trajectory []PredictionResponsePredictionItemTrajectoryItem `json:"trajectory"`
}
// GetStage returns the value of Stage.
func (s *PredictionResponsePredictionItem) GetStage() PredictionResponsePredictionItemStage {
return s.Stage
}
// GetTrajectory returns the value of Trajectory.
func (s *PredictionResponsePredictionItem) GetTrajectory() []PredictionResponsePredictionItemTrajectoryItem {
return s.Trajectory
}
// SetStage sets the value of Stage.
func (s *PredictionResponsePredictionItem) SetStage(val PredictionResponsePredictionItemStage) {
s.Stage = val
}
// SetTrajectory sets the value of Trajectory.
func (s *PredictionResponsePredictionItem) SetTrajectory(val []PredictionResponsePredictionItemTrajectoryItem) {
s.Trajectory = val
}
type PredictionResponsePredictionItemStage string
const (
PredictionResponsePredictionItemStageAscent PredictionResponsePredictionItemStage = "ascent"
PredictionResponsePredictionItemStageDescent PredictionResponsePredictionItemStage = "descent"
PredictionResponsePredictionItemStageFloat PredictionResponsePredictionItemStage = "float"
)
// AllValues returns all PredictionResponsePredictionItemStage values.
func (PredictionResponsePredictionItemStage) AllValues() []PredictionResponsePredictionItemStage {
return []PredictionResponsePredictionItemStage{
PredictionResponsePredictionItemStageAscent,
PredictionResponsePredictionItemStageDescent,
PredictionResponsePredictionItemStageFloat,
}
}
// MarshalText implements encoding.TextMarshaler.
func (s PredictionResponsePredictionItemStage) MarshalText() ([]byte, error) {
switch s {
case PredictionResponsePredictionItemStageAscent:
return []byte(s), nil
case PredictionResponsePredictionItemStageDescent:
return []byte(s), nil
case PredictionResponsePredictionItemStageFloat:
return []byte(s), nil
default:
return nil, errors.Errorf("invalid value: %q", s)
}
}
// UnmarshalText implements encoding.TextUnmarshaler.
func (s *PredictionResponsePredictionItemStage) UnmarshalText(data []byte) error {
switch PredictionResponsePredictionItemStage(data) {
case PredictionResponsePredictionItemStageAscent:
*s = PredictionResponsePredictionItemStageAscent
return nil
case PredictionResponsePredictionItemStageDescent:
*s = PredictionResponsePredictionItemStageDescent
return nil
case PredictionResponsePredictionItemStageFloat:
*s = PredictionResponsePredictionItemStageFloat
return nil
default:
return errors.Errorf("invalid value: %q", data)
}
}
type PredictionResponsePredictionItemTrajectoryItem struct {
Datetime time.Time `json:"datetime"`
Latitude float64 `json:"latitude"`
Longitude float64 `json:"longitude"`
Altitude float64 `json:"altitude"`
}
// GetDatetime returns the value of Datetime.
func (s *PredictionResponsePredictionItemTrajectoryItem) GetDatetime() time.Time {
return s.Datetime
}
// GetLatitude returns the value of Latitude.
func (s *PredictionResponsePredictionItemTrajectoryItem) GetLatitude() float64 {
return s.Latitude
}
// GetLongitude returns the value of Longitude.
func (s *PredictionResponsePredictionItemTrajectoryItem) GetLongitude() float64 {
return s.Longitude
}
// GetAltitude returns the value of Altitude.
func (s *PredictionResponsePredictionItemTrajectoryItem) GetAltitude() float64 {
return s.Altitude
}
// SetDatetime sets the value of Datetime.
func (s *PredictionResponsePredictionItemTrajectoryItem) SetDatetime(val time.Time) {
s.Datetime = val
}
// SetLatitude sets the value of Latitude.
func (s *PredictionResponsePredictionItemTrajectoryItem) SetLatitude(val float64) {
s.Latitude = val
}
// SetLongitude sets the value of Longitude.
func (s *PredictionResponsePredictionItemTrajectoryItem) SetLongitude(val float64) {
s.Longitude = val
}
// SetAltitude sets the value of Altitude.
func (s *PredictionResponsePredictionItemTrajectoryItem) SetAltitude(val float64) {
s.Altitude = val
}
type PredictionResponseRequest struct {
Dataset OptString `json:"dataset"`
LaunchLatitude OptFloat64 `json:"launch_latitude"`
LaunchLongitude OptFloat64 `json:"launch_longitude"`
LaunchDatetime OptString `json:"launch_datetime"`
LaunchAltitude OptFloat64 `json:"launch_altitude"`
Profile OptString `json:"profile"`
AscentRate OptFloat64 `json:"ascent_rate"`
BurstAltitude OptFloat64 `json:"burst_altitude"`
DescentRate OptFloat64 `json:"descent_rate"`
}
// GetDataset returns the value of Dataset.
func (s *PredictionResponseRequest) GetDataset() OptString {
return s.Dataset
}
// GetLaunchLatitude returns the value of LaunchLatitude.
func (s *PredictionResponseRequest) GetLaunchLatitude() OptFloat64 {
return s.LaunchLatitude
}
// GetLaunchLongitude returns the value of LaunchLongitude.
func (s *PredictionResponseRequest) GetLaunchLongitude() OptFloat64 {
return s.LaunchLongitude
}
// GetLaunchDatetime returns the value of LaunchDatetime.
func (s *PredictionResponseRequest) GetLaunchDatetime() OptString {
return s.LaunchDatetime
}
// GetLaunchAltitude returns the value of LaunchAltitude.
func (s *PredictionResponseRequest) GetLaunchAltitude() OptFloat64 {
return s.LaunchAltitude
}
// GetProfile returns the value of Profile.
func (s *PredictionResponseRequest) GetProfile() OptString {
return s.Profile
}
// GetAscentRate returns the value of AscentRate.
func (s *PredictionResponseRequest) GetAscentRate() OptFloat64 {
return s.AscentRate
}
// GetBurstAltitude returns the value of BurstAltitude.
func (s *PredictionResponseRequest) GetBurstAltitude() OptFloat64 {
return s.BurstAltitude
}
// GetDescentRate returns the value of DescentRate.
func (s *PredictionResponseRequest) GetDescentRate() OptFloat64 {
return s.DescentRate
}
// SetDataset sets the value of Dataset.
func (s *PredictionResponseRequest) SetDataset(val OptString) {
s.Dataset = val
}
// SetLaunchLatitude sets the value of LaunchLatitude.
func (s *PredictionResponseRequest) SetLaunchLatitude(val OptFloat64) {
s.LaunchLatitude = val
}
// SetLaunchLongitude sets the value of LaunchLongitude.
func (s *PredictionResponseRequest) SetLaunchLongitude(val OptFloat64) {
s.LaunchLongitude = val
}
// SetLaunchDatetime sets the value of LaunchDatetime.
func (s *PredictionResponseRequest) SetLaunchDatetime(val OptString) {
s.LaunchDatetime = val
}
// SetLaunchAltitude sets the value of LaunchAltitude.
func (s *PredictionResponseRequest) SetLaunchAltitude(val OptFloat64) {
s.LaunchAltitude = val
}
// SetProfile sets the value of Profile.
func (s *PredictionResponseRequest) SetProfile(val OptString) {
s.Profile = val
}
// SetAscentRate sets the value of AscentRate.
func (s *PredictionResponseRequest) SetAscentRate(val OptFloat64) {
s.AscentRate = val
}
// SetBurstAltitude sets the value of BurstAltitude.
func (s *PredictionResponseRequest) SetBurstAltitude(val OptFloat64) {
s.BurstAltitude = val
}
// SetDescentRate sets the value of DescentRate.
func (s *PredictionResponseRequest) SetDescentRate(val OptFloat64) {
s.DescentRate = val
}
type PredictionResponseWarnings map[string]jx.Raw
func (s *PredictionResponseWarnings) init() PredictionResponseWarnings {
m := *s
if m == nil {
m = map[string]jx.Raw{}
*s = m
}
return m
}
// Ref: #/components/schemas/ReadinessResponse
type ReadinessResponse struct {
Status ReadinessResponseStatus `json:"status"`
DatasetTime OptDateTime `json:"dataset_time"`
ErrorMessage OptString `json:"error_message"`
}
// GetStatus returns the value of Status.
func (s *ReadinessResponse) GetStatus() ReadinessResponseStatus {
return s.Status
}
// GetDatasetTime returns the value of DatasetTime.
func (s *ReadinessResponse) GetDatasetTime() OptDateTime {
return s.DatasetTime
}
// GetErrorMessage returns the value of ErrorMessage.
func (s *ReadinessResponse) GetErrorMessage() OptString {
return s.ErrorMessage
}
// SetStatus sets the value of Status.
func (s *ReadinessResponse) SetStatus(val ReadinessResponseStatus) {
s.Status = val
}
// SetDatasetTime sets the value of DatasetTime.
func (s *ReadinessResponse) SetDatasetTime(val OptDateTime) {
s.DatasetTime = val
}
// SetErrorMessage sets the value of ErrorMessage.
func (s *ReadinessResponse) SetErrorMessage(val OptString) {
s.ErrorMessage = val
}
type ReadinessResponseStatus string
const (
ReadinessResponseStatusOk ReadinessResponseStatus = "ok"
ReadinessResponseStatusNotReady ReadinessResponseStatus = "not_ready"
ReadinessResponseStatusError ReadinessResponseStatus = "error"
)
// AllValues returns all ReadinessResponseStatus values.
func (ReadinessResponseStatus) AllValues() []ReadinessResponseStatus {
return []ReadinessResponseStatus{
ReadinessResponseStatusOk,
ReadinessResponseStatusNotReady,
ReadinessResponseStatusError,
}
}
// MarshalText implements encoding.TextMarshaler.
func (s ReadinessResponseStatus) MarshalText() ([]byte, error) {
switch s {
case ReadinessResponseStatusOk:
return []byte(s), nil
case ReadinessResponseStatusNotReady:
return []byte(s), nil
case ReadinessResponseStatusError:
return []byte(s), nil
default:
return nil, errors.Errorf("invalid value: %q", s)
}
}
// UnmarshalText implements encoding.TextUnmarshaler.
func (s *ReadinessResponseStatus) UnmarshalText(data []byte) error {
switch ReadinessResponseStatus(data) {
case ReadinessResponseStatusOk:
*s = ReadinessResponseStatusOk
return nil
case ReadinessResponseStatusNotReady:
*s = ReadinessResponseStatusNotReady
return nil
case ReadinessResponseStatusError:
*s = ReadinessResponseStatusError
return nil
default:
return errors.Errorf("invalid value: %q", data)
}
}

View file

@ -0,0 +1,46 @@
// Code generated by ogen, DO NOT EDIT.
package rest
import (
"context"
)
// Handler handles operations described by OpenAPI v3 specification.
type Handler interface {
// PerformPrediction implements performPrediction operation.
//
// Perform prediction.
//
// GET /api/v1/prediction
PerformPrediction(ctx context.Context, params PerformPredictionParams) (*PredictionResponse, error)
// ReadinessCheck implements readinessCheck operation.
//
// Readiness check.
//
// GET /ready
ReadinessCheck(ctx context.Context) (*ReadinessResponse, error)
// NewError creates *ErrorStatusCode from error returned by handler.
//
// Used for common default response.
NewError(ctx context.Context, err error) *ErrorStatusCode
}
// Server implements http server based on OpenAPI v3 specification and
// calls Handler to handle requests.
type Server struct {
h Handler
baseServer
}
// NewServer creates new Server.
func NewServer(h Handler, opts ...ServerOption) (*Server, error) {
s, err := newServerConfig(opts...).baseServer()
if err != nil {
return nil, err
}
return &Server{
h: h,
baseServer: s,
}, nil
}

View file

@ -0,0 +1,40 @@
// Code generated by ogen, DO NOT EDIT.
package rest
import (
"context"
ht "github.com/ogen-go/ogen/http"
)
// UnimplementedHandler is no-op Handler which returns http.ErrNotImplemented.
type UnimplementedHandler struct{}
var _ Handler = UnimplementedHandler{}
// PerformPrediction implements performPrediction operation.
//
// Perform prediction.
//
// GET /api/v1/prediction
func (UnimplementedHandler) PerformPrediction(ctx context.Context, params PerformPredictionParams) (r *PredictionResponse, _ error) {
return r, ht.ErrNotImplemented
}
// ReadinessCheck implements readinessCheck operation.
//
// Readiness check.
//
// GET /ready
func (UnimplementedHandler) ReadinessCheck(ctx context.Context) (r *ReadinessResponse, _ error) {
return r, ht.ErrNotImplemented
}
// NewError creates *ErrorStatusCode from error returned by handler.
//
// Used for common default response.
func (UnimplementedHandler) NewError(ctx context.Context, err error) (r *ErrorStatusCode) {
r = new(ErrorStatusCode)
return r
}

View file

@ -0,0 +1,344 @@
// Code generated by ogen, DO NOT EDIT.
package rest
import (
"fmt"
"github.com/go-faster/errors"
"github.com/ogen-go/ogen/validate"
)
func (s PerformPredictionProfile) Validate() error {
switch s {
case "standard_profile":
return nil
case "float_profile":
return nil
default:
return errors.Errorf("invalid value: %v", s)
}
}
func (s *PredictionResponse) Validate() error {
if s == nil {
return validate.ErrNilPointer
}
var failures []validate.FieldError
if err := func() error {
if value, ok := s.Request.Get(); ok {
if err := func() error {
if err := value.Validate(); err != nil {
return err
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "request",
Error: err,
})
}
if err := func() error {
if s.Prediction == nil {
return errors.New("nil is invalid value")
}
var failures []validate.FieldError
for i, elem := range s.Prediction {
if err := func() error {
if err := elem.Validate(); err != nil {
return err
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: fmt.Sprintf("[%d]", i),
Error: err,
})
}
}
if len(failures) > 0 {
return &validate.Error{Fields: failures}
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "prediction",
Error: err,
})
}
if len(failures) > 0 {
return &validate.Error{Fields: failures}
}
return nil
}
func (s *PredictionResponsePredictionItem) Validate() error {
if s == nil {
return validate.ErrNilPointer
}
var failures []validate.FieldError
if err := func() error {
if err := s.Stage.Validate(); err != nil {
return err
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "stage",
Error: err,
})
}
if err := func() error {
if s.Trajectory == nil {
return errors.New("nil is invalid value")
}
var failures []validate.FieldError
for i, elem := range s.Trajectory {
if err := func() error {
if err := elem.Validate(); err != nil {
return err
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: fmt.Sprintf("[%d]", i),
Error: err,
})
}
}
if len(failures) > 0 {
return &validate.Error{Fields: failures}
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "trajectory",
Error: err,
})
}
if len(failures) > 0 {
return &validate.Error{Fields: failures}
}
return nil
}
func (s PredictionResponsePredictionItemStage) Validate() error {
switch s {
case "ascent":
return nil
case "descent":
return nil
case "float":
return nil
default:
return errors.Errorf("invalid value: %v", s)
}
}
func (s *PredictionResponsePredictionItemTrajectoryItem) Validate() error {
if s == nil {
return validate.ErrNilPointer
}
var failures []validate.FieldError
if err := func() error {
if err := (validate.Float{}).Validate(float64(s.Latitude)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "latitude",
Error: err,
})
}
if err := func() error {
if err := (validate.Float{}).Validate(float64(s.Longitude)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "longitude",
Error: err,
})
}
if err := func() error {
if err := (validate.Float{}).Validate(float64(s.Altitude)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "altitude",
Error: err,
})
}
if len(failures) > 0 {
return &validate.Error{Fields: failures}
}
return nil
}
func (s *PredictionResponseRequest) Validate() error {
if s == nil {
return validate.ErrNilPointer
}
var failures []validate.FieldError
if err := func() error {
if value, ok := s.LaunchLatitude.Get(); ok {
if err := func() error {
if err := (validate.Float{}).Validate(float64(value)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "launch_latitude",
Error: err,
})
}
if err := func() error {
if value, ok := s.LaunchLongitude.Get(); ok {
if err := func() error {
if err := (validate.Float{}).Validate(float64(value)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "launch_longitude",
Error: err,
})
}
if err := func() error {
if value, ok := s.LaunchAltitude.Get(); ok {
if err := func() error {
if err := (validate.Float{}).Validate(float64(value)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "launch_altitude",
Error: err,
})
}
if err := func() error {
if value, ok := s.AscentRate.Get(); ok {
if err := func() error {
if err := (validate.Float{}).Validate(float64(value)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "ascent_rate",
Error: err,
})
}
if err := func() error {
if value, ok := s.BurstAltitude.Get(); ok {
if err := func() error {
if err := (validate.Float{}).Validate(float64(value)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "burst_altitude",
Error: err,
})
}
if err := func() error {
if value, ok := s.DescentRate.Get(); ok {
if err := func() error {
if err := (validate.Float{}).Validate(float64(value)); err != nil {
return errors.Wrap(err, "float")
}
return nil
}(); err != nil {
return err
}
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "descent_rate",
Error: err,
})
}
if len(failures) > 0 {
return &validate.Error{Fields: failures}
}
return nil
}
func (s *ReadinessResponse) Validate() error {
if s == nil {
return validate.ErrNilPointer
}
var failures []validate.FieldError
if err := func() error {
if err := s.Status.Validate(); err != nil {
return err
}
return nil
}(); err != nil {
failures = append(failures, validate.FieldError{
Name: "status",
Error: err,
})
}
if len(failures) > 0 {
return &validate.Error{Fields: failures}
}
return nil
}
func (s ReadinessResponseStatus) Validate() error {
switch s {
case "ok":
return nil
case "not_ready":
return nil
case "error":
return nil
default:
return errors.Errorf("invalid value: %v", s)
}
}

155
scripts/build_elevation.py Normal file
View file

@ -0,0 +1,155 @@
#!/usr/bin/env python3
"""
Download ETOPO 2022 30-arc-second elevation data and convert to ruaumoko-compatible
binary format (int16 little-endian, 21601 lat x 43200 lon, south-to-north).
Output: ~1.74 GiB binary file.
Usage: python3 build_elevation.py [output_path]
Default output: /srv/ruaumoko-dataset
"""
import sys
import os
import struct
import tempfile
import numpy as np
CELLS_PER_DEGREE = 120
NUM_LATS = 180 * CELLS_PER_DEGREE + 1 # 21601
NUM_LONS = 360 * CELLS_PER_DEGREE # 43200
EXPECTED_SIZE = NUM_LATS * NUM_LONS * 2 # 1,866,326,400
ETOPO_URL = "https://www.ngdc.noaa.gov/thredds/fileServer/global/ETOPO2022/30s/30s_surface_elev_netcdf/ETOPO_2022_v1_30s_N90W180_surface.nc"
def download_etopo(output_path):
"""Download ETOPO 2022 NetCDF and convert to ruaumoko binary format."""
try:
import xarray as xr
except ImportError:
print("ERROR: xarray is required. Install with: pip install xarray netcdf4")
sys.exit(1)
# Check if we can download directly or need a local file
nc_path = os.environ.get("ETOPO_NC_PATH")
if nc_path and os.path.exists(nc_path):
print(f"Using local ETOPO file: {nc_path}")
else:
print(f"Downloading ETOPO 2022 30-second data (~1.1 GB)...")
print(f" URL: {ETOPO_URL}")
print(f" (Set ETOPO_NC_PATH env var to use a pre-downloaded file)")
import urllib.request
with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as f:
nc_path = f.name
try:
urllib.request.urlretrieve(ETOPO_URL, nc_path, _progress)
print()
except Exception as e:
os.unlink(nc_path)
print(f"\nDownload failed: {e}")
print("\nAlternative: manually download ETOPO 2022 30s NetCDF from:")
print(" https://www.ncei.noaa.gov/products/etopo-global-relief-model")
print(f" Then set ETOPO_NC_PATH=/path/to/file.nc and re-run")
sys.exit(1)
print(f"Opening NetCDF dataset...")
ds = xr.open_dataset(nc_path)
# ETOPO 2022 30s has:
# - lat: -90 to +90, 21601 points (south to north)
# - lon: -180 to +180, 43201 points
# We need:
# - lat: -90 to +90, 21601 points (south to north) ← same
# - lon: 0 to 360 (exclusive), 43200 points ← need to shift and drop last
z = ds["z"] # elevation variable
print(f" Shape: {z.shape}")
print(f" Lat range: {float(z.lat.min())} to {float(z.lat.max())}")
print(f" Lon range: {float(z.lon.min())} to {float(z.lon.max())}")
# Sort latitude south-to-north (should already be, but ensure)
z = z.sortby("lat")
# Shift longitude from [-180, 180] to [0, 360)
print("Shifting longitude to [0, 360)...")
z = z.assign_coords(lon=(z.lon % 360))
z = z.sortby("lon")
data = z.values
print(f" Raw shape after sort: {data.shape}")
# Handle longitude dimension: drop last col if it wraps (43201 → 43200)
if data.shape[1] == NUM_LONS + 1:
data = data[:, :NUM_LONS]
elif data.shape[1] != NUM_LONS:
print(f"ERROR: unexpected lon dimension: {data.shape[1]}, expected {NUM_LONS} or {NUM_LONS+1}")
sys.exit(1)
# Handle latitude dimension: ETOPO 2022 is cell-centered (21600 rows),
# ruaumoko expects grid-registered (21601 rows including both poles).
# Pad by duplicating edge rows for the poles.
if data.shape[0] == NUM_LATS - 1:
print(f" Padding latitude from {data.shape[0]} to {NUM_LATS} (adding north pole row)")
north_pole = data[-1:, :] # duplicate +89.99... as +90
data = np.concatenate([data, north_pole], axis=0)
elif data.shape[0] != NUM_LATS:
print(f"ERROR: unexpected lat dimension: {data.shape[0]}, expected {NUM_LATS} or {NUM_LATS-1}")
sys.exit(1)
print(f"Final grid shape: {data.shape}")
print(f"Elevation range: {data.min():.1f} to {data.max():.1f} metres")
# Write as int16 little-endian
print(f"Writing to {output_path}...")
elev_int16 = np.clip(data, -32768, 32767).astype(np.dtype("<i2"))
elev_int16.tofile(output_path)
actual_size = os.path.getsize(output_path)
print(f"Written {actual_size:,} bytes (expected {EXPECTED_SIZE:,})")
if actual_size == EXPECTED_SIZE:
print("SUCCESS")
else:
print("WARNING: size mismatch!")
ds.close()
# Spot check
verify(output_path)
def verify(path):
"""Quick spot-check of the elevation dataset."""
data = np.memmap(path, dtype="<i2", mode="r", shape=(NUM_LATS, NUM_LONS))
tests = [
("Mt Everest (~28.0N, 86.9E)", 28.0, 86.9, 8000, 9000),
("Dead Sea (~31.5N, 35.5E)", 31.5, 35.5, -500, 0),
("Pacific Ocean (~0N, 180E)", 0.0, 180.0, -6000, 0),
("Auburn AU (~-34.0S, 138.7E)", -34.03, 138.69, 200, 400),
]
print("\n Spot-check:")
for name, lat, lon, lo, hi in tests:
lat_idx = int((lat + 90) * CELLS_PER_DEGREE)
lon_idx = int(lon * CELLS_PER_DEGREE)
val = int(data[lat_idx, lon_idx])
ok = "OK" if lo <= val <= hi else "FAIL"
print(f" {name}: {val}m [{ok}] (expected {lo}-{hi})")
_last_pct = -1
def _progress(block_num, block_size, total_size):
global _last_pct
if total_size > 0:
pct = int(block_num * block_size * 100 / total_size)
if pct != _last_pct and pct % 5 == 0:
_last_pct = pct
print(f" {pct}%...", end="", flush=True)
if __name__ == "__main__":
output = sys.argv[1] if len(sys.argv) > 1 else "/srv/ruaumoko-dataset"
download_etopo(output)