predictor/internal/config/config.go
2026-05-23 00:55:35 +09:00

263 lines
7.8 KiB
Go

// Package config holds the service's runtime configuration, loaded by
// merging (in order of increasing precedence): built-in defaults, a YAML
// config file, environment variables, and command-line flags.
//
// Validation is performed once on load; downstream consumers receive an
// immutable struct.
package config
import (
"flag"
"fmt"
"os"
"strconv"
"time"
"gopkg.in/yaml.v2"
)
// Config is the top-level configuration tree.
type Config struct {
HTTP HTTPConfig `yaml:"http"`
Data DataConfig `yaml:"data"`
Download DownloadConfig `yaml:"download"`
Metrics MetricsConfig `yaml:"metrics"`
Log LogConfig `yaml:"log"`
}
// HTTPConfig configures the HTTP server.
type HTTPConfig struct {
Port int `yaml:"port"`
// AsyncWorkers caps concurrent prediction executions for the async endpoint.
AsyncWorkers int `yaml:"async_workers"`
// AsyncQueueSize bounds the async pending queue.
AsyncQueueSize int `yaml:"async_queue_size"`
// AsyncResultTTL is how long completed async results are retained.
AsyncResultTTL time.Duration `yaml:"async_result_ttl"`
}
// DataConfig configures dataset and elevation storage.
type DataConfig struct {
Dir string `yaml:"dir"`
ElevationPath string `yaml:"elevation_path"`
// Source is the dataset source identifier; only "noaa-gfs-0p50" is supported today.
Source string `yaml:"source"`
}
// DownloadConfig configures the dataset downloader.
type DownloadConfig struct {
Parallel int `yaml:"parallel"`
BandwidthBytesPerSecond int64 `yaml:"bandwidth_bytes_per_second"`
UpdateInterval time.Duration `yaml:"update_interval"`
FreshnessTTL time.Duration `yaml:"freshness_ttl"`
}
// MetricsConfig configures the metrics endpoint.
type MetricsConfig struct {
Enabled bool `yaml:"enabled"`
Path string `yaml:"path"`
}
// LogConfig configures logging.
type LogConfig struct {
Level string `yaml:"level"` // "debug", "info", "warn", "error"
}
// Defaults returns a Config with reasonable default values.
func Defaults() Config {
return Config{
HTTP: HTTPConfig{
Port: 8080,
AsyncWorkers: 4,
AsyncQueueSize: 64,
AsyncResultTTL: time.Hour,
},
Data: DataConfig{
Dir: "/tmp/predictor-data",
ElevationPath: "/srv/ruaumoko-dataset",
Source: "gfs-0p50-3h",
},
Download: DownloadConfig{
Parallel: 8,
BandwidthBytesPerSecond: 0,
UpdateInterval: 6 * time.Hour,
FreshnessTTL: 48 * time.Hour,
},
Metrics: MetricsConfig{
Enabled: true,
Path: "/metrics",
},
Log: LogConfig{Level: "info"},
}
}
// Load resolves the configuration by merging built-in defaults with
// (in increasing precedence): a YAML file (path from PREDICTOR_CONFIG_FILE
// env var or --config flag), environment variables, and command-line flags.
//
// args is os.Args[1:] in production code; tests pass a custom slice.
func Load(args []string) (Config, error) {
cfg := Defaults()
fs := flag.NewFlagSet("predictor", flag.ContinueOnError)
// Surface a deterministic usage by suppressing the default output:
fs.SetOutput(os.Stderr)
var (
configPath = fs.String("config", os.Getenv("PREDICTOR_CONFIG_FILE"), "path to YAML config file")
// Flag-driven overrides. Empty / -1 means "not specified".
flagPort = fs.Int("port", -1, "HTTP listen port")
flagDataDir = fs.String("data-dir", "", "directory for dataset files")
flagElevation = fs.String("elevation", "", "path to ruaumoko elevation dataset")
flagParallel = fs.Int("download-parallel", -1, "max concurrent GRIB downloads")
flagBandwidth = fs.Int64("download-bandwidth", -1, "download bandwidth limit in bytes/sec (0 = unlimited)")
flagInterval = fs.Duration("update-interval", 0, "scheduler refresh interval")
flagTTL = fs.Duration("freshness-ttl", 0, "max age before a dataset is considered stale")
flagMetricsEnabled = fs.Bool("metrics", true, "enable Prometheus-compatible metrics endpoint")
flagMetricsPath = fs.String("metrics-path", "", "HTTP path for the metrics endpoint")
flagLogLevel = fs.String("log-level", "", "log level: debug|info|warn|error")
)
if err := fs.Parse(args); err != nil {
return Config{}, fmt.Errorf("parse flags: %w", err)
}
// 1. File.
if *configPath != "" {
data, err := os.ReadFile(*configPath)
if err != nil {
return Config{}, fmt.Errorf("read config %s: %w", *configPath, err)
}
if err := yaml.UnmarshalStrict(data, &cfg); err != nil {
return Config{}, fmt.Errorf("parse config %s: %w", *configPath, err)
}
}
// 2. Env vars.
applyEnv(&cfg)
// 3. Flags (only when explicitly set).
if *flagPort >= 0 {
cfg.HTTP.Port = *flagPort
}
if *flagDataDir != "" {
cfg.Data.Dir = *flagDataDir
}
if *flagElevation != "" {
cfg.Data.ElevationPath = *flagElevation
}
if *flagParallel >= 0 {
cfg.Download.Parallel = *flagParallel
}
if *flagBandwidth >= 0 {
cfg.Download.BandwidthBytesPerSecond = *flagBandwidth
}
if *flagInterval != 0 {
cfg.Download.UpdateInterval = *flagInterval
}
if *flagTTL != 0 {
cfg.Download.FreshnessTTL = *flagTTL
}
// flag.Bool defaults to true here so we only override if user explicitly disables it.
if isFlagSet(fs, "metrics") {
cfg.Metrics.Enabled = *flagMetricsEnabled
}
if *flagMetricsPath != "" {
cfg.Metrics.Path = *flagMetricsPath
}
if *flagLogLevel != "" {
cfg.Log.Level = *flagLogLevel
}
if err := cfg.Validate(); err != nil {
return Config{}, err
}
return cfg, nil
}
func isFlagSet(fs *flag.FlagSet, name string) bool {
set := false
fs.Visit(func(f *flag.Flag) {
if f.Name == name {
set = true
}
})
return set
}
// applyEnv overlays PREDICTOR_* environment variables onto cfg.
func applyEnv(cfg *Config) {
if v := os.Getenv("PREDICTOR_PORT"); v != "" {
if n, err := strconv.Atoi(v); err == nil {
cfg.HTTP.Port = n
}
}
if v := os.Getenv("PREDICTOR_DATA_DIR"); v != "" {
cfg.Data.Dir = v
}
if v := os.Getenv("PREDICTOR_ELEVATION_DATASET"); v != "" {
cfg.Data.ElevationPath = v
}
if v := os.Getenv("PREDICTOR_SOURCE"); v != "" {
cfg.Data.Source = v
}
if v := os.Getenv("PREDICTOR_DOWNLOAD_PARALLEL"); v != "" {
if n, err := strconv.Atoi(v); err == nil {
cfg.Download.Parallel = n
}
}
if v := os.Getenv("PREDICTOR_DOWNLOAD_BANDWIDTH"); v != "" {
if n, err := strconv.ParseInt(v, 10, 64); err == nil {
cfg.Download.BandwidthBytesPerSecond = n
}
}
if v := os.Getenv("PREDICTOR_UPDATE_INTERVAL"); v != "" {
if d, err := time.ParseDuration(v); err == nil {
cfg.Download.UpdateInterval = d
}
}
if v := os.Getenv("PREDICTOR_DATASET_TTL"); v != "" {
if d, err := time.ParseDuration(v); err == nil {
cfg.Download.FreshnessTTL = d
}
}
if v := os.Getenv("PREDICTOR_METRICS_ENABLED"); v != "" {
cfg.Metrics.Enabled = v == "1" || v == "true" || v == "yes"
}
if v := os.Getenv("PREDICTOR_METRICS_PATH"); v != "" {
cfg.Metrics.Path = v
}
if v := os.Getenv("PREDICTOR_LOG_LEVEL"); v != "" {
cfg.Log.Level = v
}
}
// Validate reports configuration errors.
func (c Config) Validate() error {
if c.HTTP.Port < 0 || c.HTTP.Port > 65535 {
return fmt.Errorf("http.port %d outside [0, 65535]", c.HTTP.Port)
}
if c.Data.Dir == "" {
return fmt.Errorf("data.dir is required")
}
if c.Data.Source == "" {
return fmt.Errorf("data.source is required")
}
if c.Download.Parallel <= 0 {
return fmt.Errorf("download.parallel must be > 0")
}
if c.Download.UpdateInterval <= 0 {
return fmt.Errorf("download.update_interval must be > 0")
}
if c.Download.FreshnessTTL <= 0 {
return fmt.Errorf("download.freshness_ttl must be > 0")
}
if c.Metrics.Enabled && c.Metrics.Path == "" {
return fmt.Errorf("metrics.path is required when metrics enabled")
}
switch c.Log.Level {
case "debug", "info", "warn", "error":
default:
return fmt.Errorf("log.level %q is not one of debug|info|warn|error", c.Log.Level)
}
return nil
}