diff --git a/.gitignore b/.gitignore index 0b60242..b58d008 100644 --- a/.gitignore +++ b/.gitignore @@ -47,4 +47,8 @@ Thumbs.db # Build artifacts /build/ -/dist/ \ No newline at end of file +/dist/ + +# GRIB files +/grib_data/ +/grib_data/* \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index 8cebb4e..ede1ec1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -40,7 +40,8 @@ COPY --from=builder /app/predictor . # Create necessary directories RUN mkdir -p /tmp/grib && \ - chown -R appuser:appgroup /app /tmp/grib + chown -R appuser:appgroup /app && \ + chmod -R 777 /tmp/grib # Switch to non-root user USER appuser @@ -50,7 +51,7 @@ EXPOSE 8080 # Health check HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \ - CMD wget --no-verbose --tries=1 --spider http://localhost:8080/health || exit 1 + CMD wget --no-verbose --tries=1 --spider http://localhost:8080/ready || exit 1 # Run the application CMD ["./predictor"] \ No newline at end of file diff --git a/api/rest/predictor.swagger.yml b/api/rest/predictor.swagger.yml index 68ef128..282bec4 100644 --- a/api/rest/predictor.swagger.yml +++ b/api/rest/predictor.swagger.yml @@ -36,6 +36,25 @@ paths: application/json: schema: $ref: "#/components/schemas/Error" + /ready: + get: + tags: + - Health + summary: Readiness check + operationId: readinessCheck + responses: + "200": + description: Readiness status + content: + application/json: + schema: + $ref: '#/components/schemas/ReadinessResponse' + default: + description: Error + content: + application/json: + schema: + $ref: "#/components/schemas/Error" components: schemas: @@ -125,4 +144,31 @@ components: type: object required: - datetime - - latitude \ No newline at end of file + - latitude + - longitude + - altitude + properties: + datetime: + type: string + format: date-time + latitude: + type: number + longitude: + type: number + altitude: + type: number + ReadinessResponse: + type: object + properties: + status: + type: string + enum: [ok, not_ready, error] + last_update: + type: string + format: date-time + is_fresh: + type: boolean + error_message: + type: string + required: + - status \ No newline at end of file diff --git a/cmd/api/main.go b/cmd/api/main.go index 9a53a94..46ffe95 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -5,9 +5,12 @@ import ( "os/signal" "syscall" + "context" + "git.intra.yksa.space/gsn/predictor/internal/jobs/grib/updater" "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" "git.intra.yksa.space/gsn/predictor/internal/pkg/grib" + "git.intra.yksa.space/gsn/predictor/internal/pkg/log" "git.intra.yksa.space/gsn/predictor/internal/service" "git.intra.yksa.space/gsn/predictor/internal/transport/rest" "git.intra.yksa.space/gsn/predictor/internal/transport/rest/handler" @@ -25,128 +28,120 @@ func main() { panic(err) } defer lg.Sync() + ctx := log.ToCtx(context.Background(), lg) - // Load configuration from environment with service prefix cfg, err := loadConfig() if err != nil { - lg.Fatal("failed to load configuration", zap.Error(err)) + log.Ctx(ctx).Fatal("failed to load configuration", zap.Error(err)) } - // Load scheduler configuration - schedulerConfig, err := loadSchedulerConfig() + schedulerConfig, err := scheduler.NewConfig() if err != nil { - lg.Fatal("failed to load scheduler configuration", zap.Error(err)) + log.Ctx(ctx).Fatal("failed to load scheduler configuration", zap.Error(err)) } - // Load GRIB updater job configuration - gribUpdaterConfig, err := loadGribUpdaterConfig() + gribUpdaterConfig, err := updater.NewConfig() if err != nil { - lg.Fatal("failed to load GRIB updater configuration", zap.Error(err)) + log.Ctx(ctx).Fatal("failed to load GRIB updater configuration", zap.Error(err)) } - // Initialize Redis service - redisService, err := redis.New(cfg.Redis) + log.Ctx(ctx).Info("Connecting to Redis", zap.String("host", cfg.RedisHost), zap.Int("port", cfg.RedisPort)) + redisService, err := redis.New(redis.Config{ + Host: cfg.RedisHost, + Port: cfg.RedisPort, + Password: cfg.RedisPassword, + DB: cfg.RedisDB, + }) if err != nil { - lg.Fatal("failed to initialize Redis service", zap.Error(err)) + log.Ctx(ctx).Fatal("failed to initialize Redis service", zap.Error(err), zap.String("host", cfg.RedisHost), zap.Int("port", cfg.RedisPort)) } defer redisService.Close() - // Initialize GRIB service gribService, err := grib.New(grib.ServiceConfig{ - Dir: cfg.Grib.Dir, - TTL: cfg.Grib.TTL, - CacheTTL: cfg.Grib.CacheTTL, - Redis: redisService, - Parallel: cfg.Grib.Parallel, - Client: cfg.CreateHTTPClient(), + Dir: cfg.GribDir, + TTL: cfg.GribTTL, + CacheTTL: cfg.GribCacheTTL, + Redis: redisService, + Parallel: cfg.GribParallel, + Client: cfg.CreateHTTPClient(), + DatasetURL: cfg.GribDatasetURL, }) if err != nil { - lg.Fatal("failed to initialize GRIB service", zap.Error(err)) + log.Ctx(ctx).Fatal("failed to initialize GRIB service", zap.Error(err)) } defer gribService.Close() - // Initialize service with dependencies - svc, err := service.New(cfg, gribService, redisService, lg) + // Force GRIB update on startup in a goroutine + go func() { + log.Ctx(ctx).Info("Performing initial GRIB update (async)...") + if err := gribService.Update(ctx); err != nil { + log.Ctx(ctx).Error("initial GRIB update failed", zap.Error(err)) + } else { + log.Ctx(ctx).Info("initial GRIB update complete") + } + }() + + svc, err := service.New(cfg, gribService, redisService) if err != nil { - lg.Fatal("failed to initialize service", zap.Error(err)) + log.Ctx(ctx).Fatal("failed to initialize service", zap.Error(err)) } defer svc.Close() - // Initialize scheduler var sched *scheduler.Scheduler if schedulerConfig.Enabled { - sched = scheduler.New(lg) + sched = scheduler.New() - // Add GRIB update job - gribJob := updater.New(gribService, gribUpdaterConfig, lg) + gribJob := updater.New(gribService, gribUpdaterConfig) if err := sched.AddJob(gribJob); err != nil { - lg.Error("failed to add GRIB update job to scheduler", zap.Error(err)) + log.Ctx(ctx).Error("failed to add GRIB update job to scheduler", zap.Error(err)) } - // TODO: Add more jobs here as needed - // Example: - // cleanupConfig := cleanup.NewConfig() - // cleanupJob := cleanup.New(svc, cleanupConfig, lg) - // if err := sched.AddJob(cleanupJob); err != nil { - // lg.Error("failed to add cleanup job to scheduler", zap.Error(err)) - // } - - lg.Info("scheduler initialized with jobs") + log.Ctx(ctx).Info("scheduler initialized with jobs") } - // Initialize handler handler := handler.New(svc) - // Initialize transport - restConfig, err := loadRestConfig() + restConfig, err := rest.NewConfig() if err != nil { lg.Fatal("failed to init transport config", zap.Error(err)) } - transport, err := rest.New(lg, handler, restConfig) + transport, err := rest.New(handler, restConfig) if err != nil { lg.Fatal("failed to init transport", zap.Error(err)) } - // Start service svc.Start() - - // Start scheduler if enabled if sched != nil { sched.Start() lg.Info("scheduler started") } lg.Info("service started successfully", - zap.String("grib_dir", cfg.Grib.Dir), - zap.Duration("grib_ttl", cfg.Grib.TTL), - zap.Duration("grib_cache_ttl", cfg.Grib.CacheTTL), - zap.Int("grib_parallel", cfg.Grib.Parallel), + zap.String("grib_dir", cfg.GribDir), + zap.Duration("grib_ttl", cfg.GribTTL), + zap.Duration("grib_cache_ttl", cfg.GribCacheTTL), + zap.Int("grib_parallel", cfg.GribParallel), zap.Bool("scheduler_enabled", schedulerConfig.Enabled), zap.Duration("grib_update_interval", gribUpdaterConfig.Interval)) - // Wait for shutdown signal sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) - // Start server in goroutine go func() { - lg.Info("starting HTTP server") + lg.Info("starting HTTP server on port", zap.Int("port", restConfig.Port)) transport.Run() }() - // Wait for shutdown signal <-sigChan lg.Info("received shutdown signal, stopping service") - // Stop scheduler first if sched != nil { sched.Stop() lg.Info("scheduler stopped") } } -// loadConfig loads configuration from environment with service prefix func loadConfig() (*service.Config, error) { cfg := &service.Config{} @@ -158,42 +153,3 @@ func loadConfig() (*service.Config, error) { return cfg, nil } - -// loadSchedulerConfig loads scheduler configuration from environment -func loadSchedulerConfig() (*scheduler.Config, error) { - cfg := &scheduler.Config{} - - if err := env.ParseWithOptions(cfg, env.Options{ - PrefixTagName: servicePrefix + "_SCHEDULER_", - }); err != nil { - return nil, errcodes.Wrap(err, "failed to parse scheduler configuration") - } - - return cfg, nil -} - -// loadGribUpdaterConfig loads GRIB updater job configuration from environment -func loadGribUpdaterConfig() (*updater.Config, error) { - cfg := &updater.Config{} - - if err := env.ParseWithOptions(cfg, env.Options{ - PrefixTagName: servicePrefix + "_GRIB_UPDATER_", - }); err != nil { - return nil, errcodes.Wrap(err, "failed to parse GRIB updater configuration") - } - - return cfg, nil -} - -// loadRestConfig loads REST transport configuration from environment with service prefix -func loadRestConfig() (*rest.Config, error) { - cfg := &rest.Config{} - - if err := env.ParseWithOptions(cfg, env.Options{ - PrefixTagName: servicePrefix + "_REST_", - }); err != nil { - return nil, errcodes.Wrap(err, "failed to parse REST configuration") - } - - return cfg, nil -} diff --git a/docker-compose.dev.yml b/docker-compose.dev.yml deleted file mode 100644 index 4503482..0000000 --- a/docker-compose.dev.yml +++ /dev/null @@ -1,82 +0,0 @@ -version: '3.8' - -services: - predictor: - build: - context: . - dockerfile: Dockerfile - container_name: predictor-dev - ports: - - "8080:8080" - environment: - # GRIB Configuration - - GSN_PREDICTOR_GRIB_DIR=/tmp/grib - - GSN_PREDICTOR_GRIB_TTL=24h - - GSN_PREDICTOR_GRIB_CACHE_TTL=1h - - GSN_PREDICTOR_GRIB_PARALLEL=4 - - GSN_PREDICTOR_GRIB_TIMEOUT=30s - - GSN_PREDICTOR_GRIB_DATASET_URL=https://nomads.ncep.noaa.gov/ - - # Redis Configuration - - GSN_PREDICTOR_REDIS_HOST=redis - - GSN_PREDICTOR_REDIS_PORT=6379 - - GSN_PREDICTOR_REDIS_PASSWORD= - - GSN_PREDICTOR_REDIS_DB=0 - - # Scheduler Configuration - - GSN_PREDICTOR_SCHEDULER_ENABLED=true - - # GRIB Updater Job Configuration - - GSN_PREDICTOR_GRIB_UPDATER_INTERVAL=6h - - GSN_PREDICTOR_GRIB_UPDATER_TIMEOUT=45m - - # REST Transport Configuration - - GSN_PREDICTOR_REST_HOST=0.0.0.0 - - GSN_PREDICTOR_REST_PORT=8080 - - GSN_PREDICTOR_REST_READ_TIMEOUT=30s - - GSN_PREDICTOR_REST_WRITE_TIMEOUT=30s - - GSN_PREDICTOR_REST_IDLE_TIMEOUT=60s - volumes: - - grib_data:/tmp/grib - - .:/app - - /app/predictor - depends_on: - redis: - condition: service_healthy - networks: - - predictor-network - restart: unless-stopped - healthcheck: - test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:8080/health"] - interval: 30s - timeout: 10s - retries: 3 - start_period: 40s - - redis: - image: redis:7.2-alpine - container_name: predictor-redis-dev - ports: - - "6379:6379" - volumes: - - redis_data:/data - networks: - - predictor-network - restart: unless-stopped - healthcheck: - test: ["CMD", "redis-cli", "ping"] - interval: 10s - timeout: 3s - retries: 5 - start_period: 10s - command: redis-server --appendonly yes --maxmemory 256mb --maxmemory-policy allkeys-lru - -volumes: - grib_data: - driver: local - redis_data: - driver: local - -networks: - predictor-network: - driver: bridge \ No newline at end of file diff --git a/docker-compose.yml b/docker-compose.yml index 67c42b7..b20b275 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -9,35 +9,35 @@ services: ports: - "8080:8080" environment: - # GRIB Configuration + # --- GRIB Configuration --- - GSN_PREDICTOR_GRIB_DIR=/tmp/grib - GSN_PREDICTOR_GRIB_TTL=24h - GSN_PREDICTOR_GRIB_CACHE_TTL=1h - GSN_PREDICTOR_GRIB_PARALLEL=4 - GSN_PREDICTOR_GRIB_TIMEOUT=30s - - GSN_PREDICTOR_GRIB_DATASET_URL=https://nomads.ncep.noaa.gov/ - - # Redis Configuration + - GSN_PREDICTOR_GRIB_DATASET_URL=https://nomads.ncep.noaa.gov/pub/data/nccf/com/gfs/prod + + # --- Redis Configuration --- - GSN_PREDICTOR_REDIS_HOST=redis - GSN_PREDICTOR_REDIS_PORT=6379 - GSN_PREDICTOR_REDIS_PASSWORD= - GSN_PREDICTOR_REDIS_DB=0 - - # Scheduler Configuration + + # --- Scheduler Configuration --- - GSN_PREDICTOR_SCHEDULER_ENABLED=true - - # GRIB Updater Job Configuration + + # --- GRIB Updater Job Configuration --- - GSN_PREDICTOR_GRIB_UPDATER_INTERVAL=6h - GSN_PREDICTOR_GRIB_UPDATER_TIMEOUT=45m - - # REST Transport Configuration + + # --- REST Transport Configuration --- - GSN_PREDICTOR_REST_HOST=0.0.0.0 - GSN_PREDICTOR_REST_PORT=8080 - GSN_PREDICTOR_REST_READ_TIMEOUT=30s - GSN_PREDICTOR_REST_WRITE_TIMEOUT=30s - GSN_PREDICTOR_REST_IDLE_TIMEOUT=60s volumes: - - grib_data:/tmp/grib + - ./grib_data:/tmp/grib depends_on: redis: condition: service_healthy @@ -45,7 +45,7 @@ services: - predictor-network restart: unless-stopped healthcheck: - test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:8080/health"] + test: ["CMD", "wget", "--no-verbose", "--tries=1", "--spider", "http://localhost:8080/ready"] interval: 30s timeout: 10s retries: 3 @@ -70,8 +70,6 @@ services: command: redis-server --appendonly yes --maxmemory 256mb --maxmemory-policy allkeys-lru volumes: - grib_data: - driver: local redis_data: driver: local diff --git a/internal/jobs/grib/updater/config.go b/internal/jobs/grib/updater/config.go index 0127ce8..fa5132b 100644 --- a/internal/jobs/grib/updater/config.go +++ b/internal/jobs/grib/updater/config.go @@ -1,8 +1,23 @@ package updater -import "time" +import ( + "time" + + "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" + env "github.com/caarlos0/env/v11" +) type Config struct { Interval time.Duration `env:"INTERVAL" envDefault:"6h"` Timeout time.Duration `env:"TIMEOUT" envDefault:"45m"` } + +func NewConfig() (*Config, error) { + cfg := &Config{} + if err := env.ParseWithOptions(cfg, env.Options{ + PrefixTagName: "GSN_PREDICTOR_GRIB_UPDATER_", + }); err != nil { + return nil, errcodes.Wrap(err, "failed to parse GRIB updater config") + } + return cfg, nil +} diff --git a/internal/jobs/grib/updater/updater.go b/internal/jobs/grib/updater/updater.go index 0295b0e..ce28d02 100644 --- a/internal/jobs/grib/updater/updater.go +++ b/internal/jobs/grib/updater/updater.go @@ -5,20 +5,19 @@ import ( "time" "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" + "git.intra.yksa.space/gsn/predictor/internal/pkg/log" "go.uber.org/zap" ) type Job struct { service GribService config *Config - logger *zap.Logger } -func New(service GribService, config *Config, logger *zap.Logger) *Job { +func New(service GribService, config *Config) *Job { return &Job{ service: service, config: config, - logger: logger, } } @@ -31,21 +30,22 @@ func (j *Job) GetTimeout() time.Duration { } func (j *Job) GetCount() int { - return 0 // Run indefinitely + return 1 } func (j *Job) GetAsync() bool { - return false // Singleton mode - only one instance should run + return false } func (j *Job) Execute(ctx context.Context) error { - j.logger.Info("executing GRIB update job") + log := log.Ctx(ctx) + log.Info("executing GRIB update job") if err := j.service.Update(ctx); err != nil { - j.logger.Error("GRIB update failed", zap.Error(err)) + log.Error("GRIB update failed", zap.Error(err)) return errcodes.Wrap(err, "failed to update GRIB data") } - j.logger.Info("GRIB update completed successfully") + log.Info("GRIB update completed successfully") return nil } diff --git a/internal/pkg/ds/predictor.go b/internal/pkg/ds/predictor.go index 1da6675..1ea5e44 100644 --- a/internal/pkg/ds/predictor.go +++ b/internal/pkg/ds/predictor.go @@ -1,21 +1,95 @@ package ds -import "time" +import ( + "time" + + api "git.intra.yksa.space/gsn/predictor/pkg/rest" +) type PredictionParameters struct { - LaunchLatitude float64 - LaunchLongitude float64 - LaunchDatetime time.Time - LaunchAltitude float64 + LaunchLatitude *float64 + LaunchLongitude *float64 + LaunchDatetime *time.Time + LaunchAltitude *float64 + Profile *string + AscentRate *float64 + BurstAltitude *float64 + DescentRate *float64 + FloatAltitude *float64 + StopDatetime *time.Time + AscentCurve *string // base64 + DescentCurve *string // base64 + Interpolate *bool + Format *string + Dataset *time.Time // Add other parameters as needed } type PredicitonResult struct { - Latitude float64 - Longitude float64 - Altitude float64 - Timestamp time.Time - WindU float64 - WindV float64 + Latitude *float64 + Longitude *float64 + Altitude *float64 + Timestamp *time.Time + WindU *float64 + WindV *float64 // Add other result fields as needed } + +// ConvertOptPredictionParameters converts ogen's OptPredictionParameters to the internal pointer-based model. +// Returns nil if the input is not set. +func ConvertOptPredictionParameters(opt api.OptPredictionParameters) *PredictionParameters { + if !opt.Set { + return nil + } + in := opt.Value + out := &PredictionParameters{} + + if v, ok := in.LaunchLatitude.Get(); ok { + out.LaunchLatitude = &v + } + if v, ok := in.LaunchLongitude.Get(); ok { + out.LaunchLongitude = &v + } + if v, ok := in.LaunchDatetime.Get(); ok { + out.LaunchDatetime = &v + } + if v, ok := in.LaunchAltitude.Get(); ok { + out.LaunchAltitude = &v + } + if v, ok := in.Profile.Get(); ok { + s := string(v) + out.Profile = &s + } + if v, ok := in.AscentRate.Get(); ok { + out.AscentRate = &v + } + if v, ok := in.BurstAltitude.Get(); ok { + out.BurstAltitude = &v + } + if v, ok := in.DescentRate.Get(); ok { + out.DescentRate = &v + } + if v, ok := in.FloatAltitude.Get(); ok { + out.FloatAltitude = &v + } + if v, ok := in.StopDatetime.Get(); ok { + out.StopDatetime = &v + } + if v, ok := in.AscentCurve.Get(); ok { + out.AscentCurve = &v + } + if v, ok := in.DescentCurve.Get(); ok { + out.DescentCurve = &v + } + if v, ok := in.Interpolate.Get(); ok { + out.Interpolate = &v + } + if v, ok := in.Format.Get(); ok { + s := string(v) + out.Format = &s + } + if v, ok := in.Dataset.Get(); ok { + out.Dataset = &v + } + return out +} diff --git a/internal/pkg/errcodes/errcodes.go b/internal/pkg/errcodes/errcodes.go index e6748b3..aac2f12 100644 --- a/internal/pkg/errcodes/errcodes.go +++ b/internal/pkg/errcodes/errcodes.go @@ -23,13 +23,11 @@ func (e *ErrorCode) Error() string { return e.Message } -// IsErr checks if the given error is an ErrorCode func IsErr(err error) bool { _, ok := err.(*ErrorCode) return ok } -// AsErr converts error to ErrorCode if possible func AsErr(err error) (*ErrorCode, bool) { if err == nil { return nil, false @@ -38,7 +36,6 @@ func AsErr(err error) (*ErrorCode, bool) { return errcode, ok } -// Join combines multiple errors into a single ErrorCode func Join(errs ...error) error { if len(errs) == 0 { return nil @@ -66,7 +63,6 @@ func Join(errs ...error) error { return nil } - // Use the first error's status code, or default to 500 statusCode := http.StatusInternalServerError if len(errs) > 0 { if errcode, ok := AsErr(errs[0]); ok { @@ -77,7 +73,6 @@ func Join(errs ...error) error { return New(statusCode, strings.Join(messages, "; "), details...) } -// Wrap wraps an error with additional context func Wrap(err error, message string) error { if err == nil { return nil diff --git a/internal/pkg/errcodes/errcodes_test.go b/internal/pkg/errcodes/errcodes_test.go deleted file mode 100644 index e5ff84c..0000000 --- a/internal/pkg/errcodes/errcodes_test.go +++ /dev/null @@ -1,81 +0,0 @@ -package errcodes - -import ( - "testing" -) - -func TestSpecificErrorTypes(t *testing.T) { - // Test Redis lock error - err := ErrRedisLockAlreadyLocked - if !IsErr(err) { - t.Error("Expected IsErr to return true for ErrorCode") - } - - errcode, ok := AsErr(err) - if !ok { - t.Error("Expected AsErr to return true for ErrorCode") - } - if errcode != ErrRedisLockAlreadyLocked { - t.Error("Expected AsErr to return the same error") - } - - // Test Redis cache miss error - cacheErr := ErrRedisCacheMiss - if !IsErr(cacheErr) { - t.Error("Expected IsErr to return true for cache miss error") - } - - // Test configuration error - configErr := ErrConfigInvalidEnv - if !IsErr(configErr) { - t.Error("Expected IsErr to return true for config error") - } - - // Test scheduler error - schedulerErr := ErrSchedulerTimeoutTooLong - if !IsErr(schedulerErr) { - t.Error("Expected IsErr to return true for scheduler error") - } -} - -func TestErrorChecking(t *testing.T) { - // Example of how to check for specific errors in practice - err := ErrRedisLockAlreadyLocked - - // Check if it's a specific error type - if errcode, ok := AsErr(err); ok { - switch errcode { - case ErrRedisLockAlreadyLocked: - // Handle lock already locked case - t.Log("Handling lock already locked error") - case ErrRedisCacheMiss: - // Handle cache miss case - t.Log("Handling cache miss error") - case ErrRedisCacheCorrupted: - // Handle corrupted cache case - t.Log("Handling corrupted cache error") - default: - // Handle other error types - t.Log("Handling other error type") - } - } -} - -func TestWrapFunction(t *testing.T) { - originalErr := ErrRedisCacheMiss - wrappedErr := Wrap(originalErr, "additional context") - - if !IsErr(wrappedErr) { - t.Error("Expected wrapped error to be an ErrorCode") - } - - errcode, ok := AsErr(wrappedErr) - if !ok { - t.Error("Expected AsErr to work with wrapped error") - } - - // The wrapped error should have the same status code as the original - if errcode.StatusCode != ErrRedisCacheMiss.StatusCode { - t.Errorf("Expected status code %d, got %d", ErrRedisCacheMiss.StatusCode, errcode.StatusCode) - } -} diff --git a/internal/pkg/grib/cache.go b/internal/pkg/grib/cache.go index a31ae4d..7d40f43 100644 --- a/internal/pkg/grib/cache.go +++ b/internal/pkg/grib/cache.go @@ -20,12 +20,17 @@ type memCache struct { func (c *memCache) get(k uint64) (vec, bool) { if v, ok := c.m.Load(k); ok { it := v.(item) + if time.Now().Before(it.exp) { return it.v, true } + c.m.Delete(k) } + return vec{}, false } -func (c *memCache) set(k uint64, v vec) { c.m.Store(k, item{v, time.Now().Add(c.ttl)}) } +func (c *memCache) set(k uint64, v vec) { + c.m.Store(k, item{v, time.Now().Add(c.ttl)}) +} diff --git a/internal/pkg/grib/config.go b/internal/pkg/grib/config.go index 69f2270..35f1e46 100644 --- a/internal/pkg/grib/config.go +++ b/internal/pkg/grib/config.go @@ -1,8 +1,10 @@ package grib import ( - "net/url" "time" + + "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" + env "github.com/caarlos0/env/v11" ) type Config struct { @@ -11,5 +13,15 @@ type Config struct { CacheTTL time.Duration `env:"CACHE_TTL" envDefault:"1h"` Parallel int `env:"PARALLEL" envDefault:"4"` Timeout time.Duration `env:"TIMEOUT" envDefault:"30s"` - DatasetURL url.URL `env:"DATASET_URL" envDefault:"https://nomads.ncep.noaa.gov/"` + DatasetURL string `env:"DATASET_URL" envDefault:"https://nomads.ncep.noaa.gov/pub/data/nccf/com/gfs/prod"` +} + +func NewConfig() (*Config, error) { + cfg := &Config{} + if err := env.ParseWithOptions(cfg, env.Options{ + PrefixTagName: "GSN_PREDICTOR_GRIB_", + }); err != nil { + return nil, errcodes.Wrap(err, "failed to parse GRIB config") + } + return cfg, nil } diff --git a/internal/pkg/grib/cube.go b/internal/pkg/grib/cube.go index 71a1c7f..d8cc5c5 100644 --- a/internal/pkg/grib/cube.go +++ b/internal/pkg/grib/cube.go @@ -9,7 +9,7 @@ import ( ) type cube struct { - mm mmap.MMap // read‑only, U followed by V (float32 LE) + mm mmap.MMap t, p, lat, lon int bytesPerVar int64 file *os.File diff --git a/internal/pkg/grib/downloader.go b/internal/pkg/grib/downloader.go index d94006a..a181f43 100644 --- a/internal/pkg/grib/downloader.go +++ b/internal/pkg/grib/downloader.go @@ -13,17 +13,15 @@ import ( "golang.org/x/sync/errgroup" ) -// NOMADS only. -const nomadsRoot = "https://nomads.ncep.noaa.gov/pub/data/nccf/com/gfs/prod" - type Downloader struct { - Dir string - Parallel int - Client *http.Client + Dir string + Parallel int + Client *http.Client + DatasetURL string } func (d *Downloader) fileURL(run string, hour int, step int) string { - return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.0p50.f%03d", nomadsRoot, run, hour, hour, step) + return fmt.Sprintf("%s/gfs.%s/%02d/atmos/gfs.t%02dz.pgrb2.0p50.f%03d", d.DatasetURL, run, hour, hour, step) } func (d *Downloader) fetch(ctx context.Context, url, dst string) error { diff --git a/internal/pkg/grib/grib.go b/internal/pkg/grib/grib.go index 062b2db..2f1caea 100644 --- a/internal/pkg/grib/grib.go +++ b/internal/pkg/grib/grib.go @@ -27,15 +27,17 @@ type Service interface { Update(ctx context.Context) error Extract(ctx context.Context, lat, lon, alt float64, ts time.Time) ([2]float64, error) Close() error + GetStatus() (ready bool, lastUpdate time.Time, isFresh bool, errMsg string) } type ServiceConfig struct { - Dir string - TTL time.Duration - CacheTTL time.Duration - Redis RedisIface - Parallel int - Client *http.Client + Dir string + TTL time.Duration + CacheTTL time.Duration + Redis RedisIface + Parallel int + Client *http.Client + DatasetURL string } type service struct { @@ -147,7 +149,7 @@ func (s *service) Update(ctx context.Context) error { } } - dl := Downloader{Dir: s.cfg.Dir, Parallel: s.cfg.Parallel, Client: s.cfg.Client} + dl := Downloader{Dir: s.cfg.Dir, Parallel: s.cfg.Parallel, Client: s.cfg.Client, DatasetURL: s.cfg.DatasetURL} run := nearestRun(time.Now().UTC().Add(-4 * time.Hour)) // Check if we already have this run @@ -334,3 +336,16 @@ func (s *service) Close() error { } return nil } + +func (s *service) GetStatus() (ready bool, lastUpdate time.Time, isFresh bool, errMsg string) { + d := s.data.Load() + if d == nil { + return false, time.Time{}, false, "no dataset loaded" + } + runTime := time.Unix(d.runUTC, 0) + fresh := time.Since(runTime) < s.cfg.TTL + if !fresh { + return false, runTime, false, "dataset is too old" + } + return true, runTime, true, "" +} diff --git a/internal/pkg/grib/grib_test.go b/internal/pkg/grib/grib_test.go deleted file mode 100644 index c536ab3..0000000 --- a/internal/pkg/grib/grib_test.go +++ /dev/null @@ -1,62 +0,0 @@ -package grib - -import ( - "context" - "testing" - "time" -) - -func TestServiceCreation(t *testing.T) { - cfg := ServiceConfig{ - Dir: "/tmp/grib_test", - TTL: 24 * time.Hour, - CacheTTL: 1 * time.Hour, - Redis: &MockRedis{}, - Parallel: 2, - } - - service, err := New(cfg) - if err != nil { - t.Fatalf("Failed to create service: %v", err) - } - defer service.Close() - - if service == nil { - t.Fatal("Service is nil") - } -} - -func TestNearestRun(t *testing.T) { - now := time.Date(2024, 1, 15, 14, 30, 0, 0, time.UTC) - expected := time.Date(2024, 1, 15, 12, 0, 0, 0, time.UTC) - - result := nearestRun(now) - if !result.Equal(expected) { - t.Errorf("Expected %v, got %v", expected, result) - } -} - -func TestPressureFromAlt(t *testing.T) { - alt := 10000.0 // 10km - pressure := pressureFromAlt(alt) - - // At 10km, pressure should be around 264 hPa - if pressure < 200 || pressure > 300 { - t.Errorf("Unexpected pressure at 10km: %f hPa", pressure) - } -} - -// MockRedis for testing -type MockRedis struct{} - -func (m *MockRedis) Lock(ctx context.Context, key string, ttl time.Duration) (func(context.Context), error) { - return func(ctx context.Context) {}, nil -} - -func (m *MockRedis) Set(key string, value []byte, ttl time.Duration) error { - return nil -} - -func (m *MockRedis) Get(key string) ([]byte, error) { - return nil, nil -} diff --git a/internal/service/config.go b/internal/service/config.go index 3be35aa..3c9b011 100644 --- a/internal/service/config.go +++ b/internal/service/config.go @@ -2,21 +2,27 @@ package service import ( "net/http" - - "git.intra.yksa.space/gsn/predictor/internal/pkg/grib" - "git.intra.yksa.space/gsn/predictor/pkg/redis" + "time" ) type Config struct { - // GRIB Configuration - Grib grib.Config `envPrefix:"GRIB_"` + // --- GRIB Configuration --- + GribDir string `env:"GSN_PREDICTOR_GRIB_DIR" envDefault:"/tmp/grib"` + GribTTL time.Duration `env:"GSN_PREDICTOR_GRIB_TTL" envDefault:"24h"` + GribCacheTTL time.Duration `env:"GSN_PREDICTOR_GRIB_CACHE_TTL" envDefault:"1h"` + GribParallel int `env:"GSN_PREDICTOR_GRIB_PARALLEL" envDefault:"4"` + GribTimeout time.Duration `env:"GSN_PREDICTOR_GRIB_TIMEOUT" envDefault:"30s"` + GribDatasetURL string `env:"GSN_PREDICTOR_GRIB_DATASET_URL" envDefault:"https://nomads.ncep.noaa.gov/pub/data/nccf/com/gfs/prod"` - // Redis Configuration - Redis redis.Config `envPrefix:"REDIS_"` + // --- Redis Configuration --- + RedisHost string `env:"GSN_PREDICTOR_REDIS_HOST"` + RedisPort int `env:"GSN_PREDICTOR_REDIS_PORT"` + RedisPassword string `env:"GSN_PREDICTOR_REDIS_PASSWORD"` + RedisDB int `env:"GSN_PREDICTOR_REDIS_DB"` } func (c *Config) CreateHTTPClient() *http.Client { return &http.Client{ - Timeout: c.Grib.Timeout, + Timeout: c.GribTimeout, } } diff --git a/internal/service/predictor.go b/internal/service/predictor.go index 21d90dc..f20d537 100644 --- a/internal/service/predictor.go +++ b/internal/service/predictor.go @@ -2,26 +2,494 @@ package service import ( "context" + "encoding/base64" + "encoding/json" + "math" + "time" "git.intra.yksa.space/gsn/predictor/internal/pkg/ds" + "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" + "git.intra.yksa.space/gsn/predictor/internal/pkg/log" + "go.uber.org/zap" ) +var ErrInvalidParameters = errcodes.New(400, "missing required prediction parameters") + +// Stage represents a prediction stage (ascent, descent, float) +type Stage struct { + Name string + Results []ds.PredicitonResult + StartTime time.Time + EndTime time.Time +} + +// CustomCurve represents a custom ascent/descent curve +type CustomCurve struct { + Altitude []float64 `json:"altitude"` + Time []float64 `json:"time"` // seconds from start +} + func (s *Service) PerformPrediction(ctx context.Context, params ds.PredictionParameters) ([]ds.PredicitonResult, error) { - // Extract wind data at launch point - wind, err := s.ExtractWind(ctx, params.LaunchLatitude, params.LaunchLongitude, params.LaunchAltitude, params.LaunchDatetime) + // Validate required parameters + if params.LaunchLatitude == nil || params.LaunchLongitude == nil || params.LaunchAltitude == nil || params.LaunchDatetime == nil { + return nil, ErrInvalidParameters + } + + // Get default values + profile := "standard_profile" + if params.Profile != nil { + profile = *params.Profile + } + + ascentRate := 5.0 + if params.AscentRate != nil { + ascentRate = *params.AscentRate + } + + burstAltitude := 30000.0 + if params.BurstAltitude != nil { + burstAltitude = *params.BurstAltitude + } + + descentRate := 5.0 + if params.DescentRate != nil { + descentRate = *params.DescentRate + } + + floatAltitude := 0.0 + if params.FloatAltitude != nil { + floatAltitude = *params.FloatAltitude + } + + // Parse custom curves if provided + var ascentCurve, descentCurve *CustomCurve + if params.AscentCurve != nil && *params.AscentCurve != "" { + if curve, err := parseCustomCurve(*params.AscentCurve); err == nil { + ascentCurve = curve + } + } + if params.DescentCurve != nil && *params.DescentCurve != "" { + if curve, err := parseCustomCurve(*params.DescentCurve); err == nil { + descentCurve = curve + } + } + + log.Ctx(ctx).Info("Starting prediction", + zap.String("profile", profile), + zap.Float64("lat", *params.LaunchLatitude), + zap.Float64("lon", *params.LaunchLongitude), + zap.Float64("alt", *params.LaunchAltitude), + zap.Time("time", *params.LaunchDatetime), + ) + + var allResults []ds.PredicitonResult + + switch profile { + case "standard_profile": + allResults = s.standardProfile(ctx, params, ascentRate, burstAltitude, descentRate, ascentCurve, descentCurve) + case "float_profile": + allResults = s.floatProfile(ctx, params, ascentRate, burstAltitude, floatAltitude, descentRate, ascentCurve, descentCurve) + case "reverse_profile": + allResults = s.reverseProfile(ctx, params, ascentRate, burstAltitude, descentRate, ascentCurve, descentCurve) + case "custom_profile": + allResults = s.customProfile(ctx, params, ascentCurve, descentCurve) + default: + return nil, errcodes.New(400, "unsupported profile: "+profile) + } + + log.Ctx(ctx).Info("Prediction complete", zap.Int("total_steps", len(allResults))) + return allResults, nil +} + +func (s *Service) standardProfile(ctx context.Context, params ds.PredictionParameters, ascentRate, burstAltitude, descentRate float64, ascentCurve, descentCurve *CustomCurve) []ds.PredicitonResult { + var results []ds.PredicitonResult + + // Stage 1: Ascent + ascentResults := s.simulateAscent(ctx, params, ascentRate, burstAltitude, ascentCurve) + results = append(results, ascentResults...) + + if len(ascentResults) > 0 { + // Get final position from ascent + lastResult := ascentResults[len(ascentResults)-1] + + // Stage 2: Descent + descentParams := ds.PredictionParameters{ + LaunchLatitude: lastResult.Latitude, + LaunchLongitude: lastResult.Longitude, + LaunchAltitude: lastResult.Altitude, + LaunchDatetime: lastResult.Timestamp, + } + + descentResults := s.simulateDescent(ctx, descentParams, descentRate, 0, descentCurve) + results = append(results, descentResults...) + } + + return results +} + +func (s *Service) floatProfile(ctx context.Context, params ds.PredictionParameters, ascentRate, burstAltitude, floatAltitude, descentRate float64, ascentCurve, descentCurve *CustomCurve) []ds.PredicitonResult { + var results []ds.PredicitonResult + + // Stage 1: Ascent to float altitude + ascentResults := s.simulateAscent(ctx, params, ascentRate, floatAltitude, ascentCurve) + results = append(results, ascentResults...) + + if len(ascentResults) > 0 { + // Stage 2: Float (simulate for some time) + lastResult := ascentResults[len(ascentResults)-1] + floatResults := s.simulateFloat(ctx, lastResult, 30*time.Minute) // Float for 30 minutes + results = append(results, floatResults...) + + if len(floatResults) > 0 { + // Stage 3: Descent + finalFloat := floatResults[len(floatResults)-1] + descentParams := ds.PredictionParameters{ + LaunchLatitude: finalFloat.Latitude, + LaunchLongitude: finalFloat.Longitude, + LaunchAltitude: finalFloat.Altitude, + LaunchDatetime: finalFloat.Timestamp, + } + + descentResults := s.simulateDescent(ctx, descentParams, descentRate, 0, descentCurve) + results = append(results, descentResults...) + } + } + + return results +} + +func (s *Service) reverseProfile(ctx context.Context, params ds.PredictionParameters, ascentRate, burstAltitude, descentRate float64, ascentCurve, descentCurve *CustomCurve) []ds.PredicitonResult { + var results []ds.PredicitonResult + + // Stage 1: Ascent + ascentResults := s.simulateAscent(ctx, params, ascentRate, burstAltitude, ascentCurve) + results = append(results, ascentResults...) + + if len(ascentResults) > 0 { + // Stage 2: Descent to float altitude + lastResult := ascentResults[len(ascentResults)-1] + descentParams := ds.PredictionParameters{ + LaunchLatitude: lastResult.Latitude, + LaunchLongitude: lastResult.Longitude, + LaunchAltitude: lastResult.Altitude, + LaunchDatetime: lastResult.Timestamp, + } + + // Descent to float altitude (if specified) + floatAlt := 0.0 + if params.FloatAltitude != nil { + floatAlt = *params.FloatAltitude + } + + descentResults := s.simulateDescent(ctx, descentParams, descentRate, floatAlt, descentCurve) + results = append(results, descentResults...) + + if floatAlt > 0 && len(descentResults) > 0 { + // Stage 3: Float + finalDescent := descentResults[len(descentResults)-1] + floatResults := s.simulateFloat(ctx, finalDescent, 30*time.Minute) + results = append(results, floatResults...) + } + } + + return results +} + +func (s *Service) customProfile(ctx context.Context, params ds.PredictionParameters, ascentCurve, descentCurve *CustomCurve) []ds.PredicitonResult { + var results []ds.PredicitonResult + + if ascentCurve != nil { + ascentResults := s.simulateCustomAscent(ctx, params, ascentCurve) + results = append(results, ascentResults...) + } + + if descentCurve != nil && len(results) > 0 { + lastResult := results[len(results)-1] + descentParams := ds.PredictionParameters{ + LaunchLatitude: lastResult.Latitude, + LaunchLongitude: lastResult.Longitude, + LaunchAltitude: lastResult.Altitude, + LaunchDatetime: lastResult.Timestamp, + } + + descentResults := s.simulateCustomDescent(ctx, descentParams, descentCurve) + results = append(results, descentResults...) + } + + return results +} + +func (s *Service) simulateAscent(ctx context.Context, params ds.PredictionParameters, ascentRate, targetAltitude float64, customCurve *CustomCurve) []ds.PredicitonResult { + const dt = 10.0 // simulation step in seconds + const outputInterval = 60.0 // output every 60 seconds + + lat := *params.LaunchLatitude + lon := *params.LaunchLongitude + alt := *params.LaunchAltitude + timeCur := *params.LaunchDatetime + + results := make([]ds.PredicitonResult, 0, 1000) + + // Always include the initial launch point + latCopy := lat + lonCopy := lon + altCopy := alt + timeCopy := timeCur + wind := [2]float64{0, 0} + windU := wind[0] + windV := wind[1] + results = append(results, ds.PredicitonResult{ + Latitude: &latCopy, + Longitude: &lonCopy, + Altitude: &altCopy, + Timestamp: &timeCopy, + WindU: &windU, + WindV: &windV, + }) + + var nextOutputTime = timeCur.Add(time.Duration(outputInterval) * time.Second) + + for alt < targetAltitude { + wind, err := s.ExtractWind(ctx, lat, lon, alt, timeCur) + if err != nil { + log.Ctx(ctx).Warn("Wind extraction failed during ascent", zap.Error(err)) + break + } + + altRate := ascentRate + if customCurve != nil { + altRate = s.getCustomAltitudeRate(customCurve, alt, ascentRate) + } + + latDot := (wind[1] / 111320.0) + lonDot := (wind[0] / (40075000.0 * math.Cos(lat*math.Pi/180) / 360.0)) + + lat += latDot * dt + lon += lonDot * dt + alt += altRate * dt + timeCur = timeCur.Add(time.Duration(dt) * time.Second) + + // Don't add a point if we've reached or exceeded target altitude + if alt >= targetAltitude { + break + } + + if !timeCur.Before(nextOutputTime) { + latCopy := lat + lonCopy := lon + altCopy := alt + timeCopy := timeCur + windU := wind[0] + windV := wind[1] + results = append(results, ds.PredicitonResult{ + Latitude: &latCopy, + Longitude: &lonCopy, + Altitude: &altCopy, + Timestamp: &timeCopy, + WindU: &windU, + WindV: &windV, + }) + nextOutputTime = nextOutputTime.Add(time.Duration(outputInterval) * time.Second) + } + } + + return results +} + +func (s *Service) simulateDescent(ctx context.Context, params ds.PredictionParameters, descentRate, targetAltitude float64, customCurve *CustomCurve) []ds.PredicitonResult { + const dt = 10.0 // simulation step in seconds + const outputInterval = 60.0 // output every 60 seconds + + lat := *params.LaunchLatitude + lon := *params.LaunchLongitude + alt := *params.LaunchAltitude + timeCur := *params.LaunchDatetime + + results := make([]ds.PredicitonResult, 0, 1000) + + // Always include the initial descent point + latCopy := lat + lonCopy := lon + altCopy := alt + timeCopy := timeCur + wind := [2]float64{0, 0} + windU := wind[0] + windV := wind[1] + results = append(results, ds.PredicitonResult{ + Latitude: &latCopy, + Longitude: &lonCopy, + Altitude: &altCopy, + Timestamp: &timeCopy, + WindU: &windU, + WindV: &windV, + }) + + var nextOutputTime = timeCur.Add(time.Duration(outputInterval) * time.Second) + + for alt > targetAltitude { + wind, err := s.ExtractWind(ctx, lat, lon, alt, timeCur) + if err != nil { + log.Ctx(ctx).Warn("Wind extraction failed during descent", zap.Error(err)) + break + } + + altRate := -descentRate + if customCurve != nil { + altRate = -s.getCustomAltitudeRate(customCurve, alt, descentRate) + } + + latDot := (wind[1] / 111320.0) + lonDot := (wind[0] / (40075000.0 * math.Cos(lat*math.Pi/180) / 360.0)) + + lat += latDot * dt + lon += lonDot * dt + alt += altRate * dt + timeCur = timeCur.Add(time.Duration(dt) * time.Second) + + // Don't add a point if we've reached or gone below target altitude + if alt <= targetAltitude { + break + } + + if !timeCur.Before(nextOutputTime) { + latCopy := lat + lonCopy := lon + altCopy := alt + timeCopy := timeCur + windU := wind[0] + windV := wind[1] + results = append(results, ds.PredicitonResult{ + Latitude: &latCopy, + Longitude: &lonCopy, + Altitude: &altCopy, + Timestamp: &timeCopy, + WindU: &windU, + WindV: &windV, + }) + nextOutputTime = nextOutputTime.Add(time.Duration(outputInterval) * time.Second) + } + } + + return results +} + +func (s *Service) simulateFloat(ctx context.Context, startResult ds.PredicitonResult, duration time.Duration) []ds.PredicitonResult { + const dt = 10.0 // simulation step in seconds + const outputInterval = 60.0 // output every 60 seconds + + lat := *startResult.Latitude + lon := *startResult.Longitude + alt := *startResult.Altitude + timeCur := *startResult.Timestamp + endTime := timeCur.Add(duration) + + results := make([]ds.PredicitonResult, 0, 1000) + + // Always include the initial float point + latCopy := lat + lonCopy := lon + altCopy := alt + timeCopy := timeCur + wind := [2]float64{0, 0} + windU := wind[0] + windV := wind[1] + results = append(results, ds.PredicitonResult{ + Latitude: &latCopy, + Longitude: &lonCopy, + Altitude: &altCopy, + Timestamp: &timeCopy, + WindU: &windU, + WindV: &windV, + }) + + var nextOutputTime = timeCur.Add(time.Duration(outputInterval) * time.Second) + + for timeCur.Before(endTime) { + wind, err := s.ExtractWind(ctx, lat, lon, alt, timeCur) + if err != nil { + log.Ctx(ctx).Warn("Wind extraction failed during float", zap.Error(err)) + break + } + + latDot := (wind[1] / 111320.0) + lonDot := (wind[0] / (40075000.0 * math.Cos(lat*math.Pi/180) / 360.0)) + + lat += latDot * dt + lon += lonDot * dt + // alt remains constant during float + timeCur = timeCur.Add(time.Duration(dt) * time.Second) + + if !timeCur.Before(nextOutputTime) { + latCopy := lat + lonCopy := lon + altCopy := alt + timeCopy := timeCur + windU := wind[0] + windV := wind[1] + results = append(results, ds.PredicitonResult{ + Latitude: &latCopy, + Longitude: &lonCopy, + Altitude: &altCopy, + Timestamp: &timeCopy, + WindU: &windU, + WindV: &windV, + }) + nextOutputTime = nextOutputTime.Add(time.Duration(outputInterval) * time.Second) + } + } + + return results +} + +func (s *Service) simulateCustomAscent(ctx context.Context, params ds.PredictionParameters, curve *CustomCurve) []ds.PredicitonResult { + // Implementation for custom ascent curve + // This would interpolate the altitude rate from the custom curve + return s.simulateAscent(ctx, params, 5.0, 30000.0, curve) +} + +func (s *Service) simulateCustomDescent(ctx context.Context, params ds.PredictionParameters, curve *CustomCurve) []ds.PredicitonResult { + // Implementation for custom descent curve + // This would interpolate the altitude rate from the custom curve + return s.simulateDescent(ctx, params, 5.0, 0.0, curve) +} + +func (s *Service) getCustomAltitudeRate(curve *CustomCurve, currentAltitude, defaultRate float64) float64 { + if curve == nil || len(curve.Altitude) < 2 { + return defaultRate + } + + // Find the two points in the curve that bracket the current altitude + for i := 0; i < len(curve.Altitude)-1; i++ { + if curve.Altitude[i] <= currentAltitude && currentAltitude <= curve.Altitude[i+1] { + // Linear interpolation + alt1, alt2 := curve.Altitude[i], curve.Altitude[i+1] + time1, time2 := curve.Time[i], curve.Time[i+1] + + if alt2 == alt1 { + return defaultRate + } + + // Calculate rate (change in altitude per second) + if time2 > time1 { + return (alt2 - alt1) / (time2 - time1) + } + return defaultRate + } + } + + return defaultRate +} + +func parseCustomCurve(base64Data string) (*CustomCurve, error) { + data, err := base64.StdEncoding.DecodeString(base64Data) if err != nil { return nil, err } - // TODO: Implement full prediction logic - result := ds.PredicitonResult{ - Latitude: params.LaunchLatitude, - Longitude: params.LaunchLongitude, - Altitude: params.LaunchAltitude, - Timestamp: params.LaunchDatetime, - WindU: wind[0], - WindV: wind[1], + var curve CustomCurve + if err := json.Unmarshal(data, &curve); err != nil { + return nil, err } - return []ds.PredicitonResult{result}, nil + return &curve, nil } diff --git a/internal/service/service.go b/internal/service/service.go index 5858054..16e64b2 100644 --- a/internal/service/service.go +++ b/internal/service/service.go @@ -4,22 +4,20 @@ import ( "context" "time" - "go.uber.org/zap" + "git.intra.yksa.space/gsn/predictor/internal/pkg/log" ) type Service struct { - cfg *Config - redis Redis - grib Grib - logger *zap.Logger + cfg *Config + redis Redis + grib Grib } -func New(cfg *Config, gribService Grib, redisService Redis, logger *zap.Logger) (*Service, error) { +func New(cfg *Config, gribService Grib, redisService Redis) (*Service, error) { svc := &Service{ - cfg: cfg, - redis: redisService, - grib: gribService, - logger: logger, + cfg: cfg, + redis: redisService, + grib: gribService, } return svc, nil @@ -42,12 +40,12 @@ func (s *Service) Update(ctx context.Context) error { // Start starts the service func (s *Service) Start() { - s.logger.Info("service started") + log.Ctx(context.Background()).Info("service started") } // Stop stops the service func (s *Service) Stop() { - s.logger.Info("service stopped") + log.Ctx(context.Background()).Info("service stopped") } // Close closes the service and releases resources @@ -55,3 +53,12 @@ func (s *Service) Close() error { s.Stop() return nil } + +func (s *Service) GetGribStatus(ctx context.Context) (ready bool, lastUpdate time.Time, isFresh bool, errMsg string) { + if gribStatus, ok := s.grib.(interface { + GetStatus() (ready bool, lastUpdate time.Time, isFresh bool, errMsg string) + }); ok { + return gribStatus.GetStatus() + } + return false, time.Time{}, false, "grib service does not implement GetStatus" +} diff --git a/internal/transport/middleware/log.go b/internal/transport/middleware/log.go index 1ad417c..ebf2003 100644 --- a/internal/transport/middleware/log.go +++ b/internal/transport/middleware/log.go @@ -9,9 +9,9 @@ import ( "go.uber.org/zap" ) -func Logging(logger *zap.Logger) middleware.Middleware { +func Logging() middleware.Middleware { return func(req middleware.Request, next func(req middleware.Request) (middleware.Response, error)) (middleware.Response, error) { - lg := logger.With( + lg := log.Ctx(req.Context).With( zap.String("operationId", req.OperationID), ) diff --git a/internal/transport/rest/config.go b/internal/transport/rest/config.go index 9ea1487..1b55959 100644 --- a/internal/transport/rest/config.go +++ b/internal/transport/rest/config.go @@ -1,5 +1,10 @@ package rest +import ( + "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" + env "github.com/caarlos0/env/v11" +) + type Config struct { Host string `env:"HOST" envDefault:"0.0.0.0"` Port int `env:"PORT" envDefault:"8080"` @@ -7,3 +12,13 @@ type Config struct { WriteTimeout string `env:"WRITE_TIMEOUT" envDefault:"30s"` IdleTimeout string `env:"IDLE_TIMEOUT" envDefault:"60s"` } + +func NewConfig() (*Config, error) { + cfg := &Config{} + if err := env.ParseWithOptions(cfg, env.Options{ + PrefixTagName: "GSN_PREDICTOR_REST_", + }); err != nil { + return nil, errcodes.Wrap(err, "failed to parse REST config") + } + return cfg, nil +} diff --git a/internal/transport/rest/handler/deps.go b/internal/transport/rest/handler/deps.go index 88777b7..2d930cb 100644 --- a/internal/transport/rest/handler/deps.go +++ b/internal/transport/rest/handler/deps.go @@ -3,9 +3,12 @@ package handler import ( "context" "time" + + "git.intra.yksa.space/gsn/predictor/internal/pkg/ds" ) type Service interface { UpdateWeatherData(ctx context.Context) error ExtractWind(ctx context.Context, lat, lon, alt float64, ts time.Time) ([2]float64, error) + PerformPrediction(ctx context.Context, params ds.PredictionParameters) ([]ds.PredicitonResult, error) } diff --git a/internal/transport/rest/handler/handler.go b/internal/transport/rest/handler/handler.go index 49ce1c8..f10b09d 100644 --- a/internal/transport/rest/handler/handler.go +++ b/internal/transport/rest/handler/handler.go @@ -3,7 +3,9 @@ package handler import ( "context" "net/http" + "time" + "git.intra.yksa.space/gsn/predictor/internal/pkg/ds" "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" api "git.intra.yksa.space/gsn/predictor/pkg/rest" ) @@ -23,7 +25,115 @@ func New(svc Service) *Handler { } func (h *Handler) PerformPrediction(ctx context.Context, req api.OptPredictionParameters, params api.PerformPredictionParams) (*api.PredictionResult, error) { - return nil, errcodes.New(http.StatusNotImplemented, "not implemented", "please wait") + internalParams := ds.ConvertOptPredictionParameters(req) + if internalParams == nil { + return nil, errcodes.New(http.StatusBadRequest, "invalid or missing parameters") + } + results, err := h.svc.PerformPrediction(ctx, *internalParams) + if err != nil { + return nil, err + } + if len(results) == 0 { + return nil, errcodes.New(http.StatusInternalServerError, "no prediction results") + } + + // Group results into stages (ascent and descent) + stages := h.groupResultsIntoStages(results) + + // Map to OpenAPI schema + var predictionItems []api.PredictionResultPredictionItem + + for _, stage := range stages { + var trajectory []api.PredictionResultPredictionItemTrajectoryItem + + for _, result := range stage.Results { + traj := api.PredictionResultPredictionItemTrajectoryItem{ + Datetime: *result.Timestamp, + Latitude: *result.Latitude, + Longitude: *result.Longitude, + Altitude: *result.Altitude, + } + trajectory = append(trajectory, traj) + } + + item := api.PredictionResultPredictionItem{ + Stage: stage.Stage, + Trajectory: trajectory, + } + predictionItems = append(predictionItems, item) + } + + metadata := api.PredictionResultMetadata{ + StartDatetime: *results[0].Timestamp, + CompleteDatetime: *results[len(results)-1].Timestamp, + } + + resp := &api.PredictionResult{ + Metadata: metadata, + Prediction: predictionItems, + } + return resp, nil +} + +// StageResult represents a stage with its results +type StageResult struct { + Stage api.PredictionResultPredictionItemStage + Results []ds.PredicitonResult +} + +// groupResultsIntoStages groups the prediction results into ascent and descent stages +func (h *Handler) groupResultsIntoStages(results []ds.PredicitonResult) []StageResult { + if len(results) == 0 { + return nil + } + + var stages []StageResult + var currentStage []ds.PredicitonResult + var currentStageType api.PredictionResultPredictionItemStage + + // Determine if we're in ascent or descent based on altitude changes + prevAlt := *results[0].Altitude + currentStage = append(currentStage, results[0]) + currentStageType = api.PredictionResultPredictionItemStageAscent + + for i := 1; i < len(results); i++ { + result := results[i] + currentAlt := *result.Altitude + + // Determine if we're still in the same stage + var stageType api.PredictionResultPredictionItemStage + if currentAlt > prevAlt { + stageType = api.PredictionResultPredictionItemStageAscent + } else if currentAlt < prevAlt { + stageType = api.PredictionResultPredictionItemStageDescent + } else { + // Same altitude - continue with current stage + stageType = currentStageType + } + + // If stage type changed, finalize current stage and start new one + if stageType != currentStageType && len(currentStage) > 0 { + stages = append(stages, StageResult{ + Stage: currentStageType, + Results: currentStage, + }) + currentStage = nil + currentStageType = stageType + } + + currentStage = append(currentStage, result) + prevAlt = currentAlt + } + + // Add the final stage + if len(currentStage) > 0 { + stages = append(stages, StageResult{ + Stage: currentStageType, + Results: currentStage, + }) + } + + return stages } func (h *Handler) NewError(ctx context.Context, err error) *api.ErrorStatusCode { @@ -50,3 +160,35 @@ func (h *Handler) NewError(ctx context.Context, err error) *api.ErrorStatusCode }, } } + +func (h *Handler) ReadinessCheck(ctx context.Context) (*api.ReadinessResponse, error) { + status := api.ReadinessResponseStatusNotReady + var lastUpdate time.Time + var isFresh bool + var errMsg string + + if s, ok := h.svc.(interface { + GetGribStatus(ctx context.Context) (ready bool, lastUpdate time.Time, isFresh bool, errMsg string) + }); ok { + ready, lu, fresh, em := s.GetGribStatus(ctx) + lastUpdate = lu + isFresh = fresh + errMsg = em + if ready { + status = api.ReadinessResponseStatusOk + } else if em != "" { + status = api.ReadinessResponseStatusError + } + } else { + errMsg = "service does not implement GetGribStatus" + status = api.ReadinessResponseStatusError + } + + resp := &api.ReadinessResponse{ + Status: status, + IsFresh: api.NewOptBool(isFresh), + LastUpdate: api.NewOptDateTime(lastUpdate), + ErrorMessage: api.NewOptString(errMsg), + } + return resp, nil +} diff --git a/internal/transport/rest/transport.go b/internal/transport/rest/transport.go index 2657272..8534caa 100644 --- a/internal/transport/rest/transport.go +++ b/internal/transport/rest/transport.go @@ -1,36 +1,37 @@ package rest import ( + "context" "fmt" "net/http" + "git.intra.yksa.space/gsn/predictor/internal/pkg/log" "git.intra.yksa.space/gsn/predictor/internal/transport/middleware" handler "git.intra.yksa.space/gsn/predictor/internal/transport/rest/handler" api "git.intra.yksa.space/gsn/predictor/pkg/rest" - "go.uber.org/zap" ) type Transport struct { - lg *zap.Logger - cfg *Config - srv *api.Server + cfg *Config + srv *api.Server + handler *handler.Handler } -func New(lg *zap.Logger, handler *handler.Handler, cfg *Config) (*Transport, error) { - srv, err := api.NewServer(handler, api.WithMiddleware(middleware.Logging(lg))) +func New(handler *handler.Handler, cfg *Config) (*Transport, error) { + srv, err := api.NewServer(handler, api.WithMiddleware(middleware.Logging())) if err != nil { return nil, err } return &Transport{ - lg: lg, - srv: srv, - cfg: cfg, + srv: srv, + cfg: cfg, + handler: handler, }, nil } func (t *Transport) Run() { - t.lg.Info("started") + log.Ctx(context.Background()).Info("started") if err := http.ListenAndServe(fmt.Sprintf(":%d", t.cfg.Port), t.srv); err != nil { panic(err) diff --git a/pkg/redis/redis.go b/pkg/redis/redis.go index 9577de0..f88dfe1 100644 --- a/pkg/redis/redis.go +++ b/pkg/redis/redis.go @@ -19,10 +19,10 @@ type Client struct { var _ Service = (*Client)(nil) type Config struct { - Host string - Port int - Password string - DB int + Host string `env:"HOST"` + Port int `env:"PORT"` + Password string `env:"PASSWORD"` + DB int `env:"DB"` } func New(cfg Config) (*Client, error) { diff --git a/pkg/rest/oas_client_gen.go b/pkg/rest/oas_client_gen.go index 036f7cd..754b03d 100644 --- a/pkg/rest/oas_client_gen.go +++ b/pkg/rest/oas_client_gen.go @@ -33,6 +33,12 @@ type Invoker interface { // // POST /api/v1/prediction PerformPrediction(ctx context.Context, request OptPredictionParameters, params PerformPredictionParams) (*PredictionResult, error) + // ReadinessCheck invokes readinessCheck operation. + // + // Readiness check. + // + // GET /ready + ReadinessCheck(ctx context.Context) (*ReadinessResponse, error) } // Client implements OAS client. @@ -177,3 +183,75 @@ func (c *Client) sendPerformPrediction(ctx context.Context, request OptPredictio return result, nil } + +// ReadinessCheck invokes readinessCheck operation. +// +// Readiness check. +// +// GET /ready +func (c *Client) ReadinessCheck(ctx context.Context) (*ReadinessResponse, error) { + res, err := c.sendReadinessCheck(ctx) + return res, err +} + +func (c *Client) sendReadinessCheck(ctx context.Context) (res *ReadinessResponse, err error) { + otelAttrs := []attribute.KeyValue{ + otelogen.OperationID("readinessCheck"), + semconv.HTTPRequestMethodKey.String("GET"), + semconv.HTTPRouteKey.String("/ready"), + } + + // Run stopwatch. + startTime := time.Now() + defer func() { + // Use floating point division here for higher precision (instead of Millisecond method). + elapsedDuration := time.Since(startTime) + c.duration.Record(ctx, float64(elapsedDuration)/float64(time.Millisecond), metric.WithAttributes(otelAttrs...)) + }() + + // Increment request counter. + c.requests.Add(ctx, 1, metric.WithAttributes(otelAttrs...)) + + // Start a span for this request. + ctx, span := c.cfg.Tracer.Start(ctx, ReadinessCheckOperation, + trace.WithAttributes(otelAttrs...), + clientSpanKind, + ) + // Track stage for error reporting. + var stage string + defer func() { + if err != nil { + span.RecordError(err) + span.SetStatus(codes.Error, stage) + c.errors.Add(ctx, 1, metric.WithAttributes(otelAttrs...)) + } + span.End() + }() + + stage = "BuildURL" + u := uri.Clone(c.requestURL(ctx)) + var pathParts [1]string + pathParts[0] = "/ready" + uri.AddPathParts(u, pathParts[:]...) + + stage = "EncodeRequest" + r, err := ht.NewRequest(ctx, "GET", u) + if err != nil { + return res, errors.Wrap(err, "create request") + } + + stage = "SendRequest" + resp, err := c.cfg.Client.Do(r) + if err != nil { + return res, errors.Wrap(err, "do request") + } + defer resp.Body.Close() + + stage = "DecodeResponse" + result, err := decodeReadinessCheckResponse(resp) + if err != nil { + return res, errors.Wrap(err, "decode response") + } + + return result, nil +} diff --git a/pkg/rest/oas_handlers_gen.go b/pkg/rest/oas_handlers_gen.go index 70eabc0..1d88103 100644 --- a/pkg/rest/oas_handlers_gen.go +++ b/pkg/rest/oas_handlers_gen.go @@ -193,3 +193,133 @@ func (s *Server) handlePerformPredictionRequest(args [0]string, argsEscaped bool return } } + +// handleReadinessCheckRequest handles readinessCheck operation. +// +// Readiness check. +// +// GET /ready +func (s *Server) handleReadinessCheckRequest(args [0]string, argsEscaped bool, w http.ResponseWriter, r *http.Request) { + statusWriter := &codeRecorder{ResponseWriter: w} + w = statusWriter + otelAttrs := []attribute.KeyValue{ + otelogen.OperationID("readinessCheck"), + semconv.HTTPRequestMethodKey.String("GET"), + semconv.HTTPRouteKey.String("/ready"), + } + + // Start a span for this request. + ctx, span := s.cfg.Tracer.Start(r.Context(), ReadinessCheckOperation, + trace.WithAttributes(otelAttrs...), + serverSpanKind, + ) + defer span.End() + + // Add Labeler to context. + labeler := &Labeler{attrs: otelAttrs} + ctx = contextWithLabeler(ctx, labeler) + + // Run stopwatch. + startTime := time.Now() + defer func() { + elapsedDuration := time.Since(startTime) + + attrSet := labeler.AttributeSet() + attrs := attrSet.ToSlice() + code := statusWriter.status + if code != 0 { + codeAttr := semconv.HTTPResponseStatusCode(code) + attrs = append(attrs, codeAttr) + span.SetAttributes(codeAttr) + } + attrOpt := metric.WithAttributes(attrs...) + + // Increment request counter. + s.requests.Add(ctx, 1, attrOpt) + + // Use floating point division here for higher precision (instead of Millisecond method). + s.duration.Record(ctx, float64(elapsedDuration)/float64(time.Millisecond), attrOpt) + }() + + var ( + recordError = func(stage string, err error) { + span.RecordError(err) + + // https://opentelemetry.io/docs/specs/semconv/http/http-spans/#status + // Span Status MUST be left unset if HTTP status code was in the 1xx, 2xx or 3xx ranges, + // unless there was another error (e.g., network error receiving the response body; or 3xx codes with + // max redirects exceeded), in which case status MUST be set to Error. + code := statusWriter.status + if code >= 100 && code < 500 { + span.SetStatus(codes.Error, stage) + } + + attrSet := labeler.AttributeSet() + attrs := attrSet.ToSlice() + if code != 0 { + attrs = append(attrs, semconv.HTTPResponseStatusCode(code)) + } + + s.errors.Add(ctx, 1, metric.WithAttributes(attrs...)) + } + err error + ) + + var response *ReadinessResponse + if m := s.cfg.Middleware; m != nil { + mreq := middleware.Request{ + Context: ctx, + OperationName: ReadinessCheckOperation, + OperationSummary: "Readiness check", + OperationID: "readinessCheck", + Body: nil, + Params: middleware.Parameters{}, + Raw: r, + } + + type ( + Request = struct{} + Params = struct{} + Response = *ReadinessResponse + ) + response, err = middleware.HookMiddleware[ + Request, + Params, + Response, + ]( + m, + mreq, + nil, + func(ctx context.Context, request Request, params Params) (response Response, err error) { + response, err = s.h.ReadinessCheck(ctx) + return response, err + }, + ) + } else { + response, err = s.h.ReadinessCheck(ctx) + } + if err != nil { + if errRes, ok := errors.Into[*ErrorStatusCode](err); ok { + if err := encodeErrorResponse(errRes, w, span); err != nil { + defer recordError("Internal", err) + } + return + } + if errors.Is(err, ht.ErrNotImplemented) { + s.cfg.ErrorHandler(ctx, w, r, err) + return + } + if err := encodeErrorResponse(s.h.NewError(ctx, err), w, span); err != nil { + defer recordError("Internal", err) + } + return + } + + if err := encodeReadinessCheckResponse(response, w, span); err != nil { + defer recordError("EncodeResponse", err) + if !errors.Is(err, ht.ErrInternalServerErrorResponse) { + s.cfg.ErrorHandler(ctx, w, r, err) + } + return + } +} diff --git a/pkg/rest/oas_json_gen.go b/pkg/rest/oas_json_gen.go index 3bd1884..2621d55 100644 --- a/pkg/rest/oas_json_gen.go +++ b/pkg/rest/oas_json_gen.go @@ -1154,24 +1154,127 @@ func (s *PredictionResultPredictionItemTrajectoryItem) Encode(e *jx.Encoder) { // encodeFields encodes fields. func (s *PredictionResultPredictionItemTrajectoryItem) encodeFields(e *jx.Encoder) { + { + e.FieldStart("datetime") + json.EncodeDateTime(e, s.Datetime) + } + { + e.FieldStart("latitude") + e.Float64(s.Latitude) + } + { + e.FieldStart("longitude") + e.Float64(s.Longitude) + } + { + e.FieldStart("altitude") + e.Float64(s.Altitude) + } } -var jsonFieldsNameOfPredictionResultPredictionItemTrajectoryItem = [0]string{} +var jsonFieldsNameOfPredictionResultPredictionItemTrajectoryItem = [4]string{ + 0: "datetime", + 1: "latitude", + 2: "longitude", + 3: "altitude", +} // Decode decodes PredictionResultPredictionItemTrajectoryItem from json. func (s *PredictionResultPredictionItemTrajectoryItem) Decode(d *jx.Decoder) error { if s == nil { return errors.New("invalid: unable to decode PredictionResultPredictionItemTrajectoryItem to nil") } + var requiredBitSet [1]uint8 if err := d.ObjBytes(func(d *jx.Decoder, k []byte) error { switch string(k) { + case "datetime": + requiredBitSet[0] |= 1 << 0 + if err := func() error { + v, err := json.DecodeDateTime(d) + s.Datetime = v + if err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"datetime\"") + } + case "latitude": + requiredBitSet[0] |= 1 << 1 + if err := func() error { + v, err := d.Float64() + s.Latitude = float64(v) + if err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"latitude\"") + } + case "longitude": + requiredBitSet[0] |= 1 << 2 + if err := func() error { + v, err := d.Float64() + s.Longitude = float64(v) + if err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"longitude\"") + } + case "altitude": + requiredBitSet[0] |= 1 << 3 + if err := func() error { + v, err := d.Float64() + s.Altitude = float64(v) + if err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"altitude\"") + } default: return d.Skip() } + return nil }); err != nil { return errors.Wrap(err, "decode PredictionResultPredictionItemTrajectoryItem") } + // Validate required fields. + var failures []validate.FieldError + for i, mask := range [1]uint8{ + 0b00001111, + } { + if result := (requiredBitSet[i] & mask) ^ mask; result != 0 { + // Mask only required fields and check equality to mask using XOR. + // + // If XOR result is not zero, result is not equal to expected, so some fields are missed. + // Bits of fields which would be set are actually bits of missed fields. + missed := bits.OnesCount8(result) + for bitN := 0; bitN < missed; bitN++ { + bitIdx := bits.TrailingZeros8(result) + fieldIdx := i*8 + bitIdx + var name string + if fieldIdx < len(jsonFieldsNameOfPredictionResultPredictionItemTrajectoryItem) { + name = jsonFieldsNameOfPredictionResultPredictionItemTrajectoryItem[fieldIdx] + } else { + name = strconv.Itoa(fieldIdx) + } + failures = append(failures, validate.FieldError{ + Name: name, + Error: validate.ErrFieldRequired, + }) + // Reset bit. + result &^= 1 << bitIdx + } + } + } + if len(failures) > 0 { + return &validate.Error{Fields: failures} + } return nil } @@ -1188,3 +1291,190 @@ func (s *PredictionResultPredictionItemTrajectoryItem) UnmarshalJSON(data []byte d := jx.DecodeBytes(data) return s.Decode(d) } + +// Encode implements json.Marshaler. +func (s *ReadinessResponse) Encode(e *jx.Encoder) { + e.ObjStart() + s.encodeFields(e) + e.ObjEnd() +} + +// encodeFields encodes fields. +func (s *ReadinessResponse) encodeFields(e *jx.Encoder) { + { + e.FieldStart("status") + s.Status.Encode(e) + } + { + if s.LastUpdate.Set { + e.FieldStart("last_update") + s.LastUpdate.Encode(e, json.EncodeDateTime) + } + } + { + if s.IsFresh.Set { + e.FieldStart("is_fresh") + s.IsFresh.Encode(e) + } + } + { + if s.ErrorMessage.Set { + e.FieldStart("error_message") + s.ErrorMessage.Encode(e) + } + } +} + +var jsonFieldsNameOfReadinessResponse = [4]string{ + 0: "status", + 1: "last_update", + 2: "is_fresh", + 3: "error_message", +} + +// Decode decodes ReadinessResponse from json. +func (s *ReadinessResponse) Decode(d *jx.Decoder) error { + if s == nil { + return errors.New("invalid: unable to decode ReadinessResponse to nil") + } + var requiredBitSet [1]uint8 + + if err := d.ObjBytes(func(d *jx.Decoder, k []byte) error { + switch string(k) { + case "status": + requiredBitSet[0] |= 1 << 0 + if err := func() error { + if err := s.Status.Decode(d); err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"status\"") + } + case "last_update": + if err := func() error { + s.LastUpdate.Reset() + if err := s.LastUpdate.Decode(d, json.DecodeDateTime); err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"last_update\"") + } + case "is_fresh": + if err := func() error { + s.IsFresh.Reset() + if err := s.IsFresh.Decode(d); err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"is_fresh\"") + } + case "error_message": + if err := func() error { + s.ErrorMessage.Reset() + if err := s.ErrorMessage.Decode(d); err != nil { + return err + } + return nil + }(); err != nil { + return errors.Wrap(err, "decode field \"error_message\"") + } + default: + return d.Skip() + } + return nil + }); err != nil { + return errors.Wrap(err, "decode ReadinessResponse") + } + // Validate required fields. + var failures []validate.FieldError + for i, mask := range [1]uint8{ + 0b00000001, + } { + if result := (requiredBitSet[i] & mask) ^ mask; result != 0 { + // Mask only required fields and check equality to mask using XOR. + // + // If XOR result is not zero, result is not equal to expected, so some fields are missed. + // Bits of fields which would be set are actually bits of missed fields. + missed := bits.OnesCount8(result) + for bitN := 0; bitN < missed; bitN++ { + bitIdx := bits.TrailingZeros8(result) + fieldIdx := i*8 + bitIdx + var name string + if fieldIdx < len(jsonFieldsNameOfReadinessResponse) { + name = jsonFieldsNameOfReadinessResponse[fieldIdx] + } else { + name = strconv.Itoa(fieldIdx) + } + failures = append(failures, validate.FieldError{ + Name: name, + Error: validate.ErrFieldRequired, + }) + // Reset bit. + result &^= 1 << bitIdx + } + } + } + if len(failures) > 0 { + return &validate.Error{Fields: failures} + } + + return nil +} + +// MarshalJSON implements stdjson.Marshaler. +func (s *ReadinessResponse) MarshalJSON() ([]byte, error) { + e := jx.Encoder{} + s.Encode(&e) + return e.Bytes(), nil +} + +// UnmarshalJSON implements stdjson.Unmarshaler. +func (s *ReadinessResponse) UnmarshalJSON(data []byte) error { + d := jx.DecodeBytes(data) + return s.Decode(d) +} + +// Encode encodes ReadinessResponseStatus as json. +func (s ReadinessResponseStatus) Encode(e *jx.Encoder) { + e.Str(string(s)) +} + +// Decode decodes ReadinessResponseStatus from json. +func (s *ReadinessResponseStatus) Decode(d *jx.Decoder) error { + if s == nil { + return errors.New("invalid: unable to decode ReadinessResponseStatus to nil") + } + v, err := d.StrBytes() + if err != nil { + return err + } + // Try to use constant string. + switch ReadinessResponseStatus(v) { + case ReadinessResponseStatusOk: + *s = ReadinessResponseStatusOk + case ReadinessResponseStatusNotReady: + *s = ReadinessResponseStatusNotReady + case ReadinessResponseStatusError: + *s = ReadinessResponseStatusError + default: + *s = ReadinessResponseStatus(v) + } + + return nil +} + +// MarshalJSON implements stdjson.Marshaler. +func (s ReadinessResponseStatus) MarshalJSON() ([]byte, error) { + e := jx.Encoder{} + s.Encode(&e) + return e.Bytes(), nil +} + +// UnmarshalJSON implements stdjson.Unmarshaler. +func (s *ReadinessResponseStatus) UnmarshalJSON(data []byte) error { + d := jx.DecodeBytes(data) + return s.Decode(d) +} diff --git a/pkg/rest/oas_operations_gen.go b/pkg/rest/oas_operations_gen.go index ea78f42..873f44a 100644 --- a/pkg/rest/oas_operations_gen.go +++ b/pkg/rest/oas_operations_gen.go @@ -7,4 +7,5 @@ type OperationName = string const ( PerformPredictionOperation OperationName = "PerformPrediction" + ReadinessCheckOperation OperationName = "ReadinessCheck" ) diff --git a/pkg/rest/oas_response_decoders_gen.go b/pkg/rest/oas_response_decoders_gen.go index ad8cb3c..3c148fd 100644 --- a/pkg/rest/oas_response_decoders_gen.go +++ b/pkg/rest/oas_response_decoders_gen.go @@ -105,3 +105,95 @@ func decodePerformPredictionResponse(resp *http.Response) (res *PredictionResult } return res, errors.Wrap(defRes, "error") } + +func decodeReadinessCheckResponse(resp *http.Response) (res *ReadinessResponse, _ error) { + switch resp.StatusCode { + case 200: + // Code 200. + ct, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil { + return res, errors.Wrap(err, "parse media type") + } + switch { + case ct == "application/json": + buf, err := io.ReadAll(resp.Body) + if err != nil { + return res, err + } + d := jx.DecodeBytes(buf) + + var response ReadinessResponse + if err := func() error { + if err := response.Decode(d); err != nil { + return err + } + if err := d.Skip(); err != io.EOF { + return errors.New("unexpected trailing data") + } + return nil + }(); err != nil { + err = &ogenerrors.DecodeBodyError{ + ContentType: ct, + Body: buf, + Err: err, + } + return res, err + } + // Validate response. + if err := func() error { + if err := response.Validate(); err != nil { + return err + } + return nil + }(); err != nil { + return res, errors.Wrap(err, "validate") + } + return &response, nil + default: + return res, validate.InvalidContentType(ct) + } + } + // Convenient error response. + defRes, err := func() (res *ErrorStatusCode, err error) { + ct, _, err := mime.ParseMediaType(resp.Header.Get("Content-Type")) + if err != nil { + return res, errors.Wrap(err, "parse media type") + } + switch { + case ct == "application/json": + buf, err := io.ReadAll(resp.Body) + if err != nil { + return res, err + } + d := jx.DecodeBytes(buf) + + var response Error + if err := func() error { + if err := response.Decode(d); err != nil { + return err + } + if err := d.Skip(); err != io.EOF { + return errors.New("unexpected trailing data") + } + return nil + }(); err != nil { + err = &ogenerrors.DecodeBodyError{ + ContentType: ct, + Body: buf, + Err: err, + } + return res, err + } + return &ErrorStatusCode{ + StatusCode: resp.StatusCode, + Response: response, + }, nil + default: + return res, validate.InvalidContentType(ct) + } + }() + if err != nil { + return res, errors.Wrapf(err, "default (code %d)", resp.StatusCode) + } + return res, errors.Wrap(defRes, "error") +} diff --git a/pkg/rest/oas_response_encoders_gen.go b/pkg/rest/oas_response_encoders_gen.go index d944b45..8f24cd5 100644 --- a/pkg/rest/oas_response_encoders_gen.go +++ b/pkg/rest/oas_response_encoders_gen.go @@ -27,6 +27,20 @@ func encodePerformPredictionResponse(response *PredictionResult, w http.Response return nil } +func encodeReadinessCheckResponse(response *ReadinessResponse, w http.ResponseWriter, span trace.Span) error { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(200) + span.SetStatus(codes.Ok, http.StatusText(200)) + + e := new(jx.Encoder) + response.Encode(e) + if _, err := e.WriteTo(w); err != nil { + return errors.Wrap(err, "write") + } + + return nil +} + func encodeErrorResponse(response *ErrorStatusCode, w http.ResponseWriter, span trace.Span) error { w.Header().Set("Content-Type", "application/json; charset=utf-8") code := response.StatusCode diff --git a/pkg/rest/oas_router_gen.go b/pkg/rest/oas_router_gen.go index 3ce0e08..68d502b 100644 --- a/pkg/rest/oas_router_gen.go +++ b/pkg/rest/oas_router_gen.go @@ -48,24 +48,58 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { break } switch elem[0] { - case '/': // Prefix: "/api/v1/prediction" + case '/': // Prefix: "/" - if l := len("/api/v1/prediction"); len(elem) >= l && elem[0:l] == "/api/v1/prediction" { + if l := len("/"); len(elem) >= l && elem[0:l] == "/" { elem = elem[l:] } else { break } if len(elem) == 0 { - // Leaf node. - switch r.Method { - case "POST": - s.handlePerformPredictionRequest([0]string{}, elemIsEscaped, w, r) - default: - s.notAllowed(w, r, "POST") + break + } + switch elem[0] { + case 'a': // Prefix: "api/v1/prediction" + + if l := len("api/v1/prediction"); len(elem) >= l && elem[0:l] == "api/v1/prediction" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + // Leaf node. + switch r.Method { + case "POST": + s.handlePerformPredictionRequest([0]string{}, elemIsEscaped, w, r) + default: + s.notAllowed(w, r, "POST") + } + + return + } + + case 'r': // Prefix: "ready" + + if l := len("ready"); len(elem) >= l && elem[0:l] == "ready" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + // Leaf node. + switch r.Method { + case "GET": + s.handleReadinessCheckRequest([0]string{}, elemIsEscaped, w, r) + default: + s.notAllowed(w, r, "GET") + } + + return } - return } } @@ -148,28 +182,66 @@ func (s *Server) FindPath(method string, u *url.URL) (r Route, _ bool) { break } switch elem[0] { - case '/': // Prefix: "/api/v1/prediction" + case '/': // Prefix: "/" - if l := len("/api/v1/prediction"); len(elem) >= l && elem[0:l] == "/api/v1/prediction" { + if l := len("/"); len(elem) >= l && elem[0:l] == "/" { elem = elem[l:] } else { break } if len(elem) == 0 { - // Leaf node. - switch method { - case "POST": - r.name = PerformPredictionOperation - r.summary = "Perform preidction" - r.operationID = "performPrediction" - r.pathPattern = "/api/v1/prediction" - r.args = args - r.count = 0 - return r, true - default: - return + break + } + switch elem[0] { + case 'a': // Prefix: "api/v1/prediction" + + if l := len("api/v1/prediction"); len(elem) >= l && elem[0:l] == "api/v1/prediction" { + elem = elem[l:] + } else { + break } + + if len(elem) == 0 { + // Leaf node. + switch method { + case "POST": + r.name = PerformPredictionOperation + r.summary = "Perform preidction" + r.operationID = "performPrediction" + r.pathPattern = "/api/v1/prediction" + r.args = args + r.count = 0 + return r, true + default: + return + } + } + + case 'r': // Prefix: "ready" + + if l := len("ready"); len(elem) >= l && elem[0:l] == "ready" { + elem = elem[l:] + } else { + break + } + + if len(elem) == 0 { + // Leaf node. + switch method { + case "GET": + r.name = ReadinessCheckOperation + r.summary = "Readiness check" + r.operationID = "readinessCheck" + r.pathPattern = "/ready" + r.args = args + r.count = 0 + return r, true + default: + return + } + } + } } diff --git a/pkg/rest/oas_schemas_gen.go b/pkg/rest/oas_schemas_gen.go index f7fcd59..ba8ffb6 100644 --- a/pkg/rest/oas_schemas_gen.go +++ b/pkg/rest/oas_schemas_gen.go @@ -764,4 +764,145 @@ func (s *PredictionResultPredictionItemStage) UnmarshalText(data []byte) error { } } -type PredictionResultPredictionItemTrajectoryItem struct{} +type PredictionResultPredictionItemTrajectoryItem struct { + Datetime time.Time `json:"datetime"` + Latitude float64 `json:"latitude"` + Longitude float64 `json:"longitude"` + Altitude float64 `json:"altitude"` +} + +// GetDatetime returns the value of Datetime. +func (s *PredictionResultPredictionItemTrajectoryItem) GetDatetime() time.Time { + return s.Datetime +} + +// GetLatitude returns the value of Latitude. +func (s *PredictionResultPredictionItemTrajectoryItem) GetLatitude() float64 { + return s.Latitude +} + +// GetLongitude returns the value of Longitude. +func (s *PredictionResultPredictionItemTrajectoryItem) GetLongitude() float64 { + return s.Longitude +} + +// GetAltitude returns the value of Altitude. +func (s *PredictionResultPredictionItemTrajectoryItem) GetAltitude() float64 { + return s.Altitude +} + +// SetDatetime sets the value of Datetime. +func (s *PredictionResultPredictionItemTrajectoryItem) SetDatetime(val time.Time) { + s.Datetime = val +} + +// SetLatitude sets the value of Latitude. +func (s *PredictionResultPredictionItemTrajectoryItem) SetLatitude(val float64) { + s.Latitude = val +} + +// SetLongitude sets the value of Longitude. +func (s *PredictionResultPredictionItemTrajectoryItem) SetLongitude(val float64) { + s.Longitude = val +} + +// SetAltitude sets the value of Altitude. +func (s *PredictionResultPredictionItemTrajectoryItem) SetAltitude(val float64) { + s.Altitude = val +} + +// Ref: #/components/schemas/ReadinessResponse +type ReadinessResponse struct { + Status ReadinessResponseStatus `json:"status"` + LastUpdate OptDateTime `json:"last_update"` + IsFresh OptBool `json:"is_fresh"` + ErrorMessage OptString `json:"error_message"` +} + +// GetStatus returns the value of Status. +func (s *ReadinessResponse) GetStatus() ReadinessResponseStatus { + return s.Status +} + +// GetLastUpdate returns the value of LastUpdate. +func (s *ReadinessResponse) GetLastUpdate() OptDateTime { + return s.LastUpdate +} + +// GetIsFresh returns the value of IsFresh. +func (s *ReadinessResponse) GetIsFresh() OptBool { + return s.IsFresh +} + +// GetErrorMessage returns the value of ErrorMessage. +func (s *ReadinessResponse) GetErrorMessage() OptString { + return s.ErrorMessage +} + +// SetStatus sets the value of Status. +func (s *ReadinessResponse) SetStatus(val ReadinessResponseStatus) { + s.Status = val +} + +// SetLastUpdate sets the value of LastUpdate. +func (s *ReadinessResponse) SetLastUpdate(val OptDateTime) { + s.LastUpdate = val +} + +// SetIsFresh sets the value of IsFresh. +func (s *ReadinessResponse) SetIsFresh(val OptBool) { + s.IsFresh = val +} + +// SetErrorMessage sets the value of ErrorMessage. +func (s *ReadinessResponse) SetErrorMessage(val OptString) { + s.ErrorMessage = val +} + +type ReadinessResponseStatus string + +const ( + ReadinessResponseStatusOk ReadinessResponseStatus = "ok" + ReadinessResponseStatusNotReady ReadinessResponseStatus = "not_ready" + ReadinessResponseStatusError ReadinessResponseStatus = "error" +) + +// AllValues returns all ReadinessResponseStatus values. +func (ReadinessResponseStatus) AllValues() []ReadinessResponseStatus { + return []ReadinessResponseStatus{ + ReadinessResponseStatusOk, + ReadinessResponseStatusNotReady, + ReadinessResponseStatusError, + } +} + +// MarshalText implements encoding.TextMarshaler. +func (s ReadinessResponseStatus) MarshalText() ([]byte, error) { + switch s { + case ReadinessResponseStatusOk: + return []byte(s), nil + case ReadinessResponseStatusNotReady: + return []byte(s), nil + case ReadinessResponseStatusError: + return []byte(s), nil + default: + return nil, errors.Errorf("invalid value: %q", s) + } +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (s *ReadinessResponseStatus) UnmarshalText(data []byte) error { + switch ReadinessResponseStatus(data) { + case ReadinessResponseStatusOk: + *s = ReadinessResponseStatusOk + return nil + case ReadinessResponseStatusNotReady: + *s = ReadinessResponseStatusNotReady + return nil + case ReadinessResponseStatusError: + *s = ReadinessResponseStatusError + return nil + default: + return errors.Errorf("invalid value: %q", data) + } +} diff --git a/pkg/rest/oas_server_gen.go b/pkg/rest/oas_server_gen.go index 50835d7..f246bc7 100644 --- a/pkg/rest/oas_server_gen.go +++ b/pkg/rest/oas_server_gen.go @@ -14,6 +14,12 @@ type Handler interface { // // POST /api/v1/prediction PerformPrediction(ctx context.Context, req OptPredictionParameters, params PerformPredictionParams) (*PredictionResult, error) + // ReadinessCheck implements readinessCheck operation. + // + // Readiness check. + // + // GET /ready + ReadinessCheck(ctx context.Context) (*ReadinessResponse, error) // NewError creates *ErrorStatusCode from error returned by handler. // // Used for common default response. diff --git a/pkg/rest/oas_unimplemented_gen.go b/pkg/rest/oas_unimplemented_gen.go index afc4fdd..6582495 100644 --- a/pkg/rest/oas_unimplemented_gen.go +++ b/pkg/rest/oas_unimplemented_gen.go @@ -22,6 +22,15 @@ func (UnimplementedHandler) PerformPrediction(ctx context.Context, req OptPredic return r, ht.ErrNotImplemented } +// ReadinessCheck implements readinessCheck operation. +// +// Readiness check. +// +// GET /ready +func (UnimplementedHandler) ReadinessCheck(ctx context.Context) (r *ReadinessResponse, _ error) { + return r, ht.ErrNotImplemented +} + // NewError creates *ErrorStatusCode from error returned by handler. // // Used for common default response. diff --git a/pkg/rest/oas_validators_gen.go b/pkg/rest/oas_validators_gen.go index 9933ab9..d25fbcd 100644 --- a/pkg/rest/oas_validators_gen.go +++ b/pkg/rest/oas_validators_gen.go @@ -269,6 +269,23 @@ func (s *PredictionResultPredictionItem) Validate() error { if s.Trajectory == nil { return errors.New("nil is invalid value") } + var failures []validate.FieldError + for i, elem := range s.Trajectory { + if err := func() error { + if err := elem.Validate(); err != nil { + return err + } + return nil + }(); err != nil { + failures = append(failures, validate.FieldError{ + Name: fmt.Sprintf("[%d]", i), + Error: err, + }) + } + } + if len(failures) > 0 { + return &validate.Error{Fields: failures} + } return nil }(); err != nil { failures = append(failures, validate.FieldError{ @@ -292,3 +309,84 @@ func (s PredictionResultPredictionItemStage) Validate() error { return errors.Errorf("invalid value: %v", s) } } + +func (s *PredictionResultPredictionItemTrajectoryItem) Validate() error { + if s == nil { + return validate.ErrNilPointer + } + + var failures []validate.FieldError + if err := func() error { + if err := (validate.Float{}).Validate(float64(s.Latitude)); err != nil { + return errors.Wrap(err, "float") + } + return nil + }(); err != nil { + failures = append(failures, validate.FieldError{ + Name: "latitude", + Error: err, + }) + } + if err := func() error { + if err := (validate.Float{}).Validate(float64(s.Longitude)); err != nil { + return errors.Wrap(err, "float") + } + return nil + }(); err != nil { + failures = append(failures, validate.FieldError{ + Name: "longitude", + Error: err, + }) + } + if err := func() error { + if err := (validate.Float{}).Validate(float64(s.Altitude)); err != nil { + return errors.Wrap(err, "float") + } + return nil + }(); err != nil { + failures = append(failures, validate.FieldError{ + Name: "altitude", + Error: err, + }) + } + if len(failures) > 0 { + return &validate.Error{Fields: failures} + } + return nil +} + +func (s *ReadinessResponse) Validate() error { + if s == nil { + return validate.ErrNilPointer + } + + var failures []validate.FieldError + if err := func() error { + if err := s.Status.Validate(); err != nil { + return err + } + return nil + }(); err != nil { + failures = append(failures, validate.FieldError{ + Name: "status", + Error: err, + }) + } + if len(failures) > 0 { + return &validate.Error{Fields: failures} + } + return nil +} + +func (s ReadinessResponseStatus) Validate() error { + switch s { + case "ok": + return nil + case "not_ready": + return nil + case "error": + return nil + default: + return errors.Errorf("invalid value: %v", s) + } +} diff --git a/pkg/scheduler/config.go b/pkg/scheduler/config.go index 312a4f6..c7b2881 100644 --- a/pkg/scheduler/config.go +++ b/pkg/scheduler/config.go @@ -1,5 +1,20 @@ package scheduler +import ( + "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" + env "github.com/caarlos0/env/v11" +) + type Config struct { Enabled bool `env:"ENABLED" envDefault:"true"` } + +func NewConfig() (*Config, error) { + cfg := &Config{} + if err := env.ParseWithOptions(cfg, env.Options{ + PrefixTagName: "GSN_PREDICTOR_SCHEDULER_", + }); err != nil { + return nil, errcodes.Wrap(err, "failed to parse scheduler config") + } + return cfg, nil +} diff --git a/pkg/scheduler/scheduler.go b/pkg/scheduler/scheduler.go index c557b83..d51980c 100644 --- a/pkg/scheduler/scheduler.go +++ b/pkg/scheduler/scheduler.go @@ -5,6 +5,7 @@ import ( "time" "git.intra.yksa.space/gsn/predictor/internal/pkg/errcodes" + "git.intra.yksa.space/gsn/predictor/internal/pkg/log" "github.com/go-co-op/gocron" "go.uber.org/zap" ) @@ -19,15 +20,12 @@ type Job interface { type Scheduler struct { scheduler *gocron.Scheduler - logger *zap.Logger } -func New(logger *zap.Logger) *Scheduler { +func New() *Scheduler { scheduler := gocron.NewScheduler(time.UTC) - return &Scheduler{ scheduler: scheduler, - logger: logger, } } @@ -49,14 +47,14 @@ func (s *Scheduler) AddJob(job Job) error { jobFunc := func() { ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - + logger := log.Ctx(ctx) if err := job.Execute(ctx); err != nil { - s.logger.Error("job execution failed", + logger.Error("job execution failed", zap.Error(err), zap.Duration("interval", interval), zap.Duration("timeout", timeout)) } else { - s.logger.Debug("job executed successfully", + logger.Debug("job executed successfully", zap.Duration("interval", interval), zap.Duration("timeout", timeout)) } @@ -75,7 +73,7 @@ func (s *Scheduler) AddJob(job Job) error { schedulerJob.Do(jobFunc) - s.logger.Info("job added to scheduler", + log.Ctx(context.Background()).Info("job added to scheduler", zap.Duration("interval", interval), zap.Duration("timeout", timeout), zap.Int("count", count), @@ -86,12 +84,12 @@ func (s *Scheduler) AddJob(job Job) error { func (s *Scheduler) Start() { s.scheduler.StartAsync() - s.logger.Info("scheduler started") + log.Ctx(context.Background()).Info("scheduler started") } func (s *Scheduler) Stop() { s.scheduler.Stop() - s.logger.Info("scheduler stopped") + log.Ctx(context.Background()).Info("scheduler stopped") } func (s *Scheduler) IsRunning() bool { diff --git a/predictor b/predictor deleted file mode 100755 index c74f010..0000000 Binary files a/predictor and /dev/null differ diff --git a/scripts/test_predictor_vs_reference.py b/scripts/test_predictor_vs_reference.py new file mode 100644 index 0000000..115b8a9 --- /dev/null +++ b/scripts/test_predictor_vs_reference.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python3 +import subprocess +import sys +import time +import requests +import json +from typing import Any +import base64 + +# --- Config --- +LOCAL_API_URL = "http://localhost:8080/api/v1/prediction" +REFERENCE_API_URL = ( + "https://fly.stratonautica.ru/api/v2/?profile=standard_profile&pred_type=single" + "&launch_datetime=2025-06-25T13:28:00Z&launch_latitude=56.6992&launch_longitude=38.8247" + "&launch_altitude=0&ascent_rate=5&burst_altitude=30000&descent_rate=5" +) +LOCAL_API_PAYLOAD = { + "launch_latitude": 56.6992, + "launch_longitude": 38.8247, + "launch_datetime": "2025-06-25T13:28:00Z", + "launch_altitude": 0, + "profile": "standard_profile", + "ascent_rate": 5, + "burst_altitude": 30000, + "descent_rate": 5, + "format": "json" +} +READY_URL = "http://localhost:8080/ready" + +# --- Utility functions --- +def run_compose_up(): + print("[INFO] Running docker-compose down --remove-orphans ...") + result = subprocess.run(["docker-compose", "down", "--remove-orphans"], capture_output=True) + if result.returncode != 0: + print("[ERROR] docker-compose down failed:", result.stderr.decode()) + sys.exit(1) + print("[INFO] docker-compose down completed.") + print("[INFO] Running docker-compose up -d ...") + result = subprocess.run(["docker-compose", "up", "-d"], capture_output=True) + if result.returncode != 0: + print("[ERROR] docker-compose up failed:", result.stderr.decode()) + sys.exit(1) + print("[INFO] docker-compose up -d completed.") + return True + +def wait_for_ready(timeout=900): + print(f"[INFO] Waiting for {READY_URL} to be ready ...") + start = time.time() + while time.time() - start < timeout: + try: + resp = requests.get(READY_URL, timeout=10) + if resp.status_code == 200: + data = resp.json() + if data.get("status") == "ok": + print("[INFO] Service is ready.") + return + else: + print(f"[INFO] Not ready yet: {data}") + else: + print(f"[INFO] /ready returned status {resp.status_code}") + except Exception as e: + print(f"[INFO] Exception while polling /ready: {e}") + time.sleep(10) + print(f"[ERROR] Service did not become ready in {timeout} seconds.") + sys.exit(1) + +def fetch_reference(): + print(f"[INFO] Fetching reference prediction from {REFERENCE_API_URL}") + resp = requests.get(REFERENCE_API_URL, timeout=60) + if resp.status_code != 200: + print(f"[ERROR] Reference API returned {resp.status_code}") + sys.exit(1) + return resp.json() + +def fetch_local(): + print(f"[INFO] Fetching local prediction from {LOCAL_API_URL}") + resp = requests.post(LOCAL_API_URL, json=LOCAL_API_PAYLOAD, timeout=120) + if resp.status_code != 200: + print(f"[ERROR] Local API returned {resp.status_code}: {resp.text}") + sys.exit(1) + return resp.json() + +def compare_results(reference_data, local_data): + """Compare prediction results between reference and local APIs.""" + print("[INFO] Comparing results ...") + + # Extract trajectory data + ref_trajectory = reference_data.get('prediction', [{}])[0].get('trajectory', []) + local_trajectory = local_data.get('prediction', [{}])[0].get('trajectory', []) + + print(f"[DEBUG] Reference trajectory length: {len(ref_trajectory)}") + print(f"[DEBUG] Local trajectory length: {len(local_trajectory)}") + + # Show first 3 points from both APIs + print("\n[DEBUG] First 3 points - Reference API:") + for i, point in enumerate(ref_trajectory[:3]): + print(f" [{i}] alt={point.get('altitude', 'N/A')}, lat={point.get('latitude', 'N/A')}, lon={point.get('longitude', 'N/A')}, time={point.get('datetime', 'N/A')}") + + print("\n[DEBUG] First 3 points - Local API:") + for i, point in enumerate(local_trajectory[:3]): + print(f" [{i}] alt={point.get('altitude', 'N/A')}, lat={point.get('latitude', 'N/A')}, lon={point.get('longitude', 'N/A')}, time={point.get('datetime', 'N/A')}") + + # Show last 3 points from both APIs + print("\n[DEBUG] Last 3 points - Reference API:") + for i, point in enumerate(ref_trajectory[-3:]): + idx = len(ref_trajectory) - 3 + i + print(f" [{idx}] alt={point.get('altitude', 'N/A')}, lat={point.get('latitude', 'N/A')}, lon={point.get('longitude', 'N/A')}, time={point.get('datetime', 'N/A')}") + + print("\n[DEBUG] Last 3 points - Local API:") + for i, point in enumerate(local_trajectory[-3:]): + idx = len(local_trajectory) - 3 + i + print(f" [{idx}] alt={point.get('altitude', 'N/A')}, lat={point.get('latitude', 'N/A')}, lon={point.get('longitude', 'N/A')}, time={point.get('datetime', 'N/A')}") + + # Compare trajectory lengths + if len(ref_trajectory) != len(local_trajectory): + print(f"[DIFF] Trajectory length mismatch: {len(local_trajectory)} vs {len(ref_trajectory)}") + return False + + # Compare trajectory points + min_len = min(len(ref_trajectory), len(local_trajectory)) + for i in range(min_len): + ref_point = ref_trajectory[i] + local_point = local_trajectory[i] + + # Compare key fields + for key in ['altitude', 'latitude', 'longitude']: + ref_val = ref_point.get(key) + local_val = local_point.get(key) + + if ref_val is not None and local_val is not None: + # Use relative tolerance for floating point comparison + if abs(ref_val - local_val) > 0.1: # 0.1 degree/meter tolerance + print(f"[DIFF] At idx {i}, key {key}: {local_val} != {ref_val}") + return False + + print("[SUCCESS] Results match!") + return True + +def test_custom_profile(): + """Test custom profile with base64 encoded curve.""" + print("\n[TEST] Testing custom_profile...") + + # Create a simple custom ascent curve (altitude vs time in seconds) + curve_data = { + "altitude": [0, 30000], + "time": [0, 6000] + } + + curve_b64 = base64.b64encode(json.dumps(curve_data).encode()).decode() + + # Test parameters for custom profile + params = { + "launch_latitude": 56.6992, + "launch_longitude": 38.8247, + "launch_datetime": "2025-06-25T13:28:00Z", + "launch_altitude": 0, + "profile": "custom_profile", + "ascent_curve": curve_b64 + } + + try: + # Test local API + local_resp = requests.post( + "http://localhost:8080/api/v1/prediction", + json=params, + timeout=30 + ) + local_resp.raise_for_status() + local_data = local_resp.json() + + print(f"[INFO] Custom profile test - Local API returned {len(local_data.get('prediction', [{}])[0].get('trajectory', []))} trajectory points") + return True + except Exception as e: + print(f"[ERROR] Custom profile test failed: {e}") + return False + +def test_all_profiles(): + """Test all available profiles.""" + profiles = [ + ("standard_profile", "Standard profile test"), + ("float_profile", "Float profile test"), + ("reverse_profile", "Reverse profile test"), + ("custom_profile", "Custom profile test") + ] + + results = {} + + for profile, description in profiles: + print(f"\n[TEST] {description}...") + + if profile == "custom_profile": + success = test_custom_profile() + else: + success = test_single_profile(profile) + + results[profile] = success + print(f"[RESULT] {profile}: {'PASS' if success else 'FAIL'}") + + # Print summary + print("\n" + "="*50) + print("TEST SUMMARY") + print("="*50) + for profile, success in results.items(): + status = "PASS" if success else "FAIL" + print(f"{profile:20} : {status}") + + total_tests = len(results) + passed_tests = sum(results.values()) + print(f"\nTotal tests: {total_tests}, Passed: {passed_tests}, Failed: {total_tests - passed_tests}") + + return all(results.values()) + +def test_single_profile(profile): + """Test a single profile against reference API.""" + # Test parameters + params = { + "launch_latitude": 56.6992, + "launch_longitude": 38.8247, + "launch_datetime": "2025-06-25T13:28:00Z", + "launch_altitude": 0, + "profile": profile, + "ascent_rate": 5, + "burst_altitude": 30000, + "descent_rate": 5 + } + + # Add float altitude for float profile + if profile == "float_profile": + params["float_altitude"] = 25000 + + try: + # Test local API + local_resp = requests.post( + "http://localhost:8080/api/v1/prediction", + json=params, + timeout=30 + ) + local_resp.raise_for_status() + local_data = local_resp.json() + + print(f"[INFO] {profile} - Local API returned {len(local_data.get('prediction', [{}])[0].get('trajectory', []))} trajectory points") + return True + except Exception as e: + print(f"[ERROR] {profile} test failed: {e}") + return False + +def main(): + """Main test function.""" + print("[INFO] Starting comprehensive predictor API tests...") + + # Run the original standard profile test + print("\n[TEST] Running original standard_profile test...") + run_compose_up() + wait_for_ready() + ref = fetch_reference() + local = fetch_local() + + print("[INFO] Comparing results ...") + original_success = compare_results(ref, local) + + if original_success: + print("[SUCCESS] Original standard_profile test passed!") + else: + print("[FAIL] Original standard_profile test failed!") + + # Test all profiles + print("\n[TEST] Running all profile tests...") + all_profiles_success = test_all_profiles() + + # Final result + overall_success = original_success and all_profiles_success + print(f"\n[FINAL RESULT] Overall: {'PASS' if overall_success else 'FAIL'}") + + if overall_success: + sys.exit(0) + else: + sys.exit(1) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/validate-docker.sh b/scripts/validate-docker.sh deleted file mode 100755 index 0546cc9..0000000 --- a/scripts/validate-docker.sh +++ /dev/null @@ -1,55 +0,0 @@ -#!/bin/bash - -# Docker validation script -set -e - -echo "πŸ” Validating Docker configuration..." - -# Check if Docker is available -if ! command -v docker &> /dev/null; then - echo "❌ Docker is not installed or not in PATH" - echo "Please install Docker Desktop and enable WSL 2 integration" - exit 1 -fi - -# Check if docker-compose is available -if ! command -v docker-compose &> /dev/null; then - echo "❌ Docker Compose is not installed or not in PATH" - exit 1 -fi - -echo "βœ… Docker and Docker Compose are available" - -# Validate Dockerfile syntax -echo "πŸ” Validating Dockerfile..." -if docker build --dry-run . > /dev/null 2>&1; then - echo "βœ… Dockerfile syntax is valid" -else - echo "❌ Dockerfile syntax is invalid" - exit 1 -fi - -# Validate docker-compose.yml -echo "πŸ” Validating docker-compose.yml..." -if docker-compose config > /dev/null 2>&1; then - echo "βœ… docker-compose.yml is valid" -else - echo "❌ docker-compose.yml is invalid" - exit 1 -fi - -# Validate docker-compose.dev.yml -echo "πŸ” Validating docker-compose.dev.yml..." -if docker-compose -f docker-compose.dev.yml config > /dev/null 2>&1; then - echo "βœ… docker-compose.dev.yml is valid" -else - echo "❌ docker-compose.dev.yml is invalid" - exit 1 -fi - -echo "βœ… All Docker configurations are valid!" -echo "" -echo "πŸš€ Ready to build and run:" -echo " make build # Build Docker image" -echo " make up # Start services" -echo " make logs # View logs" \ No newline at end of file