diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..287fc68 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,55 @@ +# Git +.git +.gitignore + +# Docker +Dockerfile +docker-compose.yml +.dockerignore + +# Documentation +README.md +*.md + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Environment files +.env +.env.local +.env.*.local + +# Build artifacts +predictor +*.exe +*.exe~ +*.dll +*.so +*.dylib +*.test + +# Logs +*.log + +# Temporary files +/tmp/ +/temp/ + +# Test coverage +*.out + +# Go workspace +go.work \ No newline at end of file diff --git a/.gitignore b/.gitignore index f000dae..8519b69 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,62 @@ -predictor +# Binaries for programs and plugins *.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` *.test + +# Output of the go coverage tool, specifically when used with LiteIDE *.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work + +# Environment variables +.env +.env.local +.env.*.local + +# IDE files +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# OS generated files +.DS_Store +.DS_Store? +._* +.Spotlight-V100 +.Trashes +ehthumbs.db +Thumbs.db + +# Logs +*.log + +# Temporary files /tmp/ +/temp/ + +# Build artifacts +/build/ +/dist/ + +# GRIB files +/grib_data/ +/grib_data/* + +# Leaflet WebUI +/leaflet_predictor +/leaflet_predictor/* + +# Tawhiri +/tawhiri +/tawhiri/* \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..082b194 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "makefile.configureOnOpen": false +} \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..ede1ec1 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,57 @@ +# Build stage +FROM golang:1.24.4-alpine AS builder + +# Install build dependencies +RUN apk add --no-cache git ca-certificates tzdata + +# Set working directory +WORKDIR /app + +# Copy go mod files +COPY go.mod go.sum ./ + +# Download dependencies +RUN go mod download + +# Copy source code +COPY . . + +# Build the application +RUN CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build \ + -ldflags="-w -s" \ + -o predictor \ + ./cmd/api + +# Runtime stage +FROM alpine:3.19 + +# Install runtime dependencies +RUN apk add --no-cache ca-certificates tzdata + +# Create non-root user +RUN addgroup -g 1001 -S appgroup && \ + adduser -u 1001 -S appuser -G appgroup + +# Set working directory +WORKDIR /app + +# Copy binary from builder stage +COPY --from=builder /app/predictor . + +# Create necessary directories +RUN mkdir -p /tmp/grib && \ + chown -R appuser:appgroup /app && \ + chmod -R 777 /tmp/grib + +# Switch to non-root user +USER appuser + +# Expose port +EXPOSE 8080 + +# Health check +HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ + CMD wget --no-verbose --tries=1 --spider http://localhost:8080/ready || exit 1 + +# Run the application +CMD ["./predictor"] \ No newline at end of file diff --git a/Makefile b/Makefile index 7c14792..c6c35d9 100644 --- a/Makefile +++ b/Makefile @@ -1,40 +1,111 @@ -.PHONY: build run test fmt lint clean generate-ogen help +# Variables +IMAGE_NAME = predictor +TAG = latest +COMPOSE_FILE = docker-compose.yml -# Build the application +# Validate Docker configuration +.PHONY: validate-docker +validate-docker: + ./scripts/validate-docker.sh + +# Build the Docker image +.PHONY: build build: - go build -o predictor ./cmd/api + docker build -t $(IMAGE_NAME):$(TAG) . -# Run locally -run: - go run ./cmd/api +# Run the application with docker-compose +.PHONY: up +up: + docker-compose -f $(COMPOSE_FILE) up -d + +# Run the application with docker-compose and rebuild +.PHONY: up-build +up-build: + docker-compose -f $(COMPOSE_FILE) up -d --build + +# Stop the application +.PHONY: down +down: + docker-compose -f $(COMPOSE_FILE) down + +# Stop the application and remove volumes +.PHONY: down-volumes +down-volumes: + docker-compose -f $(COMPOSE_FILE) down -v + +# View logs +.PHONY: logs +logs: + docker-compose -f $(COMPOSE_FILE) logs -f + +# View logs for specific service +.PHONY: logs-predictor +logs-predictor: + docker-compose -f $(COMPOSE_FILE) logs -f predictor + + +# Check service status +.PHONY: ps +ps: + docker-compose -f $(COMPOSE_FILE) ps + +# Execute command in predictor container +.PHONY: exec +exec: + docker-compose -f $(COMPOSE_FILE) exec predictor sh + +# Clean up Docker resources +.PHONY: clean +clean: + docker-compose -f $(COMPOSE_FILE) down -v --rmi all + docker system prune -f # Run tests +.PHONY: test test: go test ./... +# Build locally +.PHONY: build-local +build-local: + go build -o predictor ./cmd/api + +# Run locally +.PHONY: run-local +run-local: + cd cmd/api && go run . + # Format code +.PHONY: fmt fmt: go fmt ./... # Lint code +.PHONY: lint lint: golangci-lint run -# Generate ogen API code from swagger spec -generate-ogen: - go run github.com/ogen-go/ogen/cmd/ogen@latest --target pkg/rest --package rest --clean api/rest/predictor.swagger.yml - -# Clean build artifacts -clean: - rm -f predictor - # Show help +.PHONY: help help: @echo "Available commands:" - @echo " build - Build binary" - @echo " run - Run locally" + @echo " validate-docker - Validate Docker configuration" + @echo " build - Build Docker image" + @echo " up - Start services with docker-compose" + @echo " up-build - Start services and rebuild images" + @echo " down - Stop services" + @echo " down-volumes - Stop services and remove volumes" + @echo " logs - View all logs" + @echo " logs-predictor - View predictor logs" + @echo " ps - Show service status" + @echo " exec - Execute shell in predictor container" + @echo " clean - Clean up Docker resources" @echo " test - Run tests" + @echo " build-local - Build locally" + @echo " run-local - Run locally" @echo " fmt - Format code" @echo " lint - Lint code" - @echo " generate-ogen - Generate API code from swagger spec" - @echo " clean - Remove build artifacts" + @echo " help - Show this help" + +generate-ogen: + go run github.com/ogen-go/ogen/cmd/ogen@latest --target pkg/rest -package gsn --clean api/rest/predictor.swagger.yml \ No newline at end of file diff --git a/README.md b/README.md deleted file mode 100644 index 579a255..0000000 --- a/README.md +++ /dev/null @@ -1,261 +0,0 @@ -# Balloon Trajectory Predictor - -High-altitude balloon trajectory prediction service. Predicts ascent, burst, and descent trajectories using GFS wind forecast data from NOAA. - -The prediction algorithms are an exact port of [Tawhiri](https://github.com/cuspaceflight/tawhiri) (Cambridge University Spaceflight) to Go, verified to produce identical results. - -## Quick Start - -```bash -# Build -make build - -# Run (downloads ~9 GB of GFS data on first start, takes 30-60 min) -PREDICTOR_DATA_DIR=/tmp/predictor-data go run ./cmd/api - -# Check readiness -curl http://localhost:8080/ready - -# Run a prediction -curl 'http://localhost:8080/api/v1/prediction?launch_latitude=52.2&launch_longitude=0.1&launch_datetime=2026-03-28T12:00:00Z&launch_altitude=0&ascent_rate=5&burst_altitude=30000&descent_rate=5' -``` - -## Configuration - -All configuration is via environment variables. - -| Variable | Default | Description | -|---|---|---| -| `PREDICTOR_PORT` | `8080` | HTTP server port | -| `PREDICTOR_DATA_DIR` | `/tmp/predictor-data` | Directory for wind datasets and temp files | -| `PREDICTOR_DOWNLOAD_PARALLEL` | `8` | Max concurrent GRIB download goroutines | -| `PREDICTOR_UPDATE_INTERVAL` | `6h` | How often to check for new forecasts | -| `PREDICTOR_DATASET_TTL` | `48h` | Max age before a dataset is considered stale | -| `PREDICTOR_ELEVATION_DATASET` | `/srv/ruaumoko-dataset` | Path to elevation dataset (optional) | - -## API - -### `GET /api/v1/prediction` - -Run a balloon trajectory prediction. - -**Parameters** (query string): - -| Parameter | Required | Description | -|---|---|---| -| `launch_latitude` | yes | Launch latitude in degrees (-90 to 90) | -| `launch_longitude` | yes | Launch longitude in degrees (-180 to 180 or 0 to 360) | -| `launch_datetime` | yes | Launch time in RFC 3339 format | -| `launch_altitude` | no | Launch altitude in metres ASL (default: 0) | -| `profile` | no | `standard_profile` (default) or `float_profile` | -| `ascent_rate` | no | Ascent rate in m/s (default: 5) | -| `burst_altitude` | no | Burst altitude in metres (default: 28000) | -| `descent_rate` | no | Sea-level descent rate in m/s (default: 5) | -| `float_altitude` | no | Float altitude in metres (float_profile only) | -| `stop_datetime` | no | Float end time (float_profile only, default: +24h) | - -**Response** (Tawhiri-compatible): - -```json -{ - "prediction": [ - { - "stage": "ascent", - "trajectory": [ - {"datetime": "2026-03-28T12:00:00Z", "latitude": 52.2, "longitude": 0.1, "altitude": 0}, - ... - ] - }, - { - "stage": "descent", - "trajectory": [...] - } - ], - "metadata": { - "start_datetime": "...", - "complete_datetime": "..." - }, - "request": { - "dataset": "2026-03-28T06:00:00Z", - "launch_latitude": 52.2, - ... - } -} -``` - -### `GET /ready` - -Health check. Returns `{"status": "ok"}` when a dataset is loaded. - -## Elevation Dataset - -Without elevation data, descent terminates at sea level (altitude <= 0). With elevation data, descent terminates at ground level, matching Tawhiri's behaviour. - -### Building the elevation dataset - -The elevation dataset uses ETOPO 2022 at 30 arc-second resolution, converted to a ruaumoko-compatible binary format (21601 x 43200 grid of int16 little-endian elevation values in metres). - -**Requirements**: Python 3, xarray, netcdf4, numpy. - -```bash -pip install xarray netcdf4 numpy - -# Downloads ~1.1 GB from NOAA, produces ~1.74 GB binary file -python3 scripts/build_elevation.py /tmp/predictor-data/ruaumoko-dataset -``` - -To skip the download if you already have the ETOPO NetCDF file: - -```bash -ETOPO_NC_PATH=/path/to/ETOPO_2022_v1_30s_N90W180_surface.nc \ - python3 scripts/build_elevation.py /tmp/predictor-data/ruaumoko-dataset -``` - -The ETOPO 2022 NetCDF can be manually downloaded from: -https://www.ncei.noaa.gov/products/etopo-global-relief-model - -### Using the elevation dataset - -```bash -PREDICTOR_ELEVATION_DATASET=/tmp/predictor-data/ruaumoko-dataset go run ./cmd/api -``` - -If the file doesn't exist or can't be read, the service starts normally with a warning and falls back to sea-level termination. - -## Architecture - -``` -cmd/api/main.go Entry point, config, scheduler, HTTP server -internal/ - dataset/ - dataset.go Shape constants, pressure levels, S3 URLs - file.go mmap-backed dataset file (read/write/blit) - downloader/ - downloader.go S3 partial GRIB download (idx + range requests) - idx.go NOAA .idx file parser - config.go Environment-based configuration - elevation/ - elevation.go Ruaumoko-compatible elevation dataset (mmap int16) - prediction/ - interpolate.go 4D wind interpolation (time, lat, lon, altitude) - solver.go RK4 integrator with binary search termination - models.go Ascent, descent, wind models; flight profiles - warnings.go Prediction warning counters - service/ - service.go Dataset lifecycle, concurrent-safe access - transport/ - middleware/log.go Request logging middleware - rest/ - handler/handler.go ogen API handler implementation - handler/deps.go Service interface - transport.go ogen HTTP server, CORS -api/rest/predictor.swagger.yml OpenAPI 3.0 spec -pkg/rest/ Generated ogen code (17 files) -scripts/ - build_elevation.py ETOPO 2022 to ruaumoko converter -``` - -## Wind Dataset - -The service downloads GFS 0.5-degree forecast data from NOAA S3: - -| Property | Value | -|---|---| -| Source | `noaa-gfs-bdp-pds.s3.amazonaws.com` | -| Resolution | 0.5 degrees | -| Grid | 361 lat x 720 lon | -| Time steps | 65 (every 3 hours, 0-192h) | -| Pressure levels | 47 (1000 to 1 hPa) | -| Variables | Geopotential height, U-wind, V-wind | -| Dataset size | 9,528,667,200 bytes (~8.87 GiB) | -| Update cadence | Every 6 hours (GFS runs at 00, 06, 12, 18 UTC) | - -Data is downloaded using HTTP Range requests against `.idx` index files, fetching only the needed GRIB messages (HGT, UGRD, VGRD at 47 pressure levels). Full download takes 30-60 minutes depending on bandwidth. - -The dataset is stored as a memory-mapped flat binary file of float32 values in C-order with shape `(65, 47, 3, 361, 720)`. - -## Prediction Algorithms - -All algorithms are exact ports of the reference implementations in Tawhiri. The following sections describe the key components. - -### Interpolation (`internal/prediction/interpolate.go`) - -4D wind interpolation from the dataset grid to arbitrary coordinates. - -1. **Trilinear weights** (`pick3`): compute 8 interpolation weights for the (hour, lat, lon) cube corners. -2. **Altitude search** (`search`): binary search on interpolated geopotential height to find the two pressure levels bracketing the target altitude. -3. **Wind extraction** (`interp4`): 8-point weighted sum at each bracket level, then linear interpolation between levels. - -Reference: `tawhiri/interpolate.pyx` - -### Solver (`internal/prediction/solver.go`) - -4th-order Runge-Kutta integrator with dt = 60 seconds. - -- State vector: (latitude, longitude, altitude) in degrees and metres. -- Time: UNIX timestamp in seconds. -- Longitude is kept in [0, 360) via Python-style modulo after each `vecadd`. -- When a terminator fires, binary search refinement (tolerance 0.01) finds the precise termination point between the last good step and the first terminated step. -- Longitude interpolation (`lngLerp`) handles the 0/360 wrap-around. - -Reference: `tawhiri/solver.pyx` - -### Models (`internal/prediction/models.go`) - -- **Constant ascent**: vertical velocity = ascent_rate m/s. -- **Drag descent**: NASA atmosphere density model with drag coefficient = sea_level_rate * 1.1045. Descent rate increases with altitude due to thinner air. -- **Wind velocity**: u, v components from interpolation converted to degrees/second: `dlat = (180/pi) * v / (R)`, `dlng = (180/pi) * u / (R * cos(lat))` where R = 6371009 + altitude. -- **Linear model**: sum of component models (e.g., wind + ascent). -- **Elevation termination**: `ground_elevation > altitude` using ruaumoko dataset. - -Reference: `tawhiri/models.py` - -### Profiles - -- **standard_profile**: ascent (constant rate + wind) until burst altitude, then descent (drag + wind) until ground level. -- **float_profile**: ascent to float altitude, then drift at constant altitude until stop time. - -## Verification - -The predictor has been verified against the reference Tawhiri implementation: - -| Test | Result | -|---|---| -| Dataset (step 0): 36.6M float32 values vs Python/cfgrib | 0 mismatches, max diff = 0.0 | -| Prediction burst point vs public Tawhiri API | Identical (lat, lon, alt all match) | -| Prediction landing point vs public Tawhiri API | Identical lat/lon, 5m altitude diff (different elevation datasets) | -| Descent point count | Identical (46 points) | -| Ascent point count | Identical (101 points) | - -## Development - -```bash -# Regenerate ogen API code after modifying the swagger spec -make generate-ogen - -# Run tests -make test - -# Format -make fmt -``` - -### Comparison tools - -```bash -# Compare single dataset step against Python/cfgrib reference -go run ./cmd/compare_step0 - -# Run prediction and compare against public Tawhiri API -go run ./cmd/compare_prediction -``` - -## References - -- [Tawhiri](https://github.com/cuspaceflight/tawhiri) — Reference Python/Cython predictor (Cambridge University Spaceflight) -- [tawhiri-downloader](https://github.com/cuspaceflight/tawhiri-downloader) — OCaml dataset downloader -- [ruaumoko](https://github.com/cuspaceflight/ruaumoko) — Global elevation dataset -- [NOAA GFS](https://www.ncei.noaa.gov/products/weather-climate-models/global-forecast) — Global Forecast System -- [NOAA GFS on S3](https://noaa-gfs-bdp-pds.s3.amazonaws.com/index.html) — Public S3 bucket -- [ETOPO 2022](https://www.ncei.noaa.gov/products/etopo-global-relief-model) — Global relief model for elevation data -- [SondeHub Tawhiri API](https://api.v2.sondehub.org/tawhiri) — Public Tawhiri instance for comparison diff --git a/api/rest/predictor.swagger.yml b/api/rest/predictor.swagger.yml index 64ec316..73e3381 100644 --- a/api/rest/predictor.swagger.yml +++ b/api/rest/predictor.swagger.yml @@ -1,6 +1,6 @@ openapi: 3.0.4 info: - title: Predictor API + title: GSN Predictor - OpenAPI 3.0 version: 0.0.1 paths: /api/v1/prediction: @@ -12,17 +12,14 @@ paths: parameters: - in: query name: launch_latitude - required: true schema: type: number - in: query name: launch_longitude - required: true schema: type: number - in: query name: launch_datetime - required: true schema: type: string format: date-time @@ -34,8 +31,7 @@ paths: name: profile schema: type: string - enum: [standard_profile, float_profile] - default: standard_profile + enum: [standard_profile, float_profile, reverse_profile, custom_profile] - in: query name: ascent_rate schema: @@ -57,6 +53,23 @@ paths: schema: type: string format: date-time + - in: query + name: ascent_curve + schema: + type: string + - in: query + name: descent_curve + schema: + type: string + - in: query + name: interpolate + schema: + type: boolean + - in: query + name: format + schema: + type: string + enum: [json] - in: query name: dataset schema: @@ -64,17 +77,17 @@ paths: format: date-time responses: "200": - description: Prediction response + description: "Prediction response" content: application/json: schema: - $ref: '#/components/schemas/PredictionResponse' + $ref: '#/components/schemas/PredictionResult' default: description: Error content: application/json: schema: - $ref: '#/components/schemas/Error' + $ref: "#/components/schemas/Error" /ready: get: tags: @@ -93,52 +106,37 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/Error' + $ref: "#/components/schemas/Error" components: schemas: Error: type: object required: - - error + - message properties: - error: - type: object - required: - - type - - description - properties: - type: - type: string - description: - type: string - PredictionResponse: + message: + type: string + details: + type: string + PredictionResult: type: object required: - - prediction - metadata + - prediction properties: - request: + metadata: type: object + required: + - complete_datetime + - start_datetime properties: - dataset: + complete_datetime: type: string - launch_latitude: - type: number - launch_longitude: - type: number - launch_datetime: + format: date-time + start_datetime: type: string - launch_altitude: - type: number - profile: - type: string - ascent_rate: - type: number - burst_altitude: - type: number - descent_rate: - type: number + format: date-time prediction: type: array items: @@ -149,7 +147,7 @@ components: properties: stage: type: string - enum: ["ascent", "descent", "float"] + enum: ["ascent", "descent"] trajectory: type: array items: @@ -169,31 +167,18 @@ components: type: number altitude: type: number - metadata: - type: object - required: - - start_datetime - - complete_datetime - properties: - start_datetime: - type: string - format: date-time - complete_datetime: - type: string - format: date-time - warnings: - type: object - additionalProperties: true ReadinessResponse: type: object - required: - - status properties: status: type: string enum: [ok, not_ready, error] - dataset_time: + last_update: type: string format: date-time + is_fresh: + type: boolean error_message: type: string + required: + - status \ No newline at end of file diff --git a/cmd/api/main.go b/cmd/api/main.go index 08d12e3..2e9bf6b 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -2,97 +2,112 @@ package main import ( "context" - "fmt" "os" "os/signal" "syscall" - "time" - "predictor-refactored/internal/downloader" - "predictor-refactored/internal/service" - "predictor-refactored/internal/transport/rest" - "predictor-refactored/internal/transport/rest/handler" - - "github.com/go-co-op/gocron" + "git.intra.yksa.space/gsn/predictor/internal/jobs/grib/updater" + "git.intra.yksa.space/gsn/predictor/internal/pkg/grib" + "git.intra.yksa.space/gsn/predictor/internal/pkg/log" + "git.intra.yksa.space/gsn/predictor/internal/service" + "git.intra.yksa.space/gsn/predictor/internal/transport/rest" + "git.intra.yksa.space/gsn/predictor/internal/transport/rest/handler" + "git.intra.yksa.space/gsn/predictor/pkg/scheduler" "go.uber.org/zap" ) +const servicePrefix = "GSN_PREDICTOR" + func main() { - log, err := zap.NewProduction() + lg, err := zap.NewProduction() if err != nil { panic(err) } - defer log.Sync() + defer lg.Sync() + ctx := log.ToCtx(context.Background(), lg) - cfg := downloader.LoadConfig() - log.Info("configuration loaded", - zap.String("data_dir", cfg.DataDir), - zap.Int("parallel", cfg.Parallel), - zap.Duration("update_interval", cfg.UpdateInterval), - zap.Duration("dataset_ttl", cfg.DatasetTTL)) - - if err := os.MkdirAll(cfg.DataDir, 0o755); err != nil { - log.Fatal("failed to create data directory", zap.Error(err)) + schedulerConfig, err := scheduler.NewConfig() + if err != nil { + log.Ctx(ctx).Fatal("failed to load scheduler configuration", zap.Error(err)) } - svc := service.New(cfg, log) + gribUpdaterConfig, err := updater.NewConfig() + if err != nil { + log.Ctx(ctx).Fatal("failed to load GRIB updater configuration", zap.Error(err)) + } + + gribCfg, err := grib.NewConfig() + if err != nil { + log.Ctx(ctx).Fatal("failed to load GRIB configuration", zap.Error(err)) + } + + gribService, err := grib.New(gribCfg) + if err != nil { + log.Ctx(ctx).Fatal("failed to initialize GRIB service", zap.Error(err)) + } + defer gribService.Close() + + // Force GRIB update on startup in a goroutine + go func() { + log.Ctx(ctx).Info("Performing initial GRIB update (async)...") + if err := gribService.Update(ctx); err != nil { + log.Ctx(ctx).Error("initial GRIB update failed", zap.Error(err)) + } else { + log.Ctx(ctx).Info("initial GRIB update complete") + } + }() + + svc, err := service.New(gribService) + if err != nil { + log.Ctx(ctx).Fatal("failed to initialize service", zap.Error(err)) + } defer svc.Close() - // Load elevation dataset (optional — falls back to sea-level termination) - elevPath := "/srv/ruaumoko-dataset" - if v := os.Getenv("PREDICTOR_ELEVATION_DATASET"); v != "" { - elevPath = v - } - svc.LoadElevation(elevPath) + var sched *scheduler.Scheduler + if schedulerConfig.Enabled { + sched = scheduler.New() - // Initial dataset load (async so the server starts immediately) - go func() { - log.Info("performing initial dataset update...") - if err := svc.Update(context.Background()); err != nil { - log.Error("initial dataset update failed", zap.Error(err)) - } else { - log.Info("initial dataset update complete") + gribJob := updater.New(gribService, gribUpdaterConfig) + if err := sched.AddJob(gribJob); err != nil { + log.Ctx(ctx).Error("failed to add GRIB update job to scheduler", zap.Error(err)) } - }() - // Scheduler for periodic dataset updates - scheduler := gocron.NewScheduler(time.UTC) - scheduler.Every(cfg.UpdateInterval).Do(func() { - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Minute) - defer cancel() - log.Info("scheduled dataset update starting") - if err := svc.Update(ctx); err != nil { - log.Error("scheduled dataset update failed", zap.Error(err)) - } else { - log.Info("scheduled dataset update complete") - } - }) - scheduler.StartAsync() - defer scheduler.Stop() - - // HTTP transport (ogen) - port := 8080 - if p := os.Getenv("PREDICTOR_PORT"); p != "" { - fmt.Sscanf(p, "%d", &port) + log.Ctx(ctx).Info("scheduler initialized with jobs") } - h := handler.New(svc, log) - transport, err := rest.New(h, port, log) + handler := handler.New(svc) + + restConfig, err := rest.NewConfig() if err != nil { - log.Fatal("failed to create transport", zap.Error(err)) + lg.Fatal("failed to init transport config", zap.Error(err)) } - go func() { - if err := transport.Run(); err != nil { - log.Fatal("HTTP server error", zap.Error(err)) - } - }() + transport, err := rest.New(handler, restConfig) + if err != nil { + lg.Fatal("failed to init transport", zap.Error(err)) + } - log.Info("service started") + svc.Start() + if sched != nil { + sched.Start() + lg.Info("scheduler started") + } + + lg.Info("service started successfully") - // Graceful shutdown sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - sig := <-sigChan - log.Info("received shutdown signal", zap.String("signal", sig.String())) + + go func() { + lg.Info("starting HTTP server on port", zap.Int("port", restConfig.Port)) + transport.Run() + }() + + <-sigChan + lg.Info("received shutdown signal, stopping service") + + if sched != nil { + sched.Stop() + lg.Info("scheduler stopped") + } } diff --git a/cmd/compare_prediction/main.go b/cmd/compare_prediction/main.go deleted file mode 100644 index 13ae2a6..0000000 --- a/cmd/compare_prediction/main.go +++ /dev/null @@ -1,195 +0,0 @@ -package main - -import ( - "context" - "encoding/json" - "fmt" - "io" - "math" - "net/http" - "os" - "time" - - "predictor-refactored/internal/dataset" - "predictor-refactored/internal/downloader" - "predictor-refactored/internal/prediction" - - "go.uber.org/zap" -) - -// Downloads a few forecast steps and runs a prediction, then compares -// against the public Tawhiri API. -func main() { - log, _ := zap.NewDevelopment() - - cfg := &downloader.Config{ - DataDir: os.TempDir(), - Parallel: 4, - } - dl := downloader.NewDownloader(cfg, log) - - ctx := context.Background() - - // Find latest run - run, err := dl.FindLatestRun(ctx) - if err != nil { - fmt.Fprintf(os.Stderr, "FindLatestRun: %v\n", err) - os.Exit(1) - } - fmt.Printf("Using run: %s\n", run.Format("2006010215")) - - // Create dataset and download first 10 steps (0-27 hours, enough for a prediction) - dsPath := fmt.Sprintf("/tmp/pred_test_%s.bin", run.Format("2006010215")) - defer os.Remove(dsPath) - - ds, err := dataset.Create(dsPath) - if err != nil { - fmt.Fprintf(os.Stderr, "Create: %v\n", err) - os.Exit(1) - } - - date := run.Format("20060102") - runHour := run.Hour() - stepsToDownload := []int{0, 3, 6, 9, 12, 15, 18, 21, 24, 27} - - fmt.Printf("Downloading %d steps...\n", len(stepsToDownload)) - for _, step := range stepsToDownload { - hourIdx := dataset.HourIndex(step) - fmt.Printf(" step %d (hour idx %d)...\n", step, hourIdx) - - urlA := dataset.GribURL(date, runHour, step) - if err := dl.DownloadAndBlit(ctx, ds, urlA, hourIdx, dataset.LevelSetA); err != nil { - fmt.Fprintf(os.Stderr, " pgrb2 step %d: %v\n", step, err) - os.Exit(1) - } - - urlB := dataset.GribURLB(date, runHour, step) - if err := dl.DownloadAndBlit(ctx, ds, urlB, hourIdx, dataset.LevelSetB); err != nil { - fmt.Fprintf(os.Stderr, " pgrb2b step %d: %v\n", step, err) - os.Exit(1) - } - } - ds.Flush() - fmt.Println("Download complete") - - // Set dataset time - ds.DSTime = run - - // Run our prediction - launchLat := 52.2135 - launchLon := 0.0964 // already in [0, 360) - launchAlt := 0.0 - ascentRate := 5.0 - burstAlt := 30000.0 - descentRate := 5.0 - - // Launch 3 hours into the forecast - launchTime := run.Add(3 * time.Hour) - launchTimestamp := float64(launchTime.Unix()) - dsEpoch := float64(run.Unix()) - - warnings := &prediction.Warnings{} - stages := prediction.StandardProfile(ascentRate, burstAlt, descentRate, ds, dsEpoch, warnings, nil) - results := prediction.RunPrediction(launchTimestamp, launchLat, launchLon, launchAlt, stages) - - fmt.Printf("\n=== Our prediction ===\n") - for i, sr := range results { - stage := "ascent" - if i == 1 { - stage = "descent" - } - first := sr.Points[0] - last := sr.Points[len(sr.Points)-1] - fmt.Printf(" %s: %d points, start=(%.4f, %.4f, %.0fm) end=(%.4f, %.4f, %.0fm)\n", - stage, len(sr.Points), - first.Lat, first.Lng, first.Alt, - last.Lat, last.Lng, last.Alt) - } - - // Get landing point - var ourLandLat, ourLandLon float64 - if len(results) >= 2 { - last := results[1].Points[len(results[1].Points)-1] - ourLandLat = last.Lat - ourLandLon = last.Lng - if ourLandLon > 180 { - ourLandLon -= 360 - } - } - fmt.Printf(" Landing: lat=%.4f, lon=%.4f\n", ourLandLat, ourLandLon) - - // Compare against public Tawhiri API - fmt.Printf("\n=== Tawhiri API comparison ===\n") - tawhiriLandLat, tawhiriLandLon, err := queryTawhiri(launchLat, launchLon, launchAlt, launchTime, ascentRate, burstAlt, descentRate) - if err != nil { - fmt.Printf(" Tawhiri API error: %v\n", err) - fmt.Println(" (Cannot compare — Tawhiri may use a different dataset)") - ds.Close() - return - } - fmt.Printf(" Tawhiri landing: lat=%.4f, lon=%.4f\n", tawhiriLandLat, tawhiriLandLon) - - dist := haversine(ourLandLat, ourLandLon, tawhiriLandLat, tawhiriLandLon) - fmt.Printf(" Distance between landing points: %.2f km\n", dist/1000) - - if dist < 1000 { - fmt.Println(" CLOSE MATCH (< 1 km)") - } else if dist < 50000 { - fmt.Printf(" MODERATE DIFFERENCE (%.1f km) — likely different datasets\n", dist/1000) - } else { - fmt.Printf(" LARGE DIFFERENCE (%.1f km) — possible bug\n", dist/1000) - } - - ds.Close() -} - -func queryTawhiri(lat, lon, alt float64, launchTime time.Time, ascentRate, burstAlt, descentRate float64) (landLat, landLon float64, err error) { - url := fmt.Sprintf( - "https://api.v2.sondehub.org/tawhiri?launch_latitude=%.4f&launch_longitude=%.4f&launch_altitude=%.0f&launch_datetime=%s&ascent_rate=%.1f&burst_altitude=%.0f&descent_rate=%.1f", - lat, lon, alt, launchTime.Format(time.RFC3339), ascentRate, burstAlt, descentRate) - - resp, err := http.Get(url) - if err != nil { - return 0, 0, err - } - defer resp.Body.Close() - - body, _ := io.ReadAll(resp.Body) - if resp.StatusCode != 200 { - return 0, 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, string(body)) - } - - var result struct { - Prediction []struct { - Stage string `json:"stage"` - Trajectory []struct { - Latitude float64 `json:"latitude"` - Longitude float64 `json:"longitude"` - Altitude float64 `json:"altitude"` - } `json:"trajectory"` - } `json:"prediction"` - } - - if err := json.Unmarshal(body, &result); err != nil { - return 0, 0, err - } - - for _, stage := range result.Prediction { - if stage.Stage == "descent" && len(stage.Trajectory) > 0 { - last := stage.Trajectory[len(stage.Trajectory)-1] - return last.Latitude, last.Longitude, nil - } - } - - return 0, 0, fmt.Errorf("no descent stage found") -} - -func haversine(lat1, lon1, lat2, lon2 float64) float64 { - const R = 6371000.0 - phi1 := lat1 * math.Pi / 180 - phi2 := lat2 * math.Pi / 180 - dphi := (lat2 - lat1) * math.Pi / 180 - dlam := (lon2 - lon1) * math.Pi / 180 - a := math.Sin(dphi/2)*math.Sin(dphi/2) + math.Cos(phi1)*math.Cos(phi2)*math.Sin(dlam/2)*math.Sin(dlam/2) - return R * 2 * math.Atan2(math.Sqrt(a), math.Sqrt(1-a)) -} diff --git a/cmd/compare_step0/main.go b/cmd/compare_step0/main.go deleted file mode 100644 index d0d697d..0000000 --- a/cmd/compare_step0/main.go +++ /dev/null @@ -1,104 +0,0 @@ -package main - -import ( - "context" - "fmt" - "os" - "time" - - "predictor-refactored/internal/dataset" - "predictor-refactored/internal/downloader" - - "go.uber.org/zap" -) - -// Downloads step 0 of a given run and writes a minimal dataset for comparison. -// Usage: go run ./cmd/compare_step0 -func main() { - if len(os.Args) < 3 { - fmt.Fprintf(os.Stderr, "Usage: %s \n", os.Args[0]) - os.Exit(1) - } - - runStr := os.Args[1] - outPath := os.Args[2] - - run, err := time.Parse("2006010215", runStr) - if err != nil { - fmt.Fprintf(os.Stderr, "Invalid run time %q: %v\n", runStr, err) - os.Exit(1) - } - - log, _ := zap.NewDevelopment() - - // Create a full-size dataset (we only fill step 0) - fmt.Printf("Creating dataset at %s (%d bytes)...\n", outPath, dataset.DatasetSize) - ds, err := dataset.Create(outPath) - if err != nil { - fmt.Fprintf(os.Stderr, "Create dataset: %v\n", err) - os.Exit(1) - } - defer ds.Close() - - cfg := &downloader.Config{ - DataDir: os.TempDir(), - Parallel: 4, - } - dl := downloader.NewDownloader(cfg, log) - - ctx := context.Background() - date := run.Format("20060102") - runHour := run.Hour() - - // Download and blit step 0 from pgrb2 - fmt.Println("Downloading pgrb2 step 0...") - urlA := dataset.GribURL(date, runHour, 0) - if err := dl.DownloadAndBlit(ctx, ds, urlA, 0, dataset.LevelSetA); err != nil { - fmt.Fprintf(os.Stderr, "pgrb2: %v\n", err) - os.Exit(1) - } - fmt.Println(" done") - - // Download and blit step 0 from pgrb2b - fmt.Println("Downloading pgrb2b step 0...") - urlB := dataset.GribURLB(date, runHour, 0) - if err := dl.DownloadAndBlit(ctx, ds, urlB, 0, dataset.LevelSetB); err != nil { - fmt.Fprintf(os.Stderr, "pgrb2b: %v\n", err) - os.Exit(1) - } - fmt.Println(" done") - - if err := ds.Flush(); err != nil { - fmt.Fprintf(os.Stderr, "Flush: %v\n", err) - os.Exit(1) - } - - // Spot-check: print same values as the Python script - fmt.Println("\n=== Go dataset values (spot check) ===") - type testPoint struct { - varName string - varIdx int - levelIdx int - lat, lon int - } - - points := []testPoint{ - {"HGT", 0, 0, 0, 0}, // HGT @ 1000mb, lat=-90, lon=0 - {"HGT", 0, 0, 180, 0}, // HGT @ 1000mb, lat=0, lon=0 - {"HGT", 0, 0, 360, 0}, // HGT @ 1000mb, lat=+90, lon=0 - {"HGT", 0, 20, 180, 360}, // HGT @ 500mb, lat=0, lon=180 - {"UGRD", 1, 0, 180, 0}, // UGRD @ 1000mb, lat=0, lon=0 - {"VGRD", 2, 0, 180, 0}, // VGRD @ 1000mb, lat=0, lon=0 - {"UGRD", 1, 20, 284, 0}, // UGRD @ 500mb, lat=52N, lon=0 - } - - for _, p := range points { - val := ds.Val(0, p.levelIdx, p.varIdx, p.lat, p.lon) - actualLat := -90.0 + float64(p.lat)*0.5 - actualLon := float64(p.lon) * 0.5 - fmt.Printf(" %-4s %4dmb lat=%+7.1f lon=%6.1f: %12.4f\n", - p.varName, dataset.Pressures[p.levelIdx], actualLat, actualLon, val) - } - - fmt.Printf("\nDataset written to %s\n", outPath) -} diff --git a/go.mod b/go.mod index 35b9486..079e12c 100644 --- a/go.mod +++ b/go.mod @@ -1,25 +1,44 @@ -module predictor-refactored +module git.intra.yksa.space/gsn/predictor -go 1.25.0 +go 1.24.4 require ( + github.com/caarlos0/env/v11 v11.3.1 github.com/edsrzf/mmap-go v1.2.0 github.com/go-co-op/gocron v1.37.0 github.com/go-faster/errors v0.7.1 - github.com/go-faster/jx v1.2.0 + github.com/go-faster/jx v1.1.0 github.com/nilsmagnus/grib v1.2.8 - github.com/ogen-go/ogen v1.20.2 - go.opentelemetry.io/otel v1.42.0 - go.opentelemetry.io/otel/metric v1.42.0 - go.opentelemetry.io/otel/trace v1.42.0 - go.uber.org/zap v1.27.1 - golang.org/x/sync v0.20.0 + github.com/ogen-go/ogen v1.16.0 + github.com/rs/cors v1.11.1 + go.opentelemetry.io/otel v1.38.0 + go.opentelemetry.io/otel/metric v1.38.0 + go.opentelemetry.io/otel/trace v1.38.0 + go.uber.org/zap v1.27.0 + golang.org/x/sync v0.17.0 ) require ( - github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.39.3 // indirect + github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.2 // indirect + github.com/aws/aws-sdk-go-v2/config v1.31.13 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.18.17 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 // indirect + github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.10 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.10 // indirect + github.com/aws/aws-sdk-go-v2/service/s3 v1.88.5 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.29.7 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.38.7 // indirect + github.com/aws/smithy-go v1.23.1 // indirect github.com/dlclark/regexp2 v1.11.5 // indirect - github.com/fatih/color v1.19.0 // indirect + github.com/fatih/color v1.18.0 // indirect github.com/ghodss/yaml v1.0.0 // indirect github.com/go-faster/yaml v0.4.6 // indirect github.com/go-logr/logr v1.4.3 // indirect @@ -31,11 +50,11 @@ require ( github.com/segmentio/asm v1.2.1 // indirect github.com/shopspring/decimal v1.4.0 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect - go.uber.org/atomic v1.9.0 // indirect + go.uber.org/atomic v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect - golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 // indirect - golang.org/x/net v0.52.0 // indirect - golang.org/x/sys v0.42.0 // indirect - golang.org/x/text v0.35.0 // indirect + golang.org/x/exp v0.0.0-20251017212417-90e834f514db // indirect + golang.org/x/net v0.46.0 // indirect + golang.org/x/sys v0.37.0 // indirect + golang.org/x/text v0.30.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect ) diff --git a/go.sum b/go.sum index 1de3241..eaa1790 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,41 @@ -github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= -github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/aws/aws-sdk-go-v2 v1.39.3 h1:h7xSsanJ4EQJXG5iuW4UqgP7qBopLpj84mpkNx3wPjM= +github.com/aws/aws-sdk-go-v2 v1.39.3/go.mod h1:yWSxrnioGUZ4WVv9TgMrNUeLV3PFESn/v+6T/Su8gnM= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.2 h1:t9yYsydLYNBk9cJ73rgPhPWqOh/52fcWDQB5b1JsKSY= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.7.2/go.mod h1:IusfVNTmiSN3t4rhxWFaBAqn+mcNdwKtPcV16eYdgko= +github.com/aws/aws-sdk-go-v2/config v1.31.13 h1:wcqQB3B0PgRPUF5ZE/QL1JVOyB0mbPevHFoAMpemR9k= +github.com/aws/aws-sdk-go-v2/config v1.31.13/go.mod h1:ySB5D5ybwqGbT6c3GszZ+u+3KvrlYCUQNo62+hkKOFk= +github.com/aws/aws-sdk-go-v2/credentials v1.18.17 h1:skpEwzN/+H8cdrrtT8y+rvWJGiWWv0DeNAe+4VTf+Vs= +github.com/aws/aws-sdk-go-v2/credentials v1.18.17/go.mod h1:Ed+nXsaYa5uBINovJhcAWkALvXw2ZLk36opcuiSZfJM= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10 h1:UuGVOX48oP4vgQ36oiKmW9RuSeT8jlgQgBFQD+HUiHY= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.18.10/go.mod h1:vM/Ini41PzvudT4YkQyE/+WiQJiQ6jzeDyU8pQKwCac= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10 h1:mj/bdWleWEh81DtpdHKkw41IrS+r3uw1J/VQtbwYYp8= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.10/go.mod h1:7+oEMxAZWP8gZCyjcm9VicI0M61Sx4DJtcGfKYv2yKQ= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10 h1:wh+/mn57yhUrFtLIxyFPh2RgxgQz/u+Yrf7hiHGHqKY= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.10/go.mod h1:7zirD+ryp5gitJJ2m1BBux56ai8RIRDykXZrJSp540w= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4 h1:WKuaxf++XKWlHWu9ECbMlha8WOEGm0OUEZqm4K/Gcfk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.4/go.mod h1:ZWy7j6v1vWGmPReu0iSGvRiise4YI5SkR3OHKTZ6Wuc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.10 h1:FHw90xCTsofzk6vjU808TSuDtDfOOKPNdz5Weyc3tUI= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.4.10/go.mod h1:n8jdIE/8F3UYkg8O4IGkQpn2qUmapg/1K1yl29/uf/c= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2 h1:xtuxji5CS0JknaXoACOunXOYOQzgfTvGAc9s2QdCJA4= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.2/go.mod h1:zxwi0DIR0rcRcgdbl7E2MSOvxDyyXGBlScvBkARFaLQ= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.1 h1:ne+eepnDB2Wh5lHKzELgEncIqeVlQ1rSF9fEa4r5I+A= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.9.1/go.mod h1:u0Jkg0L+dcG1ozUq21uFElmpbmjBnhHR5DELHIme4wg= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10 h1:DRND0dkCKtJzCj4Xl4OpVbXZgfttY5q712H9Zj7qc/0= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.10/go.mod h1:tGGNmJKOTernmR2+VJ0fCzQRurcPZj9ut60Zu5Fi6us= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.10 h1:DA+Hl5adieRyFvE7pCvBWm3VOZTRexGVkXw33SUqNoY= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.19.10/go.mod h1:L+A89dH3/gr8L4ecrdzuXUYd1znoko6myzndVGZx/DA= +github.com/aws/aws-sdk-go-v2/service/s3 v1.88.5 h1:FlGScxzCGNzT+2AvHT1ZGMvxTwAMa6gsooFb1pO/AiM= +github.com/aws/aws-sdk-go-v2/service/s3 v1.88.5/go.mod h1:N/iojY+8bW3MYol9NUMuKimpSbPEur75cuI1SmtonFM= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.7 h1:fspVFg6qMx0svs40YgRmE7LZXh9VRZvTT35PfdQR6FM= +github.com/aws/aws-sdk-go-v2/service/sso v1.29.7/go.mod h1:BQTKL3uMECaLaUV3Zc2L4Qybv8C6BIXjuu1dOPyxTQs= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2 h1:scVnW+NLXasGOhy7HhkdT9AGb6kjgW7fJ5xYkUaqHs0= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.35.2/go.mod h1:FRNCY3zTEWZXBKm2h5UBUPvCVDOecTad9KhynDyGBc0= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.7 h1:VEO5dqFkMsl8QZ2yHsFDJAIZLAkEbaYDB+xdKi0Feic= +github.com/aws/aws-sdk-go-v2/service/sts v1.38.7/go.mod h1:L1xxV3zAdB+qVrVW/pBIrIAnHFWHo6FBbFe4xOGsG/o= +github.com/aws/smithy-go v1.23.1 h1:sLvcH6dfAFwGkHLZ7dGiYF7aK6mg4CgKA/iDKjLDt9M= +github.com/aws/smithy-go v1.23.1/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/caarlos0/env/v11 v11.3.1 h1:cArPWC15hWmEt+gWk7YBi7lEXTXCvpaSdCiZE2X5mCA= +github.com/caarlos0/env/v11 v11.3.1/go.mod h1:qupehSf/Y0TUTsxKywqRt/vJjN5nz6vauiYEUUr8P4U= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -8,19 +44,21 @@ github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZ github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/edsrzf/mmap-go v1.2.0 h1:hXLYlkbaPzt1SaQk+anYwKSRNhufIDCchSPkUD6dD84= github.com/edsrzf/mmap-go v1.2.0/go.mod h1:19H/e8pUPLicwkyNgOykDXkJ9F0MHE+Z52B8EIth78Q= -github.com/fatih/color v1.19.0 h1:Zp3PiM21/9Ld6FzSKyL5c/BULoe/ONr9KlbYVOfG8+w= -github.com/fatih/color v1.19.0/go.mod h1:zNk67I0ZUT1bEGsSGyCZYZNrHuTkJJB+r6Q9VuMi0LE= +github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= +github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/go-co-op/gocron v1.37.0 h1:ZYDJGtQ4OMhTLKOKMIch+/CY70Brbb1dGdooLEhh7b0= github.com/go-co-op/gocron v1.37.0/go.mod h1:3L/n6BkO7ABj+TrfSVXLRzsP26zmikL4ISkLQ0O8iNY= github.com/go-faster/errors v0.7.1 h1:MkJTnDoEdi9pDabt1dpWf7AA8/BaSYZqibYyhZ20AYg= github.com/go-faster/errors v0.7.1/go.mod h1:5ySTjWFiphBs07IKuiL69nxdfd5+fzh1u7FPGZP2quo= -github.com/go-faster/jx v1.2.0 h1:T2YHJPrFaYu21fJtUxC9GzmluKu8rVIFDwwGBKTDseI= -github.com/go-faster/jx v1.2.0/go.mod h1:UWLOVDmMG597a5tBFPLIWJdUxz5/2emOpfsj9Neg0PE= +github.com/go-faster/jx v1.1.0 h1:ZsW3wD+snOdmTDy9eIVgQdjUpXRRV4rqW8NS3t+20bg= +github.com/go-faster/jx v1.1.0/go.mod h1:vKDNikrKoyUmpzaJ0OkIkRQClNHFX/nF3dnTJZb3skg= github.com/go-faster/yaml v0.4.6 h1:lOK/EhI04gCpPgPhgt0bChS6bvw7G3WwI8xxVe0sw9I= github.com/go-faster/yaml v0.4.6/go.mod h1:390dRIvV4zbnO7qC9FGo6YYutc+wyyUSHBgbXL52eXk= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= @@ -39,14 +77,19 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/nilsmagnus/grib v1.2.8 h1:H7ch/1/agaCqM3MC8hW1Ft+EJ+q2XB757uml/IfPvp4= github.com/nilsmagnus/grib v1.2.8/go.mod h1:XHm+5zuoOk0NSIWaGmA3JaAxI4i50YvD1L1vz+aqPOQ= -github.com/ogen-go/ogen v1.20.2 h1:mEZGPST7ZeX84AkqRlFawDLwcwuzcLO5PtYpAXLT1YE= -github.com/ogen-go/ogen v1.20.2/go.mod h1:sJ1pJVp4S1RcSZlYIiMLo0QSMSt2pls4zfrc+hNKnzk= +github.com/ogen-go/ogen v1.14.0 h1:TU1Nj4z9UBsAfTkf+IhuNNp7igdFQKqkk9+6/y4XuWg= +github.com/ogen-go/ogen v1.14.0/go.mod h1:Iw1vkqkx6SU7I9th5ceP+fVPJ6Wge4e3kAVzAxJEpPE= +github.com/ogen-go/ogen v1.16.0 h1:fKHEYokW/QrMzVNXId74/6RObRIUs9T2oroGKtR25Iw= +github.com/ogen-go/ogen v1.16.0/go.mod h1:s3nWiMzybSf8fhxckyO+wtto92+QHpEL8FmkPnhL3jI= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= @@ -54,8 +97,13 @@ github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o= +github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= +github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= -github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/rs/cors v1.11.1 h1:eU3gRzXLRK57F5rKMGMZURNdIG4EoAmX8k94r9wXWHA= +github.com/rs/cors v1.11.1/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= +github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= +github.com/segmentio/asm v1.2.0/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= github.com/segmentio/asm v1.2.1 h1:DTNbBqs57ioxAD4PrArqftgypG4/qNpXoJx8TVXxPR0= github.com/segmentio/asm v1.2.1/go.mod h1:BqMnlJP91P8d+4ibuonYZw9mfnzI9HfxselHZr5aAcs= github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= @@ -67,35 +115,57 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= -github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= -go.opentelemetry.io/otel v1.42.0 h1:lSQGzTgVR3+sgJDAU/7/ZMjN9Z+vUip7leaqBKy4sho= -go.opentelemetry.io/otel v1.42.0/go.mod h1:lJNsdRMxCUIWuMlVJWzecSMuNjE7dOYyWlqOXWkdqCc= -go.opentelemetry.io/otel/metric v1.42.0 h1:2jXG+3oZLNXEPfNmnpxKDeZsFI5o4J+nz6xUlaFdF/4= -go.opentelemetry.io/otel/metric v1.42.0/go.mod h1:RlUN/7vTU7Ao/diDkEpQpnz3/92J9ko05BIwxYa2SSI= -go.opentelemetry.io/otel/trace v1.42.0 h1:OUCgIPt+mzOnaUTpOQcBiM/PLQ/Op7oq6g4LenLmOYY= -go.opentelemetry.io/otel/trace v1.42.0/go.mod h1:f3K9S+IFqnumBkKhRJMeaZeNk9epyhnCmQh/EysQCdc= +go.opentelemetry.io/otel v1.36.0 h1:UumtzIklRBY6cI/lllNZlALOF5nNIzJVb16APdvgTXg= +go.opentelemetry.io/otel v1.36.0/go.mod h1:/TcFMXYjyRNh8khOAO9ybYkqaDBb/70aVwkNML4pP8E= +go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= +go.opentelemetry.io/otel v1.38.0/go.mod h1:zcmtmQ1+YmQM9wrNsTGV/q/uyusom3P8RxwExxkZhjM= +go.opentelemetry.io/otel/metric v1.36.0 h1:MoWPKVhQvJ+eeXWHFBOPoBOi20jh6Iq2CcCREuTYufE= +go.opentelemetry.io/otel/metric v1.36.0/go.mod h1:zC7Ks+yeyJt4xig9DEw9kuUFe5C3zLbVjV2PzT6qzbs= +go.opentelemetry.io/otel/metric v1.38.0 h1:Kl6lzIYGAh5M159u9NgiRkmoMKjvbsKtYRwgfrA6WpA= +go.opentelemetry.io/otel/metric v1.38.0/go.mod h1:kB5n/QoRM8YwmUahxvI3bO34eVtQf2i4utNVLr9gEmI= +go.opentelemetry.io/otel/trace v1.36.0 h1:ahxWNuqZjpdiFAyrIoQ4GIiAIhxAunQR6MUoKrsNd4w= +go.opentelemetry.io/otel/trace v1.36.0/go.mod h1:gQ+OnDZzrybY4k4seLzPAWNwVBBVlF2szhehOBB/tGA= +go.opentelemetry.io/otel/trace v1.38.0 h1:Fxk5bKrDZJUH+AMyyIXGcFAPah0oRcT+LuNtJrmcNLE= +go.opentelemetry.io/otel/trace v1.38.0/go.mod h1:j1P9ivuFsTceSWe1oY+EeW3sc+Pp42sO++GHkg4wwhs= go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= +go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= -go.uber.org/zap v1.27.1 h1:08RqriUEv8+ArZRYSTXy1LeBScaMpVSTBhCeaZYfMYc= -go.uber.org/zap v1.27.1/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090 h1:Di6/M8l0O2lCLc6VVRWhgCiApHV8MnQurBnFSHsQtNY= golang.org/x/exp v0.0.0-20230725093048-515e97ebf090/go.mod h1:FXUEEKJgO7OQYeo8N01OfiKP8RXMtf6e8aTskBGqWdc= -golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= -golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= -golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= -golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= +golang.org/x/exp v0.0.0-20251017212417-90e834f514db h1:by6IehL4BH5k3e3SJmcoNbOobMey2SLpAF79iPOEBvw= +golang.org/x/exp v0.0.0-20251017212417-90e834f514db/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= +golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= +golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= +golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4= +golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210= +golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= +golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= +golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug= +golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= -golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= -golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= -golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= +golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= +golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ= +golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= +golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= +golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= +golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= diff --git a/internal/dataset/dataset.go b/internal/dataset/dataset.go deleted file mode 100644 index 53367db..0000000 --- a/internal/dataset/dataset.go +++ /dev/null @@ -1,158 +0,0 @@ -package dataset - -import "fmt" - -// Dataset shape constants. -// Shape: (65, 47, 3, 361, 720) = (hour, pressure_level, variable, latitude, longitude) -// This matches the reference predictor exactly. -const ( - NumHours = 65 // 0, 3, 6, ..., 192 - NumLevels = 47 // pressure levels - NumVariables = 3 // height, wind_u, wind_v - NumLatitudes = 361 // -90.0 to +90.0 in 0.5 degree steps - NumLongitudes = 720 // 0.0 to 359.5 in 0.5 degree steps - - HourStep = 3 // hours between forecast time steps - MaxHour = 192 // maximum forecast hour - Resolution = 0.5 // grid resolution in degrees - LatStart = -90.0 // first latitude in the dataset - LonStart = 0.0 // first longitude in the dataset - - // Variable indices within the dataset. - VarHeight = 0 - VarWindU = 1 - VarWindV = 2 - - ElementSize = 4 // float32 = 4 bytes -) - -// DatasetSize is the total size of the dataset file in bytes. -// 65 * 47 * 3 * 361 * 720 * 4 = 9,528,667,200 -const DatasetSize int64 = int64(NumHours) * int64(NumLevels) * int64(NumVariables) * - int64(NumLatitudes) * int64(NumLongitudes) * int64(ElementSize) - -// LevelSet identifies which GRIB file set a pressure level belongs to. -type LevelSet int - -const ( - LevelSetA LevelSet = iota // pgrb2 (primary) - LevelSetB // pgrb2b (secondary) -) - -// Pressures contains the 47 pressure levels in hPa, sorted descending. -// Index 0 = 1000 hPa (near surface), Index 46 = 1 hPa (high atmosphere). -var Pressures = [NumLevels]int{ - 1000, 975, 950, 925, 900, 875, 850, 825, 800, 775, - 750, 725, 700, 675, 650, 625, 600, 575, 550, 525, - 500, 475, 450, 425, 400, 375, 350, 325, 300, 275, - 250, 225, 200, 175, 150, 125, 100, 70, 50, 30, - 20, 10, 7, 5, 3, 2, 1, -} - -// pressureIndex maps pressure in hPa to its index in the Pressures array. -var pressureIndex map[int]int - -// pressureLevelSet maps pressure in hPa to its GRIB file set. -var pressureLevelSet map[int]LevelSet - -func init() { - pressureIndex = make(map[int]int, NumLevels) - for i, p := range Pressures { - pressureIndex[p] = i - } - - pressureLevelSet = make(map[int]LevelSet, NumLevels) - for _, p := range PressuresPgrb2 { - pressureLevelSet[p] = LevelSetA - } - for _, p := range PressuresPgrb2b { - pressureLevelSet[p] = LevelSetB - } -} - -// PressuresPgrb2 contains levels found in the primary pgrb2 file (26 levels). -var PressuresPgrb2 = []int{ - 10, 20, 30, 50, 70, 100, 150, 200, 250, 300, 350, 400, - 450, 500, 550, 600, 650, 700, 750, 800, 850, 900, 925, - 950, 975, 1000, -} - -// PressuresPgrb2b contains levels found in the secondary pgrb2b file (21 levels). -var PressuresPgrb2b = []int{ - 1, 2, 3, 5, 7, 125, 175, 225, 275, 325, 375, 425, - 475, 525, 575, 625, 675, 725, 775, 825, 875, -} - -// PressureIndex returns the dataset index for a given pressure level in hPa. -// Returns -1 if the level is not found. -func PressureIndex(hPa int) int { - idx, ok := pressureIndex[hPa] - if !ok { - return -1 - } - return idx -} - -// PressureLevelSet returns which GRIB file set a pressure level belongs to. -func PressureLevelSet(hPa int) (LevelSet, bool) { - ls, ok := pressureLevelSet[hPa] - return ls, ok -} - -// HourIndex returns the dataset time index for a forecast hour. -// Returns -1 if the hour is invalid (not a multiple of HourStep or out of range). -func HourIndex(hour int) int { - if hour < 0 || hour > MaxHour || hour%HourStep != 0 { - return -1 - } - return hour / HourStep -} - -// Hours returns all forecast hours as a slice: [0, 3, 6, ..., 192]. -func Hours() []int { - out := make([]int, 0, NumHours) - for h := 0; h <= MaxHour; h += HourStep { - out = append(out, h) - } - return out -} - -// S3 URL configuration for NOAA GFS data. -const S3BaseURL = "https://noaa-gfs-bdp-pds.s3.amazonaws.com" - -// GribURL returns the S3 URL for a primary (pgrb2) GRIB file. -func GribURL(date string, runHour, forecastStep int) string { - return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.0p50.f%03d", - S3BaseURL, date, runHour, runHour, forecastStep) -} - -// GribURLB returns the S3 URL for a secondary (pgrb2b) GRIB file. -func GribURLB(date string, runHour, forecastStep int) string { - return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2b.0p50.f%03d", - S3BaseURL, date, runHour, runHour, forecastStep) -} - -// GribFileName returns the local filename for a primary GRIB file. -func GribFileName(runHour, forecastStep int) string { - return fmt.Sprintf("gfs.t%02dz.pgrb2.0p50.f%03d", runHour, forecastStep) -} - -// GribFileNameB returns the local filename for a secondary GRIB file. -func GribFileNameB(runHour, forecastStep int) string { - return fmt.Sprintf("gfs.t%02dz.pgrb2b.0p50.f%03d", runHour, forecastStep) -} - -// VariableIndex returns the dataset variable index for a GRIB parameter. -// Returns -1 if the parameter is not recognized. -func VariableIndex(parameterCategory, parameterNumber int) int { - switch { - case parameterCategory == 3 && parameterNumber == 5: - return VarHeight // Geopotential Height - case parameterCategory == 2 && parameterNumber == 2: - return VarWindU // U-component of wind - case parameterCategory == 2 && parameterNumber == 3: - return VarWindV // V-component of wind - default: - return -1 - } -} diff --git a/internal/dataset/dataset_test.go b/internal/dataset/dataset_test.go deleted file mode 100644 index 14b36ef..0000000 --- a/internal/dataset/dataset_test.go +++ /dev/null @@ -1,152 +0,0 @@ -package dataset - -import ( - "testing" -) - -func TestDatasetShape(t *testing.T) { - if NumHours != 65 { - t.Errorf("NumHours = %d, want 65", NumHours) - } - if NumLevels != 47 { - t.Errorf("NumLevels = %d, want 47", NumLevels) - } - if NumVariables != 3 { - t.Errorf("NumVariables = %d, want 3", NumVariables) - } - if NumLatitudes != 361 { - t.Errorf("NumLatitudes = %d, want 361", NumLatitudes) - } - if NumLongitudes != 720 { - t.Errorf("NumLongitudes = %d, want 720", NumLongitudes) - } -} - -func TestDatasetSize(t *testing.T) { - // 65 * 47 * 3 * 361 * 720 * 4 = 9,528,667,200 - want := int64(9_528_667_200) - if DatasetSize != want { - t.Errorf("DatasetSize = %d, want %d", DatasetSize, want) - } -} - -func TestPressureLevels(t *testing.T) { - if len(Pressures) != 47 { - t.Fatalf("len(Pressures) = %d, want 47", len(Pressures)) - } - - // First should be 1000 (highest pressure, near surface) - if Pressures[0] != 1000 { - t.Errorf("Pressures[0] = %d, want 1000", Pressures[0]) - } - // Last should be 1 (lowest pressure, high atmosphere) - if Pressures[46] != 1 { - t.Errorf("Pressures[46] = %d, want 1", Pressures[46]) - } - - // Should be sorted descending - for i := 1; i < len(Pressures); i++ { - if Pressures[i] >= Pressures[i-1] { - t.Errorf("Pressures not descending at [%d]: %d >= %d", i, Pressures[i], Pressures[i-1]) - } - } - - // Total levels: 26 from pgrb2 + 21 from pgrb2b = 47 - if len(PressuresPgrb2) != 26 { - t.Errorf("len(PressuresPgrb2) = %d, want 26", len(PressuresPgrb2)) - } - if len(PressuresPgrb2b) != 21 { - t.Errorf("len(PressuresPgrb2b) = %d, want 21", len(PressuresPgrb2b)) - } -} - -func TestPressureIndex(t *testing.T) { - if PressureIndex(1000) != 0 { - t.Errorf("PressureIndex(1000) = %d, want 0", PressureIndex(1000)) - } - if PressureIndex(1) != 46 { - t.Errorf("PressureIndex(1) = %d, want 46", PressureIndex(1)) - } - if PressureIndex(500) != 20 { - t.Errorf("PressureIndex(500) = %d, want 20", PressureIndex(500)) - } - if PressureIndex(9999) != -1 { - t.Errorf("PressureIndex(9999) = %d, want -1", PressureIndex(9999)) - } -} - -func TestPressureLevelSet(t *testing.T) { - // 1000 mb should be in pgrb2 (A) - ls, ok := PressureLevelSet(1000) - if !ok || ls != LevelSetA { - t.Errorf("PressureLevelSet(1000) = %v, %v; want A, true", ls, ok) - } - - // 125 mb should be in pgrb2b (B) - ls, ok = PressureLevelSet(125) - if !ok || ls != LevelSetB { - t.Errorf("PressureLevelSet(125) = %v, %v; want B, true", ls, ok) - } - - // 1, 2, 3, 5, 7 should be in pgrb2b (B) - for _, p := range []int{1, 2, 3, 5, 7} { - ls, ok := PressureLevelSet(p) - if !ok || ls != LevelSetB { - t.Errorf("PressureLevelSet(%d) = %v, %v; want B, true", p, ls, ok) - } - } - - // Every pressure level should have a level set assignment - for _, p := range Pressures { - _, ok := PressureLevelSet(p) - if !ok { - t.Errorf("PressureLevelSet(%d) not found", p) - } - } -} - -func TestHourIndex(t *testing.T) { - if HourIndex(0) != 0 { - t.Errorf("HourIndex(0) = %d, want 0", HourIndex(0)) - } - if HourIndex(3) != 1 { - t.Errorf("HourIndex(3) = %d, want 1", HourIndex(3)) - } - if HourIndex(192) != 64 { - t.Errorf("HourIndex(192) = %d, want 64", HourIndex(192)) - } - if HourIndex(1) != -1 { - t.Errorf("HourIndex(1) = %d, want -1 (not multiple of 3)", HourIndex(1)) - } - if HourIndex(195) != -1 { - t.Errorf("HourIndex(195) = %d, want -1 (out of range)", HourIndex(195)) - } -} - -func TestHours(t *testing.T) { - hours := Hours() - if len(hours) != NumHours { - t.Fatalf("len(Hours()) = %d, want %d", len(hours), NumHours) - } - if hours[0] != 0 { - t.Errorf("Hours()[0] = %d, want 0", hours[0]) - } - if hours[len(hours)-1] != MaxHour { - t.Errorf("Hours()[last] = %d, want %d", hours[len(hours)-1], MaxHour) - } -} - -func TestVariableIndex(t *testing.T) { - if VariableIndex(3, 5) != VarHeight { - t.Errorf("HGT: got %d, want %d", VariableIndex(3, 5), VarHeight) - } - if VariableIndex(2, 2) != VarWindU { - t.Errorf("UGRD: got %d, want %d", VariableIndex(2, 2), VarWindU) - } - if VariableIndex(2, 3) != VarWindV { - t.Errorf("VGRD: got %d, want %d", VariableIndex(2, 3), VarWindV) - } - if VariableIndex(0, 0) != -1 { - t.Errorf("unknown: got %d, want -1", VariableIndex(0, 0)) - } -} diff --git a/internal/dataset/file.go b/internal/dataset/file.go deleted file mode 100644 index 96f14c2..0000000 --- a/internal/dataset/file.go +++ /dev/null @@ -1,140 +0,0 @@ -package dataset - -import ( - "encoding/binary" - "fmt" - "math" - "os" - "time" - - mmap "github.com/edsrzf/mmap-go" -) - -// File represents an mmap-backed wind dataset file. -type File struct { - mm mmap.MMap - file *os.File - writable bool - DSTime time.Time // forecast run time (UTC) -} - -// Open opens an existing dataset file for reading. -func Open(path string, dsTime time.Time) (*File, error) { - f, err := os.Open(path) - if err != nil { - return nil, fmt.Errorf("open dataset: %w", err) - } - - info, err := f.Stat() - if err != nil { - f.Close() - return nil, fmt.Errorf("stat dataset: %w", err) - } - if info.Size() != DatasetSize { - f.Close() - return nil, fmt.Errorf("dataset should be %d bytes (was %d)", DatasetSize, info.Size()) - } - - mm, err := mmap.Map(f, mmap.RDONLY, 0) - if err != nil { - f.Close() - return nil, fmt.Errorf("mmap dataset: %w", err) - } - - return &File{mm: mm, file: f, writable: false, DSTime: dsTime}, nil -} - -// Create creates a new dataset file for writing. -// The file is truncated to the correct size and mmap'd read-write. -func Create(path string) (*File, error) { - f, err := os.Create(path) - if err != nil { - return nil, fmt.Errorf("create dataset: %w", err) - } - - if err := f.Truncate(DatasetSize); err != nil { - f.Close() - return nil, fmt.Errorf("truncate dataset: %w", err) - } - - mm, err := mmap.MapRegion(f, int(DatasetSize), mmap.RDWR, 0, 0) - if err != nil { - f.Close() - return nil, fmt.Errorf("mmap dataset: %w", err) - } - - return &File{mm: mm, file: f, writable: true}, nil -} - -// offset computes the byte offset for element [hour][level][variable][lat][lon]. -// Row-major C-order indexing matching the reference implementation: -// shape = (65, 47, 3, 361, 720) -func offset(hour, level, variable, lat, lon int) int64 { - idx := int64(hour) - idx = idx*int64(NumLevels) + int64(level) - idx = idx*int64(NumVariables) + int64(variable) - idx = idx*int64(NumLatitudes) + int64(lat) - idx = idx*int64(NumLongitudes) + int64(lon) - return idx * int64(ElementSize) -} - -// Val reads a float32 value from the dataset at [hour][level][variable][lat][lon]. -func (d *File) Val(hour, level, variable, lat, lon int) float32 { - off := offset(hour, level, variable, lat, lon) - bits := binary.LittleEndian.Uint32(d.mm[off : off+4]) - return math.Float32frombits(bits) -} - -// SetVal writes a float32 value to the dataset at [hour][level][variable][lat][lon]. -// Only valid on writable (created) datasets. -func (d *File) SetVal(hour, level, variable, lat, lon int, val float32) { - off := offset(hour, level, variable, lat, lon) - binary.LittleEndian.PutUint32(d.mm[off:off+4], math.Float32bits(val)) -} - -// BlitGribData copies decoded GRIB grid data into the dataset at the given position. -// gribData is 361*720 float64 values in GRIB scan order (north-to-south, west-to-east). -// This function flips the latitude so that dataset index 0 = -90 (south) and 360 = +90 (north). -func (d *File) BlitGribData(hourIdx, levelIdx, varIdx int, gribData []float64) error { - expected := NumLatitudes * NumLongitudes - if len(gribData) != expected { - return fmt.Errorf("grib data has %d values, expected %d", len(gribData), expected) - } - - for lat := 0; lat < NumLatitudes; lat++ { - for lon := 0; lon < NumLongitudes; lon++ { - // GRIB scans north-to-south: row 0 = 90N, row 360 = 90S - // Dataset stores south-to-north: index 0 = -90 (90S), index 360 = +90 (90N) - gribIdx := (360-lat)*NumLongitudes + lon - val := float32(gribData[gribIdx]) - d.SetVal(hourIdx, levelIdx, varIdx, lat, lon, val) - } - } - - return nil -} - -// Flush flushes the mmap to disk. -func (d *File) Flush() error { - if d.mm != nil { - return d.mm.Flush() - } - return nil -} - -// Close unmaps and closes the dataset file. -func (d *File) Close() error { - if d.mm != nil { - if err := d.mm.Unmap(); err != nil { - d.file.Close() - return fmt.Errorf("unmap: %w", err) - } - d.mm = nil - } - if d.file != nil { - err := d.file.Close() - d.file = nil - return err - } - return nil -} diff --git a/internal/downloader/config.go b/internal/downloader/config.go deleted file mode 100644 index 91575e1..0000000 --- a/internal/downloader/config.go +++ /dev/null @@ -1,58 +0,0 @@ -package downloader - -import ( - "os" - "strconv" - "time" -) - -// Config holds downloader configuration, loaded from environment variables. -type Config struct { - // DataDir is the directory for storing dataset files and temporary GRIB data. - DataDir string - - // Parallel is the maximum number of concurrent GRIB downloads. - Parallel int - - // UpdateInterval is how often the scheduler checks for new forecast data. - UpdateInterval time.Duration - - // DatasetTTL is how long a dataset is considered fresh before a new one is needed. - DatasetTTL time.Duration -} - -// DefaultConfig returns the default configuration. -func DefaultConfig() *Config { - return &Config{ - DataDir: "/tmp/predictor-data", - Parallel: 8, - UpdateInterval: 6 * time.Hour, - DatasetTTL: 48 * time.Hour, - } -} - -// LoadConfig loads configuration from environment variables, falling back to defaults. -func LoadConfig() *Config { - cfg := DefaultConfig() - - if v := os.Getenv("PREDICTOR_DATA_DIR"); v != "" { - cfg.DataDir = v - } - if v := os.Getenv("PREDICTOR_DOWNLOAD_PARALLEL"); v != "" { - if n, err := strconv.Atoi(v); err == nil && n > 0 { - cfg.Parallel = n - } - } - if v := os.Getenv("PREDICTOR_UPDATE_INTERVAL"); v != "" { - if d, err := time.ParseDuration(v); err == nil { - cfg.UpdateInterval = d - } - } - if v := os.Getenv("PREDICTOR_DATASET_TTL"); v != "" { - if d, err := time.ParseDuration(v); err == nil { - cfg.DatasetTTL = d - } - } - - return cfg -} diff --git a/internal/downloader/downloader.go b/internal/downloader/downloader.go deleted file mode 100644 index 32208a0..0000000 --- a/internal/downloader/downloader.go +++ /dev/null @@ -1,441 +0,0 @@ -package downloader - -import ( - "context" - "fmt" - "io" - "math" - "net/http" - "os" - "path/filepath" - "sync/atomic" - "time" - - "predictor-refactored/internal/dataset" - - "github.com/nilsmagnus/grib/griblib" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" -) - -// Downloader handles fetching GFS forecast data from S3 and assembling dataset files. -type Downloader struct { - cfg *Config - client *http.Client - log *zap.Logger -} - -// NewDownloader creates a new Downloader. -func NewDownloader(cfg *Config, log *zap.Logger) *Downloader { - return &Downloader{ - cfg: cfg, - client: &http.Client{ - Timeout: 2 * time.Minute, - }, - log: log, - } -} - -// neededVariables is the set of GRIB variable names we need. -var neededVariables = map[string]bool{ - "HGT": true, - "UGRD": true, - "VGRD": true, -} - -// FindLatestRun finds the most recent available GFS model run on S3. -// It checks the last forecast step of each run to confirm availability. -func (d *Downloader) FindLatestRun(ctx context.Context) (time.Time, error) { - now := time.Now().UTC() - hour := now.Hour() - (now.Hour() % 6) - current := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC) - - for i := 0; i < 8; i++ { - date := current.Format("20060102") - url := dataset.GribURL(date, current.Hour(), dataset.MaxHour) + ".idx" - - req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) - if err != nil { - current = current.Add(-6 * time.Hour) - continue - } - - resp, err := d.client.Do(req) - if err == nil { - resp.Body.Close() - if resp.StatusCode == http.StatusOK { - d.log.Info("found latest model run", - zap.Time("run", current), - zap.String("verified_url", url)) - return current, nil - } - } - - current = current.Add(-6 * time.Hour) - } - - return time.Time{}, fmt.Errorf("no recent GFS forecast found (checked 8 runs)") -} - -// progress tracks download progress across concurrent goroutines. -type progress struct { - bytesDownloaded atomic.Int64 - stepsCompleted atomic.Int64 - totalSteps int64 - startTime time.Time - log *zap.Logger -} - -func newProgress(totalSteps int, log *zap.Logger) *progress { - return &progress{ - totalSteps: int64(totalSteps), - startTime: time.Now(), - log: log, - } -} - -func (p *progress) addBytes(n int64) { - p.bytesDownloaded.Add(n) -} - -func (p *progress) completeStep() { - done := p.stepsCompleted.Add(1) - total := p.totalSteps - bytes := p.bytesDownloaded.Load() - elapsed := time.Since(p.startTime).Seconds() - - pct := float64(done) / float64(total) * 100 - mbDownloaded := float64(bytes) / (1024 * 1024) - mbPerSec := 0.0 - if elapsed > 0 { - mbPerSec = mbDownloaded / elapsed - } - - // Estimate remaining - eta := "" - if done > 0 && done < total { - secsPerStep := elapsed / float64(done) - remaining := secsPerStep * float64(total-done) - if remaining > 60 { - eta = fmt.Sprintf("%.0fm%02.0fs", math.Floor(remaining/60), math.Mod(remaining, 60)) - } else { - eta = fmt.Sprintf("%.0fs", remaining) - } - } - - p.log.Info("download progress", - zap.String("progress", fmt.Sprintf("%d/%d", done, total)), - zap.String("percent", fmt.Sprintf("%.1f%%", pct)), - zap.String("downloaded", fmt.Sprintf("%.1f MB", mbDownloaded)), - zap.String("speed", fmt.Sprintf("%.1f MB/s", mbPerSec)), - zap.String("eta", eta)) -} - -// Download downloads a complete forecast and assembles a dataset file. -// Returns the path to the completed dataset file. -func (d *Downloader) Download(ctx context.Context, run time.Time) (string, error) { - date := run.Format("20060102") - runHour := run.Hour() - - finalPath := filepath.Join(d.cfg.DataDir, run.Format("2006010215")) - tempPath := finalPath + ".downloading" - - // Check if final dataset already exists - if info, err := os.Stat(finalPath); err == nil && info.Size() == dataset.DatasetSize { - d.log.Info("dataset already exists", zap.String("path", finalPath)) - return finalPath, nil - } - - steps := dataset.Hours() - totalSteps := len(steps) * 2 // pgrb2 + pgrb2b per step - prog := newProgress(totalSteps, d.log) - - d.log.Info("starting dataset download", - zap.Time("run", run), - zap.Int("total_steps", totalSteps), - zap.String("temp_path", tempPath)) - - // Create the dataset file - ds, err := dataset.Create(tempPath) - if err != nil { - return "", fmt.Errorf("create dataset: %w", err) - } - defer ds.Close() - - // Process each forecast step with bounded concurrency - g, ctx := errgroup.WithContext(ctx) - sem := make(chan struct{}, d.cfg.Parallel) - - for _, step := range steps { - step := step - hourIdx := dataset.HourIndex(step) - if hourIdx < 0 { - continue - } - - // Download pgrb2 (level set A) - sem <- struct{}{} - g.Go(func() error { - defer func() { <-sem }() - url := dataset.GribURL(date, runHour, step) - err := d.downloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetA, prog) - if err != nil { - return fmt.Errorf("step %d pgrb2: %w", step, err) - } - prog.completeStep() - return nil - }) - - // Download pgrb2b (level set B) - sem <- struct{}{} - g.Go(func() error { - defer func() { <-sem }() - url := dataset.GribURLB(date, runHour, step) - err := d.downloadAndBlit(ctx, ds, url, hourIdx, dataset.LevelSetB, prog) - if err != nil { - return fmt.Errorf("step %d pgrb2b: %w", step, err) - } - prog.completeStep() - return nil - }) - } - - if err := g.Wait(); err != nil { - os.Remove(tempPath) - return "", err - } - - elapsed := time.Since(prog.startTime) - totalMB := float64(prog.bytesDownloaded.Load()) / (1024 * 1024) - d.log.Info("download complete, flushing to disk", - zap.String("downloaded", fmt.Sprintf("%.1f MB", totalMB)), - zap.Duration("elapsed", elapsed), - zap.String("avg_speed", fmt.Sprintf("%.1f MB/s", totalMB/elapsed.Seconds()))) - - // Flush to disk - if err := ds.Flush(); err != nil { - os.Remove(tempPath) - return "", fmt.Errorf("flush dataset: %w", err) - } - - // Close before rename - ds.Close() - - // Atomic rename - if err := os.Rename(tempPath, finalPath); err != nil { - os.Remove(tempPath) - return "", fmt.Errorf("rename dataset: %w", err) - } - - d.log.Info("dataset ready", zap.String("path", finalPath)) - return finalPath, nil -} - -// DownloadAndBlit downloads needed GRIB fields from a URL and writes them into the dataset. -func (d *Downloader) DownloadAndBlit(ctx context.Context, ds *dataset.File, baseURL string, hourIdx int, levelSet dataset.LevelSet) error { - return d.downloadAndBlit(ctx, ds, baseURL, hourIdx, levelSet, nil) -} - -// downloadAndBlit is the internal implementation with optional progress tracking. -func (d *Downloader) downloadAndBlit(ctx context.Context, ds *dataset.File, baseURL string, hourIdx int, levelSet dataset.LevelSet, prog *progress) error { - // 1. Download .idx - idxURL := baseURL + ".idx" - idxBody, err := d.httpGet(ctx, idxURL) - if err != nil { - return fmt.Errorf("download idx: %w", err) - } - - // 2. Parse and filter - entries := ParseIdx(idxBody) - filtered := FilterIdx(entries, neededVariables) - - // Further filter to only levels in this level set - var relevant []IdxEntry - for _, e := range filtered { - ls, ok := dataset.PressureLevelSet(e.LevelMB) - if ok && ls == levelSet { - relevant = append(relevant, e) - } - } - - if len(relevant) == 0 { - d.log.Warn("no relevant entries found in idx", - zap.String("url", idxURL), - zap.Int("total_entries", len(entries)), - zap.Int("filtered", len(filtered))) - return nil - } - - // 3. Download byte ranges and write to temp file - ranges := EntriesToRanges(relevant) - tmpFile, err := d.downloadRangesToTempFile(ctx, baseURL, ranges, prog) - if err != nil { - return fmt.Errorf("download ranges: %w", err) - } - defer os.Remove(tmpFile) - - // 4. Read GRIB messages from temp file - f, err := os.Open(tmpFile) - if err != nil { - return fmt.Errorf("open temp grib: %w", err) - } - - messages, err := griblib.ReadMessages(f) - f.Close() - if err != nil { - return fmt.Errorf("read grib messages: %w", err) - } - - // 5. Decode and blit each message into the dataset - for _, msg := range messages { - if msg.Section4.ProductDefinitionTemplateNumber != 0 { - continue - } - - product := msg.Section4.ProductDefinitionTemplate - - varIdx := dataset.VariableIndex(int(product.ParameterCategory), int(product.ParameterNumber)) - if varIdx < 0 { - continue - } - - if product.FirstSurface.Type != 100 { // isobaric surface - continue - } - - pressurePa := float64(product.FirstSurface.Value) - pressureMB := int(math.Round(pressurePa / 100.0)) - levelIdx := dataset.PressureIndex(pressureMB) - if levelIdx < 0 { - continue - } - - data := msg.Data() - if err := ds.BlitGribData(hourIdx, levelIdx, varIdx, data); err != nil { - d.log.Warn("blit failed", - zap.Int("var", varIdx), - zap.Int("level_mb", pressureMB), - zap.Error(err)) - continue - } - } - - return nil -} - -// downloadRangesToTempFile downloads multiple byte ranges from a URL, -// concatenating them into a single temp file (valid concatenated GRIB messages). -func (d *Downloader) downloadRangesToTempFile(ctx context.Context, baseURL string, ranges []ByteRange, prog *progress) (string, error) { - tmpFile, err := os.CreateTemp(d.cfg.DataDir, "grib-*.tmp") - if err != nil { - return "", fmt.Errorf("create temp file: %w", err) - } - tmpPath := tmpFile.Name() - - for _, r := range ranges { - data, err := d.httpGetRange(ctx, baseURL, r.Start, r.End) - if err != nil { - tmpFile.Close() - os.Remove(tmpPath) - return "", fmt.Errorf("download range %d-%d: %w", r.Start, r.End, err) - } - if _, err := tmpFile.Write(data); err != nil { - tmpFile.Close() - os.Remove(tmpPath) - return "", fmt.Errorf("write temp: %w", err) - } - if prog != nil { - prog.addBytes(int64(len(data))) - } - } - - if err := tmpFile.Close(); err != nil { - os.Remove(tmpPath) - return "", err - } - - return tmpPath, nil -} - -// httpGet downloads a URL and returns the body bytes. -func (d *Downloader) httpGet(ctx context.Context, url string) ([]byte, error) { - var lastErr error - for attempt := 0; attempt < 3; attempt++ { - if attempt > 0 { - select { - case <-time.After(time.Duration(attempt*2) * time.Second): - case <-ctx.Done(): - return nil, ctx.Err() - } - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return nil, err - } - - resp, err := d.client.Do(req) - if err != nil { - lastErr = err - continue - } - - body, err := io.ReadAll(resp.Body) - resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url) - continue - } - if err != nil { - lastErr = err - continue - } - - return body, nil - } - - return nil, fmt.Errorf("after 3 attempts: %w", lastErr) -} - -// httpGetRange downloads a byte range from a URL. -func (d *Downloader) httpGetRange(ctx context.Context, url string, start, end int64) ([]byte, error) { - var lastErr error - for attempt := 0; attempt < 3; attempt++ { - if attempt > 0 { - select { - case <-time.After(time.Duration(attempt*2) * time.Second): - case <-ctx.Done(): - return nil, ctx.Err() - } - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return nil, err - } - req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) - - resp, err := d.client.Do(req) - if err != nil { - lastErr = err - continue - } - - body, err := io.ReadAll(resp.Body) - resp.Body.Close() - - if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK { - lastErr = fmt.Errorf("HTTP %d for range %d-%d of %s", resp.StatusCode, start, end, url) - continue - } - if err != nil { - lastErr = err - continue - } - - return body, nil - } - - return nil, fmt.Errorf("after 3 attempts: %w", lastErr) -} diff --git a/internal/downloader/idx.go b/internal/downloader/idx.go deleted file mode 100644 index 2e09bc4..0000000 --- a/internal/downloader/idx.go +++ /dev/null @@ -1,157 +0,0 @@ -package downloader - -import ( - "fmt" - "strconv" - "strings" -) - -// IdxEntry represents a single parsed line from a GRIB .idx file. -// Example line: "15:1207405:d=2024010100:HGT:1000 mb:0 hour fcst:" -type IdxEntry struct { - Index int - Offset int64 - Variable string // "HGT", "UGRD", "VGRD", etc. - LevelMB int // pressure level in mb (0 if not a pressure level) - Hour int // forecast hour - EndOffset int64 // byte after this message (from next entry's offset, or -1 if last) -} - -// Length returns the byte length of this GRIB message, or -1 if unknown. -func (e *IdxEntry) Length() int64 { - if e.EndOffset <= 0 { - return -1 - } - return e.EndOffset - e.Offset -} - -// ParseIdx parses a .idx file body and returns all entries. -// Lines that can't be parsed are silently skipped. -func ParseIdx(body []byte) []IdxEntry { - lines := strings.Split(string(body), "\n") - var entries []IdxEntry - - for _, line := range lines { - line = strings.TrimSpace(line) - if line == "" { - continue - } - - parts := strings.Split(line, ":") - if len(parts) < 7 { - continue - } - - idx, err := strconv.Atoi(parts[0]) - if err != nil { - continue - } - - offset, err := strconv.ParseInt(parts[1], 10, 64) - if err != nil { - continue - } - - variable := parts[3] - levelStr := parts[4] - hourStr := parts[5] - - levelMB := parseLevelMB(levelStr) - hour := parseHour(hourStr) - - entries = append(entries, IdxEntry{ - Index: idx, - Offset: offset, - Variable: variable, - LevelMB: levelMB, - Hour: hour, - EndOffset: -1, // filled in below - }) - } - - // Fill in EndOffset from the next entry's Offset. - for i := 0; i < len(entries)-1; i++ { - entries[i].EndOffset = entries[i+1].Offset - } - - return entries -} - -// FilterIdx returns entries matching the given variables at pressure levels. -// Only entries with a recognized pressure level (levelMB > 0) are returned. -func FilterIdx(entries []IdxEntry, variables map[string]bool) []IdxEntry { - var filtered []IdxEntry - for _, e := range entries { - if !variables[e.Variable] { - continue - } - if e.LevelMB <= 0 { - continue - } - // Must have a known length (not the last entry) or be handled specially - if e.Length() <= 0 { - continue - } - filtered = append(filtered, e) - } - return filtered -} - -// parseLevelMB parses a level string like "1000 mb" and returns the pressure in mb. -// Returns 0 if not a pressure level. -func parseLevelMB(s string) int { - s = strings.TrimSpace(s) - if !strings.HasSuffix(s, " mb") { - return 0 - } - numStr := strings.TrimSuffix(s, " mb") - n, err := strconv.Atoi(numStr) - if err != nil { - return 0 - } - return n -} - -// parseHour parses a forecast hour string like "0 hour fcst" or "anl". -// Returns -1 if it can't be parsed. -func parseHour(s string) int { - s = strings.TrimSpace(s) - if s == "anl" { - return 0 - } - s = strings.TrimSuffix(s, " hour fcst") - n, err := strconv.Atoi(s) - if err != nil { - return -1 - } - return n -} - -// GroupByRange groups idx entries into byte ranges suitable for HTTP Range downloads. -// Each range covers one contiguous GRIB message. -type ByteRange struct { - Start int64 - End int64 // inclusive - Entry IdxEntry -} - -// EntriesToRanges converts filtered idx entries to byte ranges. -func EntriesToRanges(entries []IdxEntry) []ByteRange { - ranges := make([]ByteRange, 0, len(entries)) - for _, e := range entries { - if e.Length() <= 0 { - continue - } - ranges = append(ranges, ByteRange{ - Start: e.Offset, - End: e.EndOffset - 1, // inclusive - Entry: e, - }) - } - return ranges -} - -// FormatRange returns an HTTP Range header value for a byte range. -func (r ByteRange) FormatRange() string { - return fmt.Sprintf("bytes=%d-%d", r.Start, r.End) -} diff --git a/internal/downloader/idx_test.go b/internal/downloader/idx_test.go deleted file mode 100644 index 71a7224..0000000 --- a/internal/downloader/idx_test.go +++ /dev/null @@ -1,110 +0,0 @@ -package downloader - -import ( - "testing" -) - -const sampleIdx = `1:0:d=2024010100:HGT:1000 mb:0 hour fcst: -2:289012:d=2024010100:HGT:975 mb:0 hour fcst: -3:541876:d=2024010100:TMP:1000 mb:0 hour fcst: -4:789012:d=2024010100:UGRD:1000 mb:0 hour fcst: -5:1045678:d=2024010100:VGRD:1000 mb:0 hour fcst: -6:1298765:d=2024010100:UGRD:975 mb:0 hour fcst: -7:1567890:d=2024010100:UGRD:2 m above ground:0 hour fcst: -8:1812345:d=2024010100:VGRD:975 mb:0 hour fcst: -9:2098765:d=2024010100:HGT:500 mb:3 hour fcst: -` - -func TestParseIdx(t *testing.T) { - entries := ParseIdx([]byte(sampleIdx)) - if len(entries) != 9 { - t.Fatalf("expected 9 entries, got %d", len(entries)) - } - - // Check first entry - e := entries[0] - if e.Index != 1 || e.Offset != 0 || e.Variable != "HGT" || e.LevelMB != 1000 || e.Hour != 0 { - t.Errorf("entry 0: got %+v", e) - } - if e.EndOffset != 289012 { - t.Errorf("entry 0 EndOffset: got %d, want 289012", e.EndOffset) - } - - // Check "2 m above ground" is not a pressure level - e = entries[6] // UGRD at "2 m above ground" - if e.LevelMB != 0 { - t.Errorf("non-pressure level should have LevelMB=0, got %d", e.LevelMB) - } - - // Last entry should have EndOffset = -1 - last := entries[len(entries)-1] - if last.EndOffset != -1 { - t.Errorf("last entry EndOffset: got %d, want -1", last.EndOffset) - } -} - -func TestFilterIdx(t *testing.T) { - entries := ParseIdx([]byte(sampleIdx)) - filtered := FilterIdx(entries, neededVariables) - - // Should include HGT/UGRD/VGRD at pressure levels, exclude TMP and "above ground" - // And exclude last entry (no EndOffset) - for _, e := range filtered { - if !neededVariables[e.Variable] { - t.Errorf("unexpected variable %s", e.Variable) - } - if e.LevelMB <= 0 { - t.Errorf("non-pressure level included: %+v", e) - } - if e.Length() <= 0 { - t.Errorf("entry with unknown length included: %+v", e) - } - } - - // Count expected: HGT@1000, HGT@975, UGRD@1000, VGRD@1000, UGRD@975, VGRD@975 = 6 - // But HGT@500 at 3hr fcst is the last entry (no EndOffset), so excluded - if len(filtered) != 6 { - t.Errorf("expected 6 filtered entries, got %d", len(filtered)) - for _, e := range filtered { - t.Logf(" %s %d mb (offset %d, len %d)", e.Variable, e.LevelMB, e.Offset, e.Length()) - } - } -} - -func TestParseLevelMB(t *testing.T) { - tests := []struct { - input string - want int - }{ - {"1000 mb", 1000}, - {"975 mb", 975}, - {"1 mb", 1}, - {"2 m above ground", 0}, - {"surface", 0}, - {"tropopause", 0}, - } - for _, tt := range tests { - got := parseLevelMB(tt.input) - if got != tt.want { - t.Errorf("parseLevelMB(%q) = %d, want %d", tt.input, got, tt.want) - } - } -} - -func TestParseHour(t *testing.T) { - tests := []struct { - input string - want int - }{ - {"0 hour fcst", 0}, - {"3 hour fcst", 3}, - {"192 hour fcst", 192}, - {"anl", 0}, - } - for _, tt := range tests { - got := parseHour(tt.input) - if got != tt.want { - t.Errorf("parseHour(%q) = %d, want %d", tt.input, got, tt.want) - } - } -} diff --git a/internal/elevation/elevation.go b/internal/elevation/elevation.go deleted file mode 100644 index 9fde295..0000000 --- a/internal/elevation/elevation.go +++ /dev/null @@ -1,113 +0,0 @@ -package elevation - -import ( - "encoding/binary" - "fmt" - "math" - "os" - - mmap "github.com/edsrzf/mmap-go" -) - -// Dataset provides global elevation lookup, compatible with ruaumoko. -// Binary format: int16 little-endian elevation values in metres, row-major (lat, lon). -// Latitude axis: -90 to +90 (south to north), Longitude axis: 0 to 360 (wraps). -// Resolution: 120 cells per degree (30 arc-seconds). -const ( - CellsPerDegree = 120 - NumLats = 180*CellsPerDegree + 1 // 21601 - NumLons = 360 * CellsPerDegree // 43200 - DataSize = NumLats * NumLons * 2 // 1,866,326,400 bytes (~1.74 GiB) -) - -// Dataset is a memory-mapped global elevation grid. -type Dataset struct { - mm mmap.MMap - file *os.File -} - -// Open opens an existing elevation dataset file. -func Open(path string) (*Dataset, error) { - f, err := os.Open(path) - if err != nil { - return nil, fmt.Errorf("open elevation: %w", err) - } - - info, err := f.Stat() - if err != nil { - f.Close() - return nil, fmt.Errorf("stat elevation: %w", err) - } - if info.Size() != DataSize { - f.Close() - return nil, fmt.Errorf("elevation dataset should be %d bytes (was %d)", DataSize, info.Size()) - } - - mm, err := mmap.Map(f, mmap.RDONLY, 0) - if err != nil { - f.Close() - return nil, fmt.Errorf("mmap elevation: %w", err) - } - - return &Dataset{mm: mm, file: f}, nil -} - -// getCell reads the int16 elevation at grid indices (latIdx, lngIdx). -func (d *Dataset) getCell(latIdx, lngIdx int) int16 { - // Clamp latitude - if latIdx < 0 { - latIdx = 0 - } - if latIdx >= NumLats { - latIdx = NumLats - 1 - } - // Wrap longitude - lngIdx = lngIdx % NumLons - if lngIdx < 0 { - lngIdx += NumLons - } - - off := (latIdx*NumLons + lngIdx) * 2 - return int16(binary.LittleEndian.Uint16(d.mm[off : off+2])) -} - -// Get returns the interpolated elevation in metres at the given coordinates. -// lat: -90 to +90, lng: 0 to 360 (or -180 to 180, will be normalised). -func (d *Dataset) Get(lat, lng float64) float64 { - // Normalise longitude to [0, 360) - if lng < 0 { - lng += 360 - } - - // Convert to cell coordinates - latCell := (lat + 90.0) * CellsPerDegree - lngCell := lng * CellsPerDegree - - lat0 := int(math.Floor(latCell)) - lng0 := int(math.Floor(lngCell)) - latFrac := latCell - float64(lat0) - lngFrac := lngCell - float64(lng0) - - // Bilinear interpolation - v00 := float64(d.getCell(lat0, lng0)) - v10 := float64(d.getCell(lat0+1, lng0)) - v01 := float64(d.getCell(lat0, lng0+1)) - v11 := float64(d.getCell(lat0+1, lng0+1)) - - return (1-latFrac)*((1-lngFrac)*v00+lngFrac*v01) + - latFrac*((1-lngFrac)*v10+lngFrac*v11) -} - -// Close unmaps and closes the dataset. -func (d *Dataset) Close() error { - if d.mm != nil { - d.mm.Unmap() - d.mm = nil - } - if d.file != nil { - err := d.file.Close() - d.file = nil - return err - } - return nil -} diff --git a/internal/jobs/grib/updater/config.go b/internal/jobs/grib/updater/config.go new file mode 100644 index 0000000..fa5132b --- /dev/null +++ b/internal/jobs/grib/updater/config.go @@ -0,0 +1,23 @@ +package updater + +import ( + "time" + + "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" + env "github.com/caarlos0/env/v11" +) + +type Config struct { + Interval time.Duration `env:"INTERVAL" envDefault:"6h"` + Timeout time.Duration `env:"TIMEOUT" envDefault:"45m"` +} + +func NewConfig() (*Config, error) { + cfg := &Config{} + if err := env.ParseWithOptions(cfg, env.Options{ + PrefixTagName: "GSN_PREDICTOR_GRIB_UPDATER_", + }); err != nil { + return nil, errcodes.Wrap(err, "failed to parse GRIB updater config") + } + return cfg, nil +} diff --git a/internal/jobs/grib/updater/deps.go b/internal/jobs/grib/updater/deps.go new file mode 100644 index 0000000..f2e3ced --- /dev/null +++ b/internal/jobs/grib/updater/deps.go @@ -0,0 +1,8 @@ +package updater + +import "context" + +// GribService defines the interface for GRIB operations needed by the updater job +type GribService interface { + Update(ctx context.Context) error +} diff --git a/internal/jobs/grib/updater/updater.go b/internal/jobs/grib/updater/updater.go new file mode 100644 index 0000000..ce28d02 --- /dev/null +++ b/internal/jobs/grib/updater/updater.go @@ -0,0 +1,51 @@ +package updater + +import ( + "context" + "time" + + "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" + "git.intra.yksa.space/gsn/predictor/internal/pkg/log" + "go.uber.org/zap" +) + +type Job struct { + service GribService + config *Config +} + +func New(service GribService, config *Config) *Job { + return &Job{ + service: service, + config: config, + } +} + +func (j *Job) GetInterval() time.Duration { + return j.config.Interval +} + +func (j *Job) GetTimeout() time.Duration { + return j.config.Timeout +} + +func (j *Job) GetCount() int { + return 1 +} + +func (j *Job) GetAsync() bool { + return false +} + +func (j *Job) Execute(ctx context.Context) error { + log := log.Ctx(ctx) + log.Info("executing GRIB update job") + + if err := j.service.Update(ctx); err != nil { + log.Error("GRIB update failed", zap.Error(err)) + return errcodes.Wrap(err, "failed to update GRIB data") + } + + log.Info("GRIB update completed successfully") + return nil +} diff --git a/internal/pkg/ds/predictor.go b/internal/pkg/ds/predictor.go new file mode 100644 index 0000000..753b0b2 --- /dev/null +++ b/internal/pkg/ds/predictor.go @@ -0,0 +1,89 @@ +package ds + +import ( + "time" + + api "git.intra.yksa.space/gsn/predictor/pkg/rest" +) + +type PredictionParameters struct { + LaunchLatitude *float64 + LaunchLongitude *float64 + LaunchDatetime *time.Time + LaunchAltitude *float64 + Profile *string + AscentRate *float64 + BurstAltitude *float64 + DescentRate *float64 + FloatAltitude *float64 + StopDatetime *time.Time + AscentCurve *string // base64 + DescentCurve *string // base64 + Interpolate *bool + Format *string + Dataset *time.Time + // Add other parameters as needed +} + +type PredicitonResult struct { + Latitude *float64 + Longitude *float64 + Altitude *float64 + Timestamp *time.Time + WindU *float64 + WindV *float64 + // Add other result fields as needed +} + +// Converts flat ogen params to internal pointer-based model +func ConvertFlatPredictionParams(params api.PerformPredictionParams) *PredictionParameters { + out := &PredictionParameters{} + if v, ok := params.LaunchLatitude.Get(); ok { + out.LaunchLatitude = &v + } + if v, ok := params.LaunchLongitude.Get(); ok { + out.LaunchLongitude = &v + } + if v, ok := params.LaunchDatetime.Get(); ok { + out.LaunchDatetime = &v + } + if v, ok := params.LaunchAltitude.Get(); ok { + out.LaunchAltitude = &v + } + if v, ok := params.Profile.Get(); ok { + s := string(v) + out.Profile = &s + } + if v, ok := params.AscentRate.Get(); ok { + out.AscentRate = &v + } + if v, ok := params.BurstAltitude.Get(); ok { + out.BurstAltitude = &v + } + if v, ok := params.DescentRate.Get(); ok { + out.DescentRate = &v + } + if v, ok := params.FloatAltitude.Get(); ok { + out.FloatAltitude = &v + } + if v, ok := params.StopDatetime.Get(); ok { + out.StopDatetime = &v + } + if v, ok := params.AscentCurve.Get(); ok { + out.AscentCurve = &v + } + if v, ok := params.DescentCurve.Get(); ok { + out.DescentCurve = &v + } + if v, ok := params.Interpolate.Get(); ok { + out.Interpolate = &v + } + if v, ok := params.Format.Get(); ok { + s := string(v) + out.Format = &s + } + if v, ok := params.Dataset.Get(); ok { + out.Dataset = &v + } + return out +} diff --git a/internal/pkg/errcodes/errcodes.go b/internal/pkg/errcodes/errcodes.go new file mode 100644 index 0000000..fca56f9 --- /dev/null +++ b/internal/pkg/errcodes/errcodes.go @@ -0,0 +1,102 @@ +package errcodes + +import ( + "net/http" + "strings" +) + +type ErrorCode struct { + StatusCode int + Message string + Details string +} + +func New(statusCode int, message string, details ...string) *ErrorCode { + return &ErrorCode{ + StatusCode: statusCode, + Message: message, + Details: strings.Join(details, " "), + } +} + +func (e *ErrorCode) Error() string { + return e.Message +} + +func IsErr(err error) bool { + _, ok := err.(*ErrorCode) + return ok +} + +func AsErr(err error) (*ErrorCode, bool) { + if err == nil { + return nil, false + } + errcode, ok := err.(*ErrorCode) + return errcode, ok +} + +func Join(errs ...error) error { + if len(errs) == 0 { + return nil + } + + var messages []string + var details []string + + for _, err := range errs { + if err == nil { + continue + } + + if errcode, ok := AsErr(err); ok { + messages = append(messages, errcode.Message) + if errcode.Details != "" { + details = append(details, errcode.Details) + } + } else { + messages = append(messages, err.Error()) + } + } + + if len(messages) == 0 { + return nil + } + + statusCode := http.StatusInternalServerError + if len(errs) > 0 { + if errcode, ok := AsErr(errs[0]); ok { + statusCode = errcode.StatusCode + } + } + + return New(statusCode, strings.Join(messages, "; "), details...) +} + +func Wrap(err error, message string) error { + if err == nil { + return nil + } + + if errcode, ok := AsErr(err); ok { + return New(errcode.StatusCode, message, errcode.Message, errcode.Details) + } + + return New(http.StatusInternalServerError, message, err.Error()) +} + +var ( + ErrNoDataset = New(http.StatusNotFound, "no grib dataset found") + ErrOutOfBounds = New(http.StatusBadRequest, "requested time is out of bounds") + ErrConfig = New(http.StatusInternalServerError, "configuration error") + ErrConfigInvalidEnv = New(http.StatusInternalServerError, "invalid environment configuration") + ErrConfigMissingRequired = New(http.StatusInternalServerError, "missing required configuration") + ErrDownload = New(http.StatusInternalServerError, "download error") + ErrProcessing = New(http.StatusInternalServerError, "data processing error") + ErrNoCubeFilesFound = New(http.StatusNotFound, "no cube files found") + ErrNoValidCubeFilesFound = New(http.StatusNotFound, "no valid cube files found") + ErrLatestCubeFileIsTooOld = New(http.StatusNotFound, "latest cube file is too old") + ErrScheduler = New(http.StatusInternalServerError, "scheduler error") + ErrSchedulerInvalidJob = New(http.StatusBadRequest, "invalid job configuration") + ErrSchedulerTimeoutTooLong = New(http.StatusBadRequest, "job timeout too long", "timeout cannot exceed interval") +) diff --git a/internal/pkg/grib/README.md b/internal/pkg/grib/README.md new file mode 100644 index 0000000..bab933f --- /dev/null +++ b/internal/pkg/grib/README.md @@ -0,0 +1,100 @@ +# GRIB Module + +Этот модуль реализует функциональность для работы с GRIB-файлами, аналогичную tawhiri-downloader и tawhiri, но на Go. + +## Основные возможности + +- **Скачивание GRIB-файлов** с NOMADS (GFS прогнозы) +- **Сборка 5D-куба** (время, давление, широта, долгота, переменные u/v) +- **Эффективное хранение** с использованием mmap +- **Интерполяция** ветровых данных для произвольных координат и времени +- **Кэширование** результатов (in-memory) +- **Распределенные блокировки** для предотвращения дублирования загрузок + +## Архитектура + +### Основные компоненты + +- **Downloader** - скачивает GRIB-файлы с NOMADS +- **Cube** - управляет 5D-массивом данных через mmap +- **Extractor** - выполняет интерполяцию данных +- **Cache** - кэширует результаты запросов +- **Service** - основной интерфейс для работы с модулем + +### Структура данных + +5D-куб содержит: +- **Время**: 17 временных срезов (0, 3, 6, ..., 48 часов) +- **Давление**: 34 уровня давления (1000, 975, 950, ..., 2 hPa) +- **Широта**: 361 точка (-90° до +90°) +- **Долгота**: 720 точек (0° до 359.5°) +- **Переменные**: u-ветер и v-ветер + +## Использование + +```go +// Создание сервиса +cfg := grib.ServiceConfig{ + Dir: "/tmp/grib", + TTL: 24 * time.Hour, + CacheTTL: 1 * time.Hour, + Parallel: 4, + Client: &http.Client{Timeout: 30 * time.Second}, +} + +service, err := grib.New(cfg) +if err != nil { + log.Fatal(err) +} +defer service.Close() + +// Обновление данных +err = service.Update(ctx) + +// Извлечение ветровых данных +wind, err := service.Extract(ctx, lat, lon, alt, timestamp) +// wind[0] - u-компонента ветра +// wind[1] - v-компонента ветра +``` + +## Интерполяция + +Модуль выполняет 16-точечную интерполяцию: +1. **Временная интерполяция** между двумя ближайшими срезами +2. **Интерполяция по давлению** между двумя ближайшими уровнями +3. **Билинейная интерполяция** по широте и долготе + +## Кэширование + +- **In-memory кэш**: быстрый доступ к недавно запрошенным данным + +## Расписание обновлений + +Рекомендуемая частота вызова `Update()`: +- **Каждые 6 часов** - для получения свежих GFS прогнозов +- **При запуске** - для загрузки начальных данных +- **По требованию** - при отсутствии данных для запрашиваемого времени + +## Отличия от tawhiri + +### Преимущества Go-реализации: +- **Высокая производительность** (mmap, конкурентные загрузки) +- **Эффективное использование памяти** (не загружает весь массив в RAM) +- **Горизонтальное масштабирование** (stateless, множество реплик) +- **Встроенное кэширование** (in-memory) + +### Особенности: +- Использует `github.com/nilsmagnus/grib` вместо pygrib +- Реализует собственную логику интерполяции + +## Конфигурация + +### Переменные окружения: +- `PREDICTOR_GRIB_DATASET_URL` - URL источника данных (опционально) + +### Параметры ServiceConfig: +- `Dir` - директория для хранения файлов +- `TTL` - время жизни данных (по умолчанию 24 часа) +- `CacheTTL` - время жизни кэша (по умолчанию 1 час) +- `Parallel` - количество параллельных загрузок +- `Client` - HTTP клиент для загрузок \ No newline at end of file diff --git a/internal/pkg/grib/cache.go b/internal/pkg/grib/cache.go new file mode 100644 index 0000000..7d40f43 --- /dev/null +++ b/internal/pkg/grib/cache.go @@ -0,0 +1,36 @@ +package grib + +import ( + "sync" + "time" +) + +type vec [2]float64 + +type item struct { + v vec + exp time.Time +} + +type memCache struct { + ttl time.Duration + m sync.Map +} + +func (c *memCache) get(k uint64) (vec, bool) { + if v, ok := c.m.Load(k); ok { + it := v.(item) + + if time.Now().Before(it.exp) { + return it.v, true + } + + c.m.Delete(k) + } + + return vec{}, false +} + +func (c *memCache) set(k uint64, v vec) { + c.m.Store(k, item{v, time.Now().Add(c.ttl)}) +} diff --git a/internal/pkg/grib/config.go b/internal/pkg/grib/config.go new file mode 100644 index 0000000..009645d --- /dev/null +++ b/internal/pkg/grib/config.go @@ -0,0 +1,32 @@ +package grib + +import ( + "time" + + "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" + env "github.com/caarlos0/env/v11" +) + +type Config struct { + Dir string `env:"DIR" envDefault:"/tmp/grib"` + TTL time.Duration `env:"TTL" envDefault:"24h"` + CacheTTL time.Duration `env:"CACHE_TTL" envDefault:"1h"` + Parallel int `env:"PARALLEL" envDefault:"8"` + DatasetURL string `env:"DATASET_URL" envDefault:"https://nomads.ncep.noaa.gov/pub/data/nccf/com/gfs/prod"` + // S3 configuration + UseS3 bool `env:"USE_S3" envDefault:"true"` + S3Bucket string `env:"S3_BUCKET" envDefault:"noaa-gfs-bdp-pds"` + S3Region string `env:"S3_REGION" envDefault:"us-east-1"` + S3Timeout time.Duration `env:"S3_TIMEOUT" envDefault:"300s"` +} + +func NewConfig() (*Config, error) { + cfg := &Config{} + if err := env.ParseWithOptions(cfg, env.Options{ + PrefixTagName: "GSN_PREDICTOR_GRIB_", + }); err != nil { + return nil, errcodes.Wrap(err, "failed to parse GRIB config") + } + + return cfg, nil +} diff --git a/internal/pkg/grib/cube.go b/internal/pkg/grib/cube.go new file mode 100644 index 0000000..d2015ec --- /dev/null +++ b/internal/pkg/grib/cube.go @@ -0,0 +1,55 @@ +package grib + +import ( + "encoding/binary" + "math" + "os" + + mmap "github.com/edsrzf/mmap-go" +) + +type cube struct { + mm mmap.MMap + t, p, lat, lon int + bytesPerVar int64 + file *os.File +} + +func openCube(path string) (*cube, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + + mm, err := mmap.Map(f, mmap.RDONLY, 0) + if err != nil { + f.Close() + return nil, err + } + + const ( + nT = 97 // 0-96 hours with step 1 hour + nP = 47 // 47 pressure levels matching tawhiri + nLat = 361 + nLon = 720 + ) + + return &cube{mm: mm, t: nT, p: nP, lat: nLat, lon: nLon, bytesPerVar: int64(nT * nP * nLat * nLon * 4), file: f}, nil +} + +func (c *cube) val(varIdx, ti, pi, y, x int) float32 { + idx := (((ti*c.p+pi)*c.lat + y) * c.lon) + x + off := int64(varIdx)*c.bytesPerVar + int64(idx)*4 + bits := binary.LittleEndian.Uint32(c.mm[off : off+4]) + return math.Float32frombits(bits) +} + +func (c *cube) Close() error { + if c.mm != nil { + c.mm.Unmap() + } + if c.file != nil { + return c.file.Close() + } + return nil +} diff --git a/internal/pkg/grib/dataset.go b/internal/pkg/grib/dataset.go new file mode 100644 index 0000000..e539f65 --- /dev/null +++ b/internal/pkg/grib/dataset.go @@ -0,0 +1,13 @@ +package grib + +type dataset struct { + cube *cube + runUTC int64 // unix seconds +} + +func (d *dataset) Close() error { + if d.cube != nil { + return d.cube.Close() + } + return nil +} diff --git a/internal/pkg/grib/downloader.go b/internal/pkg/grib/downloader.go new file mode 100644 index 0000000..a93a64c --- /dev/null +++ b/internal/pkg/grib/downloader.go @@ -0,0 +1,91 @@ +package grib + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "time" + + "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" + "golang.org/x/sync/errgroup" +) + +type Downloader struct { + Dir string + Parallel int + Client *http.Client + DatasetURL string +} + +func (d *Downloader) fileURL(run string, hour int, step int) string { + return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.0p50.f%03d", d.DatasetURL, run, hour, hour, step) +} + +func (d *Downloader) fetch(ctx context.Context, url, dst string) (err error) { + // Check if final file already exists + if _, err := os.Stat(dst); err == nil { + return nil + } + + tmp := dst + ".part" + + // Remove old .part file if it exists (fixes race condition) + os.Remove(tmp) + + f, err := os.Create(tmp) + if err != nil { + return err + } + + // Cleanup .part file on any error (using named return value) + defer func() { + f.Close() + if err != nil { + os.Remove(tmp) + } + }() + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + resp, err := d.Client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errcodes.Wrap(errcodes.ErrDownload, "bad status: "+resp.Status) + } + + if _, err := io.Copy(f, resp.Body); err != nil { + return err + } + + // Close file before rename + if err := f.Close(); err != nil { + return err + } + + // If rename fails, err will be set and defer will cleanup .part file + return os.Rename(tmp, dst) +} + +func (d *Downloader) Run(ctx context.Context, run time.Time) error { + runStr := run.Format("20060102") + hour := run.Hour() + g, ctx := errgroup.WithContext(ctx) + sem := make(chan struct{}, d.Parallel) + for _, step := range steps { + step := step + sem <- struct{}{} + g.Go(func() error { + defer func() { <-sem }() + url := d.fileURL(runStr, hour, step) + dst := filepath.Join(d.Dir, fileName(run, step)) + return d.fetch(ctx, url, dst) + }) + } + return g.Wait() +} diff --git a/internal/pkg/grib/extractor.go b/internal/pkg/grib/extractor.go new file mode 100644 index 0000000..769b7cd --- /dev/null +++ b/internal/pkg/grib/extractor.go @@ -0,0 +1,54 @@ +package grib + +import "math" + +func lerp(a, b, t float64) float64 { return a + t*(b-a) } + +// Interpolate 16‑point (time, p, lat, lon) +func (d *dataset) uv(lat, lon, alt float64, tHours float64) (float64, float64) { + if lon < 0 { + lon += 360 + } + iy := (lat + 90) * 2 + y0 := int(math.Floor(iy)) + y1 := y0 + 1 + wy := iy - float64(y0) + ix := lon * 2 + x0 := int(math.Floor(ix)) % d.cube.lon + x1 := (x0 + 1) % d.cube.lon + wx := ix - float64(x0) + // For hourly data (step = 1 hour) + it0 := int(math.Floor(tHours)) + wt := tHours - float64(it0) + p := pressureFromAlt(alt) + ip0 := 0 + for ip0+1 < len(pressureLevels) && pressureLevels[ip0+1] > p { + ip0++ + } + ip1 := ip0 + 1 + wp := (pressureLevels[ip0] - p) / (pressureLevels[ip0] - pressureLevels[ip1]) + fetch := func(ti, pi int) (float64, float64) { + u00 := d.cube.val(1, ti, pi, y0, x0) + u10 := d.cube.val(1, ti, pi, y0, x1) + u01 := d.cube.val(1, ti, pi, y1, x0) + u11 := d.cube.val(1, ti, pi, y1, x1) + v00 := d.cube.val(2, ti, pi, y0, x0) + v10 := d.cube.val(2, ti, pi, y0, x1) + v01 := d.cube.val(2, ti, pi, y1, x0) + v11 := d.cube.val(2, ti, pi, y1, x1) + uxy := (1-wy)*((1-wx)*float64(u00)+wx*float64(u10)) + wy*((1-wx)*float64(u01)+wx*float64(u11)) + vxy := (1-wy)*((1-wx)*float64(v00)+wx*float64(v10)) + wy*((1-wx)*float64(v01)+wx*float64(v11)) + return uxy, vxy + } + u0p0, v0p0 := fetch(it0, ip0) + u0p1, v0p1 := fetch(it0, ip1) + u1p0, v1p0 := fetch(it0+1, ip0) + u1p1, v1p1 := fetch(it0+1, ip1) + uLow := lerp(u0p0, u0p1, wp) + vLow := lerp(v0p0, v0p1, wp) + uHig := lerp(u1p0, u1p1, wp) + vHig := lerp(v1p0, v1p1, wp) + u := lerp(uLow, uHig, wt) + v := lerp(vLow, vHig, wt) + return u, v +} diff --git a/internal/pkg/grib/grib.go b/internal/pkg/grib/grib.go new file mode 100644 index 0000000..32c4466 --- /dev/null +++ b/internal/pkg/grib/grib.go @@ -0,0 +1,321 @@ +package grib + +import ( + "context" + "encoding/binary" + "math" + "net/http" + "os" + "path/filepath" + "strings" + "sync/atomic" + "time" + + "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" + "github.com/edsrzf/mmap-go" + "github.com/nilsmagnus/grib/griblib" +) + +type Service interface { + Update(ctx context.Context) error + Extract(ctx context.Context, lat, lon, alt float64, ts time.Time) ([2]float64, error) + Close() error + GetStatus() (ready bool, lastUpdate time.Time, isFresh bool, errMsg string) +} + +type service struct { + cfg *Config + cache memCache + data atomic.Pointer[dataset] +} + +func New(cfg *Config) (Service, error) { + if cfg.TTL == 0 { + cfg.TTL = 24 * time.Hour + } + if err := os.MkdirAll(cfg.Dir, 0o755); err != nil { + return nil, err + } + s := &service{cfg: cfg, cache: memCache{ttl: cfg.CacheTTL}} + + // Try to load existing dataset on startup + if err := s.loadExistingDataset(); err != nil { + // Log error but don't fail startup - dataset will be loaded on first Update() + // This allows the service to start even if no data is available yet + } + + return s, nil +} + +// loadExistingDataset tries to load the most recent available dataset +func (s *service) loadExistingDataset() error { + // Find the most recent cube file + pattern := filepath.Join(s.cfg.Dir, "*.cube") + matches, err := filepath.Glob(pattern) + if err != nil { + return err + } + + if len(matches) == 0 { + return errcodes.ErrNoCubeFilesFound + } + + // Sort by modification time (newest first) + var latestFile string + var latestTime time.Time + + for _, match := range matches { + info, err := os.Stat(match) + if err != nil { + continue + } + + if info.ModTime().After(latestTime) { + latestTime = info.ModTime() + latestFile = match + } + } + + if latestFile == "" { + return errcodes.ErrNoValidCubeFilesFound + } + + // Check if the file is fresh enough + if time.Since(latestTime) > s.cfg.TTL { + return errcodes.Wrap(errcodes.ErrLatestCubeFileIsTooOld, "latest cube file is too old") + } + + // Load the dataset + c, err := openCube(latestFile) + if err != nil { + return err + } + + // Extract run time from filename + base := filepath.Base(latestFile) + runStr := strings.TrimSuffix(base, ".cube") + run, err := time.Parse("20060102_15", runStr) + if err != nil { + c.Close() + return err + } + + ds := &dataset{cube: c, runUTC: run.Unix()} + s.data.Store(ds) + + return nil +} + +// Update() downloads missing GRIBs, assembles cube into a single mmap‑file. +func (s *service) Update(ctx context.Context) error { + // Check if we already have fresh data + if d := s.data.Load(); d != nil { + runTime := time.Unix(d.runUTC, 0) + if time.Since(runTime) < s.cfg.TTL { + // Data is still fresh, no need to update + return nil + } + } + + // Check again after acquiring lock (double-checked locking pattern) + if d := s.data.Load(); d != nil { + runTime := time.Unix(d.runUTC, 0) + if time.Since(runTime) < s.cfg.TTL { + // Another instance already updated the data + return nil + } + } + + run := nearestRun(time.Now().UTC().Add(-24 * time.Hour)) + + // Check if we already have this run + cubePath := filepath.Join(s.cfg.Dir, run.Format("20060102_15")) + ".cube" + if _, err := os.Stat(cubePath); err == nil { + // File exists, check if it's fresh + info, err := os.Stat(cubePath) + if err == nil && time.Since(info.ModTime()) < s.cfg.TTL { + // File is fresh, just load it + c, err := openCube(cubePath) + if err != nil { + return err + } + ds := &dataset{cube: c, runUTC: run.Unix()} + s.data.Store(ds) + s.cache = memCache{ttl: s.cfg.CacheTTL} + return nil + } + } + + // Download new data using S3 or HTTP + var downloadErr error + if s.cfg.UseS3 { + s3dl, err := NewS3Downloader(s.cfg.Dir, s.cfg.Parallel, s.cfg.S3Bucket, s.cfg.S3Region) + if err != nil { + return errcodes.Wrap(err, "failed to create S3 downloader") + } + downloadErr = s3dl.Run(ctx, run) + } else { + dl := Downloader{ + Dir: s.cfg.Dir, + Parallel: s.cfg.Parallel, + Client: http.DefaultClient, + DatasetURL: s.cfg.DatasetURL, + } + downloadErr = dl.Run(ctx, run) + } + + if downloadErr != nil { + return downloadErr + } + + // Assemble cube if it doesn't exist + if _, err := os.Stat(cubePath); err != nil { + if err := assembleCube(s.cfg.Dir, run, cubePath); err != nil { + return err + } + } + + c, err := openCube(cubePath) + if err != nil { + return err + } + ds := &dataset{cube: c, runUTC: run.Unix()} + s.data.Store(ds) + s.cache = memCache{ttl: s.cfg.CacheTTL} + return nil +} + +func assembleCube(dir string, run time.Time, cubePath string) error { + const sizePerVar = 97 * 47 * 361 * 720 * 4 // 97 time steps (0-96 hours), 47 pressure levels + total := int64(sizePerVar * 3) // 3 variables: gh, u, v + f, err := os.Create(cubePath) + if err != nil { + return err + } + if err := f.Truncate(total); err != nil { + return err + } + mm, err := mmap.MapRegion(f, int(total), mmap.RDWR, 0, 0) + if err != nil { + return err + } + defer mm.Unmap() + defer f.Close() + + pIndex := make(map[int]int) + for i, p := range pressureLevels { + pIndex[int(math.Round(p))] = i + } + + for ti, step := range steps { + fn := filepath.Join(dir, fileName(run, step)) + file, err := os.Open(fn) + if err != nil { + return err + } + + messages, err := griblib.ReadMessages(file) + file.Close() // Close immediately after reading + if err != nil { + return err + } + + for _, m := range messages { + // Check if this is a wind component (u or v) or geopotential height + // ParameterCategory 2 = momentum, ParameterNumber 2 = u-wind, 3 = v-wind + // ParameterCategory 3 = mass, ParameterNumber 5 = geopotential height + if m.Section4.ProductDefinitionTemplateNumber != 0 { + continue + } + + product := m.Section4.ProductDefinitionTemplate + + var varIdx int + // Match tawhiri variable order: ['gh', 'u', 'v'] (indices 0, 1, 2) + if product.ParameterCategory == 2 { + switch product.ParameterNumber { + case 2: // u-wind + varIdx = 1 + case 3: // v-wind + varIdx = 2 + default: + continue + } + } else if product.ParameterCategory == 3 && product.ParameterNumber == 5 { + // geopotential height + varIdx = 0 + } else { + continue + } + + // Check if this is a pressure level (type 100) + if product.FirstSurface.Type != 100 { + continue + } + + // Get pressure level in hPa + pressure := float64(product.FirstSurface.Value) / 100.0 + pIdx, ok := pIndex[int(math.Round(pressure))] + if !ok { + continue + } + + vals := m.Data() + // GRIB library returns scan north->south, west->east already in row-major order + raw := make([]byte, len(vals)*4) + for i, v := range vals { + binary.LittleEndian.PutUint32(raw[i*4:], math.Float32bits(float32(v))) + } + base := int64(varIdx*sizePerVar + (ti*47+pIdx)*361*720*4) + copy(mm[base:base+int64(len(raw))], raw) + } + } + return mm.Flush() +} + +func (s *service) Extract(ctx context.Context, lat, lon, alt float64, ts time.Time) ([2]float64, error) { + var zero [2]float64 + d := s.data.Load() + if d == nil { + return zero, errcodes.ErrNoDataset + } + if ts.Before(time.Unix(d.runUTC, 0)) || ts.After(time.Unix(d.runUTC, 0).Add(96*time.Hour)) { + return zero, errcodes.ErrOutOfBounds + } + + // Try memory cache first + key := encodeKey(lat, lon, alt, ts) + if v, ok := s.cache.get(key); ok { + return [2]float64(v), nil + } + + // Calculate result + td := ts.Sub(time.Unix(d.runUTC, 0)).Hours() + u, v := d.uv(lat, lon, alt, td) + out := [2]float64{u, v} + + // Cache in memory + s.cache.set(key, vec(out)) + + return out, nil +} + +func (s *service) Close() error { + if d := s.data.Load(); d != nil { + return d.Close() + } + return nil +} + +func (s *service) GetStatus() (ready bool, lastUpdate time.Time, isFresh bool, errMsg string) { + d := s.data.Load() + if d == nil { + return false, time.Time{}, false, "no dataset loaded" + } + runTime := time.Unix(d.runUTC, 0) + fresh := time.Since(runTime) < s.cfg.TTL + if !fresh { + return false, runTime, false, "dataset is too old" + } + return true, runTime, true, "" +} diff --git a/internal/pkg/grib/pressure.go b/internal/pkg/grib/pressure.go new file mode 100644 index 0000000..add6ff0 --- /dev/null +++ b/internal/pkg/grib/pressure.go @@ -0,0 +1,16 @@ +package grib + +import "math" + +// 47 pressure levels matching tawhiri configuration +var pressureLevels = []float64{ + 1000, 975, 950, 925, 900, 875, 850, 825, 800, 775, + 750, 725, 700, 675, 650, 625, 600, 575, 550, 525, + 500, 475, 450, 425, 400, 375, 350, 325, 300, 275, + 250, 225, 200, 175, 150, 125, 100, 70, 50, 30, + 20, 10, 7, 5, 3, 2, 1, +} + +func pressureFromAlt(alt float64) float64 { // ICAO ISA + return 1013.25 * math.Pow(1-alt/44307.69396, 5.255877) +} diff --git a/internal/pkg/grib/s3_downloader.go b/internal/pkg/grib/s3_downloader.go new file mode 100644 index 0000000..0fa4c70 --- /dev/null +++ b/internal/pkg/grib/s3_downloader.go @@ -0,0 +1,265 @@ +package grib + +import ( + "context" + "fmt" + "io" + "os" + "path/filepath" + "time" + + "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/s3" + "golang.org/x/sync/errgroup" +) + +// S3Downloader downloads GRIB files from AWS S3 +type S3Downloader struct { + Dir string + Parallel int + Bucket string + Region string + Client *s3.Client +} + +// NewS3Downloader creates a new S3 downloader with anonymous access +func NewS3Downloader(dir string, parallel int, bucket, region string) (*S3Downloader, error) { + // Create AWS config with anonymous credentials for public bucket + cfg, err := config.LoadDefaultConfig(context.Background(), + config.WithRegion(region), + config.WithCredentialsProvider(aws.AnonymousCredentials{}), + ) + if err != nil { + return nil, errcodes.Wrap(err, "failed to load AWS config") + } + + client := s3.NewFromConfig(cfg) + + return &S3Downloader{ + Dir: dir, + Parallel: parallel, + Bucket: bucket, + Region: region, + Client: client, + }, nil +} + +// s3Key generates the S3 key for a GRIB file +// Path format: gfs.YYYYMMDD/HH/atmos/gfs.tHHz.pgrb2.0p50.fFFF +func (d *S3Downloader) s3Key(run string, hour int, step int) string { + return fmt.Sprintf("gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.0p50.f%03d", run, hour, hour, step) +} + +// CheckFileExists checks if a file exists in S3 using HeadObject +func (d *S3Downloader) CheckFileExists(ctx context.Context, key string) (bool, int64, error) { + input := &s3.HeadObjectInput{ + Bucket: aws.String(d.Bucket), + Key: aws.String(key), + } + + result, err := d.Client.HeadObject(ctx, input) + if err != nil { + // Check if error is NotFound + // AWS SDK v2 doesn't export specific error types, check error string + if isNotFoundError(err) { + return false, 0, nil + } + return false, 0, errcodes.Wrap(err, "failed to check file existence") + } + + size := int64(0) + if result.ContentLength != nil { + size = *result.ContentLength + } + + return true, size, nil +} + +// isNotFoundError checks if error is a NotFound error +func isNotFoundError(err error) bool { + if err == nil { + return false + } + // AWS SDK v2 error handling + errStr := err.Error() + return contains(errStr, "NotFound") || contains(errStr, "404") || contains(errStr, "NoSuchKey") +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > len(substr) && findSubstring(s, substr)) +} + +func findSubstring(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +// ListAvailableFiles lists all available files for a given run +func (d *S3Downloader) ListAvailableFiles(ctx context.Context, run string, hour int) ([]string, error) { + prefix := fmt.Sprintf("gfs.%s/%02d/atmos/", run, hour) + + input := &s3.ListObjectsV2Input{ + Bucket: aws.String(d.Bucket), + Prefix: aws.String(prefix), + } + + var files []string + paginator := s3.NewListObjectsV2Paginator(d.Client, input) + + for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) + if err != nil { + return nil, errcodes.Wrap(err, "failed to list S3 objects") + } + + for _, obj := range page.Contents { + if obj.Key != nil { + files = append(files, *obj.Key) + } + } + } + + return files, nil +} + +// fetchFromS3 downloads a file from S3 to local disk with retry logic +func (d *S3Downloader) fetchFromS3(ctx context.Context, key, dst string) (err error) { + // Check if final file already exists + if _, err := os.Stat(dst); err == nil { + return nil + } + + const maxRetries = 3 + var lastErr error + + for attempt := 0; attempt < maxRetries; attempt++ { + if attempt > 0 { + // Exponential backoff: 2s, 4s, 8s + waitTime := time.Duration(1< 0 && written != size { + return errcodes.Wrap(errcodes.ErrDownload, fmt.Sprintf("size mismatch: got %d bytes, expected %d", written, size)) + } + + // Close file before rename + if err := f.Close(); err != nil { + return err + } + fileClosed = true + + // If rename fails, err will be set and defer will cleanup .part file + return os.Rename(tmp, dst) +} + +// Run downloads all required GRIB files for a forecast run +func (d *S3Downloader) Run(ctx context.Context, run time.Time) error { + runStr := run.Format("20060102") + hour := run.Hour() + + // First, list available files to verify they exist + availableFiles, err := d.ListAvailableFiles(ctx, runStr, hour) + if err != nil { + return errcodes.Wrap(err, "failed to list available files") + } + + if len(availableFiles) == 0 { + return errcodes.Wrap(errcodes.ErrDownload, fmt.Sprintf("no files found for run %s/%02d", runStr, hour)) + } + + // Build a map of available files for quick lookup + availableMap := make(map[string]bool) + for _, file := range availableFiles { + availableMap[file] = true + } + + g, ctx := errgroup.WithContext(ctx) + sem := make(chan struct{}, d.Parallel) + + for _, step := range steps { + step := step + key := d.s3Key(runStr, hour, step) + + // Check if file is available in S3 + if !availableMap[key] { + // Log warning but don't fail - some forecast hours might not be available yet + continue + } + + sem <- struct{}{} + g.Go(func() error { + defer func() { <-sem }() + dst := filepath.Join(d.Dir, fileName(run, step)) + return d.fetchFromS3(ctx, key, dst) + }) + } + + return g.Wait() +} diff --git a/internal/pkg/grib/util.go b/internal/pkg/grib/util.go new file mode 100644 index 0000000..8de4af7 --- /dev/null +++ b/internal/pkg/grib/util.go @@ -0,0 +1,34 @@ +package grib + +import ( + "fmt" + "hash/fnv" + "time" +) + +// Generate steps from 0 to 96 with step 1 hour (97 steps total) +// GFS provides hourly data for 0-120 hours, we use first 96 hours +var steps = func() []int { + result := make([]int, 0, 97) + for i := 0; i <= 96; i++ { + result = append(result, i) + } + return result +}() + +func nearestRun(t time.Time) time.Time { + h := t.UTC().Hour() - t.UTC().Hour()%6 + return time.Date(t.Year(), t.Month(), t.Day(), h, 0, 0, 0, time.UTC) +} + +func fileName(run time.Time, step int) string { + return fmt.Sprintf("gfs.t%02dz.pgrb2.0p50.f%03d", run.Hour(), step) +} + +func encodeKey(a ...any) uint64 { + h := fnv.New64a() + for _, v := range a { + fmt.Fprint(h, v) + } + return h.Sum64() +} diff --git a/internal/pkg/log/log.go b/internal/pkg/log/log.go new file mode 100644 index 0000000..58db255 --- /dev/null +++ b/internal/pkg/log/log.go @@ -0,0 +1,23 @@ +package log + +import ( + "context" + + "go.uber.org/zap" +) + +type ctxLogKey struct{} + +func ToCtx(ctx context.Context, lg *zap.Logger) context.Context { + return context.WithValue(ctx, ctxLogKey{}, lg) +} + +func Ctx(ctx context.Context) *zap.Logger { + lg, ok := ctx.Value(ctxLogKey{}).(*zap.Logger) + if !ok || lg == nil { + zap.L().Error("no logger in context, using global") + return zap.L() + } + + return lg +} diff --git a/internal/prediction/interpolate.go b/internal/prediction/interpolate.go deleted file mode 100644 index 5ef0d14..0000000 --- a/internal/prediction/interpolate.go +++ /dev/null @@ -1,153 +0,0 @@ -package prediction - -import ( - "fmt" - - "predictor-refactored/internal/dataset" -) - -// Exact port of the reference interpolation logic (interpolate.pyx). -// 4D interpolation: time, latitude, longitude, altitude (via geopotential height). - -// lerp1 holds an index and interpolation weight for one axis. -type lerp1 struct { - index int - lerp float64 -} - -// lerp3 holds indices and a combined weight for the (hour, lat, lon) axes. -type lerp3 struct { - hour, lat, lng int - lerp float64 -} - -// RangeError indicates a coordinate is outside the dataset bounds. -type RangeError struct { - Variable string - Value float64 -} - -func (e *RangeError) Error() string { - return fmt.Sprintf("%s=%f out of range", e.Variable, e.Value) -} - -// pick computes interpolation indices and weights for a single axis. -// left: axis start, step: axis spacing, n: number of points, value: query value. -// Returns two lerp1 values (lower and upper bracket). -func pick(left, step float64, n int, value float64, variableName string) ([2]lerp1, error) { - a := (value - left) / step - b := int(a) // truncation toward zero, same as Cython cast - if b < 0 || b >= n-1 { - return [2]lerp1{}, &RangeError{Variable: variableName, Value: value} - } - l := a - float64(b) - return [2]lerp1{ - {index: b, lerp: 1 - l}, - {index: b + 1, lerp: l}, - }, nil -} - -// pick3 computes 8 trilinear interpolation weights for (hour, lat, lng). -func pick3(hour, lat, lng float64) ([8]lerp3, error) { - lhour, err := pick(0, 3, 65, hour, "hour") - if err != nil { - return [8]lerp3{}, err - } - llat, err := pick(-90, 0.5, 361, lat, "lat") - if err != nil { - return [8]lerp3{}, err - } - // Longitude wraps: tell pick the axis is one larger, then wrap index 720 → 0 - llng, err := pick(0, 0.5, 720+1, lng, "lng") - if err != nil { - return [8]lerp3{}, err - } - if llng[1].index == 720 { - llng[1].index = 0 - } - - var out [8]lerp3 - i := 0 - for _, a := range lhour { - for _, b := range llat { - for _, c := range llng { - out[i] = lerp3{ - hour: a.index, - lat: b.index, - lng: c.index, - lerp: a.lerp * b.lerp * c.lerp, - } - i++ - } - } - } - return out, nil -} - -// interp3 performs 8-point weighted interpolation at a given variable and pressure level. -func interp3(ds *dataset.File, lerps [8]lerp3, variable, level int) float64 { - var r float64 - for i := 0; i < 8; i++ { - v := ds.Val(lerps[i].hour, level, variable, lerps[i].lat, lerps[i].lng) - r += float64(v) * lerps[i].lerp - } - return r -} - -// search finds the largest pressure level index where interpolated geopotential -// height is less than the target altitude. Searches levels 0..45 (excludes topmost). -func search(ds *dataset.File, lerps [8]lerp3, target float64) int { - lower, upper := 0, 45 - - for lower < upper { - mid := (lower + upper + 1) / 2 - test := interp3(ds, lerps, dataset.VarHeight, mid) - if target <= test { - upper = mid - 1 - } else { - lower = mid - } - } - - return lower -} - -// interp4 performs altitude-interpolated wind lookup using two bracketing levels. -func interp4(ds *dataset.File, lerps [8]lerp3, altLerp lerp1, variable int) float64 { - lower := interp3(ds, lerps, variable, altLerp.index) - upper := interp3(ds, lerps, variable, altLerp.index+1) - return lower*altLerp.lerp + upper*(1-altLerp.lerp) -} - -// GetWind returns interpolated (u, v) wind components for the given position. -// hour: fractional hours since dataset start. -// lat: latitude in degrees (-90 to +90). -// lng: longitude in degrees (0 to 360). -// alt: altitude in metres above sea level. -func GetWind(ds *dataset.File, warnings *Warnings, hour, lat, lng, alt float64) (u, v float64, err error) { - lerps, err := pick3(hour, lat, lng) - if err != nil { - return 0, 0, err - } - - altidx := search(ds, lerps, alt) - lower := interp3(ds, lerps, dataset.VarHeight, altidx) - upper := interp3(ds, lerps, dataset.VarHeight, altidx+1) - - var altLerp float64 - if lower != upper { - altLerp = (upper - alt) / (upper - lower) - } else { - altLerp = 0.5 - } - - if altLerp < 0 { - warnings.AltitudeTooHigh.Add(1) - } - - alt1 := lerp1{index: altidx, lerp: altLerp} - u = interp4(ds, lerps, alt1, dataset.VarWindU) - v = interp4(ds, lerps, alt1, dataset.VarWindV) - - return u, v, nil -} diff --git a/internal/prediction/models.go b/internal/prediction/models.go deleted file mode 100644 index 8048c46..0000000 --- a/internal/prediction/models.go +++ /dev/null @@ -1,188 +0,0 @@ -package prediction - -import ( - "math" - "time" - - "predictor-refactored/internal/dataset" - "predictor-refactored/internal/elevation" -) - -// Exact port of the reference flight models (models.py). - -const ( - pi180 = math.Pi / 180.0 - _180pi = 180.0 / math.Pi -) - -// --- Up/Down Models --- - -// ConstantAscent returns a model with constant vertical velocity (m/s). -func ConstantAscent(ascentRate float64) Model { - return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) { - return 0, 0, ascentRate - } -} - -// DragDescent returns a descent-under-parachute model. -// seaLevelDescentRate is the descent rate at sea level (m/s, positive value). -// Uses the NASA atmosphere model for density at altitude. -func DragDescent(seaLevelDescentRate float64) Model { - dragCoefficient := seaLevelDescentRate * 1.1045 - - return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) { - return 0, 0, -dragCoefficient / math.Sqrt(nasaDensity(alt)) - } -} - -// nasaDensity computes air density using the NASA atmosphere model. -// Reference: http://www.grc.nasa.gov/WWW/K-12/airplane/atmosmet.html -func nasaDensity(alt float64) float64 { - var temp, pressure float64 - - switch { - case alt > 25000: - temp = -131.21 + 0.00299*alt - pressure = 2.488 * math.Pow((temp+273.1)/216.6, -11.388) - case alt > 11000: - temp = -56.46 - pressure = 22.65 * math.Exp(1.73-0.000157*alt) - default: - temp = 15.04 - 0.00649*alt - pressure = 101.29 * math.Pow((temp+273.1)/288.08, 5.256) - } - - return pressure / (0.2869 * (temp + 273.1)) -} - -// --- Sideways Models --- - -// WindVelocity returns a model that gives lateral movement at the wind velocity. -// ds is the wind dataset, dsEpoch is the dataset start time as UNIX timestamp. -func WindVelocity(ds *dataset.File, dsEpoch float64, warnings *Warnings) Model { - return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) { - tHours := (t - dsEpoch) / 3600.0 - u, v, err := GetWind(ds, warnings, tHours, lat, lng, alt) - if err != nil { - return 0, 0, 0 - } - - R := 6371009.0 + alt - dlat = _180pi * v / R - dlng = _180pi * u / (R * math.Cos(lat*pi180)) - return dlat, dlng, 0 - } -} - -// --- Model Combinations --- - -// LinearModel returns a model that sums all component models. -func LinearModel(models ...Model) Model { - return func(t, lat, lng, alt float64) (dlat, dlng, dalt float64) { - for _, m := range models { - d1, d2, d3 := m(t, lat, lng, alt) - dlat += d1 - dlng += d2 - dalt += d3 - } - return - } -} - -// --- Termination Criteria --- - -// BurstTermination returns a terminator that fires when altitude >= burstAltitude. -func BurstTermination(burstAltitude float64) Terminator { - return func(t, lat, lng, alt float64) bool { - return alt >= burstAltitude - } -} - -// SeaLevelTermination fires when altitude <= 0. -func SeaLevelTermination(t, lat, lng, alt float64) bool { - return alt <= 0 -} - -// TimeTermination returns a terminator that fires when t > maxTime. -func TimeTermination(maxTime float64) Terminator { - return func(t, lat, lng, alt float64) bool { - return t > maxTime - } -} - -// ElevationTermination returns a terminator that fires when alt < ground level. -// Uses ruaumoko-compatible elevation data. Longitude is normalised internally. -func ElevationTermination(elev *elevation.Dataset) Terminator { - return func(t, lat, lng, alt float64) bool { - return elev.Get(lat, lng) > alt - } -} - -// --- Pre-Defined Profiles --- - -// Stage pairs a model with its termination criterion. -type Stage struct { - Model Model - Terminator Terminator -} - -// StandardProfile creates the chain for a standard high-altitude balloon flight: -// ascent at constant rate → burst → descent under parachute. -// If elev is non-nil, descent terminates at ground level; otherwise at sea level. -func StandardProfile(ascentRate, burstAltitude, descentRate float64, - ds *dataset.File, dsEpoch float64, warnings *Warnings, - elev *elevation.Dataset) []Stage { - - wind := WindVelocity(ds, dsEpoch, warnings) - - modelUp := LinearModel(ConstantAscent(ascentRate), wind) - termUp := BurstTermination(burstAltitude) - - modelDown := LinearModel(DragDescent(descentRate), wind) - var termDown Terminator - if elev != nil { - termDown = ElevationTermination(elev) - } else { - termDown = Terminator(SeaLevelTermination) - } - - return []Stage{ - {Model: modelUp, Terminator: termUp}, - {Model: modelDown, Terminator: termDown}, - } -} - -// FloatProfile creates the chain for a floating balloon flight: -// ascent to float altitude → float until stop time. -func FloatProfile(ascentRate, floatAltitude float64, stopTime time.Time, - ds *dataset.File, dsEpoch float64, warnings *Warnings) []Stage { - - wind := WindVelocity(ds, dsEpoch, warnings) - - modelUp := LinearModel(ConstantAscent(ascentRate), wind) - termUp := BurstTermination(floatAltitude) - - modelFloat := wind - termFloat := TimeTermination(float64(stopTime.Unix())) - - return []Stage{ - {Model: modelUp, Terminator: termUp}, - {Model: modelFloat, Terminator: termFloat}, - } -} - -// RunPrediction runs a prediction with the given profile stages. -// launchTime is a UNIX timestamp. -func RunPrediction(launchTime float64, lat, lng, alt float64, stages []Stage) []StageResult { - chain := make([]struct { - Model Model - Terminator Terminator - }, len(stages)) - - for i, s := range stages { - chain[i].Model = s.Model - chain[i].Terminator = s.Terminator - } - - return Solve(launchTime, lat, lng, alt, chain) -} diff --git a/internal/prediction/solver.go b/internal/prediction/solver.go deleted file mode 100644 index 62e29a7..0000000 --- a/internal/prediction/solver.go +++ /dev/null @@ -1,180 +0,0 @@ -package prediction - -import "math" - -// Exact port of the reference RK4 solver (solver.pyx). -// Integrates balloon state using RK4 with dt=60 seconds. -// Termination uses binary search refinement (tolerance 0.01). - -// Vec holds the balloon state: latitude, longitude, altitude. -type Vec struct { - Lat float64 - Lng float64 - Alt float64 -} - -// Model is a function that returns (dlat/dt, dlng/dt, dalt/dt) given state. -// t is UNIX timestamp, lat/lng in degrees, alt in metres. -type Model func(t float64, lat, lng, alt float64) (dlat, dlng, dalt float64) - -// Terminator returns true when integration should stop. -type Terminator func(t float64, lat, lng, alt float64) bool - -// StageResult holds the trajectory points for one flight stage. -type StageResult struct { - Points []TrajectoryPoint -} - -// TrajectoryPoint is a single point in a trajectory (used by solver). -type TrajectoryPoint struct { - T float64 // UNIX timestamp - Lat float64 - Lng float64 - Alt float64 -} - -// pymod returns a % b with Python semantics (always non-negative when b > 0). -func pymod(a, b float64) float64 { - r := math.Mod(a, b) - if r < 0 { - r += b - } - return r -} - -// vecadd returns a + k*b, with lng wrapped to [0, 360). -func vecadd(a Vec, k float64, b Vec) Vec { - return Vec{ - Lat: a.Lat + k*b.Lat, - Lng: pymod(a.Lng+k*b.Lng, 360.0), - Alt: a.Alt + k*b.Alt, - } -} - -// scalarLerp returns (1-l)*a + l*b. -func scalarLerp(a, b, l float64) float64 { - return (1-l)*a + l*b -} - -// lngLerp interpolates longitude handling the 0/360 wrap-around. -func lngLerp(a, b, l float64) float64 { - l2 := 1 - l - - if a > b { - a, b = b, a - l, l2 = l2, l - } - - // distance round one way: b - a - // distance around other: (a + 360) - b - if b-a < 180.0 { - return l2*a + l*b - } - return pymod(l2*(a+360)+l*b, 360.0) -} - -// vecLerp returns (1-l)*a + l*b with proper longitude wrapping. -func vecLerp(a, b Vec, l float64) Vec { - return Vec{ - Lat: scalarLerp(a.Lat, b.Lat, l), - Lng: lngLerp(a.Lng, b.Lng, l), - Alt: scalarLerp(a.Alt, b.Alt, l), - } -} - -// rk4 integrates from initial conditions using RK4. -// dt=60.0 seconds, terminationTolerance=0.01. -func rk4(t float64, lat, lng, alt float64, model Model, terminator Terminator) []TrajectoryPoint { - const dt = 60.0 - const terminationTolerance = 0.01 - - y := Vec{Lat: lat, Lng: lng, Alt: alt} - result := []TrajectoryPoint{{T: t, Lat: y.Lat, Lng: y.Lng, Alt: y.Alt}} - - for { - // Evaluate model at 4 points (standard RK4) - k1lat, k1lng, k1alt := model(t, y.Lat, y.Lng, y.Alt) - k1 := Vec{Lat: k1lat, Lng: k1lng, Alt: k1alt} - - mid1 := vecadd(y, dt/2, k1) - k2lat, k2lng, k2alt := model(t+dt/2, mid1.Lat, mid1.Lng, mid1.Alt) - k2 := Vec{Lat: k2lat, Lng: k2lng, Alt: k2alt} - - mid2 := vecadd(y, dt/2, k2) - k3lat, k3lng, k3alt := model(t+dt/2, mid2.Lat, mid2.Lng, mid2.Alt) - k3 := Vec{Lat: k3lat, Lng: k3lng, Alt: k3alt} - - end := vecadd(y, dt, k3) - k4lat, k4lng, k4alt := model(t+dt, end.Lat, end.Lng, end.Alt) - k4 := Vec{Lat: k4lat, Lng: k4lng, Alt: k4alt} - - // y2 = y + dt/6*k1 + dt/3*k2 + dt/3*k3 + dt/6*k4 - y2 := y - y2 = vecadd(y2, dt/6, k1) - y2 = vecadd(y2, dt/3, k2) - y2 = vecadd(y2, dt/3, k3) - y2 = vecadd(y2, dt/6, k4) - - t2 := t + dt - - if terminator(t2, y2.Lat, y2.Lng, y2.Alt) { - // Binary search to refine the termination point. - // Find l in [0, 1] such that (t3, y3) = lerp((t, y), (t2, y2), l) - // is near where the terminator fires. - left := 0.0 - right := 1.0 - - var t3 float64 - var y3 Vec - t3 = t2 - y3 = y2 - - for right-left > terminationTolerance { - mid := (left + right) / 2 - t3 = scalarLerp(t, t2, mid) - y3 = vecLerp(y, y2, mid) - - if terminator(t3, y3.Lat, y3.Lng, y3.Alt) { - right = mid - } else { - left = mid - } - } - - result = append(result, TrajectoryPoint{T: t3, Lat: y3.Lat, Lng: y3.Lng, Alt: y3.Alt}) - break - } - - // Update current state - t = t2 - y = y2 - result = append(result, TrajectoryPoint{T: t, Lat: y.Lat, Lng: y.Lng, Alt: y.Alt}) - } - - return result -} - -// Solve runs through a chain of (model, terminator) stages. -// Returns one StageResult per stage. -func Solve(t, lat, lng, alt float64, chain []struct { - Model Model - Terminator Terminator -}) []StageResult { - var results []StageResult - - for _, stage := range chain { - points := rk4(t, lat, lng, alt, stage.Model, stage.Terminator) - results = append(results, StageResult{Points: points}) - - // Next stage starts where this one ended - if len(points) > 0 { - last := points[len(points)-1] - t = last.T - lat = last.Lat - lng = last.Lng - alt = last.Alt - } - } - - return results -} diff --git a/internal/prediction/warnings.go b/internal/prediction/warnings.go deleted file mode 100644 index 1beeb1a..0000000 --- a/internal/prediction/warnings.go +++ /dev/null @@ -1,21 +0,0 @@ -package prediction - -import "sync/atomic" - -// Warnings tracks warning conditions during a prediction run. -type Warnings struct { - AltitudeTooHigh atomic.Int64 -} - -// ToMap returns warnings as a map suitable for JSON serialization. -// Only includes warnings that have fired. -func (w *Warnings) ToMap() map[string]any { - result := make(map[string]any) - if n := w.AltitudeTooHigh.Load(); n > 0 { - result["altitude_too_high"] = map[string]any{ - "count": n, - "description": "The altitude went too high, above the max forecast wind. Wind data will be unreliable", - } - } - return result -} diff --git a/internal/service/deps.go b/internal/service/deps.go new file mode 100644 index 0000000..80a945c --- /dev/null +++ b/internal/service/deps.go @@ -0,0 +1,12 @@ +package service + +import ( + "context" + "time" +) + +type Grib interface { + Update(ctx context.Context) error + Extract(ctx context.Context, lat, lon, alt float64, ts time.Time) ([2]float64, error) + Close() error +} diff --git a/internal/service/predictor.go b/internal/service/predictor.go new file mode 100644 index 0000000..e5c6fad --- /dev/null +++ b/internal/service/predictor.go @@ -0,0 +1,516 @@ +package service + +import ( + "context" + "encoding/base64" + "encoding/json" + "math" + "time" + + "git.intra.yksa.space/gsn/predictor/internal/pkg/ds" + "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" + "git.intra.yksa.space/gsn/predictor/internal/pkg/log" + "go.uber.org/zap" +) + +var ErrInvalidParameters = errcodes.New(400, "missing required prediction parameters") + +// Stage represents a prediction stage (ascent, descent, float) +type Stage struct { + Name string + Results []ds.PredicitonResult + StartTime time.Time + EndTime time.Time +} + +// CustomCurve represents a custom ascent/descent curve +type CustomCurve struct { + Altitude []float64 `json:"altitude"` + Time []float64 `json:"time"` // seconds from start +} + +func (s *Service) PerformPrediction(ctx context.Context, params ds.PredictionParameters) ([]ds.PredicitonResult, error) { + // Validate required parameters + if params.LaunchLatitude == nil || params.LaunchLongitude == nil || params.LaunchAltitude == nil || params.LaunchDatetime == nil { + return nil, ErrInvalidParameters + } + + // Get default values + profile := "standard_profile" + if params.Profile != nil { + profile = *params.Profile + } + + ascentRate := 5.0 + if params.AscentRate != nil { + ascentRate = *params.AscentRate + } + + burstAltitude := 30000.0 + if params.BurstAltitude != nil { + burstAltitude = *params.BurstAltitude + } + + descentRate := 5.0 + if params.DescentRate != nil { + descentRate = *params.DescentRate + } + + floatAltitude := 0.0 + if params.FloatAltitude != nil { + floatAltitude = *params.FloatAltitude + } + + // Parse custom curves if provided + var ascentCurve, descentCurve *CustomCurve + if params.AscentCurve != nil && *params.AscentCurve != "" { + if curve, err := parseCustomCurve(*params.AscentCurve); err == nil { + ascentCurve = curve + } + } + if params.DescentCurve != nil && *params.DescentCurve != "" { + if curve, err := parseCustomCurve(*params.DescentCurve); err == nil { + descentCurve = curve + } + } + + log.Ctx(ctx).Info("Starting prediction", + zap.String("profile", profile), + zap.Float64("lat", *params.LaunchLatitude), + zap.Float64("lon", *params.LaunchLongitude), + zap.Float64("alt", *params.LaunchAltitude), + zap.Time("time", *params.LaunchDatetime), + ) + + var allResults []ds.PredicitonResult + + switch profile { + case "standard_profile": + allResults = s.standardProfile(ctx, params, ascentRate, burstAltitude, descentRate, ascentCurve, descentCurve) + case "float_profile": + allResults = s.floatProfile(ctx, params, ascentRate, burstAltitude, floatAltitude, descentRate, ascentCurve, descentCurve) + case "reverse_profile": + allResults = s.reverseProfile(ctx, params, ascentRate, burstAltitude, descentRate, ascentCurve, descentCurve) + case "custom_profile": + allResults = s.customProfile(ctx, params, ascentCurve, descentCurve) + default: + return nil, errcodes.New(400, "unsupported profile: "+profile) + } + + log.Ctx(ctx).Info("Prediction complete", zap.Int("total_steps", len(allResults))) + return allResults, nil +} + +func (s *Service) standardProfile(ctx context.Context, params ds.PredictionParameters, ascentRate, burstAltitude, descentRate float64, ascentCurve, descentCurve *CustomCurve) []ds.PredicitonResult { + var results []ds.PredicitonResult + + // Stage 1: Ascent + ascentResults := s.simulateAscent(ctx, params, ascentRate, burstAltitude, ascentCurve) + results = append(results, ascentResults...) + + if len(ascentResults) > 0 { + // Get final position from ascent + lastResult := ascentResults[len(ascentResults)-1] + + // Stage 2: Descent + descentParams := ds.PredictionParameters{ + LaunchLatitude: lastResult.Latitude, + LaunchLongitude: lastResult.Longitude, + LaunchAltitude: lastResult.Altitude, + LaunchDatetime: lastResult.Timestamp, + } + + descentResults := s.simulateDescent(ctx, descentParams, descentRate, 0, descentCurve) + results = append(results, descentResults...) + } + + return results +} + +func (s *Service) floatProfile(ctx context.Context, params ds.PredictionParameters, ascentRate, burstAltitude, floatAltitude, descentRate float64, ascentCurve, descentCurve *CustomCurve) []ds.PredicitonResult { + var results []ds.PredicitonResult + + // Stage 1: Ascent to float altitude + ascentResults := s.simulateAscent(ctx, params, ascentRate, floatAltitude, ascentCurve) + results = append(results, ascentResults...) + + if len(ascentResults) > 0 { + // Stage 2: Float (simulate for some time) + lastResult := ascentResults[len(ascentResults)-1] + floatResults := s.simulateFloat(ctx, lastResult, 30*time.Minute) // Float for 30 minutes + results = append(results, floatResults...) + + if len(floatResults) > 0 { + // Stage 3: Descent + finalFloat := floatResults[len(floatResults)-1] + descentParams := ds.PredictionParameters{ + LaunchLatitude: finalFloat.Latitude, + LaunchLongitude: finalFloat.Longitude, + LaunchAltitude: finalFloat.Altitude, + LaunchDatetime: finalFloat.Timestamp, + } + + descentResults := s.simulateDescent(ctx, descentParams, descentRate, 0, descentCurve) + results = append(results, descentResults...) + } + } + + return results +} + +func (s *Service) reverseProfile(ctx context.Context, params ds.PredictionParameters, ascentRate, burstAltitude, descentRate float64, ascentCurve, descentCurve *CustomCurve) []ds.PredicitonResult { + var results []ds.PredicitonResult + + // Stage 1: Ascent + ascentResults := s.simulateAscent(ctx, params, ascentRate, burstAltitude, ascentCurve) + results = append(results, ascentResults...) + + if len(ascentResults) > 0 { + // Stage 2: Descent to float altitude + lastResult := ascentResults[len(ascentResults)-1] + descentParams := ds.PredictionParameters{ + LaunchLatitude: lastResult.Latitude, + LaunchLongitude: lastResult.Longitude, + LaunchAltitude: lastResult.Altitude, + LaunchDatetime: lastResult.Timestamp, + } + + // Descent to float altitude (if specified) + floatAlt := 0.0 + if params.FloatAltitude != nil { + floatAlt = *params.FloatAltitude + } + + descentResults := s.simulateDescent(ctx, descentParams, descentRate, floatAlt, descentCurve) + results = append(results, descentResults...) + + if floatAlt > 0 && len(descentResults) > 0 { + // Stage 3: Float + finalDescent := descentResults[len(descentResults)-1] + floatResults := s.simulateFloat(ctx, finalDescent, 30*time.Minute) + results = append(results, floatResults...) + } + } + + return results +} + +func (s *Service) customProfile(ctx context.Context, params ds.PredictionParameters, ascentCurve, descentCurve *CustomCurve) []ds.PredicitonResult { + var results []ds.PredicitonResult + + if ascentCurve != nil { + ascentResults := s.simulateCustomAscent(ctx, params, ascentCurve) + results = append(results, ascentResults...) + } + + if descentCurve != nil && len(results) > 0 { + lastResult := results[len(results)-1] + descentParams := ds.PredictionParameters{ + LaunchLatitude: lastResult.Latitude, + LaunchLongitude: lastResult.Longitude, + LaunchAltitude: lastResult.Altitude, + LaunchDatetime: lastResult.Timestamp, + } + + descentResults := s.simulateCustomDescent(ctx, descentParams, descentCurve) + results = append(results, descentResults...) + } + + return results +} + +func rk4Step(lat, lon, alt float64, t time.Time, dt float64, windFunc func(lat, lon, alt float64, t time.Time) (float64, float64), altRate float64) (float64, float64, float64) { + // Helper for RK4 integration step + toRad := math.Pi / 180.0 + toDeg := 180.0 / math.Pi + R := func(alt float64) float64 { return 6371009.0 + alt } + + f := func(lat, lon, alt float64, t time.Time) (float64, float64, float64) { + windU, windV := windFunc(lat, lon, alt, t) + Rnow := R(alt) + dlat := toDeg * windV / Rnow + dlon := toDeg * windU / (Rnow * math.Cos(lat*toRad)) + return dlat, dlon, altRate + } + + k1_lat, k1_lon, k1_alt := f(lat, lon, alt, t) + k2_lat, k2_lon, k2_alt := f(lat+0.5*k1_lat*dt, lon+0.5*k1_lon*dt, alt+0.5*k1_alt*dt, t.Add(time.Duration(0.5*dt)*time.Second)) + k3_lat, k3_lon, k3_alt := f(lat+0.5*k2_lat*dt, lon+0.5*k2_lon*dt, alt+0.5*k2_alt*dt, t.Add(time.Duration(0.5*dt)*time.Second)) + k4_lat, k4_lon, k4_alt := f(lat+k3_lat*dt, lon+k3_lon*dt, alt+k3_alt*dt, t.Add(time.Duration(dt)*time.Second)) + + latNew := lat + (dt/6.0)*(k1_lat+2*k2_lat+2*k3_lat+k4_lat) + lonNew := lon + (dt/6.0)*(k1_lon+2*k2_lon+2*k3_lon+k4_lon) + altNew := alt + (dt/6.0)*(k1_alt+2*k2_alt+2*k3_alt+k4_alt) + return latNew, lonNew, altNew +} + +func (s *Service) simulateAscent(ctx context.Context, params ds.PredictionParameters, ascentRate, targetAltitude float64, customCurve *CustomCurve) []ds.PredicitonResult { + const dt = 10.0 // simulation step in seconds + const outputInterval = 60.0 // output every 60 seconds + + lat := *params.LaunchLatitude + lon := *params.LaunchLongitude + alt := *params.LaunchAltitude + timeCur := *params.LaunchDatetime + + results := make([]ds.PredicitonResult, 0, 1000) + + latCopy := lat + lonCopy := lon + altCopy := alt + timeCopy := timeCur + wind := [2]float64{0, 0} + windU := wind[0] + windV := wind[1] + results = append(results, ds.PredicitonResult{ + Latitude: &latCopy, + Longitude: &lonCopy, + Altitude: &altCopy, + Timestamp: &timeCopy, + WindU: &windU, + WindV: &windV, + }) + + nextOutputTime := timeCur.Add(time.Duration(outputInterval) * time.Second) + windFunc := func(lat, lon, alt float64, t time.Time) (float64, float64) { + w, err := s.ExtractWind(ctx, lat, lon, alt, t) + if err != nil { + log.Ctx(ctx).Warn("Wind extraction failed during ascent", zap.Error(err)) + return 0, 0 + } + return w[0], w[1] + } + + for alt < targetAltitude { + altRate := ascentRate + if customCurve != nil { + altRate = s.getCustomAltitudeRate(customCurve, alt, ascentRate) + } + latNew, lonNew, altNew := rk4Step(lat, lon, alt, timeCur, dt, windFunc, altRate) + timeCur = timeCur.Add(time.Duration(dt) * time.Second) + lat = latNew + lon = lonNew + alt = altNew + + if alt >= targetAltitude { + break + } + + if !timeCur.Before(nextOutputTime) { + wU, wV := windFunc(lat, lon, alt, timeCur) + latCopy := lat + lonCopy := lon + altCopy := alt + timeCopy := timeCur + windU := wU + windV := wV + results = append(results, ds.PredicitonResult{ + Latitude: &latCopy, + Longitude: &lonCopy, + Altitude: &altCopy, + Timestamp: &timeCopy, + WindU: &windU, + WindV: &windV, + }) + nextOutputTime = nextOutputTime.Add(time.Duration(outputInterval) * time.Second) + } + } + + return results +} + +func (s *Service) simulateDescent(ctx context.Context, params ds.PredictionParameters, descentRate, targetAltitude float64, customCurve *CustomCurve) []ds.PredicitonResult { + const dt = 10.0 // simulation step in seconds + const outputInterval = 60.0 // output every 60 seconds + + lat := *params.LaunchLatitude + lon := *params.LaunchLongitude + alt := *params.LaunchAltitude + timeCur := *params.LaunchDatetime + + results := make([]ds.PredicitonResult, 0, 1000) + + latCopy := lat + lonCopy := lon + altCopy := alt + timeCopy := timeCur + wind := [2]float64{0, 0} + windU := wind[0] + windV := wind[1] + results = append(results, ds.PredicitonResult{ + Latitude: &latCopy, + Longitude: &lonCopy, + Altitude: &altCopy, + Timestamp: &timeCopy, + WindU: &windU, + WindV: &windV, + }) + + nextOutputTime := timeCur.Add(time.Duration(outputInterval) * time.Second) + windFunc := func(lat, lon, alt float64, t time.Time) (float64, float64) { + w, err := s.ExtractWind(ctx, lat, lon, alt, t) + if err != nil { + log.Ctx(ctx).Warn("Wind extraction failed during descent", zap.Error(err)) + return 0, 0 + } + return w[0], w[1] + } + + for alt > targetAltitude { + altRate := -descentRate + if customCurve != nil { + altRate = -s.getCustomAltitudeRate(customCurve, alt, descentRate) + } + latNew, lonNew, altNew := rk4Step(lat, lon, alt, timeCur, dt, windFunc, altRate) + timeCur = timeCur.Add(time.Duration(dt) * time.Second) + lat = latNew + lon = lonNew + alt = altNew + + if alt <= targetAltitude { + break + } + + if !timeCur.Before(nextOutputTime) { + wU, wV := windFunc(lat, lon, alt, timeCur) + latCopy := lat + lonCopy := lon + altCopy := alt + timeCopy := timeCur + windU := wU + windV := wV + results = append(results, ds.PredicitonResult{ + Latitude: &latCopy, + Longitude: &lonCopy, + Altitude: &altCopy, + Timestamp: &timeCopy, + WindU: &windU, + WindV: &windV, + }) + nextOutputTime = nextOutputTime.Add(time.Duration(outputInterval) * time.Second) + } + } + + return results +} + +func (s *Service) simulateFloat(ctx context.Context, startResult ds.PredicitonResult, duration time.Duration) []ds.PredicitonResult { + const dt = 10.0 // simulation step in seconds + const outputInterval = 60.0 // output every 60 seconds + + lat := *startResult.Latitude + lon := *startResult.Longitude + alt := *startResult.Altitude + timeCur := *startResult.Timestamp + endTime := timeCur.Add(duration) + + results := make([]ds.PredicitonResult, 0, 1000) + + // Always include the initial float point + latCopy := lat + lonCopy := lon + altCopy := alt + timeCopy := timeCur + wind := [2]float64{0, 0} + windU := wind[0] + windV := wind[1] + results = append(results, ds.PredicitonResult{ + Latitude: &latCopy, + Longitude: &lonCopy, + Altitude: &altCopy, + Timestamp: &timeCopy, + WindU: &windU, + WindV: &windV, + }) + + var nextOutputTime = timeCur.Add(time.Duration(outputInterval) * time.Second) + + for timeCur.Before(endTime) { + wind, err := s.ExtractWind(ctx, lat, lon, alt, timeCur) + if err != nil { + log.Ctx(ctx).Warn("Wind extraction failed during float", zap.Error(err)) + break + } + + latDot := (wind[1] / 111320.0) + lonDot := (wind[0] / (40075000.0 * math.Cos(lat*math.Pi/180) / 360.0)) + + lat += latDot * dt + lon += lonDot * dt + // alt remains constant during float + timeCur = timeCur.Add(time.Duration(dt) * time.Second) + + if !timeCur.Before(nextOutputTime) { + latCopy := lat + lonCopy := lon + altCopy := alt + timeCopy := timeCur + windU := wind[0] + windV := wind[1] + results = append(results, ds.PredicitonResult{ + Latitude: &latCopy, + Longitude: &lonCopy, + Altitude: &altCopy, + Timestamp: &timeCopy, + WindU: &windU, + WindV: &windV, + }) + nextOutputTime = nextOutputTime.Add(time.Duration(outputInterval) * time.Second) + } + } + + return results +} + +func (s *Service) simulateCustomAscent(ctx context.Context, params ds.PredictionParameters, curve *CustomCurve) []ds.PredicitonResult { + // Implementation for custom ascent curve + // This would interpolate the altitude rate from the custom curve + return s.simulateAscent(ctx, params, 5.0, 30000.0, curve) +} + +func (s *Service) simulateCustomDescent(ctx context.Context, params ds.PredictionParameters, curve *CustomCurve) []ds.PredicitonResult { + // Implementation for custom descent curve + // This would interpolate the altitude rate from the custom curve + return s.simulateDescent(ctx, params, 5.0, 0.0, curve) +} + +func (s *Service) getCustomAltitudeRate(curve *CustomCurve, currentAltitude, defaultRate float64) float64 { + if curve == nil || len(curve.Altitude) < 2 { + return defaultRate + } + + // Find the two points in the curve that bracket the current altitude + for i := 0; i < len(curve.Altitude)-1; i++ { + if curve.Altitude[i] <= currentAltitude && currentAltitude <= curve.Altitude[i+1] { + // Linear interpolation + alt1, alt2 := curve.Altitude[i], curve.Altitude[i+1] + time1, time2 := curve.Time[i], curve.Time[i+1] + + if alt2 == alt1 { + return defaultRate + } + + // Calculate rate (change in altitude per second) + if time2 > time1 { + return (alt2 - alt1) / (time2 - time1) + } + return defaultRate + } + } + + return defaultRate +} + +func parseCustomCurve(base64Data string) (*CustomCurve, error) { + data, err := base64.StdEncoding.DecodeString(base64Data) + if err != nil { + return nil, err + } + + var curve CustomCurve + if err := json.Unmarshal(data, &curve); err != nil { + return nil, err + } + + return &curve, nil +} diff --git a/internal/service/service.go b/internal/service/service.go index 4ccd1d4..3496f96 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -2,244 +2,59 @@ package service import ( "context" - "os" - "path/filepath" - "sort" - "strings" - "sync" "time" - "predictor-refactored/internal/dataset" - "predictor-refactored/internal/downloader" - "predictor-refactored/internal/elevation" - - "go.uber.org/zap" + "git.intra.yksa.space/gsn/predictor/internal/pkg/log" ) -// Service orchestrates the dataset lifecycle and provides prediction capabilities. type Service struct { - mu sync.RWMutex - ds *dataset.File - elev *elevation.Dataset - cfg *downloader.Config - dl *downloader.Downloader - log *zap.Logger - updating sync.Mutex // prevents concurrent downloads + grib Grib } -// New creates a new Service. -func New(cfg *downloader.Config, log *zap.Logger) *Service { - return &Service{ - cfg: cfg, - dl: downloader.NewDownloader(cfg, log), - log: log, +func New(gribService Grib) (*Service, error) { + svc := &Service{ + grib: gribService, } + + return svc, nil } -// LoadElevation loads the ruaumoko-compatible elevation dataset from path. -// If the file doesn't exist, elevation termination is disabled (falls back to sea level). -func (s *Service) LoadElevation(path string) { - ds, err := elevation.Open(path) - if err != nil { - s.log.Warn("elevation dataset not available, using sea-level termination", - zap.String("path", path), zap.Error(err)) - return - } - s.elev = ds - s.log.Info("elevation dataset loaded", zap.String("path", path)) +// UpdateWeatherData updates weather forecast data using the configured grib service +func (s *Service) UpdateWeatherData(ctx context.Context) error { + return s.grib.Update(ctx) } -// Elevation returns the elevation dataset (may be nil). -func (s *Service) Elevation() *elevation.Dataset { - return s.elev +// ExtractWind extracts wind data for given coordinates and time +func (s *Service) ExtractWind(ctx context.Context, lat, lon, alt float64, ts time.Time) ([2]float64, error) { + return s.grib.Extract(ctx, lat, lon, alt, ts) } -// Ready returns true if the service has a loaded dataset. -func (s *Service) Ready() bool { - s.mu.RLock() - defer s.mu.RUnlock() - return s.ds != nil -} - -// DatasetTime returns the forecast time of the currently loaded dataset. -func (s *Service) DatasetTime() (time.Time, bool) { - s.mu.RLock() - defer s.mu.RUnlock() - if s.ds == nil { - return time.Time{}, false - } - return s.ds.DSTime, true -} - -// Dataset returns the current dataset for reading. -func (s *Service) Dataset() *dataset.File { - s.mu.RLock() - defer s.mu.RUnlock() - return s.ds -} - -// Update checks for and downloads new forecast data if needed. +// Update updates the GRIB data (implements updater.GribService) func (s *Service) Update(ctx context.Context) error { - if !s.updating.TryLock() { - s.log.Info("update already in progress, skipping") - return nil - } - defer s.updating.Unlock() - - // Check if current dataset is still fresh - if dsTime, ok := s.DatasetTime(); ok { - if time.Since(dsTime) < s.cfg.DatasetTTL { - s.log.Info("dataset still fresh, skipping update", - zap.Time("dataset_time", dsTime), - zap.Duration("age", time.Since(dsTime))) - return nil - } - } - - // Try loading an existing dataset from disk first - if err := s.loadExistingDataset(); err == nil { - return nil - } - - // Find latest available model run - run, err := s.dl.FindLatestRun(ctx) - if err != nil { - return err - } - - // Download and assemble - path, err := s.dl.Download(ctx, run) - if err != nil { - return err - } - - // Open the new dataset - ds, err := dataset.Open(path, run) - if err != nil { - return err - } - - // Swap in the new dataset - s.setDataset(ds) - s.log.Info("dataset loaded", zap.Time("run", run), zap.String("path", path)) - - // Clean old datasets - s.cleanOldDatasets(path) - - return nil + return s.UpdateWeatherData(ctx) } -// loadExistingDataset tries to find and load an existing dataset from the data directory. -func (s *Service) loadExistingDataset() error { - entries, err := os.ReadDir(s.cfg.DataDir) - if err != nil { - return err - } - - // Collect valid dataset files (name is YYYYMMDDHH, no extension, correct size) - type candidate struct { - name string - path string - run time.Time - } - var candidates []candidate - - for _, e := range entries { - if e.IsDir() || strings.Contains(e.Name(), ".") { - continue - } - if len(e.Name()) != 10 { - continue - } - - run, err := time.Parse("2006010215", e.Name()) - if err != nil { - continue - } - - path := filepath.Join(s.cfg.DataDir, e.Name()) - info, err := os.Stat(path) - if err != nil || info.Size() != dataset.DatasetSize { - continue - } - - if time.Since(run) > s.cfg.DatasetTTL { - continue - } - - candidates = append(candidates, candidate{name: e.Name(), path: path, run: run}) - } - - if len(candidates) == 0 { - return os.ErrNotExist - } - - // Pick the newest - sort.Slice(candidates, func(i, j int) bool { - return candidates[i].run.After(candidates[j].run) - }) - - best := candidates[0] - ds, err := dataset.Open(best.path, best.run) - if err != nil { - return err - } - - s.setDataset(ds) - s.log.Info("loaded existing dataset", - zap.Time("run", best.run), - zap.String("path", best.path)) - return nil +// Start starts the service +func (s *Service) Start() { + log.Ctx(context.Background()).Info("service started") } -// setDataset swaps the current dataset with a new one, closing the old one. -func (s *Service) setDataset(ds *dataset.File) { - s.mu.Lock() - old := s.ds - s.ds = ds - s.mu.Unlock() - - if old != nil { - if err := old.Close(); err != nil { - s.log.Error("failed to close old dataset", zap.Error(err)) - } - } +// Stop stops the service +func (s *Service) Stop() { + log.Ctx(context.Background()).Info("service stopped") } -// cleanOldDatasets removes dataset files other than the one at keepPath. -func (s *Service) cleanOldDatasets(keepPath string) { - entries, err := os.ReadDir(s.cfg.DataDir) - if err != nil { - return - } - - for _, e := range entries { - if e.IsDir() { - continue - } - path := filepath.Join(s.cfg.DataDir, e.Name()) - if path == keepPath { - continue - } - // Remove old datasets and temp files - if len(e.Name()) == 10 || strings.HasSuffix(e.Name(), ".downloading") { - if err := os.Remove(path); err != nil { - s.log.Warn("failed to remove old file", zap.String("path", path), zap.Error(err)) - } else { - s.log.Info("removed old dataset", zap.String("path", path)) - } - } - } -} - -// Close releases all resources. +// Close closes the service and releases resources func (s *Service) Close() error { - s.mu.Lock() - defer s.mu.Unlock() - if s.ds != nil { - err := s.ds.Close() - s.ds = nil - return err - } + s.Stop() return nil } + +func (s *Service) GetGribStatus(ctx context.Context) (ready bool, lastUpdate time.Time, isFresh bool, errMsg string) { + if gribStatus, ok := s.grib.(interface { + GetStatus() (ready bool, lastUpdate time.Time, isFresh bool, errMsg string) + }); ok { + return gribStatus.GetStatus() + } + return false, time.Time{}, false, "grib service does not implement GetStatus" +} diff --git a/internal/transport/middleware/log.go b/internal/transport/middleware/log.go index fbbbbc1..ebf2003 100644 --- a/internal/transport/middleware/log.go +++ b/internal/transport/middleware/log.go @@ -3,28 +3,42 @@ package middleware import ( "time" + "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" + "git.intra.yksa.space/gsn/predictor/internal/pkg/log" "github.com/ogen-go/ogen/middleware" "go.uber.org/zap" ) -// Logging returns an ogen middleware that logs request duration. -func Logging(log *zap.Logger) middleware.Middleware { +func Logging() middleware.Middleware { return func(req middleware.Request, next func(req middleware.Request) (middleware.Response, error)) (middleware.Response, error) { - lg := log.With(zap.String("operation", req.OperationID)) + lg := log.Ctx(req.Context).With( + zap.String("operationId", req.OperationID), + ) + + lg.Info("started request") + + req.Context = log.ToCtx(req.Context, lg) start := time.Now() resp, err := next(req) - dur := time.Since(start) + dur := time.Since(start).Microseconds() if err != nil { - lg.Error("request failed", - zap.Duration("duration", dur), - zap.Error(err)) - } else { - lg.Info("request completed", - zap.Duration("duration", dur)) + if errcode, ok := err.(*errcodes.ErrorCode); ok { + lg.Error("request error", + zap.Int("status_code", errcode.StatusCode), + zap.String("message", errcode.Message), + zap.String("details", errcode.Details), + ) + } else { + lg.Error("request internal error", + zap.Error(err), + ) + } } + lg.Info("done request", zap.Float64("duration_ms", float64(dur)/float64(1000))) + return resp, err } } diff --git a/internal/transport/rest/config.go b/internal/transport/rest/config.go new file mode 100644 index 0000000..1b55959 --- /dev/null +++ b/internal/transport/rest/config.go @@ -0,0 +1,24 @@ +package rest + +import ( + "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" + env "github.com/caarlos0/env/v11" +) + +type Config struct { + Host string `env:"HOST" envDefault:"0.0.0.0"` + Port int `env:"PORT" envDefault:"8080"` + ReadTimeout string `env:"READ_TIMEOUT" envDefault:"30s"` + WriteTimeout string `env:"WRITE_TIMEOUT" envDefault:"30s"` + IdleTimeout string `env:"IDLE_TIMEOUT" envDefault:"60s"` +} + +func NewConfig() (*Config, error) { + cfg := &Config{} + if err := env.ParseWithOptions(cfg, env.Options{ + PrefixTagName: "GSN_PREDICTOR_REST_", + }); err != nil { + return nil, errcodes.Wrap(err, "failed to parse REST config") + } + return cfg, nil +} diff --git a/internal/transport/rest/handler/deps.go b/internal/transport/rest/handler/deps.go index f81a3b8..2d930cb 100644 --- a/internal/transport/rest/handler/deps.go +++ b/internal/transport/rest/handler/deps.go @@ -1,16 +1,14 @@ package handler import ( + "context" "time" - "predictor-refactored/internal/dataset" - "predictor-refactored/internal/elevation" + "git.intra.yksa.space/gsn/predictor/internal/pkg/ds" ) -// Service defines the interface the handler needs from the service layer. type Service interface { - Ready() bool - DatasetTime() (time.Time, bool) - Dataset() *dataset.File - Elevation() *elevation.Dataset + UpdateWeatherData(ctx context.Context) error + ExtractWind(ctx context.Context, lat, lon, alt float64, ts time.Time) ([2]float64, error) + PerformPrediction(ctx context.Context, params ds.PredictionParameters) ([]ds.PredicitonResult, error) } diff --git a/internal/transport/rest/handler/handler.go b/internal/transport/rest/handler/handler.go index fc1f693..5b7141d 100644 --- a/internal/transport/rest/handler/handler.go +++ b/internal/transport/rest/handler/handler.go @@ -5,212 +5,190 @@ import ( "net/http" "time" - "predictor-refactored/internal/prediction" - api "predictor-refactored/pkg/rest" - - "go.uber.org/zap" + "git.intra.yksa.space/gsn/predictor/internal/pkg/ds" + "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" + api "git.intra.yksa.space/gsn/predictor/pkg/rest" ) -var _ api.Handler = (*Handler)(nil) +var ( + _ api.Handler = (*Handler)(nil) +) -// Handler implements the ogen-generated api.Handler interface. type Handler struct { svc Service - log *zap.Logger } -// New creates a new Handler. -func New(svc Service, log *zap.Logger) *Handler { - return &Handler{svc: svc, log: log} +func New(svc Service) *Handler { + return &Handler{ + svc: svc, + } } -// PerformPrediction implements the prediction endpoint. -func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredictionParams) (*api.PredictionResponse, error) { - if !h.svc.Ready() { - return nil, newError(http.StatusServiceUnavailable, "no dataset loaded, service is starting up") +func (h *Handler) PerformPrediction(ctx context.Context, params api.PerformPredictionParams) (*api.PredictionResult, error) { + internalParams := ds.ConvertFlatPredictionParams(params) + if internalParams == nil { + return nil, errcodes.New(http.StatusBadRequest, "invalid or missing parameters") + } + results, err := h.svc.PerformPrediction(ctx, *internalParams) + if err != nil { + return nil, err + } + if len(results) == 0 { + return nil, errcodes.New(http.StatusInternalServerError, "no prediction results") } - ds := h.svc.Dataset() - if ds == nil { - return nil, newError(http.StatusServiceUnavailable, "dataset unavailable") - } + // Group results into stages (ascent and descent) + stages := h.groupResultsIntoStages(results) - dsEpoch := float64(ds.DSTime.Unix()) + // Map to OpenAPI schema + var predictionItems []api.PredictionResultPredictionItem - // Parse parameters with defaults - profile := "standard_profile" - if p, ok := params.Profile.Get(); ok { - profile = string(p) - } + for _, stage := range stages { + var trajectory []api.PredictionResultPredictionItemTrajectoryItem - ascentRate := 5.0 - if v, ok := params.AscentRate.Get(); ok { - ascentRate = v - } - - burstAltitude := 28000.0 - if v, ok := params.BurstAltitude.Get(); ok { - burstAltitude = v - } - - descentRate := 5.0 - if v, ok := params.DescentRate.Get(); ok { - descentRate = v - } - - launchAlt := 0.0 - if v, ok := params.LaunchAltitude.Get(); ok { - launchAlt = v - } - - // Normalize longitude to [0, 360) - lng := params.LaunchLongitude - if lng < 0 { - lng += 360.0 - } - - launchTime := float64(params.LaunchDatetime.Unix()) - - warnings := &prediction.Warnings{} - - // Build profile chain - elev := h.svc.Elevation() - var stages []prediction.Stage - switch profile { - case "standard_profile": - stages = prediction.StandardProfile( - ascentRate, burstAltitude, descentRate, - ds, dsEpoch, warnings, elev) - case "float_profile": - floatAlt := 25000.0 - if v, ok := params.FloatAltitude.Get(); ok { - floatAlt = v - } - stopTime := params.LaunchDatetime.Add(24 * time.Hour) - if v, ok := params.StopDatetime.Get(); ok { - stopTime = v - } - stages = prediction.FloatProfile( - ascentRate, floatAlt, stopTime, - ds, dsEpoch, warnings) - default: - return nil, newError(http.StatusBadRequest, "unknown profile: "+profile) - } - - // Run prediction - startTime := time.Now().UTC() - results := prediction.RunPrediction(launchTime, params.LaunchLatitude, lng, launchAlt, stages) - completeTime := time.Now().UTC() - - // Build response - stageNames := []string{"ascent", "descent"} - if profile == "float_profile" { - stageNames = []string{"ascent", "float"} - } - - var predItems []api.PredictionResponsePredictionItem - for i, sr := range results { - stageName := "ascent" - if i < len(stageNames) { - stageName = stageNames[i] - } - - var stageEnum api.PredictionResponsePredictionItemStage - switch stageName { - case "ascent": - stageEnum = api.PredictionResponsePredictionItemStageAscent - case "descent": - stageEnum = api.PredictionResponsePredictionItemStageDescent - case "float": - stageEnum = api.PredictionResponsePredictionItemStageFloat - } - - var traj []api.PredictionResponsePredictionItemTrajectoryItem - for _, pt := range sr.Points { - ptLng := pt.Lng - if ptLng > 180 { - ptLng -= 360 + for _, result := range stage.Results { + traj := api.PredictionResultPredictionItemTrajectoryItem{ + Datetime: *result.Timestamp, + Latitude: *result.Latitude, + Longitude: *result.Longitude, + Altitude: *result.Altitude, } - traj = append(traj, api.PredictionResponsePredictionItemTrajectoryItem{ - Datetime: time.Unix(int64(pt.T), 0).UTC(), - Latitude: pt.Lat, - Longitude: ptLng, - Altitude: pt.Alt, - }) + trajectory = append(trajectory, traj) } - predItems = append(predItems, api.PredictionResponsePredictionItem{ - Stage: stageEnum, - Trajectory: traj, + item := api.PredictionResultPredictionItem{ + Stage: stage.Stage, + Trajectory: trajectory, + } + predictionItems = append(predictionItems, item) + } + + metadata := api.PredictionResultMetadata{ + StartDatetime: *results[0].Timestamp, + CompleteDatetime: *results[len(results)-1].Timestamp, + } + + resp := &api.PredictionResult{ + Metadata: metadata, + Prediction: predictionItems, + } + return resp, nil +} + +// StageResult represents a stage with its results +type StageResult struct { + Stage api.PredictionResultPredictionItemStage + Results []ds.PredicitonResult +} + +// groupResultsIntoStages groups the prediction results into ascent and descent stages +func (h *Handler) groupResultsIntoStages(results []ds.PredicitonResult) []StageResult { + if len(results) == 0 { + return nil + } + + var stages []StageResult + var currentStage []ds.PredicitonResult + var currentStageType api.PredictionResultPredictionItemStage + + // Determine if we're in ascent or descent based on altitude changes + prevAlt := *results[0].Altitude + currentStage = append(currentStage, results[0]) + currentStageType = api.PredictionResultPredictionItemStageAscent + + for i := 1; i < len(results); i++ { + result := results[i] + currentAlt := *result.Altitude + + // Determine if we're still in the same stage + var stageType api.PredictionResultPredictionItemStage + if currentAlt > prevAlt { + stageType = api.PredictionResultPredictionItemStageAscent + } else if currentAlt < prevAlt { + stageType = api.PredictionResultPredictionItemStageDescent + } else { + // Same altitude - continue with current stage + stageType = currentStageType + } + + // If stage type changed, finalize current stage and start new one + if stageType != currentStageType && len(currentStage) > 0 { + stages = append(stages, StageResult{ + Stage: currentStageType, + Results: currentStage, + }) + currentStage = nil + currentStageType = stageType + } + + currentStage = append(currentStage, result) + prevAlt = currentAlt + } + + // Add the final stage + if len(currentStage) > 0 { + stages = append(stages, StageResult{ + Stage: currentStageType, + Results: currentStage, }) } - resp := &api.PredictionResponse{ - Prediction: predItems, - Metadata: api.PredictionResponseMetadata{ - StartDatetime: startTime, - CompleteDatetime: completeTime, - }, - } - - // Echo request - resp.Request = api.NewOptPredictionResponseRequest(api.PredictionResponseRequest{ - Dataset: api.NewOptString(ds.DSTime.Format("2006-01-02T15:04:05Z")), - LaunchLatitude: api.NewOptFloat64(params.LaunchLatitude), - LaunchLongitude: api.NewOptFloat64(params.LaunchLongitude), - LaunchDatetime: api.NewOptString(params.LaunchDatetime.Format(time.RFC3339)), - LaunchAltitude: params.LaunchAltitude, - }) - - // Warnings - warnMap := warnings.ToMap() - if len(warnMap) > 0 { - resp.Warnings = api.NewOptPredictionResponseWarnings(api.PredictionResponseWarnings{}) - } - - h.log.Info("prediction complete", - zap.String("profile", profile), - zap.Int("stages", len(results)), - zap.Duration("elapsed", completeTime.Sub(startTime))) - - return resp, nil + return stages } -// ReadinessCheck implements the health check endpoint. -func (h *Handler) ReadinessCheck(ctx context.Context) (*api.ReadinessResponse, error) { - resp := &api.ReadinessResponse{} +func (h *Handler) NewError(ctx context.Context, err error) *api.ErrorStatusCode { + if errcode, ok := err.(*errcodes.ErrorCode); ok { + resp := api.Error{ + Message: errcode.Message, + } - if h.svc.Ready() { - resp.Status = api.ReadinessResponseStatusOk - if dsTime, ok := h.svc.DatasetTime(); ok { - resp.DatasetTime = api.NewOptDateTime(dsTime) + if errcode.Details != "" { + resp.Details = api.NewOptString(errcode.Details) + } + + return &api.ErrorStatusCode{ + StatusCode: errcode.StatusCode, + Response: resp, + } + } + + return &api.ErrorStatusCode{ + StatusCode: http.StatusInternalServerError, + Response: api.Error{ + Message: "undefined internal error", + Details: api.NewOptString(err.Error()), + }, + } +} + +func (h *Handler) ReadinessCheck(ctx context.Context) (*api.ReadinessResponse, error) { + status := api.ReadinessResponseStatusNotReady + var lastUpdate time.Time + var isFresh bool + var errMsg string + + if s, ok := h.svc.(interface { + GetGribStatus(ctx context.Context) (ready bool, lastUpdate time.Time, isFresh bool, errMsg string) + }); ok { + ready, lu, fresh, em := s.GetGribStatus(ctx) + lastUpdate = lu + isFresh = fresh + errMsg = em + if ready { + status = api.ReadinessResponseStatusOk + } else if em != "" { + status = api.ReadinessResponseStatusError } } else { - resp.Status = api.ReadinessResponseStatusNotReady - resp.ErrorMessage = api.NewOptString("no dataset loaded") + errMsg = "service does not implement GetGribStatus" + status = api.ReadinessResponseStatusError } + resp := &api.ReadinessResponse{ + Status: status, + IsFresh: api.NewOptBool(isFresh), + LastUpdate: api.NewOptDateTime(lastUpdate), + ErrorMessage: api.NewOptString(errMsg), + } return resp, nil } - -// NewError creates an ErrorStatusCode from an error returned by a handler. -func (h *Handler) NewError(ctx context.Context, err error) *api.ErrorStatusCode { - if statusErr, ok := err.(*api.ErrorStatusCode); ok { - return statusErr - } - - h.log.Error("unhandled error", zap.Error(err)) - return newError(http.StatusInternalServerError, err.Error()) -} - -func newError(status int, description string) *api.ErrorStatusCode { - return &api.ErrorStatusCode{ - StatusCode: status, - Response: api.Error{ - Error: api.ErrorError{ - Type: http.StatusText(status), - Description: description, - }, - }, - } -} diff --git a/internal/transport/rest/transport.go b/internal/transport/rest/transport.go index 3744270..c89cf8f 100644 --- a/internal/transport/rest/transport.go +++ b/internal/transport/rest/transport.go @@ -5,71 +5,43 @@ import ( "fmt" "net/http" - "predictor-refactored/internal/transport/middleware" - "predictor-refactored/internal/transport/rest/handler" - api "predictor-refactored/pkg/rest" - - "go.uber.org/zap" + "git.intra.yksa.space/gsn/predictor/internal/pkg/log" + "git.intra.yksa.space/gsn/predictor/internal/transport/middleware" + handler "git.intra.yksa.space/gsn/predictor/internal/transport/rest/handler" + api "git.intra.yksa.space/gsn/predictor/pkg/rest" + "github.com/rs/cors" ) -// Transport wraps the ogen HTTP server. type Transport struct { + cfg *Config srv *api.Server handler *handler.Handler - port int - log *zap.Logger } -// New creates a new REST transport. -func New(h *handler.Handler, port int, log *zap.Logger) (*Transport, error) { +func New(handler *handler.Handler, cfg *Config) (*Transport, error) { srv, err := api.NewServer( - h, - api.WithMiddleware(middleware.Logging(log)), + handler, + api.WithMiddleware(middleware.Logging()), ) if err != nil { - return nil, fmt.Errorf("create ogen server: %w", err) + return nil, err } return &Transport{ srv: srv, - handler: h, - port: port, - log: log, + cfg: cfg, + handler: handler, }, nil } -// Run starts the HTTP server. Blocks until the server stops. -func (t *Transport) Run() error { +func (t *Transport) Run() { + log.Ctx(context.Background()).Info("started") + mux := http.NewServeMux() mux.Handle("/", t.srv) + cors.AllowAll().Handler(mux) - httpSrv := &http.Server{ - Addr: fmt.Sprintf(":%d", t.port), - Handler: corsMiddleware(mux), + if err := http.ListenAndServe(fmt.Sprintf(":%d", t.cfg.Port), t.srv); err != nil { + panic(err) } - - t.log.Info("starting HTTP server", zap.Int("port", t.port)) - return httpSrv.ListenAndServe() -} - -// Shutdown gracefully stops the HTTP server. -func (t *Transport) Shutdown(ctx context.Context) error { - // The ogen server doesn't have a shutdown method; - // shutdown is handled by the http.Server in main.go - return nil -} - -func corsMiddleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") - w.Header().Set("Access-Control-Allow-Headers", "Content-Type") - - if r.Method == http.MethodOptions { - w.WriteHeader(http.StatusNoContent) - return - } - - next.ServeHTTP(w, r) - }) } diff --git a/pkg/rest/oas_cfg_gen.go b/pkg/rest/oas_cfg_gen.go index 9d5a059..1845c4f 100644 --- a/pkg/rest/oas_cfg_gen.go +++ b/pkg/rest/oas_cfg_gen.go @@ -1,19 +1,18 @@ // Code generated by ogen, DO NOT EDIT. -package rest +package gsn import ( "net/http" - "strings" + + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/metric" + "go.opentelemetry.io/otel/trace" ht "github.com/ogen-go/ogen/http" "github.com/ogen-go/ogen/middleware" "github.com/ogen-go/ogen/ogenerrors" "github.com/ogen-go/ogen/otelogen" - "go.opentelemetry.io/otel" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/metric" - "go.opentelemetry.io/otel/trace" ) var ( @@ -33,7 +32,6 @@ type otelConfig struct { Tracer trace.Tracer MeterProvider metric.MeterProvider Meter metric.Meter - Attributes []attribute.KeyValue } func (cfg *otelConfig) initOTEL() { @@ -83,8 +81,18 @@ func (o otelOptionFunc) applyServer(c *serverConfig) { func newServerConfig(opts ...ServerOption) serverConfig { cfg := serverConfig{ - NotFound: http.NotFound, - MethodNotAllowed: nil, + NotFound: http.NotFound, + MethodNotAllowed: func(w http.ResponseWriter, r *http.Request, allowed string) { + status := http.StatusMethodNotAllowed + if r.Method == "OPTIONS" { + w.Header().Set("Access-Control-Allow-Methods", allowed) + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + status = http.StatusNoContent + } else { + w.Header().Set("Allow", allowed) + } + w.WriteHeader(status) + }, ErrorHandler: ogenerrors.DefaultErrorHandler, Middleware: nil, MaxMultipartMemory: 32 << 20, // 32 MB @@ -107,44 +115,8 @@ func (s baseServer) notFound(w http.ResponseWriter, r *http.Request) { s.cfg.NotFound(w, r) } -type notAllowedParams struct { - allowedMethods string - allowedHeaders map[string]string - acceptPost string - acceptPatch string -} - -func (s baseServer) notAllowed(w http.ResponseWriter, r *http.Request, params notAllowedParams) { - h := w.Header() - isOptions := r.Method == "OPTIONS" - if isOptions { - h.Set("Access-Control-Allow-Methods", params.allowedMethods) - if params.allowedHeaders != nil { - m := r.Header.Get("Access-Control-Request-Method") - if m != "" { - allowedHeaders, ok := params.allowedHeaders[strings.ToUpper(m)] - if ok { - h.Set("Access-Control-Allow-Headers", allowedHeaders) - } - } - } - if params.acceptPost != "" { - h.Set("Accept-Post", params.acceptPost) - } - if params.acceptPatch != "" { - h.Set("Accept-Patch", params.acceptPatch) - } - } - if s.cfg.MethodNotAllowed != nil { - s.cfg.MethodNotAllowed(w, r, params.allowedMethods) - return - } - status := http.StatusNoContent - if !isOptions { - h.Set("Allow", params.allowedMethods) - status = http.StatusMethodNotAllowed - } - w.WriteHeader(status) +func (s baseServer) notAllowed(w http.ResponseWriter, r *http.Request, allowed string) { + s.cfg.MethodNotAllowed(w, r, allowed) } func (cfg serverConfig) baseServer() (s baseServer, err error) { @@ -243,13 +215,6 @@ func WithMeterProvider(provider metric.MeterProvider) Option { }) } -// WithAttributes specifies default otel attributes. -func WithAttributes(attributes ...attribute.KeyValue) Option { - return otelOptionFunc(func(cfg *otelConfig) { - cfg.Attributes = attributes - }) -} - // WithClient specifies http client to use. func WithClient(client ht.Client) ClientOption { return optionFunc[clientConfig](func(cfg *clientConfig) { diff --git a/pkg/rest/oas_client_gen.go b/pkg/rest/oas_client_gen.go index 02e95c3..19822e8 100644 --- a/pkg/rest/oas_client_gen.go +++ b/pkg/rest/oas_client_gen.go @@ -1,6 +1,6 @@ // Code generated by ogen, DO NOT EDIT. -package rest +package gsn import ( "context" @@ -9,15 +9,16 @@ import ( "time" "github.com/go-faster/errors" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/metric" + semconv "go.opentelemetry.io/otel/semconv/v1.26.0" + "go.opentelemetry.io/otel/trace" + "github.com/ogen-go/ogen/conv" ht "github.com/ogen-go/ogen/http" "github.com/ogen-go/ogen/otelogen" "github.com/ogen-go/ogen/uri" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/codes" - "go.opentelemetry.io/otel/metric" - semconv "go.opentelemetry.io/otel/semconv/v1.39.0" - "go.opentelemetry.io/otel/trace" ) func trimTrailingSlashes(u *url.URL) { @@ -32,7 +33,7 @@ type Invoker interface { // Perform prediction. // // GET /api/v1/prediction - PerformPrediction(ctx context.Context, params PerformPredictionParams) (*PredictionResponse, error) + PerformPrediction(ctx context.Context, params PerformPredictionParams) (*PredictionResult, error) // ReadinessCheck invokes readinessCheck operation. // // Readiness check. @@ -46,6 +47,14 @@ type Client struct { serverURL *url.URL baseClient } +type errorHandler interface { + NewError(ctx context.Context, err error) *ErrorStatusCode +} + +var _ Handler = struct { + errorHandler + *Client +}{} // NewClient initializes new Client defined by OAS. func NewClient(serverURL string, opts ...ClientOption) (*Client, error) { @@ -85,18 +94,17 @@ func (c *Client) requestURL(ctx context.Context) *url.URL { // Perform prediction. // // GET /api/v1/prediction -func (c *Client) PerformPrediction(ctx context.Context, params PerformPredictionParams) (*PredictionResponse, error) { +func (c *Client) PerformPrediction(ctx context.Context, params PerformPredictionParams) (*PredictionResult, error) { res, err := c.sendPerformPrediction(ctx, params) return res, err } -func (c *Client) sendPerformPrediction(ctx context.Context, params PerformPredictionParams) (res *PredictionResponse, err error) { +func (c *Client) sendPerformPrediction(ctx context.Context, params PerformPredictionParams) (res *PredictionResult, err error) { otelAttrs := []attribute.KeyValue{ otelogen.OperationID("performPrediction"), semconv.HTTPRequestMethodKey.String("GET"), - semconv.URLTemplateKey.String("/api/v1/prediction"), + semconv.HTTPRouteKey.String("/api/v1/prediction"), } - otelAttrs = append(otelAttrs, c.cfg.Attributes...) // Run stopwatch. startTime := time.Now() @@ -142,7 +150,10 @@ func (c *Client) sendPerformPrediction(ctx context.Context, params PerformPredic } if err := q.EncodeParam(cfg, func(e uri.Encoder) error { - return e.EncodeValue(conv.Float64ToString(params.LaunchLatitude)) + if val, ok := params.LaunchLatitude.Get(); ok { + return e.EncodeValue(conv.Float64ToString(val)) + } + return nil }); err != nil { return res, errors.Wrap(err, "encode query") } @@ -156,7 +167,10 @@ func (c *Client) sendPerformPrediction(ctx context.Context, params PerformPredic } if err := q.EncodeParam(cfg, func(e uri.Encoder) error { - return e.EncodeValue(conv.Float64ToString(params.LaunchLongitude)) + if val, ok := params.LaunchLongitude.Get(); ok { + return e.EncodeValue(conv.Float64ToString(val)) + } + return nil }); err != nil { return res, errors.Wrap(err, "encode query") } @@ -170,7 +184,10 @@ func (c *Client) sendPerformPrediction(ctx context.Context, params PerformPredic } if err := q.EncodeParam(cfg, func(e uri.Encoder) error { - return e.EncodeValue(conv.DateTimeToString(params.LaunchDatetime)) + if val, ok := params.LaunchDatetime.Get(); ok { + return e.EncodeValue(conv.DateTimeToString(val)) + } + return nil }); err != nil { return res, errors.Wrap(err, "encode query") } @@ -294,6 +311,74 @@ func (c *Client) sendPerformPrediction(ctx context.Context, params PerformPredic return res, errors.Wrap(err, "encode query") } } + { + // Encode "ascent_curve" parameter. + cfg := uri.QueryParameterEncodingConfig{ + Name: "ascent_curve", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.EncodeParam(cfg, func(e uri.Encoder) error { + if val, ok := params.AscentCurve.Get(); ok { + return e.EncodeValue(conv.StringToString(val)) + } + return nil + }); err != nil { + return res, errors.Wrap(err, "encode query") + } + } + { + // Encode "descent_curve" parameter. + cfg := uri.QueryParameterEncodingConfig{ + Name: "descent_curve", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.EncodeParam(cfg, func(e uri.Encoder) error { + if val, ok := params.DescentCurve.Get(); ok { + return e.EncodeValue(conv.StringToString(val)) + } + return nil + }); err != nil { + return res, errors.Wrap(err, "encode query") + } + } + { + // Encode "interpolate" parameter. + cfg := uri.QueryParameterEncodingConfig{ + Name: "interpolate", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.EncodeParam(cfg, func(e uri.Encoder) error { + if val, ok := params.Interpolate.Get(); ok { + return e.EncodeValue(conv.BoolToString(val)) + } + return nil + }); err != nil { + return res, errors.Wrap(err, "encode query") + } + } + { + // Encode "format" parameter. + cfg := uri.QueryParameterEncodingConfig{ + Name: "format", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.EncodeParam(cfg, func(e uri.Encoder) error { + if val, ok := params.Format.Get(); ok { + return e.EncodeValue(conv.StringToString(string(val))) + } + return nil + }); err != nil { + return res, errors.Wrap(err, "encode query") + } + } { // Encode "dataset" parameter. cfg := uri.QueryParameterEncodingConfig{ @@ -324,8 +409,7 @@ func (c *Client) sendPerformPrediction(ctx context.Context, params PerformPredic if err != nil { return res, errors.Wrap(err, "do request") } - body := resp.Body - defer body.Close() + defer resp.Body.Close() stage = "DecodeResponse" result, err := decodePerformPredictionResponse(resp) @@ -350,9 +434,8 @@ func (c *Client) sendReadinessCheck(ctx context.Context) (res *ReadinessResponse otelAttrs := []attribute.KeyValue{ otelogen.OperationID("readinessCheck"), semconv.HTTPRequestMethodKey.String("GET"), - semconv.URLTemplateKey.String("/ready"), + semconv.HTTPRouteKey.String("/ready"), } - otelAttrs = append(otelAttrs, c.cfg.Attributes...) // Run stopwatch. startTime := time.Now() @@ -398,8 +481,7 @@ func (c *Client) sendReadinessCheck(ctx context.Context) (res *ReadinessResponse if err != nil { return res, errors.Wrap(err, "do request") } - body := resp.Body - defer body.Close() + defer resp.Body.Close() stage = "DecodeResponse" result, err := decodeReadinessCheckResponse(resp) diff --git a/pkg/rest/oas_handlers_gen.go b/pkg/rest/oas_handlers_gen.go index d41771a..76ab8ad 100644 --- a/pkg/rest/oas_handlers_gen.go +++ b/pkg/rest/oas_handlers_gen.go @@ -1,6 +1,6 @@ // Code generated by ogen, DO NOT EDIT. -package rest +package gsn import ( "context" @@ -8,15 +8,16 @@ import ( "time" "github.com/go-faster/errors" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/codes" + "go.opentelemetry.io/otel/metric" + semconv "go.opentelemetry.io/otel/semconv/v1.26.0" + "go.opentelemetry.io/otel/trace" + ht "github.com/ogen-go/ogen/http" "github.com/ogen-go/ogen/middleware" "github.com/ogen-go/ogen/ogenerrors" "github.com/ogen-go/ogen/otelogen" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/codes" - "go.opentelemetry.io/otel/metric" - semconv "go.opentelemetry.io/otel/semconv/v1.39.0" - "go.opentelemetry.io/otel/trace" ) type codeRecorder struct { @@ -29,10 +30,6 @@ func (c *codeRecorder) WriteHeader(status int) { c.ResponseWriter.WriteHeader(status) } -func (c *codeRecorder) Unwrap() http.ResponseWriter { - return c.ResponseWriter -} - // handlePerformPredictionRequest handles performPrediction operation. // // Perform prediction. @@ -46,8 +43,6 @@ func (s *Server) handlePerformPredictionRequest(args [0]string, argsEscaped bool semconv.HTTPRequestMethodKey.String("GET"), semconv.HTTPRouteKey.String("/api/v1/prediction"), } - // Add attributes from config. - otelAttrs = append(otelAttrs, s.cfg.Attributes...) // Start a span for this request. ctx, span := s.cfg.Tracer.Start(r.Context(), PerformPredictionOperation, @@ -91,7 +86,7 @@ func (s *Server) handlePerformPredictionRequest(args [0]string, argsEscaped bool // unless there was another error (e.g., network error receiving the response body; or 3xx codes with // max redirects exceeded), in which case status MUST be set to Error. code := statusWriter.status - if code < 100 || code >= 500 { + if code >= 100 && code < 500 { span.SetStatus(codes.Error, stage) } @@ -120,9 +115,7 @@ func (s *Server) handlePerformPredictionRequest(args [0]string, argsEscaped bool return } - var rawBody []byte - - var response *PredictionResponse + var response *PredictionResult if m := s.cfg.Middleware; m != nil { mreq := middleware.Request{ Context: ctx, @@ -130,7 +123,6 @@ func (s *Server) handlePerformPredictionRequest(args [0]string, argsEscaped bool OperationSummary: "Perform prediction", OperationID: "performPrediction", Body: nil, - RawBody: rawBody, Params: middleware.Parameters{ { Name: "launch_latitude", @@ -172,6 +164,22 @@ func (s *Server) handlePerformPredictionRequest(args [0]string, argsEscaped bool Name: "stop_datetime", In: "query", }: params.StopDatetime, + { + Name: "ascent_curve", + In: "query", + }: params.AscentCurve, + { + Name: "descent_curve", + In: "query", + }: params.DescentCurve, + { + Name: "interpolate", + In: "query", + }: params.Interpolate, + { + Name: "format", + In: "query", + }: params.Format, { Name: "dataset", In: "query", @@ -183,7 +191,7 @@ func (s *Server) handlePerformPredictionRequest(args [0]string, argsEscaped bool type ( Request = struct{} Params = PerformPredictionParams - Response = *PredictionResponse + Response = *PredictionResult ) response, err = middleware.HookMiddleware[ Request, @@ -240,8 +248,6 @@ func (s *Server) handleReadinessCheckRequest(args [0]string, argsEscaped bool, w semconv.HTTPRequestMethodKey.String("GET"), semconv.HTTPRouteKey.String("/ready"), } - // Add attributes from config. - otelAttrs = append(otelAttrs, s.cfg.Attributes...) // Start a span for this request. ctx, span := s.cfg.Tracer.Start(r.Context(), ReadinessCheckOperation, @@ -285,7 +291,7 @@ func (s *Server) handleReadinessCheckRequest(args [0]string, argsEscaped bool, w // unless there was another error (e.g., network error receiving the response body; or 3xx codes with // max redirects exceeded), in which case status MUST be set to Error. code := statusWriter.status - if code < 100 || code >= 500 { + if code >= 100 && code < 500 { span.SetStatus(codes.Error, stage) } @@ -300,8 +306,6 @@ func (s *Server) handleReadinessCheckRequest(args [0]string, argsEscaped bool, w err error ) - var rawBody []byte - var response *ReadinessResponse if m := s.cfg.Middleware; m != nil { mreq := middleware.Request{ @@ -310,7 +314,6 @@ func (s *Server) handleReadinessCheckRequest(args [0]string, argsEscaped bool, w OperationSummary: "Readiness check", OperationID: "readinessCheck", Body: nil, - RawBody: rawBody, Params: middleware.Parameters{}, Raw: r, } diff --git a/pkg/rest/oas_json_gen.go b/pkg/rest/oas_json_gen.go index 8fa8634..ea3d61c 100644 --- a/pkg/rest/oas_json_gen.go +++ b/pkg/rest/oas_json_gen.go @@ -1,6 +1,6 @@ // Code generated by ogen, DO NOT EDIT. -package rest +package gsn import ( "math/bits" @@ -9,6 +9,7 @@ import ( "github.com/go-faster/errors" "github.com/go-faster/jx" + "github.com/ogen-go/ogen/json" "github.com/ogen-go/ogen/validate" ) @@ -23,13 +24,20 @@ func (s *Error) Encode(e *jx.Encoder) { // encodeFields encodes fields. func (s *Error) encodeFields(e *jx.Encoder) { { - e.FieldStart("error") - s.Error.Encode(e) + e.FieldStart("message") + e.Str(s.Message) + } + { + if s.Details.Set { + e.FieldStart("details") + s.Details.Encode(e) + } } } -var jsonFieldsNameOfError = [1]string{ - 0: "error", +var jsonFieldsNameOfError = [2]string{ + 0: "message", + 1: "details", } // Decode decodes Error from json. @@ -41,15 +49,27 @@ func (s *Error) Decode(d *jx.Decoder) error { if err := d.ObjBytes(func(d *jx.Decoder, k []byte) error { switch string(k) { - case "error": + case "message": requiredBitSet[0] |= 1 << 0 if err := func() error { - if err := s.Error.Decode(d); err != nil { + v, err := d.Str() + s.Message = string(v) + if err != nil { return err } return nil }(); err != nil { - return errors.Wrap(err, "decode field \"error\"") + return errors.Wrap(err, "decode field \"message\"") + } + case "details": + if err := func() error { + s.Details.Reset() + if err := s.Details.Decode(d); err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"details\"") } default: return d.Skip() @@ -107,115 +127,37 @@ func (s *Error) UnmarshalJSON(data []byte) error { return s.Decode(d) } -// Encode implements json.Marshaler. -func (s *ErrorError) Encode(e *jx.Encoder) { - e.ObjStart() - s.encodeFields(e) - e.ObjEnd() +// Encode encodes bool as json. +func (o OptBool) Encode(e *jx.Encoder) { + if !o.Set { + return + } + e.Bool(bool(o.Value)) } -// encodeFields encodes fields. -func (s *ErrorError) encodeFields(e *jx.Encoder) { - { - e.FieldStart("type") - e.Str(s.Type) +// Decode decodes bool from json. +func (o *OptBool) Decode(d *jx.Decoder) error { + if o == nil { + return errors.New("invalid: unable to decode OptBool to nil") } - { - e.FieldStart("description") - e.Str(s.Description) + o.Set = true + v, err := d.Bool() + if err != nil { + return err } -} - -var jsonFieldsNameOfErrorError = [2]string{ - 0: "type", - 1: "description", -} - -// Decode decodes ErrorError from json. -func (s *ErrorError) Decode(d *jx.Decoder) error { - if s == nil { - return errors.New("invalid: unable to decode ErrorError to nil") - } - var requiredBitSet [1]uint8 - - if err := d.ObjBytes(func(d *jx.Decoder, k []byte) error { - switch string(k) { - case "type": - requiredBitSet[0] |= 1 << 0 - if err := func() error { - v, err := d.Str() - s.Type = string(v) - if err != nil { - return err - } - return nil - }(); err != nil { - return errors.Wrap(err, "decode field \"type\"") - } - case "description": - requiredBitSet[0] |= 1 << 1 - if err := func() error { - v, err := d.Str() - s.Description = string(v) - if err != nil { - return err - } - return nil - }(); err != nil { - return errors.Wrap(err, "decode field \"description\"") - } - default: - return d.Skip() - } - return nil - }); err != nil { - return errors.Wrap(err, "decode ErrorError") - } - // Validate required fields. - var failures []validate.FieldError - for i, mask := range [1]uint8{ - 0b00000011, - } { - if result := (requiredBitSet[i] & mask) ^ mask; result != 0 { - // Mask only required fields and check equality to mask using XOR. - // - // If XOR result is not zero, result is not equal to expected, so some fields are missed. - // Bits of fields which would be set are actually bits of missed fields. - missed := bits.OnesCount8(result) - for bitN := 0; bitN < missed; bitN++ { - bitIdx := bits.TrailingZeros8(result) - fieldIdx := i*8 + bitIdx - var name string - if fieldIdx < len(jsonFieldsNameOfErrorError) { - name = jsonFieldsNameOfErrorError[fieldIdx] - } else { - name = strconv.Itoa(fieldIdx) - } - failures = append(failures, validate.FieldError{ - Name: name, - Error: validate.ErrFieldRequired, - }) - // Reset bit. - result &^= 1 << bitIdx - } - } - } - if len(failures) > 0 { - return &validate.Error{Fields: failures} - } - + o.Value = bool(v) return nil } // MarshalJSON implements stdjson.Marshaler. -func (s *ErrorError) MarshalJSON() ([]byte, error) { +func (s OptBool) MarshalJSON() ([]byte, error) { e := jx.Encoder{} s.Encode(&e) return e.Bytes(), nil } // UnmarshalJSON implements stdjson.Unmarshaler. -func (s *ErrorError) UnmarshalJSON(data []byte) error { +func (s *OptBool) UnmarshalJSON(data []byte) error { d := jx.DecodeBytes(data) return s.Decode(d) } @@ -255,108 +197,6 @@ func (s *OptDateTime) UnmarshalJSON(data []byte) error { return s.Decode(d, json.DecodeDateTime) } -// Encode encodes float64 as json. -func (o OptFloat64) Encode(e *jx.Encoder) { - if !o.Set { - return - } - e.Float64(float64(o.Value)) -} - -// Decode decodes float64 from json. -func (o *OptFloat64) Decode(d *jx.Decoder) error { - if o == nil { - return errors.New("invalid: unable to decode OptFloat64 to nil") - } - o.Set = true - v, err := d.Float64() - if err != nil { - return err - } - o.Value = float64(v) - return nil -} - -// MarshalJSON implements stdjson.Marshaler. -func (s OptFloat64) MarshalJSON() ([]byte, error) { - e := jx.Encoder{} - s.Encode(&e) - return e.Bytes(), nil -} - -// UnmarshalJSON implements stdjson.Unmarshaler. -func (s *OptFloat64) UnmarshalJSON(data []byte) error { - d := jx.DecodeBytes(data) - return s.Decode(d) -} - -// Encode encodes PredictionResponseRequest as json. -func (o OptPredictionResponseRequest) Encode(e *jx.Encoder) { - if !o.Set { - return - } - o.Value.Encode(e) -} - -// Decode decodes PredictionResponseRequest from json. -func (o *OptPredictionResponseRequest) Decode(d *jx.Decoder) error { - if o == nil { - return errors.New("invalid: unable to decode OptPredictionResponseRequest to nil") - } - o.Set = true - if err := o.Value.Decode(d); err != nil { - return err - } - return nil -} - -// MarshalJSON implements stdjson.Marshaler. -func (s OptPredictionResponseRequest) MarshalJSON() ([]byte, error) { - e := jx.Encoder{} - s.Encode(&e) - return e.Bytes(), nil -} - -// UnmarshalJSON implements stdjson.Unmarshaler. -func (s *OptPredictionResponseRequest) UnmarshalJSON(data []byte) error { - d := jx.DecodeBytes(data) - return s.Decode(d) -} - -// Encode encodes PredictionResponseWarnings as json. -func (o OptPredictionResponseWarnings) Encode(e *jx.Encoder) { - if !o.Set { - return - } - o.Value.Encode(e) -} - -// Decode decodes PredictionResponseWarnings from json. -func (o *OptPredictionResponseWarnings) Decode(d *jx.Decoder) error { - if o == nil { - return errors.New("invalid: unable to decode OptPredictionResponseWarnings to nil") - } - o.Set = true - o.Value = make(PredictionResponseWarnings) - if err := o.Value.Decode(d); err != nil { - return err - } - return nil -} - -// MarshalJSON implements stdjson.Marshaler. -func (s OptPredictionResponseWarnings) MarshalJSON() ([]byte, error) { - e := jx.Encoder{} - s.Encode(&e) - return e.Bytes(), nil -} - -// UnmarshalJSON implements stdjson.Unmarshaler. -func (s *OptPredictionResponseWarnings) UnmarshalJSON(data []byte) error { - d := jx.DecodeBytes(data) - return s.Decode(d) -} - // Encode encodes string as json. func (o OptString) Encode(e *jx.Encoder) { if !o.Set { @@ -393,19 +233,17 @@ func (s *OptString) UnmarshalJSON(data []byte) error { } // Encode implements json.Marshaler. -func (s *PredictionResponse) Encode(e *jx.Encoder) { +func (s *PredictionResult) Encode(e *jx.Encoder) { e.ObjStart() s.encodeFields(e) e.ObjEnd() } // encodeFields encodes fields. -func (s *PredictionResponse) encodeFields(e *jx.Encoder) { +func (s *PredictionResult) encodeFields(e *jx.Encoder) { { - if s.Request.Set { - e.FieldStart("request") - s.Request.Encode(e) - } + e.FieldStart("metadata") + s.Metadata.Encode(e) } { e.FieldStart("prediction") @@ -415,50 +253,38 @@ func (s *PredictionResponse) encodeFields(e *jx.Encoder) { } e.ArrEnd() } - { - e.FieldStart("metadata") - s.Metadata.Encode(e) - } - { - if s.Warnings.Set { - e.FieldStart("warnings") - s.Warnings.Encode(e) - } - } } -var jsonFieldsNameOfPredictionResponse = [4]string{ - 0: "request", +var jsonFieldsNameOfPredictionResult = [2]string{ + 0: "metadata", 1: "prediction", - 2: "metadata", - 3: "warnings", } -// Decode decodes PredictionResponse from json. -func (s *PredictionResponse) Decode(d *jx.Decoder) error { +// Decode decodes PredictionResult from json. +func (s *PredictionResult) Decode(d *jx.Decoder) error { if s == nil { - return errors.New("invalid: unable to decode PredictionResponse to nil") + return errors.New("invalid: unable to decode PredictionResult to nil") } var requiredBitSet [1]uint8 if err := d.ObjBytes(func(d *jx.Decoder, k []byte) error { switch string(k) { - case "request": + case "metadata": + requiredBitSet[0] |= 1 << 0 if err := func() error { - s.Request.Reset() - if err := s.Request.Decode(d); err != nil { + if err := s.Metadata.Decode(d); err != nil { return err } return nil }(); err != nil { - return errors.Wrap(err, "decode field \"request\"") + return errors.Wrap(err, "decode field \"metadata\"") } case "prediction": requiredBitSet[0] |= 1 << 1 if err := func() error { - s.Prediction = make([]PredictionResponsePredictionItem, 0) + s.Prediction = make([]PredictionResultPredictionItem, 0) if err := d.Arr(func(d *jx.Decoder) error { - var elem PredictionResponsePredictionItem + var elem PredictionResultPredictionItem if err := elem.Decode(d); err != nil { return err } @@ -471,145 +297,12 @@ func (s *PredictionResponse) Decode(d *jx.Decoder) error { }(); err != nil { return errors.Wrap(err, "decode field \"prediction\"") } - case "metadata": - requiredBitSet[0] |= 1 << 2 - if err := func() error { - if err := s.Metadata.Decode(d); err != nil { - return err - } - return nil - }(); err != nil { - return errors.Wrap(err, "decode field \"metadata\"") - } - case "warnings": - if err := func() error { - s.Warnings.Reset() - if err := s.Warnings.Decode(d); err != nil { - return err - } - return nil - }(); err != nil { - return errors.Wrap(err, "decode field \"warnings\"") - } default: return d.Skip() } return nil }); err != nil { - return errors.Wrap(err, "decode PredictionResponse") - } - // Validate required fields. - var failures []validate.FieldError - for i, mask := range [1]uint8{ - 0b00000110, - } { - if result := (requiredBitSet[i] & mask) ^ mask; result != 0 { - // Mask only required fields and check equality to mask using XOR. - // - // If XOR result is not zero, result is not equal to expected, so some fields are missed. - // Bits of fields which would be set are actually bits of missed fields. - missed := bits.OnesCount8(result) - for bitN := 0; bitN < missed; bitN++ { - bitIdx := bits.TrailingZeros8(result) - fieldIdx := i*8 + bitIdx - var name string - if fieldIdx < len(jsonFieldsNameOfPredictionResponse) { - name = jsonFieldsNameOfPredictionResponse[fieldIdx] - } else { - name = strconv.Itoa(fieldIdx) - } - failures = append(failures, validate.FieldError{ - Name: name, - Error: validate.ErrFieldRequired, - }) - // Reset bit. - result &^= 1 << bitIdx - } - } - } - if len(failures) > 0 { - return &validate.Error{Fields: failures} - } - - return nil -} - -// MarshalJSON implements stdjson.Marshaler. -func (s *PredictionResponse) MarshalJSON() ([]byte, error) { - e := jx.Encoder{} - s.Encode(&e) - return e.Bytes(), nil -} - -// UnmarshalJSON implements stdjson.Unmarshaler. -func (s *PredictionResponse) UnmarshalJSON(data []byte) error { - d := jx.DecodeBytes(data) - return s.Decode(d) -} - -// Encode implements json.Marshaler. -func (s *PredictionResponseMetadata) Encode(e *jx.Encoder) { - e.ObjStart() - s.encodeFields(e) - e.ObjEnd() -} - -// encodeFields encodes fields. -func (s *PredictionResponseMetadata) encodeFields(e *jx.Encoder) { - { - e.FieldStart("start_datetime") - json.EncodeDateTime(e, s.StartDatetime) - } - { - e.FieldStart("complete_datetime") - json.EncodeDateTime(e, s.CompleteDatetime) - } -} - -var jsonFieldsNameOfPredictionResponseMetadata = [2]string{ - 0: "start_datetime", - 1: "complete_datetime", -} - -// Decode decodes PredictionResponseMetadata from json. -func (s *PredictionResponseMetadata) Decode(d *jx.Decoder) error { - if s == nil { - return errors.New("invalid: unable to decode PredictionResponseMetadata to nil") - } - var requiredBitSet [1]uint8 - - if err := d.ObjBytes(func(d *jx.Decoder, k []byte) error { - switch string(k) { - case "start_datetime": - requiredBitSet[0] |= 1 << 0 - if err := func() error { - v, err := json.DecodeDateTime(d) - s.StartDatetime = v - if err != nil { - return err - } - return nil - }(); err != nil { - return errors.Wrap(err, "decode field \"start_datetime\"") - } - case "complete_datetime": - requiredBitSet[0] |= 1 << 1 - if err := func() error { - v, err := json.DecodeDateTime(d) - s.CompleteDatetime = v - if err != nil { - return err - } - return nil - }(); err != nil { - return errors.Wrap(err, "decode field \"complete_datetime\"") - } - default: - return d.Skip() - } - return nil - }); err != nil { - return errors.Wrap(err, "decode PredictionResponseMetadata") + return errors.Wrap(err, "decode PredictionResult") } // Validate required fields. var failures []validate.FieldError @@ -626,8 +319,8 @@ func (s *PredictionResponseMetadata) Decode(d *jx.Decoder) error { bitIdx := bits.TrailingZeros8(result) fieldIdx := i*8 + bitIdx var name string - if fieldIdx < len(jsonFieldsNameOfPredictionResponseMetadata) { - name = jsonFieldsNameOfPredictionResponseMetadata[fieldIdx] + if fieldIdx < len(jsonFieldsNameOfPredictionResult) { + name = jsonFieldsNameOfPredictionResult[fieldIdx] } else { name = strconv.Itoa(fieldIdx) } @@ -648,27 +341,140 @@ func (s *PredictionResponseMetadata) Decode(d *jx.Decoder) error { } // MarshalJSON implements stdjson.Marshaler. -func (s *PredictionResponseMetadata) MarshalJSON() ([]byte, error) { +func (s *PredictionResult) MarshalJSON() ([]byte, error) { e := jx.Encoder{} s.Encode(&e) return e.Bytes(), nil } // UnmarshalJSON implements stdjson.Unmarshaler. -func (s *PredictionResponseMetadata) UnmarshalJSON(data []byte) error { +func (s *PredictionResult) UnmarshalJSON(data []byte) error { d := jx.DecodeBytes(data) return s.Decode(d) } // Encode implements json.Marshaler. -func (s *PredictionResponsePredictionItem) Encode(e *jx.Encoder) { +func (s *PredictionResultMetadata) Encode(e *jx.Encoder) { e.ObjStart() s.encodeFields(e) e.ObjEnd() } // encodeFields encodes fields. -func (s *PredictionResponsePredictionItem) encodeFields(e *jx.Encoder) { +func (s *PredictionResultMetadata) encodeFields(e *jx.Encoder) { + { + e.FieldStart("complete_datetime") + json.EncodeDateTime(e, s.CompleteDatetime) + } + { + e.FieldStart("start_datetime") + json.EncodeDateTime(e, s.StartDatetime) + } +} + +var jsonFieldsNameOfPredictionResultMetadata = [2]string{ + 0: "complete_datetime", + 1: "start_datetime", +} + +// Decode decodes PredictionResultMetadata from json. +func (s *PredictionResultMetadata) Decode(d *jx.Decoder) error { + if s == nil { + return errors.New("invalid: unable to decode PredictionResultMetadata to nil") + } + var requiredBitSet [1]uint8 + + if err := d.ObjBytes(func(d *jx.Decoder, k []byte) error { + switch string(k) { + case "complete_datetime": + requiredBitSet[0] |= 1 << 0 + if err := func() error { + v, err := json.DecodeDateTime(d) + s.CompleteDatetime = v + if err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"complete_datetime\"") + } + case "start_datetime": + requiredBitSet[0] |= 1 << 1 + if err := func() error { + v, err := json.DecodeDateTime(d) + s.StartDatetime = v + if err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"start_datetime\"") + } + default: + return d.Skip() + } + return nil + }); err != nil { + return errors.Wrap(err, "decode PredictionResultMetadata") + } + // Validate required fields. + var failures []validate.FieldError + for i, mask := range [1]uint8{ + 0b00000011, + } { + if result := (requiredBitSet[i] & mask) ^ mask; result != 0 { + // Mask only required fields and check equality to mask using XOR. + // + // If XOR result is not zero, result is not equal to expected, so some fields are missed. + // Bits of fields which would be set are actually bits of missed fields. + missed := bits.OnesCount8(result) + for bitN := 0; bitN < missed; bitN++ { + bitIdx := bits.TrailingZeros8(result) + fieldIdx := i*8 + bitIdx + var name string + if fieldIdx < len(jsonFieldsNameOfPredictionResultMetadata) { + name = jsonFieldsNameOfPredictionResultMetadata[fieldIdx] + } else { + name = strconv.Itoa(fieldIdx) + } + failures = append(failures, validate.FieldError{ + Name: name, + Error: validate.ErrFieldRequired, + }) + // Reset bit. + result &^= 1 << bitIdx + } + } + } + if len(failures) > 0 { + return &validate.Error{Fields: failures} + } + + return nil +} + +// MarshalJSON implements stdjson.Marshaler. +func (s *PredictionResultMetadata) MarshalJSON() ([]byte, error) { + e := jx.Encoder{} + s.Encode(&e) + return e.Bytes(), nil +} + +// UnmarshalJSON implements stdjson.Unmarshaler. +func (s *PredictionResultMetadata) UnmarshalJSON(data []byte) error { + d := jx.DecodeBytes(data) + return s.Decode(d) +} + +// Encode implements json.Marshaler. +func (s *PredictionResultPredictionItem) Encode(e *jx.Encoder) { + e.ObjStart() + s.encodeFields(e) + e.ObjEnd() +} + +// encodeFields encodes fields. +func (s *PredictionResultPredictionItem) encodeFields(e *jx.Encoder) { { e.FieldStart("stage") s.Stage.Encode(e) @@ -683,15 +489,15 @@ func (s *PredictionResponsePredictionItem) encodeFields(e *jx.Encoder) { } } -var jsonFieldsNameOfPredictionResponsePredictionItem = [2]string{ +var jsonFieldsNameOfPredictionResultPredictionItem = [2]string{ 0: "stage", 1: "trajectory", } -// Decode decodes PredictionResponsePredictionItem from json. -func (s *PredictionResponsePredictionItem) Decode(d *jx.Decoder) error { +// Decode decodes PredictionResultPredictionItem from json. +func (s *PredictionResultPredictionItem) Decode(d *jx.Decoder) error { if s == nil { - return errors.New("invalid: unable to decode PredictionResponsePredictionItem to nil") + return errors.New("invalid: unable to decode PredictionResultPredictionItem to nil") } var requiredBitSet [1]uint8 @@ -710,9 +516,9 @@ func (s *PredictionResponsePredictionItem) Decode(d *jx.Decoder) error { case "trajectory": requiredBitSet[0] |= 1 << 1 if err := func() error { - s.Trajectory = make([]PredictionResponsePredictionItemTrajectoryItem, 0) + s.Trajectory = make([]PredictionResultPredictionItemTrajectoryItem, 0) if err := d.Arr(func(d *jx.Decoder) error { - var elem PredictionResponsePredictionItemTrajectoryItem + var elem PredictionResultPredictionItemTrajectoryItem if err := elem.Decode(d); err != nil { return err } @@ -730,7 +536,7 @@ func (s *PredictionResponsePredictionItem) Decode(d *jx.Decoder) error { } return nil }); err != nil { - return errors.Wrap(err, "decode PredictionResponsePredictionItem") + return errors.Wrap(err, "decode PredictionResultPredictionItem") } // Validate required fields. var failures []validate.FieldError @@ -747,8 +553,8 @@ func (s *PredictionResponsePredictionItem) Decode(d *jx.Decoder) error { bitIdx := bits.TrailingZeros8(result) fieldIdx := i*8 + bitIdx var name string - if fieldIdx < len(jsonFieldsNameOfPredictionResponsePredictionItem) { - name = jsonFieldsNameOfPredictionResponsePredictionItem[fieldIdx] + if fieldIdx < len(jsonFieldsNameOfPredictionResultPredictionItem) { + name = jsonFieldsNameOfPredictionResultPredictionItem[fieldIdx] } else { name = strconv.Itoa(fieldIdx) } @@ -769,69 +575,67 @@ func (s *PredictionResponsePredictionItem) Decode(d *jx.Decoder) error { } // MarshalJSON implements stdjson.Marshaler. -func (s *PredictionResponsePredictionItem) MarshalJSON() ([]byte, error) { +func (s *PredictionResultPredictionItem) MarshalJSON() ([]byte, error) { e := jx.Encoder{} s.Encode(&e) return e.Bytes(), nil } // UnmarshalJSON implements stdjson.Unmarshaler. -func (s *PredictionResponsePredictionItem) UnmarshalJSON(data []byte) error { +func (s *PredictionResultPredictionItem) UnmarshalJSON(data []byte) error { d := jx.DecodeBytes(data) return s.Decode(d) } -// Encode encodes PredictionResponsePredictionItemStage as json. -func (s PredictionResponsePredictionItemStage) Encode(e *jx.Encoder) { +// Encode encodes PredictionResultPredictionItemStage as json. +func (s PredictionResultPredictionItemStage) Encode(e *jx.Encoder) { e.Str(string(s)) } -// Decode decodes PredictionResponsePredictionItemStage from json. -func (s *PredictionResponsePredictionItemStage) Decode(d *jx.Decoder) error { +// Decode decodes PredictionResultPredictionItemStage from json. +func (s *PredictionResultPredictionItemStage) Decode(d *jx.Decoder) error { if s == nil { - return errors.New("invalid: unable to decode PredictionResponsePredictionItemStage to nil") + return errors.New("invalid: unable to decode PredictionResultPredictionItemStage to nil") } v, err := d.StrBytes() if err != nil { return err } // Try to use constant string. - switch PredictionResponsePredictionItemStage(v) { - case PredictionResponsePredictionItemStageAscent: - *s = PredictionResponsePredictionItemStageAscent - case PredictionResponsePredictionItemStageDescent: - *s = PredictionResponsePredictionItemStageDescent - case PredictionResponsePredictionItemStageFloat: - *s = PredictionResponsePredictionItemStageFloat + switch PredictionResultPredictionItemStage(v) { + case PredictionResultPredictionItemStageAscent: + *s = PredictionResultPredictionItemStageAscent + case PredictionResultPredictionItemStageDescent: + *s = PredictionResultPredictionItemStageDescent default: - *s = PredictionResponsePredictionItemStage(v) + *s = PredictionResultPredictionItemStage(v) } return nil } // MarshalJSON implements stdjson.Marshaler. -func (s PredictionResponsePredictionItemStage) MarshalJSON() ([]byte, error) { +func (s PredictionResultPredictionItemStage) MarshalJSON() ([]byte, error) { e := jx.Encoder{} s.Encode(&e) return e.Bytes(), nil } // UnmarshalJSON implements stdjson.Unmarshaler. -func (s *PredictionResponsePredictionItemStage) UnmarshalJSON(data []byte) error { +func (s *PredictionResultPredictionItemStage) UnmarshalJSON(data []byte) error { d := jx.DecodeBytes(data) return s.Decode(d) } // Encode implements json.Marshaler. -func (s *PredictionResponsePredictionItemTrajectoryItem) Encode(e *jx.Encoder) { +func (s *PredictionResultPredictionItemTrajectoryItem) Encode(e *jx.Encoder) { e.ObjStart() s.encodeFields(e) e.ObjEnd() } // encodeFields encodes fields. -func (s *PredictionResponsePredictionItemTrajectoryItem) encodeFields(e *jx.Encoder) { +func (s *PredictionResultPredictionItemTrajectoryItem) encodeFields(e *jx.Encoder) { { e.FieldStart("datetime") json.EncodeDateTime(e, s.Datetime) @@ -850,17 +654,17 @@ func (s *PredictionResponsePredictionItemTrajectoryItem) encodeFields(e *jx.Enco } } -var jsonFieldsNameOfPredictionResponsePredictionItemTrajectoryItem = [4]string{ +var jsonFieldsNameOfPredictionResultPredictionItemTrajectoryItem = [4]string{ 0: "datetime", 1: "latitude", 2: "longitude", 3: "altitude", } -// Decode decodes PredictionResponsePredictionItemTrajectoryItem from json. -func (s *PredictionResponsePredictionItemTrajectoryItem) Decode(d *jx.Decoder) error { +// Decode decodes PredictionResultPredictionItemTrajectoryItem from json. +func (s *PredictionResultPredictionItemTrajectoryItem) Decode(d *jx.Decoder) error { if s == nil { - return errors.New("invalid: unable to decode PredictionResponsePredictionItemTrajectoryItem to nil") + return errors.New("invalid: unable to decode PredictionResultPredictionItemTrajectoryItem to nil") } var requiredBitSet [1]uint8 @@ -919,7 +723,7 @@ func (s *PredictionResponsePredictionItemTrajectoryItem) Decode(d *jx.Decoder) e } return nil }); err != nil { - return errors.Wrap(err, "decode PredictionResponsePredictionItemTrajectoryItem") + return errors.Wrap(err, "decode PredictionResultPredictionItemTrajectoryItem") } // Validate required fields. var failures []validate.FieldError @@ -936,8 +740,8 @@ func (s *PredictionResponsePredictionItemTrajectoryItem) Decode(d *jx.Decoder) e bitIdx := bits.TrailingZeros8(result) fieldIdx := i*8 + bitIdx var name string - if fieldIdx < len(jsonFieldsNameOfPredictionResponsePredictionItemTrajectoryItem) { - name = jsonFieldsNameOfPredictionResponsePredictionItemTrajectoryItem[fieldIdx] + if fieldIdx < len(jsonFieldsNameOfPredictionResultPredictionItemTrajectoryItem) { + name = jsonFieldsNameOfPredictionResultPredictionItemTrajectoryItem[fieldIdx] } else { name = strconv.Itoa(fieldIdx) } @@ -958,271 +762,14 @@ func (s *PredictionResponsePredictionItemTrajectoryItem) Decode(d *jx.Decoder) e } // MarshalJSON implements stdjson.Marshaler. -func (s *PredictionResponsePredictionItemTrajectoryItem) MarshalJSON() ([]byte, error) { +func (s *PredictionResultPredictionItemTrajectoryItem) MarshalJSON() ([]byte, error) { e := jx.Encoder{} s.Encode(&e) return e.Bytes(), nil } // UnmarshalJSON implements stdjson.Unmarshaler. -func (s *PredictionResponsePredictionItemTrajectoryItem) UnmarshalJSON(data []byte) error { - d := jx.DecodeBytes(data) - return s.Decode(d) -} - -// Encode implements json.Marshaler. -func (s *PredictionResponseRequest) Encode(e *jx.Encoder) { - e.ObjStart() - s.encodeFields(e) - e.ObjEnd() -} - -// encodeFields encodes fields. -func (s *PredictionResponseRequest) encodeFields(e *jx.Encoder) { - { - if s.Dataset.Set { - e.FieldStart("dataset") - s.Dataset.Encode(e) - } - } - { - if s.LaunchLatitude.Set { - e.FieldStart("launch_latitude") - s.LaunchLatitude.Encode(e) - } - } - { - if s.LaunchLongitude.Set { - e.FieldStart("launch_longitude") - s.LaunchLongitude.Encode(e) - } - } - { - if s.LaunchDatetime.Set { - e.FieldStart("launch_datetime") - s.LaunchDatetime.Encode(e) - } - } - { - if s.LaunchAltitude.Set { - e.FieldStart("launch_altitude") - s.LaunchAltitude.Encode(e) - } - } - { - if s.Profile.Set { - e.FieldStart("profile") - s.Profile.Encode(e) - } - } - { - if s.AscentRate.Set { - e.FieldStart("ascent_rate") - s.AscentRate.Encode(e) - } - } - { - if s.BurstAltitude.Set { - e.FieldStart("burst_altitude") - s.BurstAltitude.Encode(e) - } - } - { - if s.DescentRate.Set { - e.FieldStart("descent_rate") - s.DescentRate.Encode(e) - } - } -} - -var jsonFieldsNameOfPredictionResponseRequest = [9]string{ - 0: "dataset", - 1: "launch_latitude", - 2: "launch_longitude", - 3: "launch_datetime", - 4: "launch_altitude", - 5: "profile", - 6: "ascent_rate", - 7: "burst_altitude", - 8: "descent_rate", -} - -// Decode decodes PredictionResponseRequest from json. -func (s *PredictionResponseRequest) Decode(d *jx.Decoder) error { - if s == nil { - return errors.New("invalid: unable to decode PredictionResponseRequest to nil") - } - - if err := d.ObjBytes(func(d *jx.Decoder, k []byte) error { - switch string(k) { - case "dataset": - if err := func() error { - s.Dataset.Reset() - if err := s.Dataset.Decode(d); err != nil { - return err - } - return nil - }(); err != nil { - return errors.Wrap(err, "decode field \"dataset\"") - } - case "launch_latitude": - if err := func() error { - s.LaunchLatitude.Reset() - if err := s.LaunchLatitude.Decode(d); err != nil { - return err - } - return nil - }(); err != nil { - return errors.Wrap(err, "decode field \"launch_latitude\"") - } - case "launch_longitude": - if err := func() error { - s.LaunchLongitude.Reset() - if err := s.LaunchLongitude.Decode(d); err != nil { - return err - } - return nil - }(); err != nil { - return errors.Wrap(err, "decode field \"launch_longitude\"") - } - case "launch_datetime": - if err := func() error { - s.LaunchDatetime.Reset() - if err := s.LaunchDatetime.Decode(d); err != nil { - return err - } - return nil - }(); err != nil { - return errors.Wrap(err, "decode field \"launch_datetime\"") - } - case "launch_altitude": - if err := func() error { - s.LaunchAltitude.Reset() - if err := s.LaunchAltitude.Decode(d); err != nil { - return err - } - return nil - }(); err != nil { - return errors.Wrap(err, "decode field \"launch_altitude\"") - } - case "profile": - if err := func() error { - s.Profile.Reset() - if err := s.Profile.Decode(d); err != nil { - return err - } - return nil - }(); err != nil { - return errors.Wrap(err, "decode field \"profile\"") - } - case "ascent_rate": - if err := func() error { - s.AscentRate.Reset() - if err := s.AscentRate.Decode(d); err != nil { - return err - } - return nil - }(); err != nil { - return errors.Wrap(err, "decode field \"ascent_rate\"") - } - case "burst_altitude": - if err := func() error { - s.BurstAltitude.Reset() - if err := s.BurstAltitude.Decode(d); err != nil { - return err - } - return nil - }(); err != nil { - return errors.Wrap(err, "decode field \"burst_altitude\"") - } - case "descent_rate": - if err := func() error { - s.DescentRate.Reset() - if err := s.DescentRate.Decode(d); err != nil { - return err - } - return nil - }(); err != nil { - return errors.Wrap(err, "decode field \"descent_rate\"") - } - default: - return d.Skip() - } - return nil - }); err != nil { - return errors.Wrap(err, "decode PredictionResponseRequest") - } - - return nil -} - -// MarshalJSON implements stdjson.Marshaler. -func (s *PredictionResponseRequest) MarshalJSON() ([]byte, error) { - e := jx.Encoder{} - s.Encode(&e) - return e.Bytes(), nil -} - -// UnmarshalJSON implements stdjson.Unmarshaler. -func (s *PredictionResponseRequest) UnmarshalJSON(data []byte) error { - d := jx.DecodeBytes(data) - return s.Decode(d) -} - -// Encode implements json.Marshaler. -func (s PredictionResponseWarnings) Encode(e *jx.Encoder) { - e.ObjStart() - s.encodeFields(e) - e.ObjEnd() -} - -// encodeFields implements json.Marshaler. -func (s PredictionResponseWarnings) encodeFields(e *jx.Encoder) { - for k, elem := range s { - e.FieldStart(k) - - if len(elem) != 0 { - e.Raw(elem) - } - } -} - -// Decode decodes PredictionResponseWarnings from json. -func (s *PredictionResponseWarnings) Decode(d *jx.Decoder) error { - if s == nil { - return errors.New("invalid: unable to decode PredictionResponseWarnings to nil") - } - m := s.init() - if err := d.ObjBytes(func(d *jx.Decoder, k []byte) error { - var elem jx.Raw - if err := func() error { - v, err := d.RawAppend(nil) - elem = jx.Raw(v) - if err != nil { - return err - } - return nil - }(); err != nil { - return errors.Wrapf(err, "decode field %q", k) - } - m[string(k)] = elem - return nil - }); err != nil { - return errors.Wrap(err, "decode PredictionResponseWarnings") - } - - return nil -} - -// MarshalJSON implements stdjson.Marshaler. -func (s PredictionResponseWarnings) MarshalJSON() ([]byte, error) { - e := jx.Encoder{} - s.Encode(&e) - return e.Bytes(), nil -} - -// UnmarshalJSON implements stdjson.Unmarshaler. -func (s *PredictionResponseWarnings) UnmarshalJSON(data []byte) error { +func (s *PredictionResultPredictionItemTrajectoryItem) UnmarshalJSON(data []byte) error { d := jx.DecodeBytes(data) return s.Decode(d) } @@ -1241,9 +788,15 @@ func (s *ReadinessResponse) encodeFields(e *jx.Encoder) { s.Status.Encode(e) } { - if s.DatasetTime.Set { - e.FieldStart("dataset_time") - s.DatasetTime.Encode(e, json.EncodeDateTime) + if s.LastUpdate.Set { + e.FieldStart("last_update") + s.LastUpdate.Encode(e, json.EncodeDateTime) + } + } + { + if s.IsFresh.Set { + e.FieldStart("is_fresh") + s.IsFresh.Encode(e) } } { @@ -1254,10 +807,11 @@ func (s *ReadinessResponse) encodeFields(e *jx.Encoder) { } } -var jsonFieldsNameOfReadinessResponse = [3]string{ +var jsonFieldsNameOfReadinessResponse = [4]string{ 0: "status", - 1: "dataset_time", - 2: "error_message", + 1: "last_update", + 2: "is_fresh", + 3: "error_message", } // Decode decodes ReadinessResponse from json. @@ -1279,15 +833,25 @@ func (s *ReadinessResponse) Decode(d *jx.Decoder) error { }(); err != nil { return errors.Wrap(err, "decode field \"status\"") } - case "dataset_time": + case "last_update": if err := func() error { - s.DatasetTime.Reset() - if err := s.DatasetTime.Decode(d, json.DecodeDateTime); err != nil { + s.LastUpdate.Reset() + if err := s.LastUpdate.Decode(d, json.DecodeDateTime); err != nil { return err } return nil }(); err != nil { - return errors.Wrap(err, "decode field \"dataset_time\"") + return errors.Wrap(err, "decode field \"last_update\"") + } + case "is_fresh": + if err := func() error { + s.IsFresh.Reset() + if err := s.IsFresh.Decode(d); err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"is_fresh\"") } case "error_message": if err := func() error { diff --git a/pkg/rest/oas_labeler_gen.go b/pkg/rest/oas_labeler_gen.go index b726eef..47bce6a 100644 --- a/pkg/rest/oas_labeler_gen.go +++ b/pkg/rest/oas_labeler_gen.go @@ -1,6 +1,6 @@ // Code generated by ogen, DO NOT EDIT. -package rest +package gsn import ( "context" diff --git a/pkg/rest/oas_middleware_gen.go b/pkg/rest/oas_middleware_gen.go index 57ea1b4..9d62f34 100644 --- a/pkg/rest/oas_middleware_gen.go +++ b/pkg/rest/oas_middleware_gen.go @@ -1,6 +1,6 @@ // Code generated by ogen, DO NOT EDIT. -package rest +package gsn import ( "github.com/ogen-go/ogen/middleware" diff --git a/pkg/rest/oas_operations_gen.go b/pkg/rest/oas_operations_gen.go index 68097b0..873f44a 100644 --- a/pkg/rest/oas_operations_gen.go +++ b/pkg/rest/oas_operations_gen.go @@ -1,6 +1,6 @@ // Code generated by ogen, DO NOT EDIT. -package rest +package gsn // OperationName is the ogen operation name type OperationName = string diff --git a/pkg/rest/oas_parameters_gen.go b/pkg/rest/oas_parameters_gen.go index c3be508..23cc5d8 100644 --- a/pkg/rest/oas_parameters_gen.go +++ b/pkg/rest/oas_parameters_gen.go @@ -1,12 +1,13 @@ // Code generated by ogen, DO NOT EDIT. -package rest +package gsn import ( "net/http" "time" "github.com/go-faster/errors" + "github.com/ogen-go/ogen/conv" "github.com/ogen-go/ogen/middleware" "github.com/ogen-go/ogen/ogenerrors" @@ -16,17 +17,21 @@ import ( // PerformPredictionParams is parameters of performPrediction operation. type PerformPredictionParams struct { - LaunchLatitude float64 - LaunchLongitude float64 - LaunchDatetime time.Time - LaunchAltitude OptFloat64 `json:",omitempty,omitzero"` - Profile OptPerformPredictionProfile `json:",omitempty,omitzero"` - AscentRate OptFloat64 `json:",omitempty,omitzero"` - BurstAltitude OptFloat64 `json:",omitempty,omitzero"` - DescentRate OptFloat64 `json:",omitempty,omitzero"` - FloatAltitude OptFloat64 `json:",omitempty,omitzero"` - StopDatetime OptDateTime `json:",omitempty,omitzero"` - Dataset OptDateTime `json:",omitempty,omitzero"` + LaunchLatitude OptFloat64 + LaunchLongitude OptFloat64 + LaunchDatetime OptDateTime + LaunchAltitude OptFloat64 + Profile OptPerformPredictionProfile + AscentRate OptFloat64 + BurstAltitude OptFloat64 + DescentRate OptFloat64 + FloatAltitude OptFloat64 + StopDatetime OptDateTime + AscentCurve OptString + DescentCurve OptString + Interpolate OptBool + Format OptPerformPredictionFormat + Dataset OptDateTime } func unpackPerformPredictionParams(packed middleware.Parameters) (params PerformPredictionParams) { @@ -35,21 +40,27 @@ func unpackPerformPredictionParams(packed middleware.Parameters) (params Perform Name: "launch_latitude", In: "query", } - params.LaunchLatitude = packed[key].(float64) + if v, ok := packed[key]; ok { + params.LaunchLatitude = v.(OptFloat64) + } } { key := middleware.ParameterKey{ Name: "launch_longitude", In: "query", } - params.LaunchLongitude = packed[key].(float64) + if v, ok := packed[key]; ok { + params.LaunchLongitude = v.(OptFloat64) + } } { key := middleware.ParameterKey{ Name: "launch_datetime", In: "query", } - params.LaunchDatetime = packed[key].(time.Time) + if v, ok := packed[key]; ok { + params.LaunchDatetime = v.(OptDateTime) + } } { key := middleware.ParameterKey{ @@ -114,6 +125,42 @@ func unpackPerformPredictionParams(packed middleware.Parameters) (params Perform params.StopDatetime = v.(OptDateTime) } } + { + key := middleware.ParameterKey{ + Name: "ascent_curve", + In: "query", + } + if v, ok := packed[key]; ok { + params.AscentCurve = v.(OptString) + } + } + { + key := middleware.ParameterKey{ + Name: "descent_curve", + In: "query", + } + if v, ok := packed[key]; ok { + params.DescentCurve = v.(OptString) + } + } + { + key := middleware.ParameterKey{ + Name: "interpolate", + In: "query", + } + if v, ok := packed[key]; ok { + params.Interpolate = v.(OptBool) + } + } + { + key := middleware.ParameterKey{ + Name: "format", + In: "query", + } + if v, ok := packed[key]; ok { + params.Format = v.(OptPerformPredictionFormat) + } + } { key := middleware.ParameterKey{ Name: "dataset", @@ -138,31 +185,43 @@ func decodePerformPredictionParams(args [0]string, argsEscaped bool, r *http.Req if err := q.HasParam(cfg); err == nil { if err := q.DecodeParam(cfg, func(d uri.Decoder) error { - val, err := d.DecodeValue() - if err != nil { + var paramsDotLaunchLatitudeVal float64 + if err := func() error { + val, err := d.DecodeValue() + if err != nil { + return err + } + + c, err := conv.ToFloat64(val) + if err != nil { + return err + } + + paramsDotLaunchLatitudeVal = c + return nil + }(); err != nil { return err } - - c, err := conv.ToFloat64(val) - if err != nil { - return err - } - - params.LaunchLatitude = c + params.LaunchLatitude.SetTo(paramsDotLaunchLatitudeVal) return nil }); err != nil { return err } if err := func() error { - if err := (validate.Float{}).Validate(float64(params.LaunchLatitude)); err != nil { - return errors.Wrap(err, "float") + if value, ok := params.LaunchLatitude.Get(); ok { + if err := func() error { + if err := (validate.Float{}).Validate(float64(value)); err != nil { + return errors.Wrap(err, "float") + } + return nil + }(); err != nil { + return err + } } return nil }(); err != nil { return err } - } else { - return err } return nil }(); err != nil { @@ -182,31 +241,43 @@ func decodePerformPredictionParams(args [0]string, argsEscaped bool, r *http.Req if err := q.HasParam(cfg); err == nil { if err := q.DecodeParam(cfg, func(d uri.Decoder) error { - val, err := d.DecodeValue() - if err != nil { + var paramsDotLaunchLongitudeVal float64 + if err := func() error { + val, err := d.DecodeValue() + if err != nil { + return err + } + + c, err := conv.ToFloat64(val) + if err != nil { + return err + } + + paramsDotLaunchLongitudeVal = c + return nil + }(); err != nil { return err } - - c, err := conv.ToFloat64(val) - if err != nil { - return err - } - - params.LaunchLongitude = c + params.LaunchLongitude.SetTo(paramsDotLaunchLongitudeVal) return nil }); err != nil { return err } if err := func() error { - if err := (validate.Float{}).Validate(float64(params.LaunchLongitude)); err != nil { - return errors.Wrap(err, "float") + if value, ok := params.LaunchLongitude.Get(); ok { + if err := func() error { + if err := (validate.Float{}).Validate(float64(value)); err != nil { + return errors.Wrap(err, "float") + } + return nil + }(); err != nil { + return err + } } return nil }(); err != nil { return err } - } else { - return err } return nil }(); err != nil { @@ -226,23 +297,28 @@ func decodePerformPredictionParams(args [0]string, argsEscaped bool, r *http.Req if err := q.HasParam(cfg); err == nil { if err := q.DecodeParam(cfg, func(d uri.Decoder) error { - val, err := d.DecodeValue() - if err != nil { + var paramsDotLaunchDatetimeVal time.Time + if err := func() error { + val, err := d.DecodeValue() + if err != nil { + return err + } + + c, err := conv.ToDateTime(val) + if err != nil { + return err + } + + paramsDotLaunchDatetimeVal = c + return nil + }(); err != nil { return err } - - c, err := conv.ToDateTime(val) - if err != nil { - return err - } - - params.LaunchDatetime = c + params.LaunchDatetime.SetTo(paramsDotLaunchDatetimeVal) return nil }); err != nil { return err } - } else { - return err } return nil }(); err != nil { @@ -308,11 +384,6 @@ func decodePerformPredictionParams(args [0]string, argsEscaped bool, r *http.Req Err: err, } } - // Set default value for query: profile. - { - val := PerformPredictionProfile("standard_profile") - params.Profile.SetTo(val) - } // Decode query: profile. if err := func() error { cfg := uri.QueryParameterDecodingConfig{ @@ -634,6 +705,185 @@ func decodePerformPredictionParams(args [0]string, argsEscaped bool, r *http.Req Err: err, } } + // Decode query: ascent_curve. + if err := func() error { + cfg := uri.QueryParameterDecodingConfig{ + Name: "ascent_curve", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.HasParam(cfg); err == nil { + if err := q.DecodeParam(cfg, func(d uri.Decoder) error { + var paramsDotAscentCurveVal string + if err := func() error { + val, err := d.DecodeValue() + if err != nil { + return err + } + + c, err := conv.ToString(val) + if err != nil { + return err + } + + paramsDotAscentCurveVal = c + return nil + }(); err != nil { + return err + } + params.AscentCurve.SetTo(paramsDotAscentCurveVal) + return nil + }); err != nil { + return err + } + } + return nil + }(); err != nil { + return params, &ogenerrors.DecodeParamError{ + Name: "ascent_curve", + In: "query", + Err: err, + } + } + // Decode query: descent_curve. + if err := func() error { + cfg := uri.QueryParameterDecodingConfig{ + Name: "descent_curve", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.HasParam(cfg); err == nil { + if err := q.DecodeParam(cfg, func(d uri.Decoder) error { + var paramsDotDescentCurveVal string + if err := func() error { + val, err := d.DecodeValue() + if err != nil { + return err + } + + c, err := conv.ToString(val) + if err != nil { + return err + } + + paramsDotDescentCurveVal = c + return nil + }(); err != nil { + return err + } + params.DescentCurve.SetTo(paramsDotDescentCurveVal) + return nil + }); err != nil { + return err + } + } + return nil + }(); err != nil { + return params, &ogenerrors.DecodeParamError{ + Name: "descent_curve", + In: "query", + Err: err, + } + } + // Decode query: interpolate. + if err := func() error { + cfg := uri.QueryParameterDecodingConfig{ + Name: "interpolate", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.HasParam(cfg); err == nil { + if err := q.DecodeParam(cfg, func(d uri.Decoder) error { + var paramsDotInterpolateVal bool + if err := func() error { + val, err := d.DecodeValue() + if err != nil { + return err + } + + c, err := conv.ToBool(val) + if err != nil { + return err + } + + paramsDotInterpolateVal = c + return nil + }(); err != nil { + return err + } + params.Interpolate.SetTo(paramsDotInterpolateVal) + return nil + }); err != nil { + return err + } + } + return nil + }(); err != nil { + return params, &ogenerrors.DecodeParamError{ + Name: "interpolate", + In: "query", + Err: err, + } + } + // Decode query: format. + if err := func() error { + cfg := uri.QueryParameterDecodingConfig{ + Name: "format", + Style: uri.QueryStyleForm, + Explode: true, + } + + if err := q.HasParam(cfg); err == nil { + if err := q.DecodeParam(cfg, func(d uri.Decoder) error { + var paramsDotFormatVal PerformPredictionFormat + if err := func() error { + val, err := d.DecodeValue() + if err != nil { + return err + } + + c, err := conv.ToString(val) + if err != nil { + return err + } + + paramsDotFormatVal = PerformPredictionFormat(c) + return nil + }(); err != nil { + return err + } + params.Format.SetTo(paramsDotFormatVal) + return nil + }); err != nil { + return err + } + if err := func() error { + if value, ok := params.Format.Get(); ok { + if err := func() error { + if err := value.Validate(); err != nil { + return err + } + return nil + }(); err != nil { + return err + } + } + return nil + }(); err != nil { + return err + } + } + return nil + }(); err != nil { + return params, &ogenerrors.DecodeParamError{ + Name: "format", + In: "query", + Err: err, + } + } // Decode query: dataset. if err := func() error { cfg := uri.QueryParameterDecodingConfig{ diff --git a/pkg/rest/oas_request_decoders_gen.go b/pkg/rest/oas_request_decoders_gen.go index 1ad6008..d99e2f0 100644 --- a/pkg/rest/oas_request_decoders_gen.go +++ b/pkg/rest/oas_request_decoders_gen.go @@ -1,3 +1,3 @@ // Code generated by ogen, DO NOT EDIT. -package rest +package gsn diff --git a/pkg/rest/oas_request_encoders_gen.go b/pkg/rest/oas_request_encoders_gen.go index 1ad6008..d99e2f0 100644 --- a/pkg/rest/oas_request_encoders_gen.go +++ b/pkg/rest/oas_request_encoders_gen.go @@ -1,3 +1,3 @@ // Code generated by ogen, DO NOT EDIT. -package rest +package gsn diff --git a/pkg/rest/oas_response_decoders_gen.go b/pkg/rest/oas_response_decoders_gen.go index 842583d..3c148fd 100644 --- a/pkg/rest/oas_response_decoders_gen.go +++ b/pkg/rest/oas_response_decoders_gen.go @@ -1,6 +1,6 @@ // Code generated by ogen, DO NOT EDIT. -package rest +package gsn import ( "io" @@ -9,11 +9,12 @@ import ( "github.com/go-faster/errors" "github.com/go-faster/jx" + "github.com/ogen-go/ogen/ogenerrors" "github.com/ogen-go/ogen/validate" ) -func decodePerformPredictionResponse(resp *http.Response) (res *PredictionResponse, _ error) { +func decodePerformPredictionResponse(resp *http.Response) (res *PredictionResult, _ error) { switch resp.StatusCode { case 200: // Code 200. @@ -29,7 +30,7 @@ func decodePerformPredictionResponse(resp *http.Response) (res *PredictionRespon } d := jx.DecodeBytes(buf) - var response PredictionResponse + var response PredictionResult if err := func() error { if err := response.Decode(d); err != nil { return err diff --git a/pkg/rest/oas_response_encoders_gen.go b/pkg/rest/oas_response_encoders_gen.go index 37892a3..8f24cd5 100644 --- a/pkg/rest/oas_response_encoders_gen.go +++ b/pkg/rest/oas_response_encoders_gen.go @@ -1,18 +1,19 @@ // Code generated by ogen, DO NOT EDIT. -package rest +package gsn import ( "net/http" "github.com/go-faster/errors" "github.com/go-faster/jx" - ht "github.com/ogen-go/ogen/http" "go.opentelemetry.io/otel/codes" "go.opentelemetry.io/otel/trace" + + ht "github.com/ogen-go/ogen/http" ) -func encodePerformPredictionResponse(response *PredictionResponse, w http.ResponseWriter, span trace.Span) error { +func encodePerformPredictionResponse(response *PredictionResult, w http.ResponseWriter, span trace.Span) error { w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(200) span.SetStatus(codes.Ok, http.StatusText(200)) diff --git a/pkg/rest/oas_router_gen.go b/pkg/rest/oas_router_gen.go index ac8879a..1eea998 100644 --- a/pkg/rest/oas_router_gen.go +++ b/pkg/rest/oas_router_gen.go @@ -1,6 +1,6 @@ // Code generated by ogen, DO NOT EDIT. -package rest +package gsn import ( "net/http" @@ -74,12 +74,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { case "GET": s.handlePerformPredictionRequest([0]string{}, elemIsEscaped, w, r) default: - s.notAllowed(w, r, notAllowedParams{ - allowedMethods: "GET", - allowedHeaders: nil, - acceptPost: "", - acceptPatch: "", - }) + s.notAllowed(w, r, "GET") } return @@ -99,12 +94,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { case "GET": s.handleReadinessCheckRequest([0]string{}, elemIsEscaped, w, r) default: - s.notAllowed(w, r, notAllowedParams{ - allowedMethods: "GET", - allowedHeaders: nil, - acceptPost: "", - acceptPatch: "", - }) + s.notAllowed(w, r, "GET") } return @@ -119,13 +109,12 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // Route is route object. type Route struct { - name string - summary string - operationID string - operationGroup string - pathPattern string - count int - args [0]string + name string + summary string + operationID string + pathPattern string + count int + args [0]string } // Name returns ogen operation name. @@ -145,11 +134,6 @@ func (r Route) OperationID() string { return r.operationID } -// OperationGroup returns the x-ogen-operation-group value. -func (r Route) OperationGroup() string { - return r.operationGroup -} - // PathPattern returns OpenAPI path. func (r Route) PathPattern() string { return r.pathPattern @@ -225,7 +209,6 @@ func (s *Server) FindPath(method string, u *url.URL) (r Route, _ bool) { r.name = PerformPredictionOperation r.summary = "Perform prediction" r.operationID = "performPrediction" - r.operationGroup = "" r.pathPattern = "/api/v1/prediction" r.args = args r.count = 0 @@ -250,7 +233,6 @@ func (s *Server) FindPath(method string, u *url.URL) (r Route, _ bool) { r.name = ReadinessCheckOperation r.summary = "Readiness check" r.operationID = "readinessCheck" - r.operationGroup = "" r.pathPattern = "/ready" r.args = args r.count = 0 diff --git a/pkg/rest/oas_schemas_gen.go b/pkg/rest/oas_schemas_gen.go index 28faae8..26808cc 100644 --- a/pkg/rest/oas_schemas_gen.go +++ b/pkg/rest/oas_schemas_gen.go @@ -1,13 +1,12 @@ // Code generated by ogen, DO NOT EDIT. -package rest +package gsn import ( "fmt" "time" "github.com/go-faster/errors" - "github.com/go-faster/jx" ) func (s *ErrorStatusCode) Error() string { @@ -16,42 +15,28 @@ func (s *ErrorStatusCode) Error() string { // Ref: #/components/schemas/Error type Error struct { - Error ErrorError `json:"error"` + Message string `json:"message"` + Details OptString `json:"details"` } -// GetError returns the value of Error. -func (s *Error) GetError() ErrorError { - return s.Error +// GetMessage returns the value of Message. +func (s *Error) GetMessage() string { + return s.Message } -// SetError sets the value of Error. -func (s *Error) SetError(val ErrorError) { - s.Error = val +// GetDetails returns the value of Details. +func (s *Error) GetDetails() OptString { + return s.Details } -type ErrorError struct { - Type string `json:"type"` - Description string `json:"description"` +// SetMessage sets the value of Message. +func (s *Error) SetMessage(val string) { + s.Message = val } -// GetType returns the value of Type. -func (s *ErrorError) GetType() string { - return s.Type -} - -// GetDescription returns the value of Description. -func (s *ErrorError) GetDescription() string { - return s.Description -} - -// SetType sets the value of Type. -func (s *ErrorError) SetType(val string) { - s.Type = val -} - -// SetDescription sets the value of Description. -func (s *ErrorError) SetDescription(val string) { - s.Description = val +// SetDetails sets the value of Details. +func (s *Error) SetDetails(val OptString) { + s.Details = val } // ErrorStatusCode wraps Error with StatusCode. @@ -80,6 +65,52 @@ func (s *ErrorStatusCode) SetResponse(val Error) { s.Response = val } +// NewOptBool returns new OptBool with value set to v. +func NewOptBool(v bool) OptBool { + return OptBool{ + Value: v, + Set: true, + } +} + +// OptBool is optional bool. +type OptBool struct { + Value bool + Set bool +} + +// IsSet returns true if OptBool was set. +func (o OptBool) IsSet() bool { return o.Set } + +// Reset unsets value. +func (o *OptBool) Reset() { + var v bool + o.Value = v + o.Set = false +} + +// SetTo sets value to v. +func (o *OptBool) SetTo(v bool) { + o.Set = true + o.Value = v +} + +// Get returns value and boolean that denotes whether value was set. +func (o OptBool) Get() (v bool, ok bool) { + if !o.Set { + return v, false + } + return o.Value, true +} + +// Or returns value if set, or given parameter if does not. +func (o OptBool) Or(d bool) bool { + if v, ok := o.Get(); ok { + return v + } + return d +} + // NewOptDateTime returns new OptDateTime with value set to v. func NewOptDateTime(v time.Time) OptDateTime { return OptDateTime{ @@ -172,6 +203,52 @@ func (o OptFloat64) Or(d float64) float64 { return d } +// NewOptPerformPredictionFormat returns new OptPerformPredictionFormat with value set to v. +func NewOptPerformPredictionFormat(v PerformPredictionFormat) OptPerformPredictionFormat { + return OptPerformPredictionFormat{ + Value: v, + Set: true, + } +} + +// OptPerformPredictionFormat is optional PerformPredictionFormat. +type OptPerformPredictionFormat struct { + Value PerformPredictionFormat + Set bool +} + +// IsSet returns true if OptPerformPredictionFormat was set. +func (o OptPerformPredictionFormat) IsSet() bool { return o.Set } + +// Reset unsets value. +func (o *OptPerformPredictionFormat) Reset() { + var v PerformPredictionFormat + o.Value = v + o.Set = false +} + +// SetTo sets value to v. +func (o *OptPerformPredictionFormat) SetTo(v PerformPredictionFormat) { + o.Set = true + o.Value = v +} + +// Get returns value and boolean that denotes whether value was set. +func (o OptPerformPredictionFormat) Get() (v PerformPredictionFormat, ok bool) { + if !o.Set { + return v, false + } + return o.Value, true +} + +// Or returns value if set, or given parameter if does not. +func (o OptPerformPredictionFormat) Or(d PerformPredictionFormat) PerformPredictionFormat { + if v, ok := o.Get(); ok { + return v + } + return d +} + // NewOptPerformPredictionProfile returns new OptPerformPredictionProfile with value set to v. func NewOptPerformPredictionProfile(v PerformPredictionProfile) OptPerformPredictionProfile { return OptPerformPredictionProfile{ @@ -218,98 +295,6 @@ func (o OptPerformPredictionProfile) Or(d PerformPredictionProfile) PerformPredi return d } -// NewOptPredictionResponseRequest returns new OptPredictionResponseRequest with value set to v. -func NewOptPredictionResponseRequest(v PredictionResponseRequest) OptPredictionResponseRequest { - return OptPredictionResponseRequest{ - Value: v, - Set: true, - } -} - -// OptPredictionResponseRequest is optional PredictionResponseRequest. -type OptPredictionResponseRequest struct { - Value PredictionResponseRequest - Set bool -} - -// IsSet returns true if OptPredictionResponseRequest was set. -func (o OptPredictionResponseRequest) IsSet() bool { return o.Set } - -// Reset unsets value. -func (o *OptPredictionResponseRequest) Reset() { - var v PredictionResponseRequest - o.Value = v - o.Set = false -} - -// SetTo sets value to v. -func (o *OptPredictionResponseRequest) SetTo(v PredictionResponseRequest) { - o.Set = true - o.Value = v -} - -// Get returns value and boolean that denotes whether value was set. -func (o OptPredictionResponseRequest) Get() (v PredictionResponseRequest, ok bool) { - if !o.Set { - return v, false - } - return o.Value, true -} - -// Or returns value if set, or given parameter if does not. -func (o OptPredictionResponseRequest) Or(d PredictionResponseRequest) PredictionResponseRequest { - if v, ok := o.Get(); ok { - return v - } - return d -} - -// NewOptPredictionResponseWarnings returns new OptPredictionResponseWarnings with value set to v. -func NewOptPredictionResponseWarnings(v PredictionResponseWarnings) OptPredictionResponseWarnings { - return OptPredictionResponseWarnings{ - Value: v, - Set: true, - } -} - -// OptPredictionResponseWarnings is optional PredictionResponseWarnings. -type OptPredictionResponseWarnings struct { - Value PredictionResponseWarnings - Set bool -} - -// IsSet returns true if OptPredictionResponseWarnings was set. -func (o OptPredictionResponseWarnings) IsSet() bool { return o.Set } - -// Reset unsets value. -func (o *OptPredictionResponseWarnings) Reset() { - var v PredictionResponseWarnings - o.Value = v - o.Set = false -} - -// SetTo sets value to v. -func (o *OptPredictionResponseWarnings) SetTo(v PredictionResponseWarnings) { - o.Set = true - o.Value = v -} - -// Get returns value and boolean that denotes whether value was set. -func (o OptPredictionResponseWarnings) Get() (v PredictionResponseWarnings, ok bool) { - if !o.Set { - return v, false - } - return o.Value, true -} - -// Or returns value if set, or given parameter if does not. -func (o OptPredictionResponseWarnings) Or(d PredictionResponseWarnings) PredictionResponseWarnings { - if v, ok := o.Get(); ok { - return v - } - return d -} - // NewOptString returns new OptString with value set to v. func NewOptString(v string) OptString { return OptString{ @@ -356,11 +341,47 @@ func (o OptString) Or(d string) string { return d } +type PerformPredictionFormat string + +const ( + PerformPredictionFormatJSON PerformPredictionFormat = "json" +) + +// AllValues returns all PerformPredictionFormat values. +func (PerformPredictionFormat) AllValues() []PerformPredictionFormat { + return []PerformPredictionFormat{ + PerformPredictionFormatJSON, + } +} + +// MarshalText implements encoding.TextMarshaler. +func (s PerformPredictionFormat) MarshalText() ([]byte, error) { + switch s { + case PerformPredictionFormatJSON: + return []byte(s), nil + default: + return nil, errors.Errorf("invalid value: %q", s) + } +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (s *PerformPredictionFormat) UnmarshalText(data []byte) error { + switch PerformPredictionFormat(data) { + case PerformPredictionFormatJSON: + *s = PerformPredictionFormatJSON + return nil + default: + return errors.Errorf("invalid value: %q", data) + } +} + type PerformPredictionProfile string const ( PerformPredictionProfileStandardProfile PerformPredictionProfile = "standard_profile" PerformPredictionProfileFloatProfile PerformPredictionProfile = "float_profile" + PerformPredictionProfileReverseProfile PerformPredictionProfile = "reverse_profile" + PerformPredictionProfileCustomProfile PerformPredictionProfile = "custom_profile" ) // AllValues returns all PerformPredictionProfile values. @@ -368,6 +389,8 @@ func (PerformPredictionProfile) AllValues() []PerformPredictionProfile { return []PerformPredictionProfile{ PerformPredictionProfileStandardProfile, PerformPredictionProfileFloatProfile, + PerformPredictionProfileReverseProfile, + PerformPredictionProfileCustomProfile, } } @@ -378,6 +401,10 @@ func (s PerformPredictionProfile) MarshalText() ([]byte, error) { return []byte(s), nil case PerformPredictionProfileFloatProfile: return []byte(s), nil + case PerformPredictionProfileReverseProfile: + return []byte(s), nil + case PerformPredictionProfileCustomProfile: + return []byte(s), nil default: return nil, errors.Errorf("invalid value: %q", s) } @@ -392,134 +419,114 @@ func (s *PerformPredictionProfile) UnmarshalText(data []byte) error { case PerformPredictionProfileFloatProfile: *s = PerformPredictionProfileFloatProfile return nil + case PerformPredictionProfileReverseProfile: + *s = PerformPredictionProfileReverseProfile + return nil + case PerformPredictionProfileCustomProfile: + *s = PerformPredictionProfileCustomProfile + return nil default: return errors.Errorf("invalid value: %q", data) } } -// Ref: #/components/schemas/PredictionResponse -type PredictionResponse struct { - Request OptPredictionResponseRequest `json:"request"` - Prediction []PredictionResponsePredictionItem `json:"prediction"` - Metadata PredictionResponseMetadata `json:"metadata"` - Warnings OptPredictionResponseWarnings `json:"warnings"` -} - -// GetRequest returns the value of Request. -func (s *PredictionResponse) GetRequest() OptPredictionResponseRequest { - return s.Request -} - -// GetPrediction returns the value of Prediction. -func (s *PredictionResponse) GetPrediction() []PredictionResponsePredictionItem { - return s.Prediction +// Ref: #/components/schemas/PredictionResult +type PredictionResult struct { + Metadata PredictionResultMetadata `json:"metadata"` + Prediction []PredictionResultPredictionItem `json:"prediction"` } // GetMetadata returns the value of Metadata. -func (s *PredictionResponse) GetMetadata() PredictionResponseMetadata { +func (s *PredictionResult) GetMetadata() PredictionResultMetadata { return s.Metadata } -// GetWarnings returns the value of Warnings. -func (s *PredictionResponse) GetWarnings() OptPredictionResponseWarnings { - return s.Warnings -} - -// SetRequest sets the value of Request. -func (s *PredictionResponse) SetRequest(val OptPredictionResponseRequest) { - s.Request = val -} - -// SetPrediction sets the value of Prediction. -func (s *PredictionResponse) SetPrediction(val []PredictionResponsePredictionItem) { - s.Prediction = val +// GetPrediction returns the value of Prediction. +func (s *PredictionResult) GetPrediction() []PredictionResultPredictionItem { + return s.Prediction } // SetMetadata sets the value of Metadata. -func (s *PredictionResponse) SetMetadata(val PredictionResponseMetadata) { +func (s *PredictionResult) SetMetadata(val PredictionResultMetadata) { s.Metadata = val } -// SetWarnings sets the value of Warnings. -func (s *PredictionResponse) SetWarnings(val OptPredictionResponseWarnings) { - s.Warnings = val +// SetPrediction sets the value of Prediction. +func (s *PredictionResult) SetPrediction(val []PredictionResultPredictionItem) { + s.Prediction = val } -type PredictionResponseMetadata struct { - StartDatetime time.Time `json:"start_datetime"` +type PredictionResultMetadata struct { CompleteDatetime time.Time `json:"complete_datetime"` -} - -// GetStartDatetime returns the value of StartDatetime. -func (s *PredictionResponseMetadata) GetStartDatetime() time.Time { - return s.StartDatetime + StartDatetime time.Time `json:"start_datetime"` } // GetCompleteDatetime returns the value of CompleteDatetime. -func (s *PredictionResponseMetadata) GetCompleteDatetime() time.Time { +func (s *PredictionResultMetadata) GetCompleteDatetime() time.Time { return s.CompleteDatetime } -// SetStartDatetime sets the value of StartDatetime. -func (s *PredictionResponseMetadata) SetStartDatetime(val time.Time) { - s.StartDatetime = val +// GetStartDatetime returns the value of StartDatetime. +func (s *PredictionResultMetadata) GetStartDatetime() time.Time { + return s.StartDatetime } // SetCompleteDatetime sets the value of CompleteDatetime. -func (s *PredictionResponseMetadata) SetCompleteDatetime(val time.Time) { +func (s *PredictionResultMetadata) SetCompleteDatetime(val time.Time) { s.CompleteDatetime = val } -type PredictionResponsePredictionItem struct { - Stage PredictionResponsePredictionItemStage `json:"stage"` - Trajectory []PredictionResponsePredictionItemTrajectoryItem `json:"trajectory"` +// SetStartDatetime sets the value of StartDatetime. +func (s *PredictionResultMetadata) SetStartDatetime(val time.Time) { + s.StartDatetime = val +} + +type PredictionResultPredictionItem struct { + Stage PredictionResultPredictionItemStage `json:"stage"` + Trajectory []PredictionResultPredictionItemTrajectoryItem `json:"trajectory"` } // GetStage returns the value of Stage. -func (s *PredictionResponsePredictionItem) GetStage() PredictionResponsePredictionItemStage { +func (s *PredictionResultPredictionItem) GetStage() PredictionResultPredictionItemStage { return s.Stage } // GetTrajectory returns the value of Trajectory. -func (s *PredictionResponsePredictionItem) GetTrajectory() []PredictionResponsePredictionItemTrajectoryItem { +func (s *PredictionResultPredictionItem) GetTrajectory() []PredictionResultPredictionItemTrajectoryItem { return s.Trajectory } // SetStage sets the value of Stage. -func (s *PredictionResponsePredictionItem) SetStage(val PredictionResponsePredictionItemStage) { +func (s *PredictionResultPredictionItem) SetStage(val PredictionResultPredictionItemStage) { s.Stage = val } // SetTrajectory sets the value of Trajectory. -func (s *PredictionResponsePredictionItem) SetTrajectory(val []PredictionResponsePredictionItemTrajectoryItem) { +func (s *PredictionResultPredictionItem) SetTrajectory(val []PredictionResultPredictionItemTrajectoryItem) { s.Trajectory = val } -type PredictionResponsePredictionItemStage string +type PredictionResultPredictionItemStage string const ( - PredictionResponsePredictionItemStageAscent PredictionResponsePredictionItemStage = "ascent" - PredictionResponsePredictionItemStageDescent PredictionResponsePredictionItemStage = "descent" - PredictionResponsePredictionItemStageFloat PredictionResponsePredictionItemStage = "float" + PredictionResultPredictionItemStageAscent PredictionResultPredictionItemStage = "ascent" + PredictionResultPredictionItemStageDescent PredictionResultPredictionItemStage = "descent" ) -// AllValues returns all PredictionResponsePredictionItemStage values. -func (PredictionResponsePredictionItemStage) AllValues() []PredictionResponsePredictionItemStage { - return []PredictionResponsePredictionItemStage{ - PredictionResponsePredictionItemStageAscent, - PredictionResponsePredictionItemStageDescent, - PredictionResponsePredictionItemStageFloat, +// AllValues returns all PredictionResultPredictionItemStage values. +func (PredictionResultPredictionItemStage) AllValues() []PredictionResultPredictionItemStage { + return []PredictionResultPredictionItemStage{ + PredictionResultPredictionItemStageAscent, + PredictionResultPredictionItemStageDescent, } } // MarshalText implements encoding.TextMarshaler. -func (s PredictionResponsePredictionItemStage) MarshalText() ([]byte, error) { +func (s PredictionResultPredictionItemStage) MarshalText() ([]byte, error) { switch s { - case PredictionResponsePredictionItemStageAscent: + case PredictionResultPredictionItemStageAscent: return []byte(s), nil - case PredictionResponsePredictionItemStageDescent: - return []byte(s), nil - case PredictionResponsePredictionItemStageFloat: + case PredictionResultPredictionItemStageDescent: return []byte(s), nil default: return nil, errors.Errorf("invalid value: %q", s) @@ -527,23 +534,20 @@ func (s PredictionResponsePredictionItemStage) MarshalText() ([]byte, error) { } // UnmarshalText implements encoding.TextUnmarshaler. -func (s *PredictionResponsePredictionItemStage) UnmarshalText(data []byte) error { - switch PredictionResponsePredictionItemStage(data) { - case PredictionResponsePredictionItemStageAscent: - *s = PredictionResponsePredictionItemStageAscent +func (s *PredictionResultPredictionItemStage) UnmarshalText(data []byte) error { + switch PredictionResultPredictionItemStage(data) { + case PredictionResultPredictionItemStageAscent: + *s = PredictionResultPredictionItemStageAscent return nil - case PredictionResponsePredictionItemStageDescent: - *s = PredictionResponsePredictionItemStageDescent - return nil - case PredictionResponsePredictionItemStageFloat: - *s = PredictionResponsePredictionItemStageFloat + case PredictionResultPredictionItemStageDescent: + *s = PredictionResultPredictionItemStageDescent return nil default: return errors.Errorf("invalid value: %q", data) } } -type PredictionResponsePredictionItemTrajectoryItem struct { +type PredictionResultPredictionItemTrajectoryItem struct { Datetime time.Time `json:"datetime"` Latitude float64 `json:"latitude"` Longitude float64 `json:"longitude"` @@ -551,162 +555,50 @@ type PredictionResponsePredictionItemTrajectoryItem struct { } // GetDatetime returns the value of Datetime. -func (s *PredictionResponsePredictionItemTrajectoryItem) GetDatetime() time.Time { +func (s *PredictionResultPredictionItemTrajectoryItem) GetDatetime() time.Time { return s.Datetime } // GetLatitude returns the value of Latitude. -func (s *PredictionResponsePredictionItemTrajectoryItem) GetLatitude() float64 { +func (s *PredictionResultPredictionItemTrajectoryItem) GetLatitude() float64 { return s.Latitude } // GetLongitude returns the value of Longitude. -func (s *PredictionResponsePredictionItemTrajectoryItem) GetLongitude() float64 { +func (s *PredictionResultPredictionItemTrajectoryItem) GetLongitude() float64 { return s.Longitude } // GetAltitude returns the value of Altitude. -func (s *PredictionResponsePredictionItemTrajectoryItem) GetAltitude() float64 { +func (s *PredictionResultPredictionItemTrajectoryItem) GetAltitude() float64 { return s.Altitude } // SetDatetime sets the value of Datetime. -func (s *PredictionResponsePredictionItemTrajectoryItem) SetDatetime(val time.Time) { +func (s *PredictionResultPredictionItemTrajectoryItem) SetDatetime(val time.Time) { s.Datetime = val } // SetLatitude sets the value of Latitude. -func (s *PredictionResponsePredictionItemTrajectoryItem) SetLatitude(val float64) { +func (s *PredictionResultPredictionItemTrajectoryItem) SetLatitude(val float64) { s.Latitude = val } // SetLongitude sets the value of Longitude. -func (s *PredictionResponsePredictionItemTrajectoryItem) SetLongitude(val float64) { +func (s *PredictionResultPredictionItemTrajectoryItem) SetLongitude(val float64) { s.Longitude = val } // SetAltitude sets the value of Altitude. -func (s *PredictionResponsePredictionItemTrajectoryItem) SetAltitude(val float64) { +func (s *PredictionResultPredictionItemTrajectoryItem) SetAltitude(val float64) { s.Altitude = val } -type PredictionResponseRequest struct { - Dataset OptString `json:"dataset"` - LaunchLatitude OptFloat64 `json:"launch_latitude"` - LaunchLongitude OptFloat64 `json:"launch_longitude"` - LaunchDatetime OptString `json:"launch_datetime"` - LaunchAltitude OptFloat64 `json:"launch_altitude"` - Profile OptString `json:"profile"` - AscentRate OptFloat64 `json:"ascent_rate"` - BurstAltitude OptFloat64 `json:"burst_altitude"` - DescentRate OptFloat64 `json:"descent_rate"` -} - -// GetDataset returns the value of Dataset. -func (s *PredictionResponseRequest) GetDataset() OptString { - return s.Dataset -} - -// GetLaunchLatitude returns the value of LaunchLatitude. -func (s *PredictionResponseRequest) GetLaunchLatitude() OptFloat64 { - return s.LaunchLatitude -} - -// GetLaunchLongitude returns the value of LaunchLongitude. -func (s *PredictionResponseRequest) GetLaunchLongitude() OptFloat64 { - return s.LaunchLongitude -} - -// GetLaunchDatetime returns the value of LaunchDatetime. -func (s *PredictionResponseRequest) GetLaunchDatetime() OptString { - return s.LaunchDatetime -} - -// GetLaunchAltitude returns the value of LaunchAltitude. -func (s *PredictionResponseRequest) GetLaunchAltitude() OptFloat64 { - return s.LaunchAltitude -} - -// GetProfile returns the value of Profile. -func (s *PredictionResponseRequest) GetProfile() OptString { - return s.Profile -} - -// GetAscentRate returns the value of AscentRate. -func (s *PredictionResponseRequest) GetAscentRate() OptFloat64 { - return s.AscentRate -} - -// GetBurstAltitude returns the value of BurstAltitude. -func (s *PredictionResponseRequest) GetBurstAltitude() OptFloat64 { - return s.BurstAltitude -} - -// GetDescentRate returns the value of DescentRate. -func (s *PredictionResponseRequest) GetDescentRate() OptFloat64 { - return s.DescentRate -} - -// SetDataset sets the value of Dataset. -func (s *PredictionResponseRequest) SetDataset(val OptString) { - s.Dataset = val -} - -// SetLaunchLatitude sets the value of LaunchLatitude. -func (s *PredictionResponseRequest) SetLaunchLatitude(val OptFloat64) { - s.LaunchLatitude = val -} - -// SetLaunchLongitude sets the value of LaunchLongitude. -func (s *PredictionResponseRequest) SetLaunchLongitude(val OptFloat64) { - s.LaunchLongitude = val -} - -// SetLaunchDatetime sets the value of LaunchDatetime. -func (s *PredictionResponseRequest) SetLaunchDatetime(val OptString) { - s.LaunchDatetime = val -} - -// SetLaunchAltitude sets the value of LaunchAltitude. -func (s *PredictionResponseRequest) SetLaunchAltitude(val OptFloat64) { - s.LaunchAltitude = val -} - -// SetProfile sets the value of Profile. -func (s *PredictionResponseRequest) SetProfile(val OptString) { - s.Profile = val -} - -// SetAscentRate sets the value of AscentRate. -func (s *PredictionResponseRequest) SetAscentRate(val OptFloat64) { - s.AscentRate = val -} - -// SetBurstAltitude sets the value of BurstAltitude. -func (s *PredictionResponseRequest) SetBurstAltitude(val OptFloat64) { - s.BurstAltitude = val -} - -// SetDescentRate sets the value of DescentRate. -func (s *PredictionResponseRequest) SetDescentRate(val OptFloat64) { - s.DescentRate = val -} - -type PredictionResponseWarnings map[string]jx.Raw - -func (s *PredictionResponseWarnings) init() PredictionResponseWarnings { - m := *s - if m == nil { - m = map[string]jx.Raw{} - *s = m - } - return m -} - // Ref: #/components/schemas/ReadinessResponse type ReadinessResponse struct { Status ReadinessResponseStatus `json:"status"` - DatasetTime OptDateTime `json:"dataset_time"` + LastUpdate OptDateTime `json:"last_update"` + IsFresh OptBool `json:"is_fresh"` ErrorMessage OptString `json:"error_message"` } @@ -715,9 +607,14 @@ func (s *ReadinessResponse) GetStatus() ReadinessResponseStatus { return s.Status } -// GetDatasetTime returns the value of DatasetTime. -func (s *ReadinessResponse) GetDatasetTime() OptDateTime { - return s.DatasetTime +// GetLastUpdate returns the value of LastUpdate. +func (s *ReadinessResponse) GetLastUpdate() OptDateTime { + return s.LastUpdate +} + +// GetIsFresh returns the value of IsFresh. +func (s *ReadinessResponse) GetIsFresh() OptBool { + return s.IsFresh } // GetErrorMessage returns the value of ErrorMessage. @@ -730,9 +627,14 @@ func (s *ReadinessResponse) SetStatus(val ReadinessResponseStatus) { s.Status = val } -// SetDatasetTime sets the value of DatasetTime. -func (s *ReadinessResponse) SetDatasetTime(val OptDateTime) { - s.DatasetTime = val +// SetLastUpdate sets the value of LastUpdate. +func (s *ReadinessResponse) SetLastUpdate(val OptDateTime) { + s.LastUpdate = val +} + +// SetIsFresh sets the value of IsFresh. +func (s *ReadinessResponse) SetIsFresh(val OptBool) { + s.IsFresh = val } // SetErrorMessage sets the value of ErrorMessage. diff --git a/pkg/rest/oas_server_gen.go b/pkg/rest/oas_server_gen.go index 7b6c592..9ef451e 100644 --- a/pkg/rest/oas_server_gen.go +++ b/pkg/rest/oas_server_gen.go @@ -1,6 +1,6 @@ // Code generated by ogen, DO NOT EDIT. -package rest +package gsn import ( "context" @@ -13,7 +13,7 @@ type Handler interface { // Perform prediction. // // GET /api/v1/prediction - PerformPrediction(ctx context.Context, params PerformPredictionParams) (*PredictionResponse, error) + PerformPrediction(ctx context.Context, params PerformPredictionParams) (*PredictionResult, error) // ReadinessCheck implements readinessCheck operation. // // Readiness check. diff --git a/pkg/rest/oas_unimplemented_gen.go b/pkg/rest/oas_unimplemented_gen.go index 8c3d8be..9662197 100644 --- a/pkg/rest/oas_unimplemented_gen.go +++ b/pkg/rest/oas_unimplemented_gen.go @@ -1,6 +1,6 @@ // Code generated by ogen, DO NOT EDIT. -package rest +package gsn import ( "context" @@ -18,7 +18,7 @@ var _ Handler = UnimplementedHandler{} // Perform prediction. // // GET /api/v1/prediction -func (UnimplementedHandler) PerformPrediction(ctx context.Context, params PerformPredictionParams) (r *PredictionResponse, _ error) { +func (UnimplementedHandler) PerformPrediction(ctx context.Context, params PerformPredictionParams) (r *PredictionResult, _ error) { return r, ht.ErrNotImplemented } diff --git a/pkg/rest/oas_validators_gen.go b/pkg/rest/oas_validators_gen.go index 33b3d41..0cd02b6 100644 --- a/pkg/rest/oas_validators_gen.go +++ b/pkg/rest/oas_validators_gen.go @@ -1,49 +1,45 @@ // Code generated by ogen, DO NOT EDIT. -package rest +package gsn import ( "fmt" "github.com/go-faster/errors" + "github.com/ogen-go/ogen/validate" ) +func (s PerformPredictionFormat) Validate() error { + switch s { + case "json": + return nil + default: + return errors.Errorf("invalid value: %v", s) + } +} + func (s PerformPredictionProfile) Validate() error { switch s { case "standard_profile": return nil case "float_profile": return nil + case "reverse_profile": + return nil + case "custom_profile": + return nil default: return errors.Errorf("invalid value: %v", s) } } -func (s *PredictionResponse) Validate() error { +func (s *PredictionResult) Validate() error { if s == nil { return validate.ErrNilPointer } var failures []validate.FieldError - if err := func() error { - if value, ok := s.Request.Get(); ok { - if err := func() error { - if err := value.Validate(); err != nil { - return err - } - return nil - }(); err != nil { - return err - } - } - return nil - }(); err != nil { - failures = append(failures, validate.FieldError{ - Name: "request", - Error: err, - }) - } if err := func() error { if s.Prediction == nil { return errors.New("nil is invalid value") @@ -78,7 +74,7 @@ func (s *PredictionResponse) Validate() error { return nil } -func (s *PredictionResponsePredictionItem) Validate() error { +func (s *PredictionResultPredictionItem) Validate() error { if s == nil { return validate.ErrNilPointer } @@ -129,20 +125,18 @@ func (s *PredictionResponsePredictionItem) Validate() error { return nil } -func (s PredictionResponsePredictionItemStage) Validate() error { +func (s PredictionResultPredictionItemStage) Validate() error { switch s { case "ascent": return nil case "descent": return nil - case "float": - return nil default: return errors.Errorf("invalid value: %v", s) } } -func (s *PredictionResponsePredictionItemTrajectoryItem) Validate() error { +func (s *PredictionResultPredictionItemTrajectoryItem) Validate() error { if s == nil { return validate.ErrNilPointer } @@ -187,126 +181,6 @@ func (s *PredictionResponsePredictionItemTrajectoryItem) Validate() error { return nil } -func (s *PredictionResponseRequest) Validate() error { - if s == nil { - return validate.ErrNilPointer - } - - var failures []validate.FieldError - if err := func() error { - if value, ok := s.LaunchLatitude.Get(); ok { - if err := func() error { - if err := (validate.Float{}).Validate(float64(value)); err != nil { - return errors.Wrap(err, "float") - } - return nil - }(); err != nil { - return err - } - } - return nil - }(); err != nil { - failures = append(failures, validate.FieldError{ - Name: "launch_latitude", - Error: err, - }) - } - if err := func() error { - if value, ok := s.LaunchLongitude.Get(); ok { - if err := func() error { - if err := (validate.Float{}).Validate(float64(value)); err != nil { - return errors.Wrap(err, "float") - } - return nil - }(); err != nil { - return err - } - } - return nil - }(); err != nil { - failures = append(failures, validate.FieldError{ - Name: "launch_longitude", - Error: err, - }) - } - if err := func() error { - if value, ok := s.LaunchAltitude.Get(); ok { - if err := func() error { - if err := (validate.Float{}).Validate(float64(value)); err != nil { - return errors.Wrap(err, "float") - } - return nil - }(); err != nil { - return err - } - } - return nil - }(); err != nil { - failures = append(failures, validate.FieldError{ - Name: "launch_altitude", - Error: err, - }) - } - if err := func() error { - if value, ok := s.AscentRate.Get(); ok { - if err := func() error { - if err := (validate.Float{}).Validate(float64(value)); err != nil { - return errors.Wrap(err, "float") - } - return nil - }(); err != nil { - return err - } - } - return nil - }(); err != nil { - failures = append(failures, validate.FieldError{ - Name: "ascent_rate", - Error: err, - }) - } - if err := func() error { - if value, ok := s.BurstAltitude.Get(); ok { - if err := func() error { - if err := (validate.Float{}).Validate(float64(value)); err != nil { - return errors.Wrap(err, "float") - } - return nil - }(); err != nil { - return err - } - } - return nil - }(); err != nil { - failures = append(failures, validate.FieldError{ - Name: "burst_altitude", - Error: err, - }) - } - if err := func() error { - if value, ok := s.DescentRate.Get(); ok { - if err := func() error { - if err := (validate.Float{}).Validate(float64(value)); err != nil { - return errors.Wrap(err, "float") - } - return nil - }(); err != nil { - return err - } - } - return nil - }(); err != nil { - failures = append(failures, validate.FieldError{ - Name: "descent_rate", - Error: err, - }) - } - if len(failures) > 0 { - return &validate.Error{Fields: failures} - } - return nil -} - func (s *ReadinessResponse) Validate() error { if s == nil { return validate.ErrNilPointer diff --git a/pkg/scheduler/config.go b/pkg/scheduler/config.go new file mode 100644 index 0000000..c7b2881 --- /dev/null +++ b/pkg/scheduler/config.go @@ -0,0 +1,20 @@ +package scheduler + +import ( + "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" + env "github.com/caarlos0/env/v11" +) + +type Config struct { + Enabled bool `env:"ENABLED" envDefault:"true"` +} + +func NewConfig() (*Config, error) { + cfg := &Config{} + if err := env.ParseWithOptions(cfg, env.Options{ + PrefixTagName: "GSN_PREDICTOR_SCHEDULER_", + }); err != nil { + return nil, errcodes.Wrap(err, "failed to parse scheduler config") + } + return cfg, nil +} diff --git a/pkg/scheduler/scheduler.go b/pkg/scheduler/scheduler.go new file mode 100644 index 0000000..d51980c --- /dev/null +++ b/pkg/scheduler/scheduler.go @@ -0,0 +1,97 @@ +package scheduler + +import ( + "context" + "time" + + "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" + "git.intra.yksa.space/gsn/predictor/internal/pkg/log" + "github.com/go-co-op/gocron" + "go.uber.org/zap" +) + +type Job interface { + GetInterval() time.Duration + GetTimeout() time.Duration + GetCount() int + GetAsync() bool + Execute(context.Context) error +} + +type Scheduler struct { + scheduler *gocron.Scheduler +} + +func New() *Scheduler { + scheduler := gocron.NewScheduler(time.UTC) + return &Scheduler{ + scheduler: scheduler, + } +} + +func (s *Scheduler) AddJob(job Job) error { + interval := job.GetInterval() + timeout := job.GetTimeout() + count := job.GetCount() + async := job.GetAsync() + + // Validate job parameters + if !async && count != 1 { + return errcodes.ErrSchedulerInvalidJob + } + if timeout > interval { + return errcodes.ErrSchedulerTimeoutTooLong + } + + // Create job function with timeout + jobFunc := func() { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + logger := log.Ctx(ctx) + if err := job.Execute(ctx); err != nil { + logger.Error("job execution failed", + zap.Error(err), + zap.Duration("interval", interval), + zap.Duration("timeout", timeout)) + } else { + logger.Debug("job executed successfully", + zap.Duration("interval", interval), + zap.Duration("timeout", timeout)) + } + } + + // Add job to scheduler + schedulerJob := s.scheduler.Every(interval) + + if !async { + schedulerJob = schedulerJob.SingletonMode() + } + + if count > 0 { + schedulerJob = schedulerJob.LimitRunsTo(count) + } + + schedulerJob.Do(jobFunc) + + log.Ctx(context.Background()).Info("job added to scheduler", + zap.Duration("interval", interval), + zap.Duration("timeout", timeout), + zap.Int("count", count), + zap.Bool("async", async)) + + return nil +} + +func (s *Scheduler) Start() { + s.scheduler.StartAsync() + log.Ctx(context.Background()).Info("scheduler started") +} + +func (s *Scheduler) Stop() { + s.scheduler.Stop() + log.Ctx(context.Background()).Info("scheduler stopped") +} + +func (s *Scheduler) IsRunning() bool { + return s.scheduler.IsRunning() +} diff --git a/scripts/build_elevation.py b/scripts/build_elevation.py deleted file mode 100644 index 8f4244a..0000000 --- a/scripts/build_elevation.py +++ /dev/null @@ -1,155 +0,0 @@ -#!/usr/bin/env python3 -""" -Download ETOPO 2022 30-arc-second elevation data and convert to ruaumoko-compatible -binary format (int16 little-endian, 21601 lat x 43200 lon, south-to-north). - -Output: ~1.74 GiB binary file. - -Usage: python3 build_elevation.py [output_path] - Default output: /srv/ruaumoko-dataset -""" - -import sys -import os -import struct -import tempfile -import numpy as np - -CELLS_PER_DEGREE = 120 -NUM_LATS = 180 * CELLS_PER_DEGREE + 1 # 21601 -NUM_LONS = 360 * CELLS_PER_DEGREE # 43200 -EXPECTED_SIZE = NUM_LATS * NUM_LONS * 2 # 1,866,326,400 - -ETOPO_URL = "https://www.ngdc.noaa.gov/thredds/fileServer/global/ETOPO2022/30s/30s_surface_elev_netcdf/ETOPO_2022_v1_30s_N90W180_surface.nc" - - -def download_etopo(output_path): - """Download ETOPO 2022 NetCDF and convert to ruaumoko binary format.""" - try: - import xarray as xr - except ImportError: - print("ERROR: xarray is required. Install with: pip install xarray netcdf4") - sys.exit(1) - - # Check if we can download directly or need a local file - nc_path = os.environ.get("ETOPO_NC_PATH") - if nc_path and os.path.exists(nc_path): - print(f"Using local ETOPO file: {nc_path}") - else: - print(f"Downloading ETOPO 2022 30-second data (~1.1 GB)...") - print(f" URL: {ETOPO_URL}") - print(f" (Set ETOPO_NC_PATH env var to use a pre-downloaded file)") - - import urllib.request - with tempfile.NamedTemporaryFile(suffix=".nc", delete=False) as f: - nc_path = f.name - - try: - urllib.request.urlretrieve(ETOPO_URL, nc_path, _progress) - print() - except Exception as e: - os.unlink(nc_path) - print(f"\nDownload failed: {e}") - print("\nAlternative: manually download ETOPO 2022 30s NetCDF from:") - print(" https://www.ncei.noaa.gov/products/etopo-global-relief-model") - print(f" Then set ETOPO_NC_PATH=/path/to/file.nc and re-run") - sys.exit(1) - - print(f"Opening NetCDF dataset...") - ds = xr.open_dataset(nc_path) - - # ETOPO 2022 30s has: - # - lat: -90 to +90, 21601 points (south to north) - # - lon: -180 to +180, 43201 points - # We need: - # - lat: -90 to +90, 21601 points (south to north) ← same - # - lon: 0 to 360 (exclusive), 43200 points ← need to shift and drop last - - z = ds["z"] # elevation variable - print(f" Shape: {z.shape}") - print(f" Lat range: {float(z.lat.min())} to {float(z.lat.max())}") - print(f" Lon range: {float(z.lon.min())} to {float(z.lon.max())}") - - # Sort latitude south-to-north (should already be, but ensure) - z = z.sortby("lat") - - # Shift longitude from [-180, 180] to [0, 360) - print("Shifting longitude to [0, 360)...") - z = z.assign_coords(lon=(z.lon % 360)) - z = z.sortby("lon") - - data = z.values - print(f" Raw shape after sort: {data.shape}") - - # Handle longitude dimension: drop last col if it wraps (43201 → 43200) - if data.shape[1] == NUM_LONS + 1: - data = data[:, :NUM_LONS] - elif data.shape[1] != NUM_LONS: - print(f"ERROR: unexpected lon dimension: {data.shape[1]}, expected {NUM_LONS} or {NUM_LONS+1}") - sys.exit(1) - - # Handle latitude dimension: ETOPO 2022 is cell-centered (21600 rows), - # ruaumoko expects grid-registered (21601 rows including both poles). - # Pad by duplicating edge rows for the poles. - if data.shape[0] == NUM_LATS - 1: - print(f" Padding latitude from {data.shape[0]} to {NUM_LATS} (adding north pole row)") - north_pole = data[-1:, :] # duplicate +89.99... as +90 - data = np.concatenate([data, north_pole], axis=0) - elif data.shape[0] != NUM_LATS: - print(f"ERROR: unexpected lat dimension: {data.shape[0]}, expected {NUM_LATS} or {NUM_LATS-1}") - sys.exit(1) - - print(f"Final grid shape: {data.shape}") - print(f"Elevation range: {data.min():.1f} to {data.max():.1f} metres") - - # Write as int16 little-endian - print(f"Writing to {output_path}...") - elev_int16 = np.clip(data, -32768, 32767).astype(np.dtype(" 0: - pct = int(block_num * block_size * 100 / total_size) - if pct != _last_pct and pct % 5 == 0: - _last_pct = pct - print(f" {pct}%...", end="", flush=True) - - -if __name__ == "__main__": - output = sys.argv[1] if len(sys.argv) > 1 else "/srv/ruaumoko-dataset" - download_etopo(output) diff --git a/scripts/test_predictor_vs_reference.py b/scripts/test_predictor_vs_reference.py new file mode 100644 index 0000000..e2de9ae --- /dev/null +++ b/scripts/test_predictor_vs_reference.py @@ -0,0 +1,303 @@ +#!/usr/bin/env python3 +import subprocess +import sys +import time +import requests +import json +from typing import Any +import base64 +import math + +# --- Config --- +REFERENCE_API_URL = "https://fly.stratonautica.ru/api/v2/?profile=standard_profile&pred_type=single&launch_datetime=2025-06-25T20%3A45%3A00Z&launch_latitude=56.6992&launch_longitude=38.8247&launch_altitude=0&ascent_rate=5&burst_altitude=30000&descent_rate=5" +LOCAL_API_URL = "http://localhost:8080/api/v1/prediction?profile=standard_profile&pred_type=single&launch_datetime=2025-06-25T20%3A45%3A00Z&launch_latitude=56.6992&launch_longitude=38.8247&launch_altitude=0&ascent_rate=5&burst_altitude=30000&descent_rate=5" + +LOCAL_API_PAYLOAD = { + "launch_latitude": 56.6992, + "launch_longitude": 38.8247, + "launch_datetime": "2025-06-25T20-45-000Z", + "launch_altitude": 0, + "profile": "standard_profile", + "ascent_rate": 5, + "burst_altitude": 30000, + "descent_rate": 5, + "format": "json" +} +READY_URL = "http://localhost:8080/ready" + +# --- Utility functions --- +def run_compose_up(): + print("[INFO] Running docker-compose down --remove-orphans ...") + result = subprocess.run(["docker-compose", "down", "--remove-orphans"], capture_output=True) + if result.returncode != 0: + print("[ERROR] docker-compose down failed:", result.stderr.decode()) + sys.exit(1) + print("[INFO] docker-compose down completed.") + print("[INFO] Running docker-compose up -d ...") + result = subprocess.run(["docker-compose", "up", "-d"], capture_output=True) + if result.returncode != 0: + print("[ERROR] docker-compose up failed:", result.stderr.decode()) + sys.exit(1) + print("[INFO] docker-compose up -d completed.") + return True + +def wait_for_ready(timeout=900): + print(f"[INFO] Waiting for {READY_URL} to be ready ...") + start = time.time() + while time.time() - start < timeout: + try: + resp = requests.get(READY_URL, timeout=10) + if resp.status_code == 200: + data = resp.json() + if data.get("status") == "ok": + print("[INFO] Service is ready.") + return + else: + print(f"[INFO] Not ready yet: {data}") + else: + print(f"[INFO] /ready returned status {resp.status_code}") + except Exception as e: + print(f"[INFO] Exception while polling /ready: {e}") + time.sleep(10) + print(f"[ERROR] Service did not become ready in {timeout} seconds.") + sys.exit(1) + +def fetch_reference(): + print(f"[INFO] Fetching reference prediction from {REFERENCE_API_URL}") + resp = requests.get(REFERENCE_API_URL, timeout=60) + if resp.status_code != 200: + print(f"[ERROR] Reference API returned {resp.status_code}: {resp.text}") + sys.exit(1) + return resp.json() + +def fetch_local(): + print(f"[INFO] Fetching local prediction from {LOCAL_API_URL}") + resp = requests.get(LOCAL_API_URL, timeout=60) + if resp.status_code != 200: + print(f"[ERROR] Local API returned {resp.status_code}: {resp.text}") + sys.exit(1) + return resp.json() + +def haversine(lat1, lon1, lat2, lon2): + """Calculate the great-circle distance between two points on the Earth (specified in decimal degrees). Returns distance in kilometers.""" + R = 6371.0 # Earth radius in kilometers + lat1, lon1, lat2, lon2 = map(math.radians, [lat1, lon1, lat2, lon2]) + dlat = lat2 - lat1 + dlon = lon2 - lon1 + a = math.sin(dlat/2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon/2)**2 + c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a)) + return R * c + +def compare_results(reference_data, local_data): + """Compare prediction results between reference and local APIs.""" + print("[INFO] Comparing results ...") + + # Extract trajectory data + ref_trajectory = reference_data.get('prediction', [{}])[0].get('trajectory', []) + local_trajectory = local_data.get('prediction', [{}])[0].get('trajectory', []) + + print(f"[DEBUG] Reference trajectory length: {len(ref_trajectory)}") + print(f"[DEBUG] Local trajectory length: {len(local_trajectory)}") + + # Show first 3 points from both APIs + print("\n[DEBUG] First 3 points - Reference API:") + for i, point in enumerate(ref_trajectory[:3]): + print(f" [{i}] alt={point.get('altitude', 'N/A')}, lat={point.get('latitude', 'N/A')}, lon={point.get('longitude', 'N/A')}, time={point.get('datetime', 'N/A')}") + + print("\n[DEBUG] First 3 points - Local API:") + for i, point in enumerate(local_trajectory[:3]): + print(f" [{i}] alt={point.get('altitude', 'N/A')}, lat={point.get('latitude', 'N/A')}, lon={point.get('longitude', 'N/A')}, time={point.get('datetime', 'N/A')}") + + # Show last 3 points from both APIs + print("\n[DEBUG] Last 3 points - Reference API:") + for i, point in enumerate(ref_trajectory[-3:]): + idx = len(ref_trajectory) - 3 + i + print(f" [{idx}] alt={point.get('altitude', 'N/A')}, lat={point.get('latitude', 'N/A')}, lon={point.get('longitude', 'N/A')}, time={point.get('datetime', 'N/A')}") + + print("\n[DEBUG] Last 3 points - Local API:") + for i, point in enumerate(local_trajectory[-3:]): + idx = len(local_trajectory) - 3 + i + print(f" [{idx}] alt={point.get('altitude', 'N/A')}, lat={point.get('latitude', 'N/A')}, lon={point.get('longitude', 'N/A')}, time={point.get('datetime', 'N/A')}") + + # Compare trajectory lengths + if len(ref_trajectory) != len(local_trajectory): + print(f"[DIFF] Trajectory length mismatch: {len(local_trajectory)} vs {len(ref_trajectory)}") + return False + + # Compare trajectory points and calculate drift + min_len = min(len(ref_trajectory), len(local_trajectory)) + max_drift = 0.0 + max_drift_idx = -1 + drift_list = [] + print("\n[DRIFT] Trajectory point-by-point distance (km):") + for i in range(min_len): + ref_point = ref_trajectory[i] + local_point = local_trajectory[i] + ref_lat = ref_point.get('latitude') + ref_lon = ref_point.get('longitude') + local_lat = local_point.get('latitude') + local_lon = local_point.get('longitude') + drift_km = None + if None not in (ref_lat, ref_lon, local_lat, local_lon): + drift_km = haversine(ref_lat, ref_lon, local_lat, local_lon) + drift_list.append(drift_km) + if drift_km > max_drift: + max_drift = drift_km + max_drift_idx = i + print(f" [{i}] Drift: {drift_km:.3f} km") + else: + print(f" [{i}] Drift: N/A (missing lat/lon)") + if drift_list: + mean_drift = sum(drift_list) / len(drift_list) + print(f"\n[DRIFT] Max drift: {max_drift:.3f} km at idx {max_drift_idx}") + print(f"[DRIFT] Mean drift: {mean_drift:.3f} km over {len(drift_list)} points") + else: + print("[DRIFT] No valid drift data to report.") + # Continue with original comparison for altitude, etc. + for i in range(min_len): + ref_point = ref_trajectory[i] + local_point = local_trajectory[i] + for key in ['altitude', 'latitude', 'longitude']: + ref_val = ref_point.get(key) + local_val = local_point.get(key) + if ref_val is not None and local_val is not None: + if abs(ref_val - local_val) > 0.1: + print(f"[DIFF] At idx {i}, key {key}: {local_val} != {ref_val}") + return False + print("[SUCCESS] Results match!") + return True + +def test_custom_profile(): + """Test custom profile with base64 encoded curve.""" + print("\n[TEST] Testing custom_profile...") + # Create a simple custom ascent curve (altitude vs time in seconds) + curve_data = { + "altitude": [0, 30000], + "time": [0, 6000] + } + curve_b64 = base64.b64encode(json.dumps(curve_data).encode()).decode() + # Test parameters for custom profile + params = { + "launch_latitude": 56.6992, + "launch_longitude": 38.8247, + "launch_datetime": "2025-06-25T13:28:00Z", + "launch_altitude": 0, + "profile": "custom_profile", + "ascent_curve": curve_b64 + } + try: + # Test local API (use GET) + local_resp = requests.get( + "http://localhost:8080/api/v1/prediction", + params=params, + timeout=30 + ) + local_resp.raise_for_status() + local_data = local_resp.json() + print(f"[INFO] Custom profile test - Local API returned {len(local_data.get('prediction', [{}])[0].get('trajectory', []))} trajectory points") + return True + except Exception as e: + print(f"[ERROR] Custom profile test failed: {e}") + return False + +def test_all_profiles(): + """Test all available profiles.""" + profiles = [ + ("standard_profile", "Standard profile test"), + ("float_profile", "Float profile test"), + ("reverse_profile", "Reverse profile test"), + ("custom_profile", "Custom profile test") + ] + + results = {} + + for profile, description in profiles: + print(f"\n[TEST] {description}...") + + if profile == "custom_profile": + success = test_custom_profile() + else: + success = test_single_profile(profile) + + results[profile] = success + print(f"[RESULT] {profile}: {'PASS' if success else 'FAIL'}") + + # Print summary + print("\n" + "="*50) + print("TEST SUMMARY") + print("="*50) + for profile, success in results.items(): + status = "PASS" if success else "FAIL" + print(f"{profile:20} : {status}") + + total_tests = len(results) + passed_tests = sum(results.values()) + print(f"\nTotal tests: {total_tests}, Passed: {passed_tests}, Failed: {total_tests - passed_tests}") + + return all(results.values()) + +def test_single_profile(profile): + """Test a single profile against reference API.""" + # Test parameters + params = { + "launch_latitude": 56.6992, + "launch_longitude": 38.8247, + "launch_datetime": "2025-06-25T13:28:00Z", + "launch_altitude": 0, + "profile": profile, + "ascent_rate": 5, + "burst_altitude": 30000, + "descent_rate": 5 + } + # Add float altitude for float profile + if profile == "float_profile": + params["float_altitude"] = 25000 + try: + # Test local API (use GET) + local_resp = requests.get( + "http://localhost:8080/api/v1/prediction", + params=params, + timeout=30 + ) + local_resp.raise_for_status() + local_data = local_resp.json() + print(f"[INFO] {profile} - Local API returned {len(local_data.get('prediction', [{}])[0].get('trajectory', []))} trajectory points") + return True + except Exception as e: + print(f"[ERROR] {profile} test failed: {e}") + return False + +def main(): + """Main test function.""" + print("[INFO] Starting comprehensive predictor API tests...") + + # Run the original standard profile test + print("\n[TEST] Running original standard_profile test...") + run_compose_up() + wait_for_ready() + ref = fetch_reference() + local = fetch_local() + + print("[INFO] Comparing results ...") + original_success = compare_results(ref, local) + + if original_success: + print("[SUCCESS] Original standard_profile test passed!") + else: + print("[FAIL] Original standard_profile test failed!") + + # Test all profiles + print("\n[TEST] Running all profile tests...") + all_profiles_success = test_all_profiles() + + # Final result + overall_success = original_success and all_profiles_success + print(f"\n[FINAL RESULT] Overall: {'PASS' if overall_success else 'FAIL'}") + + if overall_success: + sys.exit(0) + else: + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/test_s3_download.go b/scripts/test_s3_download.go new file mode 100644 index 0000000..0391cd6 --- /dev/null +++ b/scripts/test_s3_download.go @@ -0,0 +1,89 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "time" + + "git.intra.yksa.space/gsn/predictor/internal/pkg/grib" +) + +func main() { + ctx := context.Background() + + // Create S3 downloader + downloader, err := grib.NewS3Downloader( + "/tmp/grib_test", + 4, // parallel downloads + "noaa-gfs-bdp-pds", + "us-east-1", + ) + if err != nil { + log.Fatalf("Failed to create S3 downloader: %v", err) + } + + // Ensure directory exists + if err := os.MkdirAll("/tmp/grib_test", 0o755); err != nil { + log.Fatalf("Failed to create directory: %v", err) + } + + // Find nearest run (6-hour intervals: 00, 06, 12, 18 UTC) + now := time.Now().UTC() + hour := now.Hour() - (now.Hour() % 6) + // Use data from 6 hours ago to ensure it's available + run := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC).Add(-6 * time.Hour) + + fmt.Printf("Testing S3 download for run: %s\n", run.Format("2006-01-02 15:04 MST")) + + // List available files first + runStr := run.Format("20060102") + fmt.Printf("Listing available files for %s/%02d...\n", runStr, run.Hour()) + files, err := downloader.ListAvailableFiles(ctx, runStr, run.Hour()) + if err != nil { + log.Fatalf("Failed to list files: %v", err) + } + + fmt.Printf("Found %d files in S3:\n", len(files)) + if len(files) > 0 { + // Show first 5 files + for i, file := range files { + if i >= 5 { + fmt.Printf("... and %d more files\n", len(files)-5) + break + } + fmt.Printf(" - %s\n", file) + } + } + + // Try downloading just first 3 forecast hours (f000, f001, f002) + fmt.Println("\nTesting download of first 3 forecast hours...") + testRun := run + + // Create a timeout context for the download + downloadCtx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + + if err := downloader.Run(downloadCtx, testRun); err != nil { + log.Fatalf("Failed to download: %v", err) + } + + fmt.Println("\nDownload completed successfully!") + + // Check downloaded files + entries, err := os.ReadDir("/tmp/grib_test") + if err != nil { + log.Fatalf("Failed to read directory: %v", err) + } + + fmt.Printf("\nDownloaded %d files:\n", len(entries)) + for i, entry := range entries { + if i >= 10 { + fmt.Printf("... and %d more files\n", len(entries)-10) + break + } + info, _ := entry.Info() + fmt.Printf(" - %s (%.2f MB)\n", entry.Name(), float64(info.Size())/1024/1024) + } +} diff --git a/scripts/test_s3_simple.go b/scripts/test_s3_simple.go new file mode 100644 index 0000000..2ae8414 --- /dev/null +++ b/scripts/test_s3_simple.go @@ -0,0 +1,68 @@ +package main + +import ( + "context" + "fmt" + "io" + "log" + "os" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/s3" +) + +func main() { + ctx := context.Background() + + // Create AWS config with anonymous credentials + cfg, err := config.LoadDefaultConfig(ctx, + config.WithRegion("us-east-1"), + config.WithCredentialsProvider(aws.AnonymousCredentials{}), + ) + if err != nil { + log.Fatalf("Failed to load config: %v", err) + } + + client := s3.NewFromConfig(cfg) + + // Try to download a single file + bucket := "noaa-gfs-bdp-pds" + key := "gfs.20251020/00/atmos/gfs.t00z.pgrb2.0p50.f000" + + fmt.Printf("Downloading: s3://%s/%s\n", bucket, key) + + input := &s3.GetObjectInput{ + Bucket: aws.String(bucket), + Key: aws.String(key), + } + + result, err := client.GetObject(ctx, input) + if err != nil { + log.Fatalf("Failed to get object: %v", err) + } + defer result.Body.Close() + + // Create output file + outFile := "/tmp/test_grib.part" + f, err := os.Create(outFile) + if err != nil { + log.Fatalf("Failed to create file: %v", err) + } + defer f.Close() + + // Copy data + written, err := io.Copy(f, result.Body) + if err != nil { + log.Fatalf("Failed to copy data: %v (wrote %d bytes)", err, written) + } + + fmt.Printf("Successfully downloaded %d bytes\n", written) + + // Rename + if err := os.Rename(outFile, "/tmp/test_grib"); err != nil { + log.Fatalf("Failed to rename: %v", err) + } + + fmt.Println("Download complete!") +}