263 lines
7.8 KiB
Go
263 lines
7.8 KiB
Go
// Package config holds the service's runtime configuration, loaded by
|
|
// merging (in order of increasing precedence): built-in defaults, a YAML
|
|
// config file, environment variables, and command-line flags.
|
|
//
|
|
// Validation is performed once on load; downstream consumers receive an
|
|
// immutable struct.
|
|
package config
|
|
|
|
import (
|
|
"flag"
|
|
"fmt"
|
|
"os"
|
|
"strconv"
|
|
"time"
|
|
|
|
"gopkg.in/yaml.v2"
|
|
)
|
|
|
|
// Config is the top-level configuration tree.
|
|
type Config struct {
|
|
HTTP HTTPConfig `yaml:"http"`
|
|
Data DataConfig `yaml:"data"`
|
|
Download DownloadConfig `yaml:"download"`
|
|
Metrics MetricsConfig `yaml:"metrics"`
|
|
Log LogConfig `yaml:"log"`
|
|
}
|
|
|
|
// HTTPConfig configures the HTTP server.
|
|
type HTTPConfig struct {
|
|
Port int `yaml:"port"`
|
|
// AsyncWorkers caps concurrent prediction executions for the async endpoint.
|
|
AsyncWorkers int `yaml:"async_workers"`
|
|
// AsyncQueueSize bounds the async pending queue.
|
|
AsyncQueueSize int `yaml:"async_queue_size"`
|
|
// AsyncResultTTL is how long completed async results are retained.
|
|
AsyncResultTTL time.Duration `yaml:"async_result_ttl"`
|
|
}
|
|
|
|
// DataConfig configures dataset and elevation storage.
|
|
type DataConfig struct {
|
|
Dir string `yaml:"dir"`
|
|
ElevationPath string `yaml:"elevation_path"`
|
|
// Source is the dataset source identifier; only "noaa-gfs-0p50" is supported today.
|
|
Source string `yaml:"source"`
|
|
}
|
|
|
|
// DownloadConfig configures the dataset downloader.
|
|
type DownloadConfig struct {
|
|
Parallel int `yaml:"parallel"`
|
|
BandwidthBytesPerSecond int64 `yaml:"bandwidth_bytes_per_second"`
|
|
UpdateInterval time.Duration `yaml:"update_interval"`
|
|
FreshnessTTL time.Duration `yaml:"freshness_ttl"`
|
|
}
|
|
|
|
// MetricsConfig configures the metrics endpoint.
|
|
type MetricsConfig struct {
|
|
Enabled bool `yaml:"enabled"`
|
|
Path string `yaml:"path"`
|
|
}
|
|
|
|
// LogConfig configures logging.
|
|
type LogConfig struct {
|
|
Level string `yaml:"level"` // "debug", "info", "warn", "error"
|
|
}
|
|
|
|
// Defaults returns a Config with reasonable default values.
|
|
func Defaults() Config {
|
|
return Config{
|
|
HTTP: HTTPConfig{
|
|
Port: 8080,
|
|
AsyncWorkers: 4,
|
|
AsyncQueueSize: 64,
|
|
AsyncResultTTL: time.Hour,
|
|
},
|
|
Data: DataConfig{
|
|
Dir: "/tmp/predictor-data",
|
|
ElevationPath: "/srv/ruaumoko-dataset",
|
|
Source: "gfs-0p50-3h",
|
|
},
|
|
Download: DownloadConfig{
|
|
Parallel: 8,
|
|
BandwidthBytesPerSecond: 0,
|
|
UpdateInterval: 6 * time.Hour,
|
|
FreshnessTTL: 48 * time.Hour,
|
|
},
|
|
Metrics: MetricsConfig{
|
|
Enabled: true,
|
|
Path: "/metrics",
|
|
},
|
|
Log: LogConfig{Level: "info"},
|
|
}
|
|
}
|
|
|
|
// Load resolves the configuration by merging built-in defaults with
|
|
// (in increasing precedence): a YAML file (path from PREDICTOR_CONFIG_FILE
|
|
// env var or --config flag), environment variables, and command-line flags.
|
|
//
|
|
// args is os.Args[1:] in production code; tests pass a custom slice.
|
|
func Load(args []string) (Config, error) {
|
|
cfg := Defaults()
|
|
|
|
fs := flag.NewFlagSet("predictor", flag.ContinueOnError)
|
|
// Surface a deterministic usage by suppressing the default output:
|
|
fs.SetOutput(os.Stderr)
|
|
|
|
var (
|
|
configPath = fs.String("config", os.Getenv("PREDICTOR_CONFIG_FILE"), "path to YAML config file")
|
|
// Flag-driven overrides. Empty / -1 means "not specified".
|
|
flagPort = fs.Int("port", -1, "HTTP listen port")
|
|
flagDataDir = fs.String("data-dir", "", "directory for dataset files")
|
|
flagElevation = fs.String("elevation", "", "path to ruaumoko elevation dataset")
|
|
flagParallel = fs.Int("download-parallel", -1, "max concurrent GRIB downloads")
|
|
flagBandwidth = fs.Int64("download-bandwidth", -1, "download bandwidth limit in bytes/sec (0 = unlimited)")
|
|
flagInterval = fs.Duration("update-interval", 0, "scheduler refresh interval")
|
|
flagTTL = fs.Duration("freshness-ttl", 0, "max age before a dataset is considered stale")
|
|
flagMetricsEnabled = fs.Bool("metrics", true, "enable Prometheus-compatible metrics endpoint")
|
|
flagMetricsPath = fs.String("metrics-path", "", "HTTP path for the metrics endpoint")
|
|
flagLogLevel = fs.String("log-level", "", "log level: debug|info|warn|error")
|
|
)
|
|
if err := fs.Parse(args); err != nil {
|
|
return Config{}, fmt.Errorf("parse flags: %w", err)
|
|
}
|
|
|
|
// 1. File.
|
|
if *configPath != "" {
|
|
data, err := os.ReadFile(*configPath)
|
|
if err != nil {
|
|
return Config{}, fmt.Errorf("read config %s: %w", *configPath, err)
|
|
}
|
|
if err := yaml.UnmarshalStrict(data, &cfg); err != nil {
|
|
return Config{}, fmt.Errorf("parse config %s: %w", *configPath, err)
|
|
}
|
|
}
|
|
|
|
// 2. Env vars.
|
|
applyEnv(&cfg)
|
|
|
|
// 3. Flags (only when explicitly set).
|
|
if *flagPort >= 0 {
|
|
cfg.HTTP.Port = *flagPort
|
|
}
|
|
if *flagDataDir != "" {
|
|
cfg.Data.Dir = *flagDataDir
|
|
}
|
|
if *flagElevation != "" {
|
|
cfg.Data.ElevationPath = *flagElevation
|
|
}
|
|
if *flagParallel >= 0 {
|
|
cfg.Download.Parallel = *flagParallel
|
|
}
|
|
if *flagBandwidth >= 0 {
|
|
cfg.Download.BandwidthBytesPerSecond = *flagBandwidth
|
|
}
|
|
if *flagInterval != 0 {
|
|
cfg.Download.UpdateInterval = *flagInterval
|
|
}
|
|
if *flagTTL != 0 {
|
|
cfg.Download.FreshnessTTL = *flagTTL
|
|
}
|
|
// flag.Bool defaults to true here so we only override if user explicitly disables it.
|
|
if isFlagSet(fs, "metrics") {
|
|
cfg.Metrics.Enabled = *flagMetricsEnabled
|
|
}
|
|
if *flagMetricsPath != "" {
|
|
cfg.Metrics.Path = *flagMetricsPath
|
|
}
|
|
if *flagLogLevel != "" {
|
|
cfg.Log.Level = *flagLogLevel
|
|
}
|
|
|
|
if err := cfg.Validate(); err != nil {
|
|
return Config{}, err
|
|
}
|
|
return cfg, nil
|
|
}
|
|
|
|
func isFlagSet(fs *flag.FlagSet, name string) bool {
|
|
set := false
|
|
fs.Visit(func(f *flag.Flag) {
|
|
if f.Name == name {
|
|
set = true
|
|
}
|
|
})
|
|
return set
|
|
}
|
|
|
|
// applyEnv overlays PREDICTOR_* environment variables onto cfg.
|
|
func applyEnv(cfg *Config) {
|
|
if v := os.Getenv("PREDICTOR_PORT"); v != "" {
|
|
if n, err := strconv.Atoi(v); err == nil {
|
|
cfg.HTTP.Port = n
|
|
}
|
|
}
|
|
if v := os.Getenv("PREDICTOR_DATA_DIR"); v != "" {
|
|
cfg.Data.Dir = v
|
|
}
|
|
if v := os.Getenv("PREDICTOR_ELEVATION_DATASET"); v != "" {
|
|
cfg.Data.ElevationPath = v
|
|
}
|
|
if v := os.Getenv("PREDICTOR_SOURCE"); v != "" {
|
|
cfg.Data.Source = v
|
|
}
|
|
if v := os.Getenv("PREDICTOR_DOWNLOAD_PARALLEL"); v != "" {
|
|
if n, err := strconv.Atoi(v); err == nil {
|
|
cfg.Download.Parallel = n
|
|
}
|
|
}
|
|
if v := os.Getenv("PREDICTOR_DOWNLOAD_BANDWIDTH"); v != "" {
|
|
if n, err := strconv.ParseInt(v, 10, 64); err == nil {
|
|
cfg.Download.BandwidthBytesPerSecond = n
|
|
}
|
|
}
|
|
if v := os.Getenv("PREDICTOR_UPDATE_INTERVAL"); v != "" {
|
|
if d, err := time.ParseDuration(v); err == nil {
|
|
cfg.Download.UpdateInterval = d
|
|
}
|
|
}
|
|
if v := os.Getenv("PREDICTOR_DATASET_TTL"); v != "" {
|
|
if d, err := time.ParseDuration(v); err == nil {
|
|
cfg.Download.FreshnessTTL = d
|
|
}
|
|
}
|
|
if v := os.Getenv("PREDICTOR_METRICS_ENABLED"); v != "" {
|
|
cfg.Metrics.Enabled = v == "1" || v == "true" || v == "yes"
|
|
}
|
|
if v := os.Getenv("PREDICTOR_METRICS_PATH"); v != "" {
|
|
cfg.Metrics.Path = v
|
|
}
|
|
if v := os.Getenv("PREDICTOR_LOG_LEVEL"); v != "" {
|
|
cfg.Log.Level = v
|
|
}
|
|
}
|
|
|
|
// Validate reports configuration errors.
|
|
func (c Config) Validate() error {
|
|
if c.HTTP.Port < 0 || c.HTTP.Port > 65535 {
|
|
return fmt.Errorf("http.port %d outside [0, 65535]", c.HTTP.Port)
|
|
}
|
|
if c.Data.Dir == "" {
|
|
return fmt.Errorf("data.dir is required")
|
|
}
|
|
if c.Data.Source == "" {
|
|
return fmt.Errorf("data.source is required")
|
|
}
|
|
if c.Download.Parallel <= 0 {
|
|
return fmt.Errorf("download.parallel must be > 0")
|
|
}
|
|
if c.Download.UpdateInterval <= 0 {
|
|
return fmt.Errorf("download.update_interval must be > 0")
|
|
}
|
|
if c.Download.FreshnessTTL <= 0 {
|
|
return fmt.Errorf("download.freshness_ttl must be > 0")
|
|
}
|
|
if c.Metrics.Enabled && c.Metrics.Path == "" {
|
|
return fmt.Errorf("metrics.path is required when metrics enabled")
|
|
}
|
|
switch c.Log.Level {
|
|
case "debug", "info", "warn", "error":
|
|
default:
|
|
return fmt.Errorf("log.level %q is not one of debug|info|warn|error", c.Log.Level)
|
|
}
|
|
return nil
|
|
}
|