// Package gfs implements datasets.Source for NOAA GFS 0.5-degree forecasts. package gfs import ( "context" "errors" "fmt" "io" "math" "net/http" "os" "sync" "time" "github.com/nilsmagnus/grib/griblib" "go.uber.org/zap" "golang.org/x/sync/errgroup" "predictor-refactored/internal/datasets" "predictor-refactored/internal/weather" wgfs "predictor-refactored/internal/weather/gfs" ) // Source is the GFS implementation of datasets.Source. type Source struct { Parallel int // max concurrent step downloads Client *http.Client // optional; defaults to a 2-minute-timeout client Log *zap.Logger } // NewSource returns a default Source. func NewSource(log *zap.Logger) *Source { return &Source{ Parallel: 8, Client: &http.Client{Timeout: 2 * time.Minute}, Log: log, } } // ID returns the source identifier. func (s *Source) ID() string { return "noaa-gfs-0p50" } func (s *Source) log() *zap.Logger { if s.Log == nil { return zap.NewNop() } return s.Log } func (s *Source) client() *http.Client { if s.Client == nil { return &http.Client{Timeout: 2 * time.Minute} } return s.Client } func (s *Source) parallel() int { if s.Parallel <= 0 { return 8 } return s.Parallel } // LatestEpoch returns the most recent run NOAA has finished publishing, // determined by HEAD-ing the .idx for the final forecast hour. Walks back // up to 8 runs (48 hours) before giving up. func (s *Source) LatestEpoch(ctx context.Context) (time.Time, error) { now := time.Now().UTC() hour := now.Hour() - (now.Hour() % 6) current := time.Date(now.Year(), now.Month(), now.Day(), hour, 0, 0, 0, time.UTC) for range 8 { date := current.Format("20060102") url := wgfs.GribURL(date, current.Hour(), wgfs.MaxHour) + ".idx" req, err := http.NewRequestWithContext(ctx, http.MethodHead, url, nil) if err == nil { resp, err := s.client().Do(req) if err == nil { resp.Body.Close() if resp.StatusCode == http.StatusOK { s.log().Info("latest GFS run discovered", zap.Time("run", current), zap.String("verified_url", url)) return current, nil } } } current = current.Add(-6 * time.Hour) } return time.Time{}, fmt.Errorf("no recent GFS run found (checked 8 runs)") } // Open loads a stored dataset as a WindField. func (s *Source) Open(_ context.Context, epoch time.Time, store datasets.Storage) (weather.WindField, error) { if !store.Exists(epoch) { return nil, fmt.Errorf("epoch %s not found", epoch.Format(time.RFC3339)) } file, err := wgfs.Open(store.Path(epoch), epoch.UTC()) if err != nil { return nil, err } return wgfs.NewWind(file), nil } // neededVariables is the GRIB variable set we extract. var neededVariables = map[string]bool{"HGT": true, "UGRD": true, "VGRD": true} // Download fetches the full dataset for epoch in parallel, resuming any // previously-completed work units. Honours ctx cancellation and prog // (which may be nil). func (s *Source) Download(ctx context.Context, epoch time.Time, store datasets.Storage, prog datasets.ProgressSink, throttle datasets.Throttle) error { if prog == nil { prog = noopSink{} } handle, err := store.BeginWrite(epoch) if err != nil { return fmt.Errorf("begin write: %w", err) } manifest := handle.Manifest() // Open or create the temp file. If a previous attempt left a partial // file of the right size, reuse it (resume); otherwise Create. file, err := openOrCreateCube(handle.Path()) if err != nil { _ = handle.Abort() return err } date := epoch.UTC().Format("20060102") runHour := epoch.UTC().Hour() steps := wgfs.Hours() totalUnits := len(steps) * 2 prog.SetTotal(totalUnits) // Pre-count already-done units so progress is accurate on resume. for _, u := range manifest.Units() { _ = u prog.StepComplete() } start := time.Now() g, ctx := errgroup.WithContext(ctx) g.SetLimit(s.parallel()) // fileMu serialises concurrent BlitGribData calls because the underlying // mmap is shared and SetVal isn't atomic. var fileMu sync.Mutex for _, step := range steps { hourIdx := wgfs.HourIndex(step) if hourIdx < 0 { continue } for _, ls := range []wgfs.LevelSet{wgfs.LevelSetA, wgfs.LevelSetB} { unit := unitKey(step, ls) if manifest.Has(unit) { continue } g.Go(func() error { var url string switch ls { case wgfs.LevelSetA: url = wgfs.GribURL(date, runHour, step) case wgfs.LevelSetB: url = wgfs.GribURLB(date, runHour, step) } if err := s.downloadAndBlit(ctx, file, &fileMu, url, hourIdx, ls, prog, throttle); err != nil { return fmt.Errorf("step %d %s: %w", step, levelSetLabel(ls), err) } if err := manifest.Mark(unit); err != nil { return fmt.Errorf("mark unit: %w", err) } prog.StepComplete() return nil }) } } if err := g.Wait(); err != nil { _ = file.Close() // Don't Abort on context cancellation — preserve progress for resume. if errors.Is(err, context.Canceled) { return err } // Other errors: abort if no progress was made; otherwise leave for resume. if len(manifest.Units()) == 0 { _ = handle.Abort() } return err } if err := file.Flush(); err != nil { _ = file.Close() return fmt.Errorf("flush: %w", err) } if err := file.Close(); err != nil { return fmt.Errorf("close: %w", err) } if err := handle.Commit(); err != nil { return fmt.Errorf("commit: %w", err) } s.log().Info("download complete", zap.Time("epoch", epoch), zap.Duration("elapsed", time.Since(start))) return nil } // openOrCreateCube returns a writable cube file at path, creating it if the // file does not exist or has the wrong size. func openOrCreateCube(path string) (*wgfs.File, error) { info, err := os.Stat(path) if err == nil && info.Size() == wgfs.DatasetSize { return wgfs.OpenWritable(path) } if err != nil && !errors.Is(err, os.ErrNotExist) { return nil, fmt.Errorf("stat cube: %w", err) } // Wrong-size or missing — truncate-create. return wgfs.Create(path) } // downloadAndBlit fetches and decodes one (URL, level-set) chunk and writes // it into the dataset. func (s *Source) downloadAndBlit( ctx context.Context, file *wgfs.File, fileMu *sync.Mutex, baseURL string, hourIdx int, ls wgfs.LevelSet, prog datasets.ProgressSink, throttle datasets.Throttle, ) error { idxBody, err := s.httpGet(ctx, baseURL+".idx", throttle, prog) if err != nil { return fmt.Errorf("idx: %w", err) } entries := ParseIdx(idxBody) filtered := FilterIdx(entries, neededVariables) var relevant []IdxEntry for _, e := range filtered { set, ok := wgfs.PressureLevelSet(e.LevelMB) if ok && set == ls { relevant = append(relevant, e) } } if len(relevant) == 0 { return nil } ranges := EntriesToRanges(relevant) tmp, err := os.CreateTemp("", "gfs-msg-*.tmp") if err != nil { return fmt.Errorf("temp: %w", err) } tmpPath := tmp.Name() defer os.Remove(tmpPath) for _, r := range ranges { body, err := s.httpGetRange(ctx, baseURL, r.Start, r.End, throttle, prog) if err != nil { tmp.Close() return fmt.Errorf("range %d-%d: %w", r.Start, r.End, err) } if _, err := tmp.Write(body); err != nil { tmp.Close() return fmt.Errorf("write tmp: %w", err) } } if err := tmp.Close(); err != nil { return err } f, err := os.Open(tmpPath) if err != nil { return err } messages, err := griblib.ReadMessages(f) f.Close() if err != nil { return fmt.Errorf("read grib: %w", err) } for _, msg := range messages { if msg.Section4.ProductDefinitionTemplateNumber != 0 { continue } p := msg.Section4.ProductDefinitionTemplate varIdx := wgfs.VariableIndex(int(p.ParameterCategory), int(p.ParameterNumber)) if varIdx < 0 { continue } if p.FirstSurface.Type != 100 { // isobaric only continue } pressureMB := int(math.Round(float64(p.FirstSurface.Value) / 100.0)) levelIdx := wgfs.PressureIndex(pressureMB) if levelIdx < 0 { continue } data := msg.Data() fileMu.Lock() err := file.BlitGribData(hourIdx, levelIdx, varIdx, data) fileMu.Unlock() if err != nil { return fmt.Errorf("blit: %w", err) } } return nil } // httpGet downloads a URL body with 3 retries and optional throttling. func (s *Source) httpGet(ctx context.Context, url string, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) { var lastErr error for attempt := range 3 { if attempt > 0 { select { case <-time.After(time.Duration(attempt*2) * time.Second): case <-ctx.Done(): return nil, ctx.Err() } } req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err } resp, err := s.client().Do(req) if err != nil { lastErr = err continue } body, err := readThrottled(ctx, resp.Body, throttle, prog) resp.Body.Close() if resp.StatusCode != http.StatusOK { lastErr = fmt.Errorf("HTTP %d for %s", resp.StatusCode, url) continue } if err != nil { lastErr = err continue } return body, nil } return nil, fmt.Errorf("after 3 attempts: %w", lastErr) } // httpGetRange downloads an inclusive byte range with 3 retries and throttling. func (s *Source) httpGetRange(ctx context.Context, url string, start, end int64, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) { var lastErr error for attempt := range 3 { if attempt > 0 { select { case <-time.After(time.Duration(attempt*2) * time.Second): case <-ctx.Done(): return nil, ctx.Err() } } req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, err } req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", start, end)) resp, err := s.client().Do(req) if err != nil { lastErr = err continue } body, err := readThrottled(ctx, resp.Body, throttle, prog) resp.Body.Close() if resp.StatusCode != http.StatusPartialContent && resp.StatusCode != http.StatusOK { lastErr = fmt.Errorf("HTTP %d for range %d-%d of %s", resp.StatusCode, start, end, url) continue } if err != nil { lastErr = err continue } return body, nil } return nil, fmt.Errorf("after 3 attempts: %w", lastErr) } // readThrottled reads r into memory, consulting throttle (if non-nil) before // each chunk and reporting bytes to prog. func readThrottled(ctx context.Context, r io.Reader, throttle datasets.Throttle, prog datasets.ProgressSink) ([]byte, error) { buf := make([]byte, 0, 64*1024) chunk := make([]byte, 32*1024) for { if throttle != nil { if err := throttle.Wait(ctx, len(chunk)); err != nil { return nil, err } } n, err := r.Read(chunk) if n > 0 { buf = append(buf, chunk[:n]...) prog.Bytes(int64(n)) } if errors.Is(err, io.EOF) { return buf, nil } if err != nil { return nil, err } } } func unitKey(step int, ls wgfs.LevelSet) string { return fmt.Sprintf("step%03d-%s", step, levelSetLabel(ls)) } func levelSetLabel(ls wgfs.LevelSet) string { if ls == wgfs.LevelSetB { return "B" } return "A" } // noopSink discards progress events. type noopSink struct{} func (noopSink) SetTotal(int) {} func (noopSink) StepComplete() {} func (noopSink) Bytes(int64) {}